If the total packet size is larger than 1024, then it is splitted into subpackets
Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
parent
bda4860aca
commit
b96ff488e7
|
@ -75,6 +75,9 @@ class Squirrel(Hazelnut):
|
|||
"""
|
||||
Send a formatted packet to a client.
|
||||
"""
|
||||
if len(pkt) > 1024:
|
||||
# The packet is too large to be sent by the protocol. We split the packet in subpackets.
|
||||
return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024))
|
||||
return self.send_raw_data(client, pkt.marshal())
|
||||
|
||||
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
|
||||
|
|
|
@ -42,8 +42,7 @@ class TLV:
|
|||
It is ensured that the data is valid.
|
||||
"""
|
||||
|
||||
@property
|
||||
def tlv_length(self) -> int:
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the total length (in bytes) of the TLV, including the type and the length.
|
||||
Except for Pad1, this is 2 plus the length of the body of the TLV.
|
||||
|
@ -77,8 +76,7 @@ class Pad1TLV(TLV):
|
|||
# TODO Add some easter eggs
|
||||
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
|
||||
|
||||
@property
|
||||
def tlv_length(self) -> int:
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
A Pad1 has always a length of 1.
|
||||
"""
|
||||
|
@ -104,7 +102,7 @@ class PadNTLV(TLV):
|
|||
"""
|
||||
self.type = raw_data[0]
|
||||
self.length = raw_data[1]
|
||||
self.mbz = raw_data[2:self.tlv_length]
|
||||
self.mbz = raw_data[2:len(self)]
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
"""
|
||||
|
@ -190,7 +188,7 @@ class DataTLV(TLV):
|
|||
self.length = raw_data[1]
|
||||
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
|
||||
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
|
||||
self.data = raw_data[14:self.tlv_length]
|
||||
self.data = raw_data[14:len(self)]
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
return self.type.to_bytes(1, sys.byteorder) + \
|
||||
|
@ -291,7 +289,7 @@ class WarningTLV(TLV):
|
|||
|
||||
class Packet:
|
||||
"""
|
||||
A Packet is a wrapper around the
|
||||
A Packet is a wrapper around the raw data that it sent and received to other clients.
|
||||
"""
|
||||
magic: int
|
||||
version: int
|
||||
|
@ -331,7 +329,7 @@ class Packet:
|
|||
tlv = TLV.tlv_classes()[tlv_type]()
|
||||
tlv.unmarshal(data[4:4 + pkt.body_length])
|
||||
pkt.body.append(tlv)
|
||||
read_bytes += tlv.tlv_length
|
||||
read_bytes += len(tlv)
|
||||
|
||||
pkt.validate_data()
|
||||
|
||||
|
@ -347,6 +345,36 @@ class Packet:
|
|||
data += b"".join(tlv.marshal() for tlv in self.body)
|
||||
return data
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Calculates the length, in bytes, of the packet.
|
||||
"""
|
||||
return 4 + sum(len(tlv) for tlv in self.body)
|
||||
|
||||
def split(self, pkt_size: int) -> List["Packet"]:
|
||||
"""
|
||||
If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024),
|
||||
then we split the packet in sub-packets.
|
||||
Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet,
|
||||
then we don't need to split TLVs in smaller TLVs.
|
||||
"""
|
||||
packets = []
|
||||
|
||||
current_size = 4 # Packet header length
|
||||
body = []
|
||||
for tlv in self.body:
|
||||
if current_size + len(tlv) > pkt_size:
|
||||
packets.append(Packet.construct(*body))
|
||||
body.clear()
|
||||
current_size = 4
|
||||
body.append(tlv)
|
||||
current_size += len(tlv)
|
||||
|
||||
if body:
|
||||
packets.append(Packet.construct(*body))
|
||||
|
||||
return packets
|
||||
|
||||
@staticmethod
|
||||
def construct(*tlvs: TLV) -> "Packet":
|
||||
"""
|
||||
|
@ -356,5 +384,5 @@ class Packet:
|
|||
pkt.magic = 95
|
||||
pkt.version = 0
|
||||
pkt.body = tlvs
|
||||
pkt.body_length = sum(tlv.tlv_length for tlv in tlvs)
|
||||
pkt.body_length = sum(len(tlv) for tlv in tlvs)
|
||||
return pkt
|
||||
|
|
Loading…
Reference in New Issue