diff --git a/squinnondation/squinnondation.py b/squinnondation/squinnondation.py index 7946ef8..3ef127c 100644 --- a/squinnondation/squinnondation.py +++ b/squinnondation/squinnondation.py @@ -3,7 +3,8 @@ import socket from argparse import ArgumentParser -from typing import Any, Tuple +from enum import Enum +from typing import Any, Optional, Tuple class Squinnondation: @@ -47,15 +48,18 @@ class Squinnondation: pkt = Packet() pkt.magic = 95 pkt.version = 0 - pkt.body = TLV() + pkt.body = DataTLV() msg = f"Hello world, my name is {squirrel.nickname}!" - pkt.body.raw_data = msg.encode("UTF-8") - pkt.body_length = len(pkt.body.raw_data) + pkt.body.data = msg.encode("UTF-8") + pkt.body.sender_id = 42 + pkt.body.nonce = 18 + pkt.body.length = len(msg) + 1 + 1 + 8 + 4 + pkt.body_length = pkt.body.length + 2 squirrel.send_packet(hazelnut, pkt) while True: pkt, addr = squirrel.receive_packet() - print(f"received message: {pkt.body.raw_data.decode('UTF-8')}") + print(f"received message: {pkt.body.data.decode('UTF-8')}") class TLV: @@ -64,7 +68,18 @@ class TLV: TODO: add subclasses for each type of TLV """ type: int - raw_data: bytes + + def unmarshal(self, raw_data: bytes) -> None: + """ + Parse data and construct TLV. + """ + raise NotImplementedError + + def marshal(self) -> bytes: + """ + Translate the TLV into a byte array. + """ + raise NotImplementedError def validate_data(self) -> bool: """ @@ -74,6 +89,184 @@ class TLV: """ return True + @staticmethod + def tlv_classes(): + return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] + + +class Pad1TLV(TLV): + """ + This TLV is simply ignored. + """ + type: int = 0 + + def unmarshal(self, raw_data: bytes) -> None: + """ + There is nothing to do. We ignore the packet. + """ + self.type = raw_data[0] + + def marshal(self) -> bytes: + """ + The TLV is empty. + """ + return self.type.to_bytes(1, "big") + + +class PadNTLV(TLV): + """ + This TLV is filled with zeros. It is ignored. + """ + type: int = 1 + length: int + mbz: bytes + + def validate_data(self) -> bool: + if self.mbz != int(0).to_bytes(self.length, "big"): + raise ValueError("The body of a PadN TLV is not filled with zeros.") + return True + + def unmarshal(self, raw_data: bytes) -> None: + """ + Store the zero-array, then ignore the packet. + """ + self.type = raw_data[0] + self.length = raw_data[1] + self.mbz = raw_data[2:2 + self.length] + + def marshal(self) -> bytes: + """ + Construct the byte array filled by zeros. + """ + return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length] + + +class HelloTLV(TLV): + type: int = 2 + length: int + source_id: int + dest_id: Optional[int] + + def validate_data(self) -> bool: + if self.length != 8 and self.length != 16: + raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," + f"found {self.length}") + return True + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.source_id = int.from_bytes(raw_data[2:10], "big") + if self.length == 16: + self.dest_id = int.from_bytes(raw_data[10:18], "big") + + def marshal(self) -> bytes: + data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big") + if self.dest_id: + data += self.dest_id.to_bytes(8, "big") + return data + + +class NeighbourTLV(TLV): + type: int = 3 + length: int + ip_address: int + port: int + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.ip_address = raw_data[2:18] + self.port = int.from_bytes(raw_data[18:20], "big") + + def marshal(self) -> bytes: + return self.type.to_bytes(1, "big") + \ + self.length.to_bytes(1, "big") + \ + self.ip_address.to_bytes(16, "big") + \ + self.port.to_bytes(2, "big") + + +class DataTLV(TLV): + type: int = 4 + length: int + sender_id: int + nonce: int + data: bytes + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.sender_id = int.from_bytes(raw_data[2:10], "big") + self.nonce = int.from_bytes(raw_data[10:14], "big") + self.data = raw_data[14:2 + self.length] + + def marshal(self) -> bytes: + return self.type.to_bytes(1, "big") + \ + self.length.to_bytes(1, "big") + \ + self.sender_id.to_bytes(8, "big") + \ + self.nonce.to_bytes(4, "big") + \ + self.data + + +class AckTLV(TLV): + type: int = 5 + length: int + sender_id: int + nonce: int + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.sender_id = int.from_bytes(raw_data[2:10], "big") + self.nonce = int.from_bytes(raw_data[10:14], "big") + + def marshal(self) -> bytes: + return self.type.to_bytes(1, "big") + \ + self.length.to_bytes(1, "big") + \ + self.sender_id.to_bytes(8, "big") + \ + self.nonce.to_bytes(4, "big") + + +class GoAwayTLV(TLV): + class GoAwayType(Enum): + UNKNOWN = 0 + EXIT = 1 + TIMEOUT = 2 + PROTOCOL_VIOLATION = 3 + + type: int = 6 + length: int + code: GoAwayType + message: str + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.code = GoAwayTLV.GoAwayType(raw_data[2]) + self.message = raw_data[3:self.length - 1].decode("UTF-8") + + def marshal(self) -> bytes: + return self.type.to_bytes(1, "big") + \ + self.length.to_bytes(1, "big") + \ + self.code.value.to_bytes(1, "big") + \ + self.message.encode("UTF-8")[:self.length - 1] + + +class WarningTLV(TLV): + type: int = 7 + length: int + message: str + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.message = raw_data[2:self.length].decode("UTF-8") + + def marshal(self) -> bytes: + return self.type.to_bytes(1, "big") + \ + self.length.to_bytes(1, "big") + \ + self.message.encode("UTF-8")[:self.length] + class Packet: """ @@ -108,8 +301,11 @@ class Packet: pkt.magic = data[0] pkt.version = data[1] pkt.body_length = int.from_bytes(data[2:4], byteorder="big") - pkt.body = TLV() - pkt.body.raw_data = data[4:4+pkt.body_length] + tlv_type = data[4] + if not (0 <= tlv_type < len(TLV.tlv_classes())): + raise ValueError(f"TLV type is not supported: {tlv_type}") + pkt.body = TLV.tlv_classes()[tlv_type]() + pkt.body.unmarshal(data[4:4+pkt.body_length]) pkt.validate_data() @@ -122,7 +318,7 @@ class Packet: data = self.magic.to_bytes(1, "big") data += self.version.to_bytes(1, "big") data += self.body_length.to_bytes(2, "big") - data += self.body.raw_data + data += self.body.marshal() return data