If the total packet size is larger than 1024, then it is splitted into subpackets
Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
parent
bda4860aca
commit
b96ff488e7
|
@ -75,6 +75,9 @@ class Squirrel(Hazelnut):
|
||||||
"""
|
"""
|
||||||
Send a formatted packet to a client.
|
Send a formatted packet to a client.
|
||||||
"""
|
"""
|
||||||
|
if len(pkt) > 1024:
|
||||||
|
# The packet is too large to be sent by the protocol. We split the packet in subpackets.
|
||||||
|
return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024))
|
||||||
return self.send_raw_data(client, pkt.marshal())
|
return self.send_raw_data(client, pkt.marshal())
|
||||||
|
|
||||||
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
|
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
|
||||||
|
|
|
@ -42,8 +42,7 @@ class TLV:
|
||||||
It is ensured that the data is valid.
|
It is ensured that the data is valid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
def __len__(self) -> int:
|
||||||
def tlv_length(self) -> int:
|
|
||||||
"""
|
"""
|
||||||
Returns the total length (in bytes) of the TLV, including the type and the length.
|
Returns the total length (in bytes) of the TLV, including the type and the length.
|
||||||
Except for Pad1, this is 2 plus the length of the body of the TLV.
|
Except for Pad1, this is 2 plus the length of the body of the TLV.
|
||||||
|
@ -77,8 +76,7 @@ class Pad1TLV(TLV):
|
||||||
# TODO Add some easter eggs
|
# TODO Add some easter eggs
|
||||||
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
|
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
|
||||||
|
|
||||||
@property
|
def __len__(self) -> int:
|
||||||
def tlv_length(self) -> int:
|
|
||||||
"""
|
"""
|
||||||
A Pad1 has always a length of 1.
|
A Pad1 has always a length of 1.
|
||||||
"""
|
"""
|
||||||
|
@ -104,7 +102,7 @@ class PadNTLV(TLV):
|
||||||
"""
|
"""
|
||||||
self.type = raw_data[0]
|
self.type = raw_data[0]
|
||||||
self.length = raw_data[1]
|
self.length = raw_data[1]
|
||||||
self.mbz = raw_data[2:self.tlv_length]
|
self.mbz = raw_data[2:len(self)]
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
def marshal(self) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
@ -190,7 +188,7 @@ class DataTLV(TLV):
|
||||||
self.length = raw_data[1]
|
self.length = raw_data[1]
|
||||||
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
|
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
|
||||||
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
|
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
|
||||||
self.data = raw_data[14:self.tlv_length]
|
self.data = raw_data[14:len(self)]
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
def marshal(self) -> bytes:
|
||||||
return self.type.to_bytes(1, sys.byteorder) + \
|
return self.type.to_bytes(1, sys.byteorder) + \
|
||||||
|
@ -291,7 +289,7 @@ class WarningTLV(TLV):
|
||||||
|
|
||||||
class Packet:
|
class Packet:
|
||||||
"""
|
"""
|
||||||
A Packet is a wrapper around the
|
A Packet is a wrapper around the raw data that it sent and received to other clients.
|
||||||
"""
|
"""
|
||||||
magic: int
|
magic: int
|
||||||
version: int
|
version: int
|
||||||
|
@ -331,7 +329,7 @@ class Packet:
|
||||||
tlv = TLV.tlv_classes()[tlv_type]()
|
tlv = TLV.tlv_classes()[tlv_type]()
|
||||||
tlv.unmarshal(data[4:4 + pkt.body_length])
|
tlv.unmarshal(data[4:4 + pkt.body_length])
|
||||||
pkt.body.append(tlv)
|
pkt.body.append(tlv)
|
||||||
read_bytes += tlv.tlv_length
|
read_bytes += len(tlv)
|
||||||
|
|
||||||
pkt.validate_data()
|
pkt.validate_data()
|
||||||
|
|
||||||
|
@ -347,6 +345,36 @@ class Packet:
|
||||||
data += b"".join(tlv.marshal() for tlv in self.body)
|
data += b"".join(tlv.marshal() for tlv in self.body)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the length, in bytes, of the packet.
|
||||||
|
"""
|
||||||
|
return 4 + sum(len(tlv) for tlv in self.body)
|
||||||
|
|
||||||
|
def split(self, pkt_size: int) -> List["Packet"]:
|
||||||
|
"""
|
||||||
|
If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024),
|
||||||
|
then we split the packet in sub-packets.
|
||||||
|
Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet,
|
||||||
|
then we don't need to split TLVs in smaller TLVs.
|
||||||
|
"""
|
||||||
|
packets = []
|
||||||
|
|
||||||
|
current_size = 4 # Packet header length
|
||||||
|
body = []
|
||||||
|
for tlv in self.body:
|
||||||
|
if current_size + len(tlv) > pkt_size:
|
||||||
|
packets.append(Packet.construct(*body))
|
||||||
|
body.clear()
|
||||||
|
current_size = 4
|
||||||
|
body.append(tlv)
|
||||||
|
current_size += len(tlv)
|
||||||
|
|
||||||
|
if body:
|
||||||
|
packets.append(Packet.construct(*body))
|
||||||
|
|
||||||
|
return packets
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def construct(*tlvs: TLV) -> "Packet":
|
def construct(*tlvs: TLV) -> "Packet":
|
||||||
"""
|
"""
|
||||||
|
@ -356,5 +384,5 @@ class Packet:
|
||||||
pkt.magic = 95
|
pkt.magic = 95
|
||||||
pkt.version = 0
|
pkt.version = 0
|
||||||
pkt.body = tlvs
|
pkt.body = tlvs
|
||||||
pkt.body_length = sum(tlv.tlv_length for tlv in tlvs)
|
pkt.body_length = sum(len(tlv) for tlv in tlvs)
|
||||||
return pkt
|
return pkt
|
||||||
|
|
Loading…
Reference in New Issue