from amaranth import *

from amaranth.lib.wiring import Component, Signature, In, Out, connect, flipped
from amaranth.lib.fixed import SQ
from amaranth.lib import data

from signal_types import DQEnc, ABCEnc, ABCVector, DQVector, Current

from zyp_amaranth_libs.stream import StreamSignature

class TorquePI(Component):
    enable: In(1)
    setpoint: In(DQVector)
    input: In(StreamSignature(DQVector))
    output: Out(StreamSignature(DQVector))

    gain_kp: In(SQ(8, 8))
    gain_ki: In(SQ(8, 8))
    gain_ff0: In(SQ(8, 8))
    out_max: In(SQ(16, 0))

    def elaborate(self, platform):
        m = Module()

        error = Signal(DQVector)
        i_acc = Signal(DQVector)

        mul_gain = Signal(SQ(8, 8))
        mul_sig = Signal(SQ(16, 0))
        mul_add = Signal(SQ(16, 0))
        mul_res = Signal(SQ(24, 0))
        mul_res_clamped = Signal(SQ(16, 0))

        m.d.sync += mul_res.eq(mul_gain * mul_sig + mul_add)

        m.d.comb += [
            mul_res_clamped.eq(mul_res),
        ]

        with m.If(mul_res > self.out_max):
            m.d.comb += mul_res_clamped.eq(self.out_max)

        with m.If(mul_res < -self.out_max):
            m.d.comb += mul_res_clamped.eq(-self.out_max)

        # FSM
        with m.FSM() as fsm:
            with m.State('IDLE'):
                m.d.comb += self.input.ready.eq(1)
                with m.If(self.input.valid):
                    m.d.sync += [
                        error[0].eq(self.setpoint[0] - self.input.data[0]),
                        error[1].eq(self.setpoint[1] - self.input.data[1]),
                    ]
                    m.next = 'START'

            with m.State('START'):
                m.next = 'I_D'
                m.d.comb += [
                    mul_gain.eq(self.gain_ki),
                    mul_sig.eq(error[0]),
                    mul_add.eq(i_acc[0]),
                ]

            with m.State('I_D'):
                m.d.sync += i_acc[0].eq(mul_res_clamped)

                m.next = 'I_Q'
                m.d.comb += [
                    mul_gain.eq(self.gain_ki),
                    mul_sig.eq(error[1]),
                    mul_add.eq(i_acc[1]),
                ]

            with m.State('I_Q'):
                m.d.sync += i_acc[1].eq(mul_res_clamped)

                m.next = 'P_D'
                m.d.comb += [
                    mul_gain.eq(self.gain_kp),
                    mul_sig.eq(error[0]),
                    mul_add.eq(i_acc[0]),
                ]

            with m.State('P_D'):
                m.d.sync += self.output.data[0].eq(mul_res_clamped)

                m.next = 'P_Q'
                m.d.comb += [
                    mul_gain.eq(self.gain_kp),
                    mul_sig.eq(error[1]),
                    mul_add.eq(i_acc[1]),
                ]

            with m.State('P_Q'):
                m.d.sync += self.output.data[1].eq(mul_res_clamped)

                m.next = 'FF0_D'
                m.d.comb += [
                    mul_gain.eq(self.gain_ff0),
                    mul_sig.eq(self.setpoint[0]),
                    mul_add.eq(self.output.data[0]),
                ]

            with m.State('FF0_D'):
                m.d.sync += self.output.data[0].eq(mul_res_clamped)

                m.next = 'FF0_Q'
                m.d.comb += [
                    mul_gain.eq(self.gain_ff0),
                    mul_sig.eq(self.setpoint[1]),
                    mul_add.eq(self.output.data[1]),
                ]

            with m.State('FF0_Q'):
                m.d.sync += self.output.data[1].eq(mul_res_clamped)

                m.next = 'DONE'

            with m.State('DONE'):
                m.d.comb += self.output.valid.eq(1)
                with m.If(self.output.ready):
                    m.next = 'IDLE'

        with m.If(self.enable == 0):
            m.d.sync += [
                i_acc.eq(0),
                self.output.data.eq(0),
            ]

        return m