diff --git a/squinnondation/hazel.py b/squinnondation/hazel.py index e85c281..06d1024 100644 --- a/squinnondation/hazel.py +++ b/squinnondation/hazel.py @@ -75,6 +75,9 @@ class Squirrel(Hazelnut): """ Send a formatted packet to a client. """ + if len(pkt) > 1024: + # The packet is too large to be sent by the protocol. We split the packet in subpackets. + return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024)) return self.send_raw_data(client, pkt.marshal()) def send_raw_data(self, client: Hazelnut, data: bytes) -> int: diff --git a/squinnondation/messages.py b/squinnondation/messages.py index 6c50f29..1fe9d39 100644 --- a/squinnondation/messages.py +++ b/squinnondation/messages.py @@ -42,8 +42,7 @@ class TLV: It is ensured that the data is valid. """ - @property - def tlv_length(self) -> int: + def __len__(self) -> int: """ Returns the total length (in bytes) of the TLV, including the type and the length. Except for Pad1, this is 2 plus the length of the body of the TLV. @@ -77,8 +76,7 @@ class Pad1TLV(TLV): # TODO Add some easter eggs squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.") - @property - def tlv_length(self) -> int: + def __len__(self) -> int: """ A Pad1 has always a length of 1. """ @@ -104,7 +102,7 @@ class PadNTLV(TLV): """ self.type = raw_data[0] self.length = raw_data[1] - self.mbz = raw_data[2:self.tlv_length] + self.mbz = raw_data[2:len(self)] def marshal(self) -> bytes: """ @@ -190,7 +188,7 @@ class DataTLV(TLV): self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder) self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder)) - self.data = raw_data[14:self.tlv_length] + self.data = raw_data[14:len(self)] def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ @@ -291,7 +289,7 @@ class WarningTLV(TLV): class Packet: """ - A Packet is a wrapper around the + A Packet is a wrapper around the raw data that it sent and received to other clients. """ magic: int version: int @@ -331,7 +329,7 @@ class Packet: tlv = TLV.tlv_classes()[tlv_type]() tlv.unmarshal(data[4:4 + pkt.body_length]) pkt.body.append(tlv) - read_bytes += tlv.tlv_length + read_bytes += len(tlv) pkt.validate_data() @@ -347,6 +345,36 @@ class Packet: data += b"".join(tlv.marshal() for tlv in self.body) return data + def __len__(self) -> int: + """ + Calculates the length, in bytes, of the packet. + """ + return 4 + sum(len(tlv) for tlv in self.body) + + def split(self, pkt_size: int) -> List["Packet"]: + """ + If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024), + then we split the packet in sub-packets. + Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet, + then we don't need to split TLVs in smaller TLVs. + """ + packets = [] + + current_size = 4 # Packet header length + body = [] + for tlv in self.body: + if current_size + len(tlv) > pkt_size: + packets.append(Packet.construct(*body)) + body.clear() + current_size = 4 + body.append(tlv) + current_size += len(tlv) + + if body: + packets.append(Packet.construct(*body)) + + return packets + @staticmethod def construct(*tlvs: TLV) -> "Packet": """ @@ -356,5 +384,5 @@ class Packet: pkt.magic = 95 pkt.version = 0 pkt.body = tlvs - pkt.body_length = sum(tlv.tlv_length for tlv in tlvs) + pkt.body_length = sum(len(tlv) for tlv in tlvs) return pkt