class JTAGTransport(wiring.Component):
    input: wiring.In(stream.Signature(8))
    output: wiring.Out(stream.Signature(8))

    def elaborate(self, platform):
        m = Module()

        m.domains += ClockDomain('jtag', reset_less = True)

        m.submodules.output_fifo = output_fifo = fifo.AsyncFIFO(width = 8, depth = 4, w_domain = 'sync', r_domain = 'jtag')
        wiring.connect(m, wiring.flipped(self.input), output_fifo.w_stream)

        m.submodules.input_fifo = input_fifo = fifo.AsyncFIFO(width = 8, depth = 4, w_domain = 'jtag', r_domain = 'sync')
        wiring.connect(m, input_fifo.r_stream, wiring.flipped(self.output))

        jtdi = Signal()
        jtdi_reg = Signal()
        jce1 = Signal()
        jtdo1 = Signal()

        m.d.jtag += jtdi_reg.eq(jtdi)

        # Input state machine
        input_state = Signal(range(9))
        m.d.jtag += input_fifo.w_stream.valid.eq(0)
        with m.If(jce1):
            with m.Switch(input_state):
                with m.Case(0):
                    with m.If(jtdi_reg):
                        m.d.jtag += input_state.eq(1)

                for i in range(8):
                    with m.Case(i + 1):
                        m.d.jtag += input_fifo.w_stream.payload[i].eq(jtdi_reg)
                        m.d.jtag += input_state.eq(input_state + 1)
                        if i == 7:
                            m.d.jtag += input_fifo.w_stream.valid.eq(1)
                            m.d.jtag += input_state.eq(0)

        with m.Else():
            m.d.jtag += input_state.eq(0)

        # Output state machine
        output_state = Signal(range(10))
        with m.If(jce1):
            with m.Switch(output_state):
                # Inactive/preparing state
                with m.Case(0):
                    m.d.jtag += output_state.eq(1)

                # Ready state
                with m.Case(1):
                    #m.d.comb += output_fifo.r_stream.ready.eq(1)
                    with m.If(output_fifo.r_stream.valid):
                        m.d.comb += jtdo1.eq(1)
                        m.d.jtag += output_state.eq(2)

                # Outputting state
                for i in range(8):
                    with m.Case(i + 2):
                        m.d.comb += jtdo1.eq(output_fifo.r_stream.payload[i])
                        m.d.jtag += output_state.eq(output_state + 1)
                        if i == 7:
                            m.d.comb += output_fifo.r_stream.ready.eq(1)
                            m.d.jtag += output_state.eq(1)

        with m.Else():
            m.d.jtag += output_state.eq(0)

        m.submodules.jtagg = jtagg = Instance('JTAGG',
            o_JTCK = ClockSignal('jtag'),
            o_JTDI = jtdi,
            o_JCE1 = jce1,
            i_JTDO1 = jtdo1,
        )

        return m