squinnondation/squinnondation/squinnondation.py

462 lines
15 KiB
Python

# Copyright (C) 2020 by eichhornchen, ÿnérant
# SPDX-License-Identifier: GPL-3.0-or-later
import curses
import socket
from argparse import ArgumentParser
from enum import Enum
from ipaddress import IPv6Address
from threading import Thread
from typing import Any, List, Optional, Tuple
from squinnondation.term_manager import TermManager
class Squinnondation:
args: Any
bind_address: str
bind_port: int
client_address: str
client_port: int
def parse_arguments(self) -> None:
parser = ArgumentParser(description="MIRC client.")
parser.add_argument('bind_address', type=str, default="localhost",
help="Address of the client.")
parser.add_argument('bind_port', type=int, default=2500,
help="Port of the client. Must be between 1024 and 65535.")
parser.add_argument('--client_address', type=str, default="localhost",
help="Address of the first neighbour.")
parser.add_argument('--client_port', type=int, default=2500,
help="Port of the first neighbour. Must be between 1024 and 65535.")
parser.add_argument('--bind-only', '-b', action='store_true',
help="Don't connect to another client, only listen to connections.")
self.args = parser.parse_args()
if not (1024 <= self.args.bind_port <= 65535) and (1024 <= self.args.client_port <= 65535):
raise ValueError("Ports must be between 1024 and 65535.")
self.bind_address = self.args.bind_address
self.bind_port = self.args.bind_port
self.client_address = self.args.client_address
self.client_port = self.args.client_port
@staticmethod
def main() -> None: # pragma: no cover
instance = Squinnondation()
instance.parse_arguments()
with TermManager() as term_manager:
screen = term_manager.screen
screen.addstr(0, 0, "Enter your nickname: ")
nickname = screen.getstr().decode("UTF-8")
squirrel = Squirrel(nickname, instance.bind_address, instance.bind_port)
squirrel.refresh_history()
squirrel.refresh_input()
if not instance.args.bind_only:
hazelnut = Hazelnut(address=instance.client_address, port=instance.client_port)
squirrel.hazelnuts[(instance.client_address, instance.client_port)] = hazelnut
Worm(squirrel).start()
while True:
squirrel.refresh_history()
squirrel.refresh_input()
msg = screen.getstr(curses.LINES - 1, 3 + len(squirrel.nickname)).decode("UTF-8")
msg = f"<{squirrel.nickname}> {msg}"
squirrel.history.append(msg)
for hazelnut in list(squirrel.hazelnuts.values()):
pkt = Packet()
pkt.magic = 95
pkt.version = 0
tlv = DataTLV()
tlv.data = msg.encode("UTF-8")
tlv.sender_id = 42
tlv.nonce = 18
tlv.length = len(tlv.data) + 1 + 1 + 8 + 4
pkt.body = [tlv]
pkt.body_length = tlv.length + 2
squirrel.send_packet(hazelnut, pkt)
class TLV:
"""
The Tag-Length-Value contains the different type of data that can be sent.
TODO: add subclasses for each type of TLV
"""
type: int
def unmarshal(self, raw_data: bytes) -> None:
"""
Parse data and construct TLV.
"""
raise NotImplementedError
def marshal(self) -> bytes:
"""
Translate the TLV into a byte array.
"""
raise NotImplementedError
def validate_data(self) -> bool:
"""
Ensure that the TLV is well-formed.
Raises a ValueError if it is not the case.
TODO: Make some tests
"""
return True
@staticmethod
def tlv_classes():
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
class Pad1TLV(TLV):
"""
This TLV is simply ignored.
"""
type: int = 0
def unmarshal(self, raw_data: bytes) -> None:
"""
There is nothing to do. We ignore the packet.
"""
self.type = raw_data[0]
def marshal(self) -> bytes:
"""
The TLV is empty.
"""
return self.type.to_bytes(1, "big")
class PadNTLV(TLV):
"""
This TLV is filled with zeros. It is ignored.
"""
type: int = 1
length: int
mbz: bytes
def validate_data(self) -> bool:
if self.mbz != int(0).to_bytes(self.length, "big"):
raise ValueError("The body of a PadN TLV is not filled with zeros.")
return True
def unmarshal(self, raw_data: bytes) -> None:
"""
Store the zero-array, then ignore the packet.
"""
self.type = raw_data[0]
self.length = raw_data[1]
self.mbz = raw_data[2:2 + self.length]
def marshal(self) -> bytes:
"""
Construct the byte array filled by zeros.
"""
return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length]
class HelloTLV(TLV):
type: int = 2
length: int
source_id: int
dest_id: Optional[int]
def validate_data(self) -> bool:
if self.length != 8 and self.length != 16:
raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
f"found {self.length}")
return True
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.source_id = int.from_bytes(raw_data[2:10], "big")
if self.length == 16:
self.dest_id = int.from_bytes(raw_data[10:18], "big")
def marshal(self) -> bytes:
data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big")
if self.dest_id:
data += self.dest_id.to_bytes(8, "big")
return data
class NeighbourTLV(TLV):
type: int = 3
length: int
ip_address: IPv6Address
port: int
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.ip_address = IPv6Address(raw_data[2:18])
self.port = int.from_bytes(raw_data[18:20], "big")
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.ip_address.packed + \
self.port.to_bytes(2, "big")
class DataTLV(TLV):
type: int = 4
length: int
sender_id: int
nonce: int
data: bytes
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], "big")
self.nonce = int.from_bytes(raw_data[10:14], "big")
self.data = raw_data[14:2 + self.length]
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.sender_id.to_bytes(8, "big") + \
self.nonce.to_bytes(4, "big") + \
self.data
class AckTLV(TLV):
type: int = 5
length: int
sender_id: int
nonce: int
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], "big")
self.nonce = int.from_bytes(raw_data[10:14], "big")
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.sender_id.to_bytes(8, "big") + \
self.nonce.to_bytes(4, "big")
class GoAwayTLV(TLV):
class GoAwayType(Enum):
UNKNOWN = 0
EXIT = 1
TIMEOUT = 2
PROTOCOL_VIOLATION = 3
type: int = 6
length: int
code: GoAwayType
message: str
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.code = GoAwayTLV.GoAwayType(raw_data[2])
self.message = raw_data[3:self.length - 1].decode("UTF-8")
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.code.value.to_bytes(1, "big") + \
self.message.encode("UTF-8")[:self.length - 1]
class WarningTLV(TLV):
type: int = 7
length: int
message: str
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.message = raw_data[2:self.length].decode("UTF-8")
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.message.encode("UTF-8")[:self.length]
class Packet:
"""
A Packet is a wrapper around the
"""
magic: int
version: int
body_length: int
body: List[TLV]
def validate_data(self) -> bool:
"""
Ensure that the packet is well-formed.
Raises a ValueError if the packet contains bad data.
"""
if self.magic != 95:
raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
if self.version != 0:
raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
if not (0 <= self.body_length <= 120):
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
"found: {:d}".format(self.body_length))
return all(tlv.validate_data() for tlv in self.body)
@staticmethod
def unmarshal(data: bytes) -> "Packet":
"""
Read raw data and build the packet wrapper.
Raises a ValueError whenever the data is invalid.
"""
pkt = Packet()
pkt.magic = data[0]
pkt.version = data[1]
pkt.body_length = int.from_bytes(data[2:4], byteorder="big")
pkt.body = []
read_bytes = 0
while read_bytes <= min(len(data) - 4, pkt.body_length):
tlv_type = data[4]
if not (0 <= tlv_type < len(TLV.tlv_classes())):
raise ValueError(f"TLV type is not supported: {tlv_type}")
tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4:4 + pkt.body_length])
pkt.body.append(tlv)
# Pad1TLV has no length
read_bytes += 1 if tlv_type == 0 else tlv.length + 2
pkt.validate_data()
return pkt
def marshal(self) -> bytes:
"""
Compute the byte array data associated to the packet.
"""
data = self.magic.to_bytes(1, "big")
data += self.version.to_bytes(1, "big")
data += self.body_length.to_bytes(2, "big")
data += b"".join(tlv.marshal() for tlv in self.body)
return data
class Hazelnut:
"""
A hazelnut is a connected client, with its socket.
"""
def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500):
self.nickname = nickname
try:
# Resolve DNS as an IPv6
address = socket.getaddrinfo(address, None, socket.AF_INET6)[0][4][0]
except socket.gaierror:
# This is not a valid IPv6. Assume it can be resolved as an IPv4, and we use IPv4-mapping
# to compute a valid IPv6.
# See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4
address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0]
self.address = IPv6Address(address)
self.port = port
class Squirrel(Hazelnut):
"""
The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Create UDP socket
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Bind the socket
self.socket.bind((str(self.address), self.port))
self.history = []
self.history_pad = curses.newpad(curses.LINES - 2, curses.COLS)
self.input_pad = curses.newpad(1, curses.COLS)
self.hazelnuts = dict()
self.history.append(f"<system> Listening on {self.address}:{self.port}")
def find_hazelnut(self, address: str, port: int) -> Hazelnut:
"""
Translate an address into a hazelnut, and store it in the list of the hazelnuts, ie. the neighbours.
"""
if (address, port) in self.hazelnuts:
return self.hazelnuts[(address, port)]
hazelnut = Hazelnut(address=address, port=port)
self.hazelnuts[(address, port)] = hazelnut
return hazelnut
def send_packet(self, client: Hazelnut, pkt: Packet) -> int:
"""
Send a formatted packet to a client.
"""
return self.send_raw_data(client, pkt.marshal())
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
"""
Send a raw packet to a client.
"""
return self.socket.sendto(data, (str(client.address), client.port))
def receive_packet(self) -> Tuple[Packet, Hazelnut]:
"""
Receive a packet from the socket and translate it into a Python object.
Warning: the process is blocking, it should be ran inside a dedicated thread.
"""
data, addr = self.receive_raw_data()
return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1])
def receive_raw_data(self) -> Tuple[bytes, Any]:
"""
Receive a packet from the socket.
"""
return self.socket.recvfrom(1024)
def refresh_history(self) -> None:
"""
Rewrite the history of the messages.
"""
self.history_pad.erase()
for i, msg in enumerate(self.history[max(0, len(self.history) - curses.LINES + 2):]):
self.history_pad.addstr(i, 0, msg)
self.history_pad.refresh(0, 0, 0, 0, curses.LINES - 2, curses.COLS)
def refresh_input(self) -> None:
"""
Redraw input line. Must not be called while the message is not sent.
"""
self.input_pad.erase()
self.input_pad.addstr(0, 0, f"<{self.nickname}> ")
self.input_pad.refresh(0, 0, curses.LINES - 1, 0, curses.LINES - 1, curses.COLS - 1)
class Worm(Thread):
"""
The worm is the hazel listener.
It always waits for an incoming packet, then it treats it, and continues to wait.
It is in a dedicated thread.
"""
def __init__(self, squirrel: Squirrel, *args, **kwargs):
super().__init__(*args, **kwargs)
self.squirrel = squirrel
def run(self) -> None:
while True:
try:
pkt, hazelnut = self.squirrel.receive_packet()
except ValueError as error:
self.squirrel.history.append("<system> An error occured while receiving a packet: {}".format(error))
else:
self.squirrel.history.append(pkt.body[0].data.decode('UTF-8'))
self.squirrel.refresh_history()