diff --git a/squinnondation/squinnondation.py b/squinnondation/squinnondation.py index d916c5f..23bb01c 100644 --- a/squinnondation/squinnondation.py +++ b/squinnondation/squinnondation.py @@ -167,16 +167,7 @@ class Squinnondation: squirrel.add_message(msg) for hazelnut in list(squirrel.hazelnuts.values()): - pkt = Packet() - pkt.magic = 95 - pkt.version = 0 - tlv = DataTLV() - tlv.data = msg.encode("UTF-8") - tlv.sender_id = 42 - tlv.nonce = 18 - tlv.length = len(tlv.data) + 1 + 1 + 8 + 4 - pkt.body = [tlv] - pkt.body_length = tlv.length + 2 + pkt = Packet.construct(DataTLV.construct(msg)) squirrel.send_packet(hazelnut, pkt) @@ -186,6 +177,7 @@ class TLV: TODO: add subclasses for each type of TLV """ type: int + length: int def unmarshal(self, raw_data: bytes) -> None: """ @@ -213,6 +205,14 @@ class TLV: It is ensured that the data is valid. """ + @property + def tlv_length(self) -> int: + """ + Returns the total length (in bytes) of the TLV, including the type and the length. + Except for Pad1, this is 2 plus the length of the body of the TLV. + """ + return 2 + self.length + @staticmethod def tlv_classes(): return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] @@ -241,6 +241,13 @@ class Pad1TLV(TLV): squirrel.add_system_message(f"For each byte in the packet that I received, you will die today. " "And eat cookies.") + @property + def tlv_length(self) -> int: + """ + A Pad1 has always a length of 1. + """ + return 1 + class PadNTLV(TLV): """ @@ -261,7 +268,7 @@ class PadNTLV(TLV): """ self.type = raw_data[0] self.length = raw_data[1] - self.mbz = raw_data[2:2 + self.length] + self.mbz = raw_data[2:self.tlv_length] def marshal(self) -> bytes: """ @@ -345,7 +352,7 @@ class DataTLV(TLV): self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], "big") self.nonce = int.from_bytes(raw_data[10:14], "big") - self.data = raw_data[14:2 + self.length] + self.data = raw_data[14:self.tlv_length] def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ @@ -361,6 +368,16 @@ class DataTLV(TLV): """ squirrel.add_message(self.data.decode('UTF-8')) + @staticmethod + def construct(message: str) -> "DataTLV": + tlv = DataTLV() + tlv.type = 4 + tlv.sender_id = 42 # FIXME Use the good sender id + tlv.nonce = 42 # FIXME Use an incremental nonce + tlv.data = message.encode("UTF-8") + tlv.length = 12 + len(tlv.data) + return tlv + class AckTLV(TLV): type: int = 5 @@ -476,8 +493,7 @@ class Packet: tlv = TLV.tlv_classes()[tlv_type]() tlv.unmarshal(data[4:4 + pkt.body_length]) pkt.body.append(tlv) - # Pad1TLV has no length - read_bytes += 1 if tlv_type == 0 else tlv.length + 2 + read_bytes += tlv.tlv_length pkt.validate_data() @@ -493,6 +509,18 @@ class Packet: data += b"".join(tlv.marshal() for tlv in self.body) return data + @staticmethod + def construct(*tlvs: TLV) -> "Packet": + """ + Construct a new packet from the given TLVs and calculate the good lengths + """ + pkt = Packet() + pkt.magic = 95 + pkt.version = 0 + pkt.body = tlvs + pkt.body_length = sum(tlv.tlv_length for tlv in tlvs) + return pkt + class Hazelnut: """