diff --git a/squinnondation/squinnondation.py b/squinnondation/squinnondation.py index 4ea5c5f..740c94b 100644 --- a/squinnondation/squinnondation.py +++ b/squinnondation/squinnondation.py @@ -6,7 +6,7 @@ from argparse import ArgumentParser from enum import Enum from ipaddress import IPv6Address from threading import Thread -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple class Squinnondation: @@ -50,13 +50,14 @@ class Squinnondation: pkt = Packet() pkt.magic = 95 pkt.version = 0 - pkt.body = DataTLV() + tlv = DataTLV() msg = f"Hello world, my name is {squirrel.nickname}!" - pkt.body.data = msg.encode("UTF-8") - pkt.body.sender_id = 42 - pkt.body.nonce = 18 - pkt.body.length = len(msg) + 1 + 1 + 8 + 4 - pkt.body_length = pkt.body.length + 2 + tlv.data = msg.encode("UTF-8") + tlv.sender_id = 42 + tlv.nonce = 18 + tlv.length = len(msg) + 1 + 1 + 8 + 4 + pkt.body = [tlv] + pkt.body_length = tlv.length + 2 squirrel.send_packet(hazelnut, pkt) Worm(squirrel).start() @@ -275,7 +276,7 @@ class Packet: magic: int version: int body_length: int - body: TLV + body: List[TLV] def validate_data(self) -> bool: """ @@ -289,7 +290,7 @@ class Packet: if not (0 <= self.body_length <= 120): raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," "found: {:d}".format(self.body_length)) - return self.body.validate_data() + return all(tlv.validate_data() for tlv in self.body) @staticmethod def unmarshal(data: bytes) -> "Packet": @@ -301,11 +302,17 @@ class Packet: pkt.magic = data[0] pkt.version = data[1] pkt.body_length = int.from_bytes(data[2:4], byteorder="big") - tlv_type = data[4] - if not (0 <= tlv_type < len(TLV.tlv_classes())): - raise ValueError(f"TLV type is not supported: {tlv_type}") - pkt.body = TLV.tlv_classes()[tlv_type]() - pkt.body.unmarshal(data[4:4 + pkt.body_length]) + pkt.body = [] + read_bytes = 0 + while read_bytes <= min(len(data) - 4, pkt.body_length): + tlv_type = data[4] + if not (0 <= tlv_type < len(TLV.tlv_classes())): + raise ValueError(f"TLV type is not supported: {tlv_type}") + tlv = TLV.tlv_classes()[tlv_type]() + tlv.unmarshal(data[4:4 + pkt.body_length]) + pkt.body.append(tlv) + # Pad1TLV has no length + read_bytes += 1 if tlv_type == 0 else tlv.length + 2 pkt.validate_data() @@ -318,7 +325,7 @@ class Packet: data = self.magic.to_bytes(1, "big") data += self.version.to_bytes(1, "big") data += self.body_length.to_bytes(2, "big") - data += self.body.marshal() + data += b"".join(tlv.marshal() for tlv in self.body) return data @@ -410,15 +417,17 @@ class Worm(Thread): except ValueError as error: print("An error occured while receiving a packet: ", error) else: - print(pkt.body.data.decode('UTF-8')) + print(pkt.body[0].data.decode('UTF-8')) pkt = Packet() pkt.magic = 95 pkt.version = 0 - pkt.body = DataTLV() + pkt.body = [] + tlv = DataTLV() msg = f"Hello my dear hazelnut, I am {self.squirrel.nickname}!" - pkt.body.data = msg.encode("UTF-8") - pkt.body.sender_id = 42 - pkt.body.nonce = 18 - pkt.body.length = len(msg) + 1 + 1 + 8 + 4 - pkt.body_length = pkt.body.length + 2 + tlv.data = msg.encode("UTF-8") + tlv.sender_id = 42 + tlv.nonce = 18 + tlv.length = len(msg) + 1 + 1 + 8 + 4 + pkt.body.append(tlv) + pkt.body_length = tlv.length + 2 self.squirrel.send_packet(hazelnut, pkt)