class MultilaneStreamInterface(PureInterface):
    def __init__(self, signature, *, path, src_loc_at = 0):
        super().__init__(signature, path = path, src_loc_at = src_loc_at + 1)

        if signature.backpressure == False:
            self.ready = C(1)
    
    async def recv_packet(self, sim):
        assert 'last' in self.signature.members

        await self.ready.set(1)

        values = []
        done = False

        while not done:
            await sim.tick().until(self.valid)

            for i in range(self.signature.lanes):
                v = await self.data[i].get()
                values.append(v)

                if await self.last[i].get():
                    done = True
                    break

            await sim.tick()

        await self.ready.set(0)

        return values

    async def send_packet(self, sim, values):
        transactions = list(itertools.batched(values, self.signature.lanes))
        first = 0
        last = len(transactions) - 1

        await self.valid.set(1)

        for transaction_num, data in enumerate(transactions):
            for i, value in enumerate(data):
                await self.data[i].set(value)

                if 'first' in self.signature.members:
                    await self.first[i].set(transaction_num == first and i == 0)
                if 'last' in self.signature.members:
                    await self.last[i].set(transaction_num == last and i == len(data) - 1)

            await sim.tick().until(self.ready)
            await sim.tick()
        
        await self.valid.set(0)


class StreamSignature(Signature):
    def __init__(self, data_shape, *, backpressure = True, first = False, last = False, lanes = None):
        members = {
            'data': Out(data_shape),
            'valid': Out(1),
            'ready': In(1),
        }

        if first:
            members['first'] = Out(1)

        if last:
            members['last'] = Out(1)

        if lanes is not None:
            for name in ['data', 'first', 'last']:
                if name in members:
                    members[name] = members[name].array(lanes)

        super().__init__(members)

        self.backpressure = backpressure
        self.lanes = lanes

    def create(self, *, path = (), src_loc_at = 0):
        if self.lanes is None:
            return StreamInterface(self, path = path, src_loc_at = src_loc_at + 1)
        else:
            return MultilaneStreamInterface(self, path = path, src_loc_at = src_loc_at + 1)