# Copyright (C) 2020 by eichhornchen, ÿnérant # SPDX-License-Identifier: GPL-3.0-or-later import re from typing import Any, List, Optional from ipaddress import IPv6Address from enum import Enum import socket import sys import time class TLV: """ The Tag-Length-Value contains the different type of data that can be sent. """ type: int length: int def unmarshal(self, raw_data: bytes) -> None: """ Parse data and construct TLV. """ raise NotImplementedError def marshal(self) -> bytes: """ Translate the TLV into a byte array. """ raise NotImplementedError def validate_data(self) -> bool: """ Ensure that the TLV is well-formed. Raises a ValueError if it is not the case. """ return True def handle(self, user: Any, sender: Any) -> None: """ Indicates what to do when this TLV is received from a given peer. It is ensured that the data is valid. """ def __len__(self) -> int: """ Returns the total length (in bytes) of the TLV, including the type and the length. Except for Pad1, this is 2 plus the length of the body of the TLV. """ return 2 + self.length @staticmethod def tlv_classes() -> list: return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] class Pad1TLV(TLV): """ This TLV is simply ignored. """ type: int = 0 def unmarshal(self, raw_data: bytes) -> None: """ There is nothing to do. We ignore the packet. """ self.type = raw_data[0] def marshal(self) -> bytes: """ The TLV is empty. """ return self.type.to_bytes(1, sys.byteorder) def handle(self, user: Any, sender: Any) -> None: if not sender.active or not sender.symmetric or not sender.id: # It doesn't say hello, we don't listen to it user.send_packet(sender, Packet.construct(WarningTLV.construct( "You are not my neighbour, I won't listen to your Pad1TLV. Please say Hello to me before."))) return user.add_system_message("I received a Pad1TLV, how disapointing.") def __len__(self) -> int: """ A Pad1 has always a length of 1. """ return 1 @staticmethod def construct() -> "Pad1TLV": tlv = Pad1TLV() tlv.type = 0 return tlv class PadNTLV(TLV): """ This TLV is filled with zeros. It is ignored. """ type: int = 1 length: int mbz: bytes def validate_data(self) -> bool: if self.mbz != int(0).to_bytes(self.length, sys.byteorder): raise ValueError("The body of a PadN TLV is not filled with zeros.") return True def unmarshal(self, raw_data: bytes) -> None: """ Store the zero-array, then ignore the packet. """ self.type = raw_data[0] self.length = raw_data[1] self.mbz = raw_data[2:len(self)] def marshal(self) -> bytes: """ Construct the byte array filled by zeros. """ return self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \ + self.mbz[:self.length] def handle(self, user: Any, sender: Any) -> None: if not sender.active or not sender.symmetric or not sender.id: # It doesn't say hello, we don't listen to it user.send_packet(sender, Packet.construct(WarningTLV.construct( "You are not my neighbour, I won't listen to your PadNTLV. Please say Hello to me before."))) return user.add_system_message(f"I received {self.length} zeros.") @staticmethod def construct(length: int) -> "PadNTLV": tlv = PadNTLV() tlv.type = 1 tlv.length = length tlv.mbz = b'0' * length return tlv class HelloTLV(TLV): type: int = 2 length: int source_id: int dest_id: Optional[int] def validate_data(self) -> bool: if self.length != 8 and self.length != 16: raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," f"found {self.length}") return True def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.source_id = int.from_bytes(raw_data[2:10], sys.byteorder) if self.is_long: self.dest_id = int.from_bytes(raw_data[10:18], sys.byteorder) def marshal(self) -> bytes: data = self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \ + self.source_id.to_bytes(8, sys.byteorder) if self.dest_id: data += self.dest_id.to_bytes(8, sys.byteorder) return data def handle(self, user: Any, sender: Any) -> None: time_h = time.time() if sender.id > 0 and sender.id != self.source_id: user.send_packet(sender, Packet.construct(WarningTLV.construct( f"You were known as the ID {sender.id}, but you declared that you have the ID {self.source_id}."))) user.add_system_message(f"A client known as the id {sender.id} declared that it uses " f"the id {self.source_id}.") sender.id = self.source_id if not sender.active: sender.id = self.source_id # The sender we are given misses an id time_hl = time.time() else: time_hl = sender.last_long_hello_time if self.is_long and self.dest_id == user.id: time_hl = time.time() # Add entry to/actualize the active peers dictionnary sender.last_hello_time = time_h sender.last_long_hello_time = time_hl sender.symmetric = True sender.active = True user.update_peer_table(sender) user.nbNS += 1 user.add_system_message(f"{self.source_id} sent me a Hello " + ("long" if self.is_long else "short")) if not self.is_long: user.send_packet(sender, Packet.construct(HelloTLV.construct(16, user, sender))) @property def is_long(self) -> bool: return self.length == 16 @staticmethod def construct(length: int, user: Any, destination: Any = None) -> "HelloTLV": tlv = HelloTLV() tlv.type = 2 tlv.source_id = user.id if user else 0 if (destination is None) or destination.id == -1 or length == 8: tlv.length = 8 tlv.dest_id = None # if the destination id is not known, or # if the destination was not precised, send a short hello else: tlv.length = 16 tlv.dest_id = destination.id return tlv class NeighbourTLV(TLV): type: int = 3 length: int ip_address: IPv6Address port: int def validate_data(self) -> bool: if not (1 <= self.port <= 65535): raise ValueError(f"Invalid port received in NeighbourTLV: {self.port}") return True def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.ip_address = IPv6Address(raw_data[2:18]) self.port = int.from_bytes(raw_data[18:20], sys.byteorder) def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.ip_address.packed + \ self.port.to_bytes(2, sys.byteorder) def handle(self, user: Any, sender: Any) -> None: if not sender.active or not sender.symmetric or not sender.id: # It doesn't say hello, we don't listen to it user.send_packet(sender, Packet.construct(WarningTLV.construct( "You are not my neighbour, I won't listen to your NeighbourTLV. Please say Hello to me before."))) return if (self.ip_address, self.port) in user.addresses: # This case should never happen (and in our protocol it is not possible), # but we include this test as a security measure. return if not (str(self.ip_address), self.port) in user.neighbours: peer = user.new_peer(str(self.ip_address), self.port) peer.potential = True user.update_peer_table(peer) # user.add_system_message(f"New potential friend {self.ip_address}:{self.port}!") @staticmethod def construct(address: str, port: int) -> "NeighbourTLV": tlv = NeighbourTLV() tlv.type = 3 tlv.length = 18 tlv.ip_address = IPv6Address(address) tlv.port = port return tlv class DataTLV(TLV): type: int = 4 length: int sender_id: int nonce: int data: bytes def validate_data(self) -> bool: if len(self.data) >= 256 - 4 - 8: raise ValueError("The data is too long, the length is larger that one byte.") return True def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder) self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder)) self.data = raw_data[14:len(self)] if self.data[-1] == 0: self.data = self.data[:-1] self.length -= 1 def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.sender_id.to_bytes(8, sys.byteorder) + \ socket.htonl(self.nonce).to_bytes(4, sys.byteorder) + \ self.data def handle(self, user: Any, sender: Any) -> None: """ A message has been sent. We log it. """ if not sender.active or not sender.symmetric or not sender.id: # It doesn't say hello, we don't listen to it user.send_packet(sender, Packet.construct(WarningTLV.construct( "You are not my neighbour, I won't listen to your DataTLV. Please say Hello to me before."))) return if 0 in self.data: user.send_packet(user.find_peer_by_id(self.sender_id) or sender, Packet.construct(WarningTLV.construct( f"The length of your DataTLV mismatches. You told me that the length is {len(self.data)} " f"while a zero was found at index {self.data.index(0)}."))) self.data = self.data[:self.data.index(0)] msg = self.data.decode('UTF-8') # Acknowledge the packet user.send_packet(sender, Packet.construct(AckTLV.construct(self.sender_id, self.nonce))) if not user.receive_message_from(self, msg, self.sender_id, self.nonce, sender): # The message was already received, do not print it on screen user.add_system_message(f"I was inundated a message which I already knew {self.sender_id, self.nonce}") return nickname_match = re.match("(.*): (.*)", msg) if nickname_match is None: user.send_packet(sender, Packet.construct(WarningTLV.construct( "Unable to retrieve your username. Please use the syntax 'nickname: message'"))) else: nickname = nickname_match.group(1) author = user.find_peer_by_id(self.sender_id) if author: if author.nickname is None: author.nickname = nickname elif author.nickname != nickname: user.send_packet(author, Packet.construct(WarningTLV.construct( "It seems that you used two different nicknames. " f"Known nickname: {author.nickname}, found: {nickname}"))) author.nickname = nickname @staticmethod def construct(message: str, user: Any) -> "DataTLV": tlv = DataTLV() tlv.type = 4 tlv.sender_id = user.id if user else 0 tlv.nonce = user.incr_nonce if user else 0 tlv.data = message.encode("UTF-8") tlv.length = 12 + len(tlv.data) if user: user.incr_nonce += 1 return tlv class AckTLV(TLV): type: int = 5 length: int sender_id: int nonce: int def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder) self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder)) def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.sender_id.to_bytes(8, sys.byteorder) + \ socket.htonl(self.nonce).to_bytes(4, sys.byteorder) def handle(self, user: Any, sender: Any) -> None: """ When an AckTLV is received, we know that we do not have to inundate that neighbour anymore. """ if not sender.active or not sender.symmetric or not sender.id: # It doesn't say hello, we don't listen to it user.send_packet(sender, Packet.construct(WarningTLV.construct( "You are not my neighbour, I won't listen to your AckTLV. Please say Hello to me before."))) return user.add_system_message(f"I received an AckTLV from {sender}") user.remove_from_inundation(sender, self.sender_id, self.nonce) @staticmethod def construct(sender_id: int, nonce: int) -> "AckTLV": tlv = AckTLV() tlv.type = 5 tlv.length = 12 tlv.sender_id = sender_id tlv.nonce = nonce return tlv class GoAwayType(Enum): UNKNOWN = 0 EXIT = 1 TIMEOUT = 2 PROTOCOL_VIOLATION = 3 class GoAwayTLV(TLV): type: int = 6 length: int code: GoAwayType message: str def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.code = GoAwayType(raw_data[2]) self.message = raw_data[3:2 + self.length].decode("UTF-8") def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.code.value.to_bytes(1, sys.byteorder) + \ self.message.encode("UTF-8")[:self.length - 1] def handle(self, user: Any, sender: Any) -> None: if not sender.active or not sender.symmetric or not sender.id: # It doesn't say hello, we don't listen to it user.send_packet(sender, Packet.construct(WarningTLV.construct( "You are not my neighbour, I won't listen to your GoAwayTLV. Please say Hello to me before."))) return if sender.active: sender.active = False user.update_peer_table(sender) user.add_system_message("Someone told me that he went away : " + self.message) @staticmethod def construct(ga_type: GoAwayType, message: str) -> "GoAwayTLV": tlv = GoAwayTLV() tlv.type = 6 tlv.code = ga_type tlv.message = message tlv.length = 1 + len(tlv.message.encode("UTF-8")) return tlv class WarningTLV(TLV): type: int = 7 length: int message: str def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.message = raw_data[2:self.length + 2].decode("UTF-8") def marshal(self) -> bytes: return self.type.to_bytes(1, sys.byteorder) + \ self.length.to_bytes(1, sys.byteorder) + \ self.message.encode("UTF-8")[:self.length] def handle(self, user: Any, sender: Any) -> None: user.add_message(f"warning: *A client warned you: {self.message}*" if not user.squinnondation.no_markdown else "warning: A client warned you: {self.message}") @staticmethod def construct(message: str) -> "WarningTLV": tlv = WarningTLV() tlv.type = 7 tlv.message = message tlv.length = len(tlv.message.encode("UTF-8")) return tlv class Packet: """ A Packet is a wrapper around the raw data that it sent and received to other clients. """ magic: int version: int body_length: int body: List[TLV] def validate_data(self) -> bool: """ Ensure that the packet is well-formed. Raises a ValueError if the packet contains bad data. """ if self.magic != 95: raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) if self.version != 0: raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) if not (0 <= self.body_length <= 1200): raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," "found: {:d}".format(self.body_length)) return all(tlv.validate_data() for tlv in self.body) @staticmethod def unmarshal(data: bytes) -> "Packet": """ Read raw data and build the packet wrapper. Raises a ValueError whenever the data is invalid. """ pkt = Packet() pkt.magic = data[0] pkt.version = data[1] pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder)) if len(data) != 4 + pkt.body_length: raise ValueError(f"Invalid packet length: " f"declared body length is {pkt.body_length} while {len(data) - 4} bytes are avalaible") pkt.body = [] read_bytes = 0 while read_bytes < min(len(data) - 4, pkt.body_length): tlv_type = data[4 + read_bytes] if not (0 <= tlv_type < len(TLV.tlv_classes())): raise ValueError(f"TLV type is not supported: {tlv_type}") tlv_length = data[4 + read_bytes + 1] if tlv_type > 0 else -1 if 2 + tlv_length > pkt.body_length - read_bytes: raise ValueError(f"TLV length is too long: requesting {tlv_length} bytes, " f"remaining {pkt.body_length - read_bytes}") tlv = TLV.tlv_classes()[tlv_type]() tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + 2 + tlv_length]) pkt.body.append(tlv) read_bytes += len(tlv) pkt.validate_data() return pkt def marshal(self) -> bytes: """ Compute the byte array data associated to the packet. """ data = self.magic.to_bytes(1, sys.byteorder) data += self.version.to_bytes(1, sys.byteorder) data += socket.htons(self.body_length).to_bytes(2, sys.byteorder) data += b"".join(tlv.marshal() for tlv in self.body) return data def __len__(self) -> int: """ Calculates the length, in bytes, of the packet. """ return 4 + sum(len(tlv) for tlv in self.body) def split(self, pkt_size: int) -> List["Packet"]: """ If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024), then we split the packet in sub-packets. Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet, then we don't need to split TLVs in smaller TLVs. """ packets = [] current_size = 4 # Packet header length body = [] for tlv in self.body: if current_size + len(tlv) > pkt_size: packets.append(Packet.construct(*body)) body.clear() current_size = 4 body.append(tlv) current_size += len(tlv) if body: packets.append(Packet.construct(*body)) return packets @staticmethod def construct(*tlvs: TLV) -> "Packet": """ Construct a new packet from the given TLVs and calculate the good lengths """ pkt = Packet() pkt.magic = 95 pkt.version = 0 pkt.body = tlvs pkt.body_length = sum(len(tlv) for tlv in tlvs) return pkt