If the total packet size is larger than 1024, then it is splitted into subpackets

Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
Yohann D'ANELLO 2020-12-21 16:04:13 +01:00
parent bda4860aca
commit b96ff488e7
Signed by: ynerant
GPG Key ID: 3A75C55819C8CF85
2 changed files with 40 additions and 9 deletions

View File

@ -75,6 +75,9 @@ class Squirrel(Hazelnut):
"""
Send a formatted packet to a client.
"""
if len(pkt) > 1024:
# The packet is too large to be sent by the protocol. We split the packet in subpackets.
return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024))
return self.send_raw_data(client, pkt.marshal())
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:

View File

@ -42,8 +42,7 @@ class TLV:
It is ensured that the data is valid.
"""
@property
def tlv_length(self) -> int:
def __len__(self) -> int:
"""
Returns the total length (in bytes) of the TLV, including the type and the length.
Except for Pad1, this is 2 plus the length of the body of the TLV.
@ -77,8 +76,7 @@ class Pad1TLV(TLV):
# TODO Add some easter eggs
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
@property
def tlv_length(self) -> int:
def __len__(self) -> int:
"""
A Pad1 has always a length of 1.
"""
@ -104,7 +102,7 @@ class PadNTLV(TLV):
"""
self.type = raw_data[0]
self.length = raw_data[1]
self.mbz = raw_data[2:self.tlv_length]
self.mbz = raw_data[2:len(self)]
def marshal(self) -> bytes:
"""
@ -190,7 +188,7 @@ class DataTLV(TLV):
self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
self.data = raw_data[14:self.tlv_length]
self.data = raw_data[14:len(self)]
def marshal(self) -> bytes:
return self.type.to_bytes(1, sys.byteorder) + \
@ -291,7 +289,7 @@ class WarningTLV(TLV):
class Packet:
"""
A Packet is a wrapper around the
A Packet is a wrapper around the raw data that it sent and received to other clients.
"""
magic: int
version: int
@ -331,7 +329,7 @@ class Packet:
tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4:4 + pkt.body_length])
pkt.body.append(tlv)
read_bytes += tlv.tlv_length
read_bytes += len(tlv)
pkt.validate_data()
@ -347,6 +345,36 @@ class Packet:
data += b"".join(tlv.marshal() for tlv in self.body)
return data
def __len__(self) -> int:
"""
Calculates the length, in bytes, of the packet.
"""
return 4 + sum(len(tlv) for tlv in self.body)
def split(self, pkt_size: int) -> List["Packet"]:
"""
If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024),
then we split the packet in sub-packets.
Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet,
then we don't need to split TLVs in smaller TLVs.
"""
packets = []
current_size = 4 # Packet header length
body = []
for tlv in self.body:
if current_size + len(tlv) > pkt_size:
packets.append(Packet.construct(*body))
body.clear()
current_size = 4
body.append(tlv)
current_size += len(tlv)
if body:
packets.append(Packet.construct(*body))
return packets
@staticmethod
def construct(*tlvs: TLV) -> "Packet":
"""
@ -356,5 +384,5 @@ class Packet:
pkt.magic = 95
pkt.version = 0
pkt.body = tlvs
pkt.body_length = sum(tlv.tlv_length for tlv in tlvs)
pkt.body_length = sum(len(tlv) for tlv in tlvs)
return pkt