class MatrixSource(Component): def __init__(self, element_shape, elements): self._elements = elements rows = len(elements) cols = len(elements[0]) assert all(len(row) == cols for row in elements), 'All rows must have the same length' stream_shape = data.StructLayout({ 'element': element_shape, 'col': range(cols), 'row': range(rows), }) super().__init__({ 'start': In(1), 'active': Out(1), 'matrix': Out(StreamSignature(stream_shape)), }) def elaborate(self, platform): m = Module() elements = [{'element': element, 'col': col, 'row': row} for row, elements_row in enumerate(self._elements) for col, element in enumerate(elements_row)] element_array = Signal(data.ArrayLayout(self.matrix.data.shape(), len(elements))) idx = Signal(range(len(elements))) m.d.comb += [element_array[i][k].eq(v) for i, element in enumerate(elements) for k, v in element.items()] m.d.comb += [ self.matrix.valid.eq(self.active), self.matrix.data.eq(element_array[idx]), ] with m.If(self.start): m.d.sync += [ self.active.eq(1), idx.eq(0), ] with m.If(self.matrix.valid & self.matrix.ready): m.d.sync += idx.eq(idx + 1) with m.If(idx == len(elements) - 1): m.d.sync += self.active.eq(0) return m class MatrixMultiplier(Component): def __init__(self, input_shape, output_shape, element_shape): assert isinstance(input_shape, data.ArrayLayout), 'Input shape must be an ArrayLayout' assert isinstance(output_shape, data.ArrayLayout), 'Output shape must be an ArrayLayout' stream_shape = data.StructLayout({ 'element': element_shape, 'col': range(input_shape.length), 'row': range(output_shape.length), }) super().__init__({ 'reset': In(1), 'input': In(input_shape), 'output': Out(output_shape), 'matrix': In(StreamSignature(stream_shape, backpressure = False)), }) def elaborate(self, platform): m = Module() with m.If(self.reset): m.d.sync += self.output.eq(0) with m.If(self.matrix.valid): m.d.sync += self.output[self.matrix.data.row].eq(self.output[self.matrix.data.row] + self.input[self.matrix.data.col] * self.matrix.data.element) return m class DUT(Elaboratable): def elaborate(self, platform): m = Module() Scalar = SQ(7, 8) Vector = data.ArrayLayout(Scalar, 3) m.submodules.matrix_source = self.matrix_source = MatrixSource(Scalar, [ [1, 0, 0], [0, 1.1, 0], [0, 0, -0.5], ]) m.submodules.matrix_multiplier = self.matrix_multiplier = MatrixMultiplier(Vector, Vector, Scalar) connect(m, self.matrix_source.matrix, self.matrix_multiplier.matrix) return m dut = DUT() async def test_matrix(sim, input): await dut.matrix_multiplier.reset.set(1) await sim.tick() await dut.matrix_multiplier.reset.set(0) await dut.matrix_multiplier.input.set(input) await dut.matrix_source.start.set(1) await sim.tick() await dut.matrix_source.start.set(0) await sim.tick().until(dut.matrix_source.active == 0) output = await dut.matrix_multiplier.output.get() print(f'{input} => {output}') async def testbench(sim): await test_matrix(sim, [1, 2, 3]) await test_matrix(sim, [0, 0, 0]) await test_matrix(sim, [0.5, 0.5, 0.5]) ''' % ./sim_matrix.py [1, 2, 3] => [1.0, 2.203125, -1.5] [0, 0, 0] => [0.0, 0.0, 0.0] [0.5, 0.5, 0.5] => [0.5, 0.55078125, -0.25] '''