371 lines
12 KiB
Python
371 lines
12 KiB
Python
# Copyright (C) 2020 by eichhornchen, ÿnérant
|
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
|
|
import socket
|
|
from argparse import ArgumentParser
|
|
from enum import Enum
|
|
from typing import Any, Optional, Tuple
|
|
|
|
|
|
class Squinnondation:
|
|
args: Any
|
|
bind_address: str
|
|
bind_port: int
|
|
client_address: str
|
|
client_port: int
|
|
|
|
def parse_arguments(self) -> None:
|
|
parser = ArgumentParser(description="MIRC client.")
|
|
parser.add_argument('bind_address', type=str, default="localhost",
|
|
help="Address of the client.")
|
|
parser.add_argument('bind_port', type=int, default=2500,
|
|
help="Port of the client. Must be between 1024 and 65535.")
|
|
parser.add_argument('--client_address', type=str, default="localhost",
|
|
help="Address of the first neighbour.")
|
|
parser.add_argument('--client_port', type=int, default=2500,
|
|
help="Port of the first neighbour. Must be between 1024 and 65535.")
|
|
parser.add_argument('--bind-only', '-b', action='store_true',
|
|
help="Don't connect to another client, only listen to connections.")
|
|
self.args = parser.parse_args()
|
|
|
|
if not (1024 <= self.args.bind_port <= 65535) and (1024 <= self.args.client_port <= 65535):
|
|
raise ValueError("Ports must be between 1024 and 65535.")
|
|
|
|
self.bind_address = self.args.bind_address
|
|
self.bind_port = self.args.bind_port
|
|
self.client_address = self.args.client_address
|
|
self.client_port = self.args.client_port
|
|
|
|
@staticmethod
|
|
def main() -> None: # pragma: no cover
|
|
instance = Squinnondation()
|
|
instance.parse_arguments()
|
|
|
|
squirrel = Squirrel(input("Enter your nickname: "), instance.bind_address, instance.bind_port)
|
|
|
|
if not instance.args.bind_only:
|
|
hazelnut = Hazelnut(address=instance.client_address, port=instance.client_port)
|
|
pkt = Packet()
|
|
pkt.magic = 95
|
|
pkt.version = 0
|
|
pkt.body = DataTLV()
|
|
msg = f"Hello world, my name is {squirrel.nickname}!"
|
|
pkt.body.data = msg.encode("UTF-8")
|
|
pkt.body.sender_id = 42
|
|
pkt.body.nonce = 18
|
|
pkt.body.length = len(msg) + 1 + 1 + 8 + 4
|
|
pkt.body_length = pkt.body.length + 2
|
|
squirrel.send_packet(hazelnut, pkt)
|
|
|
|
while True:
|
|
pkt, addr = squirrel.receive_packet()
|
|
print(f"received message: {pkt.body.data.decode('UTF-8')}")
|
|
|
|
|
|
class TLV:
|
|
"""
|
|
The Tag-Length-Value contains the different type of data that can be sent.
|
|
TODO: add subclasses for each type of TLV
|
|
"""
|
|
type: int
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
"""
|
|
Parse data and construct TLV.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def marshal(self) -> bytes:
|
|
"""
|
|
Translate the TLV into a byte array.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def validate_data(self) -> bool:
|
|
"""
|
|
Ensure that the TLV is well-formed.
|
|
Raises a ValueError if it is not the case.
|
|
TODO: Make some tests
|
|
"""
|
|
return True
|
|
|
|
@staticmethod
|
|
def tlv_classes():
|
|
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
|
|
|
|
|
|
class Pad1TLV(TLV):
|
|
"""
|
|
This TLV is simply ignored.
|
|
"""
|
|
type: int = 0
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
"""
|
|
There is nothing to do. We ignore the packet.
|
|
"""
|
|
self.type = raw_data[0]
|
|
|
|
def marshal(self) -> bytes:
|
|
"""
|
|
The TLV is empty.
|
|
"""
|
|
return self.type.to_bytes(1, "big")
|
|
|
|
|
|
class PadNTLV(TLV):
|
|
"""
|
|
This TLV is filled with zeros. It is ignored.
|
|
"""
|
|
type: int = 1
|
|
length: int
|
|
mbz: bytes
|
|
|
|
def validate_data(self) -> bool:
|
|
if self.mbz != int(0).to_bytes(self.length, "big"):
|
|
raise ValueError("The body of a PadN TLV is not filled with zeros.")
|
|
return True
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
"""
|
|
Store the zero-array, then ignore the packet.
|
|
"""
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.mbz = raw_data[2:2 + self.length]
|
|
|
|
def marshal(self) -> bytes:
|
|
"""
|
|
Construct the byte array filled by zeros.
|
|
"""
|
|
return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length]
|
|
|
|
|
|
class HelloTLV(TLV):
|
|
type: int = 2
|
|
length: int
|
|
source_id: int
|
|
dest_id: Optional[int]
|
|
|
|
def validate_data(self) -> bool:
|
|
if self.length != 8 and self.length != 16:
|
|
raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
|
|
f"found {self.length}")
|
|
return True
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.source_id = int.from_bytes(raw_data[2:10], "big")
|
|
if self.length == 16:
|
|
self.dest_id = int.from_bytes(raw_data[10:18], "big")
|
|
|
|
def marshal(self) -> bytes:
|
|
data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big")
|
|
if self.dest_id:
|
|
data += self.dest_id.to_bytes(8, "big")
|
|
return data
|
|
|
|
|
|
class NeighbourTLV(TLV):
|
|
type: int = 3
|
|
length: int
|
|
ip_address: int
|
|
port: int
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.ip_address = raw_data[2:18]
|
|
self.port = int.from_bytes(raw_data[18:20], "big")
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, "big") + \
|
|
self.length.to_bytes(1, "big") + \
|
|
self.ip_address.to_bytes(16, "big") + \
|
|
self.port.to_bytes(2, "big")
|
|
|
|
|
|
class DataTLV(TLV):
|
|
type: int = 4
|
|
length: int
|
|
sender_id: int
|
|
nonce: int
|
|
data: bytes
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.sender_id = int.from_bytes(raw_data[2:10], "big")
|
|
self.nonce = int.from_bytes(raw_data[10:14], "big")
|
|
self.data = raw_data[14:2 + self.length]
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, "big") + \
|
|
self.length.to_bytes(1, "big") + \
|
|
self.sender_id.to_bytes(8, "big") + \
|
|
self.nonce.to_bytes(4, "big") + \
|
|
self.data
|
|
|
|
|
|
class AckTLV(TLV):
|
|
type: int = 5
|
|
length: int
|
|
sender_id: int
|
|
nonce: int
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.sender_id = int.from_bytes(raw_data[2:10], "big")
|
|
self.nonce = int.from_bytes(raw_data[10:14], "big")
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, "big") + \
|
|
self.length.to_bytes(1, "big") + \
|
|
self.sender_id.to_bytes(8, "big") + \
|
|
self.nonce.to_bytes(4, "big")
|
|
|
|
|
|
class GoAwayTLV(TLV):
|
|
class GoAwayType(Enum):
|
|
UNKNOWN = 0
|
|
EXIT = 1
|
|
TIMEOUT = 2
|
|
PROTOCOL_VIOLATION = 3
|
|
|
|
type: int = 6
|
|
length: int
|
|
code: GoAwayType
|
|
message: str
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.code = GoAwayTLV.GoAwayType(raw_data[2])
|
|
self.message = raw_data[3:self.length - 1].decode("UTF-8")
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, "big") + \
|
|
self.length.to_bytes(1, "big") + \
|
|
self.code.value.to_bytes(1, "big") + \
|
|
self.message.encode("UTF-8")[:self.length - 1]
|
|
|
|
|
|
class WarningTLV(TLV):
|
|
type: int = 7
|
|
length: int
|
|
message: str
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.message = raw_data[2:self.length].decode("UTF-8")
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, "big") + \
|
|
self.length.to_bytes(1, "big") + \
|
|
self.message.encode("UTF-8")[:self.length]
|
|
|
|
|
|
class Packet:
|
|
"""
|
|
A Packet is a wrapper around the
|
|
"""
|
|
magic: int
|
|
version: int
|
|
body_length: int
|
|
body: TLV
|
|
|
|
def validate_data(self) -> bool:
|
|
"""
|
|
Ensure that the packet is well-formed.
|
|
Raises a ValueError if the packet contains bad data.
|
|
"""
|
|
if self.magic != 95:
|
|
raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
|
|
if self.version != 0:
|
|
raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
|
|
if not (0 <= self.body_length <= 120):
|
|
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
|
|
"found: {:d}".format(self.body_length))
|
|
return self.body.validate_data()
|
|
|
|
@staticmethod
|
|
def unmarshal(data: bytes) -> "Packet":
|
|
"""
|
|
Read raw data and build the packet wrapper.
|
|
Raises a ValueError whenever the data is invalid.
|
|
"""
|
|
pkt = Packet()
|
|
pkt.magic = data[0]
|
|
pkt.version = data[1]
|
|
pkt.body_length = int.from_bytes(data[2:4], byteorder="big")
|
|
tlv_type = data[4]
|
|
if not (0 <= tlv_type < len(TLV.tlv_classes())):
|
|
raise ValueError(f"TLV type is not supported: {tlv_type}")
|
|
pkt.body = TLV.tlv_classes()[tlv_type]()
|
|
pkt.body.unmarshal(data[4:4+pkt.body_length])
|
|
|
|
pkt.validate_data()
|
|
|
|
return pkt
|
|
|
|
def marshal(self) -> bytes:
|
|
"""
|
|
Compute the byte array data associated to the packet.
|
|
"""
|
|
data = self.magic.to_bytes(1, "big")
|
|
data += self.version.to_bytes(1, "big")
|
|
data += self.body_length.to_bytes(2, "big")
|
|
data += self.body.marshal()
|
|
return data
|
|
|
|
|
|
class Hazelnut:
|
|
"""
|
|
A hazelnut is a connected client, with its socket.
|
|
"""
|
|
def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500):
|
|
self.nickname = nickname
|
|
self.address = address
|
|
self.port = port
|
|
|
|
|
|
class Squirrel(Hazelnut):
|
|
"""
|
|
The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts.
|
|
"""
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
self.socket.bind((self.address, self.port))
|
|
print(f"Listening on {self.address}:{self.port}")
|
|
|
|
def send_packet(self, client: Hazelnut, pkt: Packet) -> int:
|
|
"""
|
|
Send a formatted packet to a client.
|
|
"""
|
|
return self.send_raw_data(client, pkt.marshal())
|
|
|
|
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
|
|
"""
|
|
Send a raw packet to a client.
|
|
"""
|
|
return self.socket.sendto(data, (client.address, client.port))
|
|
|
|
def receive_packet(self) -> Tuple[Packet, Any]:
|
|
"""
|
|
Receive a packet from the socket and translate it into a Python object.
|
|
TODO: Translate the address into the correct hazelnut.
|
|
"""
|
|
data, addr = self.receive_raw_data()
|
|
return Packet.unmarshal(data), addr
|
|
|
|
def receive_raw_data(self) -> Tuple[bytes, Any]:
|
|
"""
|
|
Receive a packet from the socket.
|
|
TODO: Translate the address into the correct hazelnut.
|
|
"""
|
|
return self.socket.recvfrom(1024)
|