#!/usr/bin/env python3

from amaranth import *
from amaranth.lib.wiring import connect, Component, Signature, In, Out
from amaranth.lib.memory import Memory

from amaranth_soc import wishbone
from amaranth_soc.wishbone.sram import WishboneSRAM
from amaranth_soc.memory import MemoryMap

from sentinel.top import Top as Sentinel

class WBMemory(Component):
    def __init__(self, *, sim=False, num_bytes=0x800):
        bus_signature = wishbone.Signature(addr_width=30, data_width=32,
                                           granularity=8)
        sig = {
            "bus": In(bus_signature)
        }

        if sim:
            sig["ctrl"] = Out(Signature({
                "force_ws": Out(1)  # noqa: F821
            }))

        self.sim = sim
        self.num_bytes = num_bytes
        self._mem_set = False

        super().__init__(sig)

        # Allocate a bunch of address space for RAM
        self.bus.memory_map = MemoryMap(addr_width=32, data_width=8)
        # But only actually _use_ a small chunk of it.
        self.bus.memory_map.add_resource(self, name=("ram",), size=num_bytes)
        self.mem = Memory(shape=32, depth=self.num_bytes//4, init=[])
        self._mem = self.mem
        self.wb_bus = self.bus

    @property
    def init(self):
        return self.mem.init

    @init.setter
    def init(self, mem):
        self.mem.init[:len(mem)] = mem

    def elaborate(self, plat):
        m = Module()

        m.submodules.mem = self.mem
        w_port = self.mem.write_port(granularity=8)
        r_port = self.mem.read_port(transparent_for=(w_port,))

        m.d.comb += [
            r_port.addr.eq(self.bus.adr),
            w_port.addr.eq(self.bus.adr),
            self.bus.dat_r.eq(r_port.data),
            w_port.data.eq(self.bus.dat_w),
            r_port.en.eq(self.bus.stb & self.bus.cyc & ~self.bus.we),
        ]

        with m.If(self.bus.stb & self.bus.cyc & self.bus.we):
            m.d.comb += w_port.en.eq(self.bus.sel)

        if self.sim:
            ack_cond = self.bus.stb & self.bus.cyc & ~self.bus.ack & \
                      ~self.ctrl.force_ws
        else:
            ack_cond = self.bus.stb & self.bus.cyc & ~self.bus.ack

        with m.If(ack_cond):
            m.d.sync += self.bus.ack.eq(1)
        with m.Else():
            m.d.sync += self.bus.ack.eq(0)

        return m

class DUT(Elaboratable):
    def __init__(self):
        self.cpu = Sentinel()
        self.decoder = wishbone.Decoder(addr_width = 30, data_width = 32, granularity = 8)
        #self.mem = WishboneSRAM(size = 0x800, data_width = 32, granularity = 8)
        self.mem = WBMemory()

    def elaborate(self, platform):
        m = Module()

        m.submodules += self.cpu
        m.submodules += self.mem
        m.submodules += self.decoder

        connect(m, self.cpu.bus, self.decoder.bus)
        self.decoder.add(self.mem.wb_bus, addr = 0x0)

        return m

from amaranth.sim import Simulator, SimulatorContext

def main():
    dut = DUT()

    async def testbench(ctx: SimulatorContext):
        from elftools.elf.elffile import ELFFile
        import struct

        filename = 'firmware/foo.elf'
        e = ELFFile(open(filename, 'rb'))

        for segment in sorted(e.iter_segments(), key = lambda x: x.header.p_paddr):
            if segment.header.p_type != 'PT_LOAD':
                continue
            
            if segment.header.p_filesz == 0:
                continue

            # Segments may have padding before first section, strip it.
            skip = min(s.header.sh_offset for s in e.iter_sections() if segment.section_in_segment(s)) - segment.header.p_offset

            addr = segment.header.p_paddr + skip
            data = segment.data()[skip:]

            if not data:
                continue

            assert addr & 3 == 0, 'addr must be 32b aligned'
            data = data + b'\0' * -(len(data) % -4)

            for i, (w,) in enumerate(struct.iter_unpack('<I', data)):
                ctx.set(dut.mem._mem.data[addr + i], w)

        #ctx.set(dut.mem._mem.data[0], 0x02002783)
        #ctx.set(dut.mem._mem.data[1], 0x00178793)
        #ctx.set(dut.mem._mem.data[2], 0x02f02023)
        #ctx.set(dut.mem._mem.data[3], 0xff5ff06f)

        print(ctx.get(dut.mem._mem.data[0x100]))

        for i in range(1000):
            await ctx.tick()

        print(ctx.get(dut.mem._mem.data[0x100]))

        for i in range(1000):
            await ctx.tick()

        print(ctx.get(dut.mem._mem.data[0x100]))

    sim = Simulator(dut)
    sim.add_clock(1e-6)
    sim.add_testbench(testbench)
    with sim.write_vcd('foo.vcd'):
        sim.run()