# Copyright (C) 2020 by eichhornchen, ÿnérant # SPDX-License-Identifier: GPL-3.0-or-later import re from typing import Any, List, Optional from ipaddress import IPv6Address from enum import Enum import socket import sys class TLV: """ The Tag-Length-Value contains the different type of data that can be sent. TODO: add subclasses for each type of TLV """ type: int length: int def unmarshal(self, raw_data: bytes) -> None: """ Parse data and construct TLV. """ raise NotImplementedError def marshal(self) -> bytes: """ Translate the TLV into a byte array. """ raise NotImplementedError def validate_data(self) -> bool: """ Ensure that the TLV is well-formed. Raises a ValueError if it is not the case. TODO: Make some tests """ return True def handle(self, squirrel: Any, sender: Any) -> None: """ Indicates what to do when this TLV is received from a given hazel. It is ensured that the data is valid. """ def __len__(self) -> int: """ Returns the total length (in bytes) of the TLV, including the type and the length. Except for Pad1, this is 2 plus the length of the body of the TLV. """ return 2 + self.length @staticmethod def tlv_classes() -> list: return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] class Pad1TLV(TLV): """ This TLV is simply ignored. """ type: int = 0 def unmarshal(self, raw_data: bytes) -> None: """ There is nothing to do. We ignore the packet. """ self.type = raw_data[0] def marshal(self) -> bytes: """ The TLV is empty. """ return self.type.to_bytes(1, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: # TODO Add some easter eggs squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.") def __len__(self) -> int: """ A Pad1 has always a length of 1. """ return 1 class PadNTLV(TLV): """ This TLV is filled with zeros. It is ignored. """ type: int = 1 length: int mbz: bytes def validate_data(self) -> bool: if self.mbz != int(0).to_bytes(self.length, sys.byteorder): raise ValueError("The body of a PadN TLV is not filled with zeros.") return True def unmarshal(self, raw_data: bytes) -> None: """ Store the zero-array, then ignore the packet. """ self.type = raw_data[0] self.length = raw_data[1] self.mbz = raw_data[2:len(self)] def marshal(self) -> bytes: """ Construct the byte array filled by zeros. """ return self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \ + self.mbz[:self.length] def handle(self, squirrel: Any, sender: Any) -> None: # TODO Add some easter eggs squirrel.add_system_message(f"I received {self.length} zeros, am I so a bag guy ? :cold_sweat:") class HelloTLV(TLV): type: int = 2 length: int source_id: int dest_id: Optional[int] def validate_data(self) -> bool: if self.length != 8 and self.length != 16: raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," f"found {self.length}") return True def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.source_id = int.from_bytes(raw_data[2:10], sys.byteorder) if self.is_long: self.dest_id = int.from_bytes(raw_data[10:18], sys.byteorder) def marshal(self) -> bytes: data = self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \ + self.source_id.to_bytes(8, sys.byteorder) if self.dest_id: data += self.dest_id.to_bytes(8, sys.byteorder) return data def handle(self, squirrel: Any, sender: Any) -> None: # TODO Implement HelloTLV squirrel.add_system_message("Aaaawwww, someone spoke to me and said me Hello smiling_face_with_" + (":3_hearts:" if self.is_long else "smiling_eyes:")) @property def is_long(self) -> bool: return self.length == 16 class NeighbourTLV(TLV): type: int = 3 length: int ip_address: IPv6Address port: int def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.ip_address = IPv6Address(raw_data[2:18]) self.port = int.from_bytes(raw_data[18:20], sys.byteorder) def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.ip_address.packed + \ self.port.to_bytes(2, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: # TODO Implement NeighbourTLV squirrel.add_system_message("I have a friend!") squirrel.add_system_message(f"Welcome {self.ip_address}:{self.port}!") class DataTLV(TLV): type: int = 4 length: int sender_id: int nonce: int data: bytes def validate_data(self) -> bool: if len(self.data) >= 256 - 4 - 8: raise ValueError("The data is too long, the length is larger that one byte.") return True def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder) self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder)) self.data = raw_data[14:len(self)] def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.sender_id.to_bytes(8, sys.byteorder) + \ socket.htonl(self.nonce).to_bytes(4, sys.byteorder) + \ self.data def handle(self, squirrel: Any, sender: Any) -> None: """ A message has been sent. We log it. TODO: Check that the tuple (sender_id, nonce) is unique to avoid duplicates. """ msg = self.data.decode('UTF-8') if not squirrel.receive_message_from(msg, self.sender_id, self.nonce): # The message was already received return # Acknowledge the packet squirrel.send_packet(sender, Packet.construct(AckTLV.construct(self.sender_id, self.nonce))) nickname_match = re.match("(.*): (.*)", msg) if nickname_match is None: squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( "Unable to retrieve your username. Please use the syntax 'nickname: message'"))) else: nickname = nickname_match.group(1) if sender.nickname is None: sender.nickname = nickname elif sender.nickname != nickname: squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( "It seems that you used two different nicknames. " f"Known nickname: {sender.nickname}, found: {nickname}"))) sender.nickname = nickname @staticmethod def construct(message: str, squirrel: Any) -> "DataTLV": tlv = DataTLV() tlv.type = 4 tlv.sender_id = squirrel.id if squirrel else 0 tlv.nonce = squirrel.incr_nonce if squirrel else 0 tlv.data = message.encode("UTF-8") tlv.length = 12 + len(tlv.data) if squirrel: squirrel.incr_nonce += 1 return tlv class AckTLV(TLV): type: int = 5 length: int sender_id: int nonce: int def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder) self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder)) def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.sender_id.to_bytes(8, sys.byteorder) + \ socket.htonl(self.nonce).to_bytes(4, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: # TODO Implement AckTLV squirrel.add_system_message("I received an AckTLV. I don't know what to do with it. Please implement me!") @staticmethod def construct(sender_id: int, nonce: int) -> "AckTLV": tlv = AckTLV() tlv.type = 5 tlv.length = 12 tlv.sender_id = sender_id tlv.nonce = nonce return tlv class GoAwayTLV(TLV): class GoAwayType(Enum): UNKNOWN = 0 EXIT = 1 TIMEOUT = 2 PROTOCOL_VIOLATION = 3 type: int = 6 length: int code: GoAwayType message: str def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.code = GoAwayTLV.GoAwayType(raw_data[2]) self.message = raw_data[3:self.length - 1].decode("UTF-8") def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.code.value.to_bytes(1, sys.byteorder) + \ self.message.encode("UTF-8")[:self.length - 1] def handle(self, squirrel: Any, sender: Any) -> None: # TODO Implement GoAwayTLV squirrel.add_system_message("Some told me that he went away. That's not very nice :( " "I should send him some cake.") class WarningTLV(TLV): type: int = 7 length: int message: str def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.message = raw_data[2:self.length + 2].decode("UTF-8") def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.message.encode("UTF-8")[:self.length] def handle(self, squirrel: Any, sender: Any) -> None: squirrel.add_message(f"warning: *A client warned you: {self.message}*" if not squirrel.squinnondation.no_markdown else f"warning: A client warned you: {self.message}") @staticmethod def construct(message: str) -> "WarningTLV": tlv = WarningTLV() tlv.type = 7 tlv.message = message tlv.length = len(tlv.message.encode("UTF-8")) return tlv class Packet: """ A Packet is a wrapper around the raw data that it sent and received to other clients. """ magic: int version: int body_length: int body: List[TLV] def validate_data(self) -> bool: """ Ensure that the packet is well-formed. Raises a ValueError if the packet contains bad data. """ if self.magic != 95: raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) if self.version != 0: raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) if not (0 <= self.body_length <= 1200): raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," "found: {:d}".format(self.body_length)) return all(tlv.validate_data() for tlv in self.body) @staticmethod def unmarshal(data: bytes) -> "Packet": """ Read raw data and build the packet wrapper. Raises a ValueError whenever the data is invalid. """ pkt = Packet() pkt.magic = data[0] pkt.version = data[1] pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder)) pkt.body = [] read_bytes = 0 while read_bytes < min(len(data) - 4, pkt.body_length): tlv_type = data[4 + read_bytes] if not (0 <= tlv_type < len(TLV.tlv_classes())): raise ValueError(f"TLV type is not supported: {tlv_type}") tlv = TLV.tlv_classes()[tlv_type]() tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length]) pkt.body.append(tlv) read_bytes += len(tlv) pkt.validate_data() return pkt def marshal(self) -> bytes: """ Compute the byte array data associated to the packet. """ data = self.magic.to_bytes(1, sys.byteorder) data += self.version.to_bytes(1, sys.byteorder) data += socket.htons(self.body_length).to_bytes(2, sys.byteorder) data += b"".join(tlv.marshal() for tlv in self.body) return data def __len__(self) -> int: """ Calculates the length, in bytes, of the packet. """ return 4 + sum(len(tlv) for tlv in self.body) def split(self, pkt_size: int) -> List["Packet"]: """ If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024), then we split the packet in sub-packets. Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet, then we don't need to split TLVs in smaller TLVs. """ packets = [] current_size = 4 # Packet header length body = [] for tlv in self.body: if current_size + len(tlv) > pkt_size: packets.append(Packet.construct(*body)) body.clear() current_size = 4 body.append(tlv) current_size += len(tlv) if body: packets.append(Packet.construct(*body)) return packets @staticmethod def construct(*tlvs: TLV) -> "Packet": """ Construct a new packet from the given TLVs and calculate the good lengths """ pkt = Packet() pkt.magic = 95 pkt.version = 0 pkt.body = tlvs pkt.body_length = sum(len(tlv) for tlv in tlvs) return pkt