squinnondation/squinnondation/squinnondation.py

594 lines
22 KiB
Python

# Copyright (C) 2020 by eichhornchen, ÿnérant
# SPDX-License-Identifier: GPL-3.0-or-later
import curses
import re
import socket
from argparse import ArgumentParser
from enum import Enum
from ipaddress import IPv6Address
from threading import Thread
from typing import Any, List, Optional, Tuple
import emoji
from squinnondation.term_manager import TermManager
class Squinnondation:
args: Any
bind_address: str
bind_port: int
no_emoji: bool
no_markdown: bool
screen: Any
def parse_arguments(self) -> None:
parser = ArgumentParser(description="MIRC client.")
parser.add_argument('bind_address', type=str, default="localhost",
help="Address of the client.")
parser.add_argument('bind_port', type=int, default=2500,
help="Port of the client. Must be between 1024 and 65535.")
parser.add_argument('--client_address', type=str, default=None,
help="Address of the first neighbour.")
parser.add_argument('--client_port', type=int, default=0,
help="Port of the first neighbour. Must be between 1024 and 65535.")
parser.add_argument('--no-emoji', '-ne', action='store_true',
help="Don't replace emojis.")
parser.add_argument('--no-markdown', '-nm', action='store_true',
help="Don't replace emojis.")
self.args = parser.parse_args()
if not (1024 <= self.args.bind_port <= 65535) or\
not (not self.args.client_port or 1024 <= self.args.client_port <= 65535):
raise ValueError("Ports must be between 1024 and 65535.")
self.bind_address = self.args.bind_address
self.bind_port = self.args.bind_port
self.no_emoji = self.args.no_emoji
self.no_markdown = self.args.no_markdown
@staticmethod
def main() -> None: # pragma: no cover
instance = Squinnondation()
instance.parse_arguments()
with TermManager() as term_manager:
screen = term_manager.screen
instance.screen = screen
screen.addstr(0, 0, "Enter your nickname: ")
curses.echo()
nickname = screen.getstr().decode("UTF-8")
curses.noecho()
squirrel = Squirrel(instance, nickname)
squirrel.refresh_history()
squirrel.refresh_input()
if instance.args.client_address and instance.args.client_port:
hazelnut = Hazelnut(address=instance.args.client_address, port=instance.args.client_port)
squirrel.hazelnuts[(instance.args.client_address, instance.args.client_port)] = hazelnut
Worm(squirrel).start()
while True:
squirrel.refresh_history()
squirrel.refresh_input()
key = screen.getkey(curses.LINES - 1, 3 + len(squirrel.nickname) + len(squirrel.input_buffer))
if key == "\x7f": # backspace
squirrel.input_buffer = squirrel.input_buffer[:-1]
continue
elif len(key) > 1:
squirrel.history.append(f"<system> *unmanaged key press: {key}*")
continue
elif key != "\n":
squirrel.input_buffer += key
continue
msg = squirrel.input_buffer
squirrel.input_buffer = ""
if not msg:
continue
msg = f"<{squirrel.nickname}> {msg}"
squirrel.history.append(msg)
for hazelnut in list(squirrel.hazelnuts.values()):
pkt = Packet()
pkt.magic = 95
pkt.version = 0
tlv = DataTLV()
tlv.data = msg.encode("UTF-8")
tlv.sender_id = 42
tlv.nonce = 18
tlv.length = len(tlv.data) + 1 + 1 + 8 + 4
pkt.body = [tlv]
pkt.body_length = tlv.length + 2
squirrel.send_packet(hazelnut, pkt)
class TLV:
"""
The Tag-Length-Value contains the different type of data that can be sent.
TODO: add subclasses for each type of TLV
"""
type: int
def unmarshal(self, raw_data: bytes) -> None:
"""
Parse data and construct TLV.
"""
raise NotImplementedError
def marshal(self) -> bytes:
"""
Translate the TLV into a byte array.
"""
raise NotImplementedError
def validate_data(self) -> bool:
"""
Ensure that the TLV is well-formed.
Raises a ValueError if it is not the case.
TODO: Make some tests
"""
return True
@staticmethod
def tlv_classes():
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
class Pad1TLV(TLV):
"""
This TLV is simply ignored.
"""
type: int = 0
def unmarshal(self, raw_data: bytes) -> None:
"""
There is nothing to do. We ignore the packet.
"""
self.type = raw_data[0]
def marshal(self) -> bytes:
"""
The TLV is empty.
"""
return self.type.to_bytes(1, "big")
class PadNTLV(TLV):
"""
This TLV is filled with zeros. It is ignored.
"""
type: int = 1
length: int
mbz: bytes
def validate_data(self) -> bool:
if self.mbz != int(0).to_bytes(self.length, "big"):
raise ValueError("The body of a PadN TLV is not filled with zeros.")
return True
def unmarshal(self, raw_data: bytes) -> None:
"""
Store the zero-array, then ignore the packet.
"""
self.type = raw_data[0]
self.length = raw_data[1]
self.mbz = raw_data[2:2 + self.length]
def marshal(self) -> bytes:
"""
Construct the byte array filled by zeros.
"""
return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length]
class HelloTLV(TLV):
type: int = 2
length: int
source_id: int
dest_id: Optional[int]
def validate_data(self) -> bool:
if self.length != 8 and self.length != 16:
raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
f"found {self.length}")
return True
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.source_id = int.from_bytes(raw_data[2:10], "big")
if self.length == 16:
self.dest_id = int.from_bytes(raw_data[10:18], "big")
def marshal(self) -> bytes:
data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big")
if self.dest_id:
data += self.dest_id.to_bytes(8, "big")
return data
class NeighbourTLV(TLV):
type: int = 3
length: int
ip_address: IPv6Address
port: int
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.ip_address = IPv6Address(raw_data[2:18])
self.port = int.from_bytes(raw_data[18:20], "big")
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.ip_address.packed + \
self.port.to_bytes(2, "big")
class DataTLV(TLV):
type: int = 4
length: int
sender_id: int
nonce: int
data: bytes
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], "big")
self.nonce = int.from_bytes(raw_data[10:14], "big")
self.data = raw_data[14:2 + self.length]
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.sender_id.to_bytes(8, "big") + \
self.nonce.to_bytes(4, "big") + \
self.data
class AckTLV(TLV):
type: int = 5
length: int
sender_id: int
nonce: int
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], "big")
self.nonce = int.from_bytes(raw_data[10:14], "big")
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.sender_id.to_bytes(8, "big") + \
self.nonce.to_bytes(4, "big")
class GoAwayTLV(TLV):
class GoAwayType(Enum):
UNKNOWN = 0
EXIT = 1
TIMEOUT = 2
PROTOCOL_VIOLATION = 3
type: int = 6
length: int
code: GoAwayType
message: str
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.code = GoAwayTLV.GoAwayType(raw_data[2])
self.message = raw_data[3:self.length - 1].decode("UTF-8")
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.code.value.to_bytes(1, "big") + \
self.message.encode("UTF-8")[:self.length - 1]
class WarningTLV(TLV):
type: int = 7
length: int
message: str
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.message = raw_data[2:self.length].decode("UTF-8")
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
self.length.to_bytes(1, "big") + \
self.message.encode("UTF-8")[:self.length]
class Packet:
"""
A Packet is a wrapper around the
"""
magic: int
version: int
body_length: int
body: List[TLV]
def validate_data(self) -> bool:
"""
Ensure that the packet is well-formed.
Raises a ValueError if the packet contains bad data.
"""
if self.magic != 95:
raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
if self.version != 0:
raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
if not (0 <= self.body_length <= 120):
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
"found: {:d}".format(self.body_length))
return all(tlv.validate_data() for tlv in self.body)
@staticmethod
def unmarshal(data: bytes) -> "Packet":
"""
Read raw data and build the packet wrapper.
Raises a ValueError whenever the data is invalid.
"""
pkt = Packet()
pkt.magic = data[0]
pkt.version = data[1]
pkt.body_length = int.from_bytes(data[2:4], byteorder="big")
pkt.body = []
read_bytes = 0
while read_bytes <= min(len(data) - 4, pkt.body_length):
tlv_type = data[4]
if not (0 <= tlv_type < len(TLV.tlv_classes())):
raise ValueError(f"TLV type is not supported: {tlv_type}")
tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4:4 + pkt.body_length])
pkt.body.append(tlv)
# Pad1TLV has no length
read_bytes += 1 if tlv_type == 0 else tlv.length + 2
pkt.validate_data()
return pkt
def marshal(self) -> bytes:
"""
Compute the byte array data associated to the packet.
"""
data = self.magic.to_bytes(1, "big")
data += self.version.to_bytes(1, "big")
data += self.body_length.to_bytes(2, "big")
data += b"".join(tlv.marshal() for tlv in self.body)
return data
class Hazelnut:
"""
A hazelnut is a connected client, with its socket.
"""
def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500):
self.nickname = nickname
try:
# Resolve DNS as an IPv6
address = socket.getaddrinfo(address, None, socket.AF_INET6)[0][4][0]
except socket.gaierror:
# This is not a valid IPv6. Assume it can be resolved as an IPv4, and we use IPv4-mapping
# to compute a valid IPv6.
# See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4
address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0]
self.address = IPv6Address(address)
self.port = port
class Squirrel(Hazelnut):
"""
The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts.
"""
def __init__(self, instance: Squinnondation, nickname: str):
super().__init__(nickname, instance.bind_address, instance.bind_port)
# Create UDP socket
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Bind the socket
self.socket.bind((str(self.address), self.port))
self.squinnondation = instance
self.input_buffer = ""
self.history = []
self.history_pad = curses.newpad(curses.LINES - 2, curses.COLS)
self.input_pad = curses.newpad(1, curses.COLS)
curses.init_color(curses.COLOR_WHITE, 1000, 1000, 1000)
for i in range(curses.COLOR_BLACK + 1, curses.COLOR_WHITE):
curses.init_pair(i + 1, i, curses.COLOR_BLACK)
self.hazelnuts = dict()
self.history.append(f"<system> *Listening on {self.address}:{self.port}*")
def find_hazelnut(self, address: str, port: int) -> Hazelnut:
"""
Translate an address into a hazelnut, and store it in the list of the hazelnuts, ie. the neighbours.
"""
if (address, port) in self.hazelnuts:
return self.hazelnuts[(address, port)]
hazelnut = Hazelnut(address=address, port=port)
self.hazelnuts[(address, port)] = hazelnut
return hazelnut
def send_packet(self, client: Hazelnut, pkt: Packet) -> int:
"""
Send a formatted packet to a client.
"""
return self.send_raw_data(client, pkt.marshal())
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
"""
Send a raw packet to a client.
"""
return self.socket.sendto(data, (str(client.address), client.port))
def receive_packet(self) -> Tuple[Packet, Hazelnut]:
"""
Receive a packet from the socket and translate it into a Python object.
Warning: the process is blocking, it should be ran inside a dedicated thread.
"""
data, addr = self.receive_raw_data()
return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1])
def receive_raw_data(self) -> Tuple[bytes, Any]:
"""
Receive a packet from the socket.
"""
return self.socket.recvfrom(1024)
def print_markdown(self, pad: Any, y: int, x: int, msg: str,
bold: bool = False, italic: bool = False, underline: bool = False, strike: bool = False) -> int:
"""
Parse a markdown-formatted text and format the text as bold, italic or text text.
***text***: bold, italic
**text**: bold
*text*: italic
__text__: underline
_text_: italic
~~text~~: strikethrough
"""
# Replace :emoji_name: by the good emoji
if not self.squinnondation.no_emoji:
msg = emoji.emojize(msg, use_aliases=True)
if self.squinnondation.no_markdown:
pad.addstr(y, x, msg)
return len(msg)
underline_match = re.match("(.*)__(.*)__(.*)", msg)
if underline_match:
before, text, after = underline_match.group(1), underline_match.group(2), underline_match.group(3)
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline)
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, not underline)
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline)
return len_before + len_mid + len_after
italic_match = re.match("(.*)_(.*)_(.*)", msg)
if italic_match:
before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3)
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline)
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline)
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline)
return len_before + len_mid + len_after
bold_italic_match = re.match("(.*)\\*\\*\\*(.*)\\*\\*\\*(.*)", msg)
if bold_italic_match:
before, text, after = bold_italic_match.group(1), bold_italic_match.group(2),\
bold_italic_match.group(3)
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, not italic, underline, strike)
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
return len_before + len_mid + len_after
bold_match = re.match("(.*)\\*\\*(.*)\\*\\*(.*)", msg)
if bold_match:
before, text, after = bold_match.group(1), bold_match.group(2), bold_match.group(3)
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, italic, underline, strike)
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
return len_before + len_mid + len_after
italic_match = re.match("(.*)\\*(.*)\\*(.*)", msg)
if italic_match:
before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3)
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline, strike)
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
return len_before + len_mid + len_after
strike_match = re.match("(.*)~~(.*)~~(.*)", msg)
if strike_match:
before, text, after = strike_match.group(1), strike_match.group(2), strike_match.group(3)
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, underline, not strike)
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
return len_before + len_mid + len_after
size = len(msg)
attrs = 0
attrs |= curses.A_BOLD if bold else 0
attrs |= curses.A_ITALIC if italic else 0
attrs |= curses.A_UNDERLINE if underline else 0
if strike:
msg = "".join(c + "\u0336" for c in msg)
pad.addstr(y, x, msg, attrs)
return size
def refresh_history(self) -> None:
"""
Rewrite the history of the messages.
"""
y, x = self.squinnondation.screen.getmaxyx()
if curses.is_term_resized(curses.LINES, curses.COLS):
curses.resizeterm(y, x)
self.history_pad.resize(curses.LINES - 2, curses.COLS - 1)
self.input_pad.resize(1, curses.COLS - 1)
self.history_pad.erase()
for i, msg in enumerate(self.history[max(0, len(self.history) - curses.LINES + 2):]):
if not re.match("<.*> .*", msg):
msg = "<unknown> " + msg
match = re.match("<(.*)> (.*)", msg)
nickname = match.group(1)
msg = match.group(2)
color_id = sum(ord(c) for c in nickname) % 6 + 1
self.history_pad.addstr(i, 0, "<")
self.history_pad.addstr(i, 1, nickname, curses.A_BOLD | curses.color_pair(color_id + 1))
self.history_pad.addstr(i, 1 + len(nickname), "> ")
self.print_markdown(self.history_pad, i, 3 + len(nickname), msg)
self.history_pad.refresh(0, 0, 0, 0, curses.LINES - 2, curses.COLS)
def refresh_input(self) -> None:
"""
Redraw input line. Must not be called while the message is not sent.
"""
self.input_pad.erase()
color_id = sum(ord(c) for c in self.nickname) % 6 + 1
self.input_pad.addstr(0, 0, "<")
self.input_pad.addstr(0, 1, self.nickname, curses.A_BOLD | curses.color_pair(color_id + 1))
self.input_pad.addstr(0, 1 + len(self.nickname), "> ")
self.input_pad.addstr(0, 3 + len(self.nickname), self.input_buffer)
self.input_pad.refresh(0, 0, curses.LINES - 1, 0, curses.LINES - 1, curses.COLS - 1)
class Worm(Thread):
"""
The worm is the hazel listener.
It always waits for an incoming packet, then it treats it, and continues to wait.
It is in a dedicated thread.
"""
def __init__(self, squirrel: Squirrel, *args, **kwargs):
super().__init__(*args, **kwargs)
self.squirrel = squirrel
def run(self) -> None:
while True:
try:
pkt, hazelnut = self.squirrel.receive_packet()
except ValueError as error:
self.squirrel.history.append("<system> *An error occured while receiving a packet: {}*".format(error))
else:
self.squirrel.history.append(pkt.body[0].data.decode('UTF-8'))
self.squirrel.refresh_history()
self.squirrel.refresh_input()