class MatrixSource(Component):
    def __init__(self, element_shape, elements):
        self._elements = elements

        rows = len(elements)
        cols = len(elements[0])
        assert all(len(row) == cols for row in elements), 'All rows must have the same length'

        stream_shape = data.StructLayout({
            'element': element_shape,
            'col': range(cols),
            'row': range(rows),
        })

        super().__init__({
            'start': In(1),
            'active': Out(1),
            'matrix': Out(StreamSignature(stream_shape)),
        })

    def elaborate(self, platform):
        m = Module()

        elements = [{'element': element, 'col': col, 'row': row} for row, elements_row in enumerate(self._elements) for col, element in enumerate(elements_row)]

        element_array = Signal(data.ArrayLayout(self.matrix.data.shape(), len(elements)))

        idx = Signal(range(len(elements)))

        m.d.comb += [element_array[i][k].eq(v) for i, element in enumerate(elements) for k, v in element.items()]

        m.d.comb += [
            self.matrix.valid.eq(self.active),
            self.matrix.data.eq(element_array[idx]),
        ]

        with m.If(self.start):
            m.d.sync += [
                self.active.eq(1),
                idx.eq(0),
            ]

        with m.If(self.matrix.valid & self.matrix.ready):
            m.d.sync += idx.eq(idx + 1)

            with m.If(idx == len(elements) - 1):
                m.d.sync += self.active.eq(0)

        return m

class MatrixMultiplier(Component):
    def __init__(self, input_shape, output_shape, element_shape):
        assert isinstance(input_shape, data.ArrayLayout), 'Input shape must be an ArrayLayout'
        assert isinstance(output_shape, data.ArrayLayout), 'Output shape must be an ArrayLayout'

        stream_shape = data.StructLayout({
            'element': element_shape,
            'col': range(input_shape.length),
            'row': range(output_shape.length),
        })

        super().__init__({
            'reset': In(1),
            'input': In(input_shape),
            'output': Out(output_shape),
            'matrix': In(StreamSignature(stream_shape, backpressure = False)),
        })
    
    def elaborate(self, platform):
        m = Module()

        with m.If(self.reset):
            m.d.sync += self.output.eq(0)
        
        with m.If(self.matrix.valid):
            m.d.sync += self.output[self.matrix.data.row].eq(self.output[self.matrix.data.row] + self.input[self.matrix.data.col] * self.matrix.data.element)

        return m

class DUT(Elaboratable):
    def elaborate(self, platform):
        m = Module()

        Scalar = SQ(7, 8)
        Vector = data.ArrayLayout(Scalar, 3)

        m.submodules.matrix_source = self.matrix_source = MatrixSource(Scalar, [
            [1,   0,    0],
            [0, 1.1,    0],
            [0,   0, -0.5],
        ])

        m.submodules.matrix_multiplier = self.matrix_multiplier = MatrixMultiplier(Vector, Vector, Scalar)

        connect(m, self.matrix_source.matrix, self.matrix_multiplier.matrix)

        return m

dut = DUT()

async def test_matrix(sim, input):
    await dut.matrix_multiplier.reset.set(1)
    await sim.tick()
    await dut.matrix_multiplier.reset.set(0)

    await dut.matrix_multiplier.input.set(input)

    await dut.matrix_source.start.set(1)
    await sim.tick()
    await dut.matrix_source.start.set(0)
    await sim.tick().until(dut.matrix_source.active == 0)

    output = await dut.matrix_multiplier.output.get()

    print(f'{input} => {output}')

async def testbench(sim):
    await test_matrix(sim, [1, 2, 3])
    await test_matrix(sim, [0, 0, 0])
    await test_matrix(sim, [0.5, 0.5, 0.5])

'''
% ./sim_matrix.py
[1, 2, 3] => [1.0, 2.203125, -1.5]
[0, 0, 0] => [0.0, 0.0, 0.0]
[0.5, 0.5, 0.5] => [0.5, 0.55078125, -0.25]
'''