# Copyright (C) 2020 by eichhornchen, ÿnérant # SPDX-License-Identifier: GPL-3.0-or-later import socket from argparse import ArgumentParser from enum import Enum from ipaddress import IPv6Address from typing import Any, Optional, Tuple class Squinnondation: args: Any bind_address: str bind_port: int client_address: str client_port: int def parse_arguments(self) -> None: parser = ArgumentParser(description="MIRC client.") parser.add_argument('bind_address', type=str, default="localhost", help="Address of the client.") parser.add_argument('bind_port', type=int, default=2500, help="Port of the client. Must be between 1024 and 65535.") parser.add_argument('--client_address', type=str, default="localhost", help="Address of the first neighbour.") parser.add_argument('--client_port', type=int, default=2500, help="Port of the first neighbour. Must be between 1024 and 65535.") parser.add_argument('--bind-only', '-b', action='store_true', help="Don't connect to another client, only listen to connections.") self.args = parser.parse_args() if not (1024 <= self.args.bind_port <= 65535) and (1024 <= self.args.client_port <= 65535): raise ValueError("Ports must be between 1024 and 65535.") self.bind_address = self.args.bind_address self.bind_port = self.args.bind_port self.client_address = self.args.client_address self.client_port = self.args.client_port @staticmethod def main() -> None: # pragma: no cover instance = Squinnondation() instance.parse_arguments() squirrel = Squirrel(input("Enter your nickname: "), instance.bind_address, instance.bind_port) if not instance.args.bind_only: hazelnut = Hazelnut(address=instance.client_address, port=instance.client_port) pkt = Packet() pkt.magic = 95 pkt.version = 0 pkt.body = DataTLV() msg = f"Hello world, my name is {squirrel.nickname}!" pkt.body.data = msg.encode("UTF-8") pkt.body.sender_id = 42 pkt.body.nonce = 18 pkt.body.length = len(msg) + 1 + 1 + 8 + 4 pkt.body_length = pkt.body.length + 2 squirrel.send_packet(hazelnut, pkt) while True: pkt, addr = squirrel.receive_packet() print(f"received message: {pkt.body.data.decode('UTF-8')}") class TLV: """ The Tag-Length-Value contains the different type of data that can be sent. TODO: add subclasses for each type of TLV """ type: int def unmarshal(self, raw_data: bytes) -> None: """ Parse data and construct TLV. """ raise NotImplementedError def marshal(self) -> bytes: """ Translate the TLV into a byte array. """ raise NotImplementedError def validate_data(self) -> bool: """ Ensure that the TLV is well-formed. Raises a ValueError if it is not the case. TODO: Make some tests """ return True @staticmethod def tlv_classes(): return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] class Pad1TLV(TLV): """ This TLV is simply ignored. """ type: int = 0 def unmarshal(self, raw_data: bytes) -> None: """ There is nothing to do. We ignore the packet. """ self.type = raw_data[0] def marshal(self) -> bytes: """ The TLV is empty. """ return self.type.to_bytes(1, "big") class PadNTLV(TLV): """ This TLV is filled with zeros. It is ignored. """ type: int = 1 length: int mbz: bytes def validate_data(self) -> bool: if self.mbz != int(0).to_bytes(self.length, "big"): raise ValueError("The body of a PadN TLV is not filled with zeros.") return True def unmarshal(self, raw_data: bytes) -> None: """ Store the zero-array, then ignore the packet. """ self.type = raw_data[0] self.length = raw_data[1] self.mbz = raw_data[2:2 + self.length] def marshal(self) -> bytes: """ Construct the byte array filled by zeros. """ return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length] class HelloTLV(TLV): type: int = 2 length: int source_id: int dest_id: Optional[int] def validate_data(self) -> bool: if self.length != 8 and self.length != 16: raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," f"found {self.length}") return True def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.source_id = int.from_bytes(raw_data[2:10], "big") if self.length == 16: self.dest_id = int.from_bytes(raw_data[10:18], "big") def marshal(self) -> bytes: data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big") if self.dest_id: data += self.dest_id.to_bytes(8, "big") return data class NeighbourTLV(TLV): type: int = 3 length: int ip_address: IPv6Address port: int def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.ip_address = IPv6Address(raw_data[2:18]) self.port = int.from_bytes(raw_data[18:20], "big") def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.ip_address.packed + \ self.port.to_bytes(2, "big") class DataTLV(TLV): type: int = 4 length: int sender_id: int nonce: int data: bytes def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], "big") self.nonce = int.from_bytes(raw_data[10:14], "big") self.data = raw_data[14:2 + self.length] def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.sender_id.to_bytes(8, "big") + \ self.nonce.to_bytes(4, "big") + \ self.data class AckTLV(TLV): type: int = 5 length: int sender_id: int nonce: int def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], "big") self.nonce = int.from_bytes(raw_data[10:14], "big") def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.sender_id.to_bytes(8, "big") + \ self.nonce.to_bytes(4, "big") class GoAwayTLV(TLV): class GoAwayType(Enum): UNKNOWN = 0 EXIT = 1 TIMEOUT = 2 PROTOCOL_VIOLATION = 3 type: int = 6 length: int code: GoAwayType message: str def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.code = GoAwayTLV.GoAwayType(raw_data[2]) self.message = raw_data[3:self.length - 1].decode("UTF-8") def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.code.value.to_bytes(1, "big") + \ self.message.encode("UTF-8")[:self.length - 1] class WarningTLV(TLV): type: int = 7 length: int message: str def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.message = raw_data[2:self.length].decode("UTF-8") def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.message.encode("UTF-8")[:self.length] class Packet: """ A Packet is a wrapper around the """ magic: int version: int body_length: int body: TLV def validate_data(self) -> bool: """ Ensure that the packet is well-formed. Raises a ValueError if the packet contains bad data. """ if self.magic != 95: raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) if self.version != 0: raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) if not (0 <= self.body_length <= 120): raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," "found: {:d}".format(self.body_length)) return self.body.validate_data() @staticmethod def unmarshal(data: bytes) -> "Packet": """ Read raw data and build the packet wrapper. Raises a ValueError whenever the data is invalid. """ pkt = Packet() pkt.magic = data[0] pkt.version = data[1] pkt.body_length = int.from_bytes(data[2:4], byteorder="big") tlv_type = data[4] if not (0 <= tlv_type < len(TLV.tlv_classes())): raise ValueError(f"TLV type is not supported: {tlv_type}") pkt.body = TLV.tlv_classes()[tlv_type]() pkt.body.unmarshal(data[4:4+pkt.body_length]) pkt.validate_data() return pkt def marshal(self) -> bytes: """ Compute the byte array data associated to the packet. """ data = self.magic.to_bytes(1, "big") data += self.version.to_bytes(1, "big") data += self.body_length.to_bytes(2, "big") data += self.body.marshal() return data class Hazelnut: """ A hazelnut is a connected client, with its socket. """ def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500): self.nickname = nickname # Resolve DNS as an IPv6 address = socket.getaddrinfo(address, None, socket.AF_INET6)[0][4][0] self.address = IPv6Address(address) self.port = port class Squirrel(Hazelnut): """ The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) self.socket.bind((str(self.address), self.port)) print(f"Listening on {self.address}:{self.port}") def send_packet(self, client: Hazelnut, pkt: Packet) -> int: """ Send a formatted packet to a client. """ return self.send_raw_data(client, pkt.marshal()) def send_raw_data(self, client: Hazelnut, data: bytes) -> int: """ Send a raw packet to a client. """ return self.socket.sendto(data, (str(client.address), client.port)) def receive_packet(self) -> Tuple[Packet, Any]: """ Receive a packet from the socket and translate it into a Python object. TODO: Translate the address into the correct hazelnut. """ data, addr = self.receive_raw_data() return Packet.unmarshal(data), addr def receive_raw_data(self) -> Tuple[bytes, Any]: """ Receive a packet from the socket. TODO: Translate the address into the correct hazelnut. """ return self.socket.recvfrom(1024)