class Shifter(Module):
    def __init__(self, pads, clk_freq):
        width = len(pads.sdi)

        self.sink = sink = stream.Endpoint([('data', width)])
        self.source = source = stream.Endpoint([('data', width)])

        conv_t = int(ceil(clk_freq * 1000e-9))
        conv_cnt = Signal(max = conv_t)

        valid = Signal()
        cs = Signal()
        sdi = Signal(width)
        sdo = Signal(width)

        self.specials += DDROutput(
            i1 = 0,
            i2 = valid,
            o = pads.sclk,
            clk = ClockSignal(),
        )

        self.specials += DDROutput(
            i1 = ~cs,
            i2 = ~cs,
            o = pads.cs_n,
            clk = ClockSignal(),
        )

        for i in range(width):
            self.specials += DDROutput(
                i1 = sdi[i],
                i2 = sdi[i],
                o = pads.sdi[i],
                clk = ClockSignal(),
            )

            self.specials += DDRInput(
                i = pads.sdo[i],
                o1 = Signal(), # TODO: Determine correct latency/phase
                o2 = sdo[i],
                clk = ClockSignal(),
            )
        
        self.submodules.fsm = fsm = FSM()

        fsm.act('IDLE',
            If(sink.valid,
                NextState('START'),
            ),
        )

        fsm.act('START',
            cs.eq(1),
            NextState('ACTIVE'),
        )

        fsm.act('ACTIVE',
            cs.eq(1),
            valid.eq(sink.valid),
            sdi.eq(sink.data),
            sink.ready.eq(1),
            If(sink.valid & sink.last,
                NextState('END'),
            )
        )

        fsm.act('END',
            cs.eq(1),
            NextState('CONV'),
            NextValue(conv_cnt, conv_t - 1),
        )

        fsm.act('CONV',
            NextValue(conv_cnt, conv_cnt - 1),
            If(conv_cnt == 0,
                NextState('IDLE'),
            )
        )