More abstraction on packet building
Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
		@@ -167,16 +167,7 @@ class Squinnondation:
 | 
				
			|||||||
                squirrel.add_message(msg)
 | 
					                squirrel.add_message(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                for hazelnut in list(squirrel.hazelnuts.values()):
 | 
					                for hazelnut in list(squirrel.hazelnuts.values()):
 | 
				
			||||||
                    pkt = Packet()
 | 
					                    pkt = Packet.construct(DataTLV.construct(msg))
 | 
				
			||||||
                    pkt.magic = 95
 | 
					 | 
				
			||||||
                    pkt.version = 0
 | 
					 | 
				
			||||||
                    tlv = DataTLV()
 | 
					 | 
				
			||||||
                    tlv.data = msg.encode("UTF-8")
 | 
					 | 
				
			||||||
                    tlv.sender_id = 42
 | 
					 | 
				
			||||||
                    tlv.nonce = 18
 | 
					 | 
				
			||||||
                    tlv.length = len(tlv.data) + 1 + 1 + 8 + 4
 | 
					 | 
				
			||||||
                    pkt.body = [tlv]
 | 
					 | 
				
			||||||
                    pkt.body_length = tlv.length + 2
 | 
					 | 
				
			||||||
                    squirrel.send_packet(hazelnut, pkt)
 | 
					                    squirrel.send_packet(hazelnut, pkt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -186,6 +177,7 @@ class TLV:
 | 
				
			|||||||
    TODO: add subclasses for each type of TLV
 | 
					    TODO: add subclasses for each type of TLV
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    type: int
 | 
					    type: int
 | 
				
			||||||
 | 
					    length: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def unmarshal(self, raw_data: bytes) -> None:
 | 
					    def unmarshal(self, raw_data: bytes) -> None:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -213,6 +205,14 @@ class TLV:
 | 
				
			|||||||
        It is ensured that the data is valid.
 | 
					        It is ensured that the data is valid.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def tlv_length(self) -> int:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns the total length (in bytes) of the TLV, including the type and the length.
 | 
				
			||||||
 | 
					        Except for Pad1, this is 2 plus the length of the body of the TLV.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return 2 + self.length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def tlv_classes():
 | 
					    def tlv_classes():
 | 
				
			||||||
        return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
 | 
					        return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
 | 
				
			||||||
@@ -241,6 +241,13 @@ class Pad1TLV(TLV):
 | 
				
			|||||||
        squirrel.add_system_message(f"For each byte in the packet that I received, you will die today. "
 | 
					        squirrel.add_system_message(f"For each byte in the packet that I received, you will die today. "
 | 
				
			||||||
                                    "And eat cookies.")
 | 
					                                    "And eat cookies.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def tlv_length(self) -> int:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        A Pad1 has always a length of 1.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PadNTLV(TLV):
 | 
					class PadNTLV(TLV):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@@ -261,7 +268,7 @@ class PadNTLV(TLV):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        self.type = raw_data[0]
 | 
					        self.type = raw_data[0]
 | 
				
			||||||
        self.length = raw_data[1]
 | 
					        self.length = raw_data[1]
 | 
				
			||||||
        self.mbz = raw_data[2:2 + self.length]
 | 
					        self.mbz = raw_data[2:self.tlv_length]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def marshal(self) -> bytes:
 | 
					    def marshal(self) -> bytes:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -345,7 +352,7 @@ class DataTLV(TLV):
 | 
				
			|||||||
        self.length = raw_data[1]
 | 
					        self.length = raw_data[1]
 | 
				
			||||||
        self.sender_id = int.from_bytes(raw_data[2:10], "big")
 | 
					        self.sender_id = int.from_bytes(raw_data[2:10], "big")
 | 
				
			||||||
        self.nonce = int.from_bytes(raw_data[10:14], "big")
 | 
					        self.nonce = int.from_bytes(raw_data[10:14], "big")
 | 
				
			||||||
        self.data = raw_data[14:2 + self.length]
 | 
					        self.data = raw_data[14:self.tlv_length]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def marshal(self) -> bytes:
 | 
					    def marshal(self) -> bytes:
 | 
				
			||||||
        return self.type.to_bytes(1, "big") + \
 | 
					        return self.type.to_bytes(1, "big") + \
 | 
				
			||||||
@@ -361,6 +368,16 @@ class DataTLV(TLV):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        squirrel.add_message(self.data.decode('UTF-8'))
 | 
					        squirrel.add_message(self.data.decode('UTF-8'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def construct(message: str) -> "DataTLV":
 | 
				
			||||||
 | 
					        tlv = DataTLV()
 | 
				
			||||||
 | 
					        tlv.type = 4
 | 
				
			||||||
 | 
					        tlv.sender_id = 42  # FIXME Use the good sender id
 | 
				
			||||||
 | 
					        tlv.nonce = 42  # FIXME Use an incremental nonce
 | 
				
			||||||
 | 
					        tlv.data = message.encode("UTF-8")
 | 
				
			||||||
 | 
					        tlv.length = 12 + len(tlv.data)
 | 
				
			||||||
 | 
					        return tlv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AckTLV(TLV):
 | 
					class AckTLV(TLV):
 | 
				
			||||||
    type: int = 5
 | 
					    type: int = 5
 | 
				
			||||||
@@ -476,8 +493,7 @@ class Packet:
 | 
				
			|||||||
            tlv = TLV.tlv_classes()[tlv_type]()
 | 
					            tlv = TLV.tlv_classes()[tlv_type]()
 | 
				
			||||||
            tlv.unmarshal(data[4:4 + pkt.body_length])
 | 
					            tlv.unmarshal(data[4:4 + pkt.body_length])
 | 
				
			||||||
            pkt.body.append(tlv)
 | 
					            pkt.body.append(tlv)
 | 
				
			||||||
            # Pad1TLV has no length
 | 
					            read_bytes += tlv.tlv_length
 | 
				
			||||||
            read_bytes += 1 if tlv_type == 0 else tlv.length + 2
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        pkt.validate_data()
 | 
					        pkt.validate_data()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -493,6 +509,18 @@ class Packet:
 | 
				
			|||||||
        data += b"".join(tlv.marshal() for tlv in self.body)
 | 
					        data += b"".join(tlv.marshal() for tlv in self.body)
 | 
				
			||||||
        return data
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def construct(*tlvs: TLV) -> "Packet":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Construct a new packet from the given TLVs and calculate the good lengths
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        pkt = Packet()
 | 
				
			||||||
 | 
					        pkt.magic = 95
 | 
				
			||||||
 | 
					        pkt.version = 0
 | 
				
			||||||
 | 
					        pkt.body = tlvs
 | 
				
			||||||
 | 
					        pkt.body_length = sum(tlv.tlv_length for tlv in tlvs)
 | 
				
			||||||
 | 
					        return pkt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Hazelnut:
 | 
					class Hazelnut:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user