If the total packet size is larger than 1024, then it is splitted into subpackets

Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
Yohann D'ANELLO 2020-12-21 16:04:13 +01:00
parent bda4860aca
commit b96ff488e7
Signed by: ynerant
GPG Key ID: 3A75C55819C8CF85
2 changed files with 40 additions and 9 deletions

View File

@ -75,6 +75,9 @@ class Squirrel(Hazelnut):
""" """
Send a formatted packet to a client. Send a formatted packet to a client.
""" """
if len(pkt) > 1024:
# The packet is too large to be sent by the protocol. We split the packet in subpackets.
return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024))
return self.send_raw_data(client, pkt.marshal()) return self.send_raw_data(client, pkt.marshal())
def send_raw_data(self, client: Hazelnut, data: bytes) -> int: def send_raw_data(self, client: Hazelnut, data: bytes) -> int:

View File

@ -42,8 +42,7 @@ class TLV:
It is ensured that the data is valid. It is ensured that the data is valid.
""" """
@property def __len__(self) -> int:
def tlv_length(self) -> int:
""" """
Returns the total length (in bytes) of the TLV, including the type and the length. Returns the total length (in bytes) of the TLV, including the type and the length.
Except for Pad1, this is 2 plus the length of the body of the TLV. Except for Pad1, this is 2 plus the length of the body of the TLV.
@ -77,8 +76,7 @@ class Pad1TLV(TLV):
# TODO Add some easter eggs # TODO Add some easter eggs
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.") squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
@property def __len__(self) -> int:
def tlv_length(self) -> int:
""" """
A Pad1 has always a length of 1. A Pad1 has always a length of 1.
""" """
@ -104,7 +102,7 @@ class PadNTLV(TLV):
""" """
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.mbz = raw_data[2:self.tlv_length] self.mbz = raw_data[2:len(self)]
def marshal(self) -> bytes: def marshal(self) -> bytes:
""" """
@ -190,7 +188,7 @@ class DataTLV(TLV):
self.length = raw_data[1] self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder) self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder)) self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
self.data = raw_data[14:self.tlv_length] self.data = raw_data[14:len(self)]
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, sys.byteorder) + \ return self.type.to_bytes(1, sys.byteorder) + \
@ -291,7 +289,7 @@ class WarningTLV(TLV):
class Packet: class Packet:
""" """
A Packet is a wrapper around the A Packet is a wrapper around the raw data that it sent and received to other clients.
""" """
magic: int magic: int
version: int version: int
@ -331,7 +329,7 @@ class Packet:
tlv = TLV.tlv_classes()[tlv_type]() tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4:4 + pkt.body_length]) tlv.unmarshal(data[4:4 + pkt.body_length])
pkt.body.append(tlv) pkt.body.append(tlv)
read_bytes += tlv.tlv_length read_bytes += len(tlv)
pkt.validate_data() pkt.validate_data()
@ -347,6 +345,36 @@ class Packet:
data += b"".join(tlv.marshal() for tlv in self.body) data += b"".join(tlv.marshal() for tlv in self.body)
return data return data
def __len__(self) -> int:
"""
Calculates the length, in bytes, of the packet.
"""
return 4 + sum(len(tlv) for tlv in self.body)
def split(self, pkt_size: int) -> List["Packet"]:
"""
If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024),
then we split the packet in sub-packets.
Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet,
then we don't need to split TLVs in smaller TLVs.
"""
packets = []
current_size = 4 # Packet header length
body = []
for tlv in self.body:
if current_size + len(tlv) > pkt_size:
packets.append(Packet.construct(*body))
body.clear()
current_size = 4
body.append(tlv)
current_size += len(tlv)
if body:
packets.append(Packet.construct(*body))
return packets
@staticmethod @staticmethod
def construct(*tlvs: TLV) -> "Packet": def construct(*tlvs: TLV) -> "Packet":
""" """
@ -356,5 +384,5 @@ class Packet:
pkt.magic = 95 pkt.magic = 95
pkt.version = 0 pkt.version = 0
pkt.body = tlvs pkt.body = tlvs
pkt.body_length = sum(tlv.tlv_length for tlv in tlvs) pkt.body_length = sum(len(tlv) for tlv in tlvs)
return pkt return pkt