import deps
import os

import luna.usb2
from luna.gateware.architecture.car import PHYResetController

import nmigen
from nmigen.back import verilog

import migen
from litex.soc.interconnect.stream import Endpoint

class NMigenWrapper(nmigen.Elaboratable):
    def __init__(self, name = None):
        self.name = name
        self.submodules = []
        self.signals = []
        self.comb = []

        self.connections = {}

    def elaborate(self, platform):
        m = nmigen.Module()

        m.domains += nmigen.ClockDomain('sync')
        #m.domains += nmigen.ClockDomain('usb')
        #m.domains.sync = nmigen.ClockDomain()
        #m.domains.usb  = nmigen.ClockDomain()

        m.submodules += self.submodules
        m.d.comb += self.comb

        return m
    
    def get_instance(self):
        return migen.Instance(self.name, **self.connections)

    def generate_verilog(self):
        ports = []
        for s in self.signals:
            ports.extend(s._lhs_signals())

        return verilog.convert(self, name = self.name, ports = ports)

class USBDevice(migen.Module):
    def __init__(self, platform, *args, **kwargs):
        self.platform = platform

        self._wrapper = NMigenWrapper('USBDevice')

        sync_clk = nmigen.Signal(name = 'sync_clk')
        sync_rst = nmigen.Signal(name = 'sync_rst')
        #usb_clk = nmigen.Signal(name = 'usb_clk')

        self._wrapper.signals.extend([
            sync_clk,
            sync_rst,
            #usb_clk,
        ])

        self._wrapper.comb.append([
            nmigen.ClockSignal('sync').eq(sync_clk),
            nmigen.ResetSignal('sync').eq(sync_rst),
            #nmigen.ClockSignal('usb').eq(usb_clk),
        ])

        phy_reset_controller = PHYResetController()
        self._wrapper.submodules.append(phy_reset_controller)

        self._wrapper.comb.append([
            nmigen.ResetSignal('usb').eq(phy_reset_controller.phy_reset),
        ])

        ulpi_pads = platform.request('ulpi')
        ulpi_data = migen.TSTriple(8)
        #ulpi_rst = migen.Signal()

        self.specials += ulpi_data.get_tristate(ulpi_pads.data)
        #self.comb += ulpi_pads.rst.eq(~ulpi_rst)

        self._wrapper.connections.update(dict(
            i_sync_clk = migen.ClockSignal(),
            i_sync_rst = migen.ResetSignal(),
            #i_usb_clk = migen.ClockSignal(),

            o_ulpi__data__o = ulpi_data.o,
            o_ulpi__data__oe = ulpi_data.oe,
            i_ulpi__data__i = ulpi_data.i,
            i_ulpi__clk__i = ulpi_pads.clk,
            o_ulpi__stp = ulpi_pads.stp,
            i_ulpi__nxt__i = ulpi_pads.nxt,
            i_ulpi__dir__i = ulpi_pads.dir,
            o_ulpi__rst = ulpi_pads.rst, 
            #o_ulpi__rst = ulpi_rst, 
        ))

        from nmigen.hdl.rec import DIR_FANIN, DIR_FANOUT, DIR_NONE
        self._ulpi = nmigen.Record(
            [
                ('data', [('i', 8, DIR_FANIN), ('o', 8, DIR_FANOUT), ('oe', 1, DIR_FANOUT)]),
                ('clk', [('i', 1, DIR_FANIN)]),
                ('stp', 1, DIR_FANOUT),
                ('nxt', [('i', 1, DIR_FANIN)]),
                ('dir', [('i', 1, DIR_FANIN)]),
                ('rst', 1, DIR_FANOUT),
            ],
            name = 'ulpi',
        )

        self._wrapper.signals.append(self._ulpi)

        self.usb = luna.usb2.USBDevice(bus = self._ulpi, *args, **kwargs)
        self._wrapper.submodules.append(self.usb)

        self._wrapper.comb.append([
            self.usb.connect.eq(1),
        ])

    def add_endpoint(self, ep):
        self.usb.add_endpoint(ep._ep)
        ep.wrap(self._wrapper)
    
    def do_finalize(self):
        verilog_filename = os.path.join(self.platform.output_dir, 'gateware', 'USBDevice.v')

        with open(verilog_filename, 'w') as f:
            f.write(self._wrapper.generate_verilog())

        self.platform.add_source(verilog_filename)

        self.specials += self._wrapper.get_instance()

