| class MatrixSource(Component):
def __init__(self, element_shape, elements):
self._elements = elements
rows = len(elements)
cols = len(elements[0])
assert all(len(row) == cols for row in elements), 'All rows must have the same length'
stream_shape = data.StructLayout({
'element': element_shape,
'col': range(cols),
'row': range(rows),
})
super().__init__({
'start': In(1),
'active': Out(1),
'matrix': Out(StreamSignature(stream_shape)),
})
def elaborate(self, platform):
m = Module()
elements = [{'element': element, 'col': col, 'row': row} for row, elements_row in enumerate(self._elements) for col, element in enumerate(elements_row)]
element_array = Signal(data.ArrayLayout(self.matrix.data.shape(), len(elements)))
idx = Signal(range(len(elements)))
m.d.comb += [element_array[i][k].eq(v) for i, element in enumerate(elements) for k, v in element.items()]
m.d.comb += [
self.matrix.valid.eq(self.active),
self.matrix.data.eq(element_array[idx]),
]
with m.If(self.start):
m.d.sync += [
self.active.eq(1),
idx.eq(0),
]
with m.If(self.matrix.valid & self.matrix.ready):
m.d.sync += idx.eq(idx + 1)
with m.If(idx == len(elements) - 1):
m.d.sync += self.active.eq(0)
return m
class MatrixMultiplier(Component):
def __init__(self, input_shape, output_shape, element_shape):
assert isinstance(input_shape, data.ArrayLayout), 'Input shape must be an ArrayLayout'
assert isinstance(output_shape, data.ArrayLayout), 'Output shape must be an ArrayLayout'
stream_shape = data.StructLayout({
'element': element_shape,
'col': range(input_shape.length),
'row': range(output_shape.length),
})
super().__init__({
'reset': In(1),
'input': In(input_shape),
'output': Out(output_shape),
'matrix': In(StreamSignature(stream_shape, backpressure = False)),
})
def elaborate(self, platform):
m = Module()
with m.If(self.reset):
m.d.sync += self.output.eq(0)
with m.If(self.matrix.valid):
m.d.sync += self.output[self.matrix.data.row].eq(self.output[self.matrix.data.row] + self.input[self.matrix.data.col] * self.matrix.data.element)
return m
class DUT(Elaboratable):
def elaborate(self, platform):
m = Module()
Scalar = SQ(7, 8)
Vector = data.ArrayLayout(Scalar, 3)
m.submodules.matrix_source = self.matrix_source = MatrixSource(Scalar, [
[1, 0, 0],
[0, 1.1, 0],
[0, 0, -0.5],
])
m.submodules.matrix_multiplier = self.matrix_multiplier = MatrixMultiplier(Vector, Vector, Scalar)
connect(m, self.matrix_source.matrix, self.matrix_multiplier.matrix)
return m
dut = DUT()
async def test_matrix(sim, input):
await dut.matrix_multiplier.reset.set(1)
await sim.tick()
await dut.matrix_multiplier.reset.set(0)
await dut.matrix_multiplier.input.set(input)
await dut.matrix_source.start.set(1)
await sim.tick()
await dut.matrix_source.start.set(0)
await sim.tick().until(dut.matrix_source.active == 0)
output = await dut.matrix_multiplier.output.get()
print(f'{input} => {output}')
async def testbench(sim):
await test_matrix(sim, [1, 2, 3])
await test_matrix(sim, [0, 0, 0])
await test_matrix(sim, [0.5, 0.5, 0.5])
'''
% ./sim_matrix.py
[1, 2, 3] => [1.0, 2.203125, -1.5]
[0, 0, 0] => [0.0, 0.0, 0.0]
[0.5, 0.5, 0.5] => [0.5, 0.55078125, -0.25]
'''
|