class USBStreamOutEndpoint:
    def __init__(self, *, endpoint_number, **kwargs):
        self._ep = luna.usb2.USBStreamOutEndpoint(endpoint_number = endpoint_number, **kwargs)
        self.prefix = f'ep_{endpoint_number}_out'

        self.source = Endpoint([('data', 8)])
    
    def wrap(self, wrapper):
        stream = nmigen.Record(self._ep.stream.layout, name = self.prefix)
        wrapper.comb.append([
            stream.payload.eq(self._ep.stream.payload),
            stream.first.eq(self._ep.stream.first),
            stream.last.eq(self._ep.stream.last),
            stream.valid.eq(self._ep.stream.valid),
            self._ep.stream.ready.eq(stream.ready),
        ])
        wrapper.signals.append(stream)
    
        wrapper.connections.update({
            f'o_{self.prefix}__payload': self.source.data,
            f'o_{self.prefix}__first': self.source.first,
            f'o_{self.prefix}__last': self.source.last,
            f'o_{self.prefix}__valid': self.source.valid,
            f'i_{self.prefix}__ready': self.source.ready,
        })

class USBStreamInEndpoint:
    def __init__(self, *, endpoint_number, **kwargs):
        self._ep = luna.usb2.USBStreamInEndpoint(endpoint_number = endpoint_number, **kwargs)
        self.prefix = f'ep_{endpoint_number}_in'

        self.sink = Endpoint([('data', 8)])
    
    def wrap(self, wrapper):
        stream = nmigen.Record(self._ep.stream.layout, name = self.prefix)
        wrapper.comb.append([
            self._ep.stream.payload.eq(stream.payload),
            self._ep.stream.first.eq(stream.first),
            self._ep.stream.last.eq(stream.last),
            self._ep.stream.valid.eq(stream.valid),
            stream.ready.eq(self._ep.stream.ready),
        ])
        wrapper.signals.append(stream)

        wrapper.connections.update({
            f'i_{self.prefix}__payload': self.sink.data,
            f'i_{self.prefix}__first': self.sink.first,
            f'i_{self.prefix}__last': self.sink.last,
            f'i_{self.prefix}__valid': self.sink.valid,
            f'o_{self.prefix}__ready': self.sink.ready,
        })

class USBMultibyteStreamInEndpoint:
    def __init__(self, *, endpoint_number, byte_width, **kwargs):
        self._ep = luna.usb2.USBMultibyteStreamInEndpoint(endpoint_number = endpoint_number, byte_width = byte_width, **kwargs)
        self.prefix = f'ep_{endpoint_number}_in'

        self.sink = Endpoint([('data', 8 * byte_width)])
    
    def wrap(self, wrapper):
        stream = nmigen.Record(self._ep.stream.layout, name = self.prefix)
        wrapper.comb.append([
            self._ep.stream.payload.eq(stream.payload),
            self._ep.stream.first.eq(stream.first),
            self._ep.stream.last.eq(stream.last),
            self._ep.stream.valid.eq(stream.valid),
            stream.ready.eq(self._ep.stream.ready),
        ])
        wrapper.signals.append(stream)

        wrapper.connections.update({
            f'i_{self.prefix}__payload': self.sink.data,
            f'i_{self.prefix}__first': self.sink.first,
            f'i_{self.prefix}__last': self.sink.last,
            f'i_{self.prefix}__valid': self.sink.valid,
            f'o_{self.prefix}__ready': self.sink.ready,
        })