# Copyright (C) 2020 by eichhornchen, ÿnérant # SPDX-License-Identifier: GPL-3.0-or-later import curses import re import socket from argparse import ArgumentParser from enum import Enum from ipaddress import IPv6Address from threading import Thread from typing import Any, List, Optional, Tuple import emoji from squinnondation.term_manager import TermManager class Squinnondation: args: Any bind_address: str bind_port: int no_emoji: bool no_markdown: bool screen: Any def parse_arguments(self) -> None: parser = ArgumentParser(description="MIRC client.") parser.add_argument('bind_address', type=str, default="localhost", help="Address of the client.") parser.add_argument('bind_port', type=int, default=2500, help="Port of the client. Must be between 1024 and 65535.") parser.add_argument('--client_address', type=str, default=None, help="Address of the first neighbour.") parser.add_argument('--client_port', type=int, default=0, help="Port of the first neighbour. Must be between 1024 and 65535.") parser.add_argument('--no-emoji', '-ne', action='store_true', help="Don't replace emojis.") parser.add_argument('--no-markdown', '-nm', action='store_true', help="Don't replace emojis.") self.args = parser.parse_args() if not (1024 <= self.args.bind_port <= 65535) or\ not (not self.args.client_port or 1024 <= self.args.client_port <= 65535): raise ValueError("Ports must be between 1024 and 65535.") self.bind_address = self.args.bind_address self.bind_port = self.args.bind_port self.no_emoji = self.args.no_emoji self.no_markdown = self.args.no_markdown @staticmethod def main() -> None: # pragma: no cover instance = Squinnondation() instance.parse_arguments() with TermManager() as term_manager: screen = term_manager.screen instance.screen = screen screen.addstr(0, 0, "Enter your nickname: ") curses.echo() nickname = screen.getstr().decode("UTF-8") curses.noecho() squirrel = Squirrel(instance, nickname) squirrel.refresh_history() squirrel.refresh_input() if instance.args.client_address and instance.args.client_port: hazelnut = Hazelnut(address=instance.args.client_address, port=instance.args.client_port) squirrel.hazelnuts[(instance.args.client_address, instance.args.client_port)] = hazelnut Worm(squirrel).start() while True: squirrel.refresh_history() squirrel.refresh_input() key = screen.getkey(curses.LINES - 1, 3 + len(squirrel.nickname) + len(squirrel.input_buffer)) if key == "\x7f": # backspace squirrel.input_buffer = squirrel.input_buffer[:-1] continue elif len(key) > 1: squirrel.history.append(f" *unmanaged key press: {key}*") continue elif key != "\n": squirrel.input_buffer += key continue msg = squirrel.input_buffer squirrel.input_buffer = "" if not msg: continue msg = f"<{squirrel.nickname}> {msg}" squirrel.history.append(msg) for hazelnut in list(squirrel.hazelnuts.values()): pkt = Packet() pkt.magic = 95 pkt.version = 0 tlv = DataTLV() tlv.data = msg.encode("UTF-8") tlv.sender_id = 42 tlv.nonce = 18 tlv.length = len(tlv.data) + 1 + 1 + 8 + 4 pkt.body = [tlv] pkt.body_length = tlv.length + 2 squirrel.send_packet(hazelnut, pkt) class TLV: """ The Tag-Length-Value contains the different type of data that can be sent. TODO: add subclasses for each type of TLV """ type: int def unmarshal(self, raw_data: bytes) -> None: """ Parse data and construct TLV. """ raise NotImplementedError def marshal(self) -> bytes: """ Translate the TLV into a byte array. """ raise NotImplementedError def validate_data(self) -> bool: """ Ensure that the TLV is well-formed. Raises a ValueError if it is not the case. TODO: Make some tests """ return True @staticmethod def tlv_classes(): return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] class Pad1TLV(TLV): """ This TLV is simply ignored. """ type: int = 0 def unmarshal(self, raw_data: bytes) -> None: """ There is nothing to do. We ignore the packet. """ self.type = raw_data[0] def marshal(self) -> bytes: """ The TLV is empty. """ return self.type.to_bytes(1, "big") class PadNTLV(TLV): """ This TLV is filled with zeros. It is ignored. """ type: int = 1 length: int mbz: bytes def validate_data(self) -> bool: if self.mbz != int(0).to_bytes(self.length, "big"): raise ValueError("The body of a PadN TLV is not filled with zeros.") return True def unmarshal(self, raw_data: bytes) -> None: """ Store the zero-array, then ignore the packet. """ self.type = raw_data[0] self.length = raw_data[1] self.mbz = raw_data[2:2 + self.length] def marshal(self) -> bytes: """ Construct the byte array filled by zeros. """ return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length] class HelloTLV(TLV): type: int = 2 length: int source_id: int dest_id: Optional[int] def validate_data(self) -> bool: if self.length != 8 and self.length != 16: raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," f"found {self.length}") return True def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.source_id = int.from_bytes(raw_data[2:10], "big") if self.length == 16: self.dest_id = int.from_bytes(raw_data[10:18], "big") def marshal(self) -> bytes: data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big") if self.dest_id: data += self.dest_id.to_bytes(8, "big") return data class NeighbourTLV(TLV): type: int = 3 length: int ip_address: IPv6Address port: int def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.ip_address = IPv6Address(raw_data[2:18]) self.port = int.from_bytes(raw_data[18:20], "big") def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.ip_address.packed + \ self.port.to_bytes(2, "big") class DataTLV(TLV): type: int = 4 length: int sender_id: int nonce: int data: bytes def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], "big") self.nonce = int.from_bytes(raw_data[10:14], "big") self.data = raw_data[14:2 + self.length] def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.sender_id.to_bytes(8, "big") + \ self.nonce.to_bytes(4, "big") + \ self.data class AckTLV(TLV): type: int = 5 length: int sender_id: int nonce: int def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.sender_id = int.from_bytes(raw_data[2:10], "big") self.nonce = int.from_bytes(raw_data[10:14], "big") def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.sender_id.to_bytes(8, "big") + \ self.nonce.to_bytes(4, "big") class GoAwayTLV(TLV): class GoAwayType(Enum): UNKNOWN = 0 EXIT = 1 TIMEOUT = 2 PROTOCOL_VIOLATION = 3 type: int = 6 length: int code: GoAwayType message: str def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.code = GoAwayTLV.GoAwayType(raw_data[2]) self.message = raw_data[3:self.length - 1].decode("UTF-8") def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.code.value.to_bytes(1, "big") + \ self.message.encode("UTF-8")[:self.length - 1] class WarningTLV(TLV): type: int = 7 length: int message: str def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] self.message = raw_data[2:self.length].decode("UTF-8") def marshal(self) -> bytes: return self.type.to_bytes(1, "big") + \ self.length.to_bytes(1, "big") + \ self.message.encode("UTF-8")[:self.length] class Packet: """ A Packet is a wrapper around the """ magic: int version: int body_length: int body: List[TLV] def validate_data(self) -> bool: """ Ensure that the packet is well-formed. Raises a ValueError if the packet contains bad data. """ if self.magic != 95: raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) if self.version != 0: raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) if not (0 <= self.body_length <= 120): raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," "found: {:d}".format(self.body_length)) return all(tlv.validate_data() for tlv in self.body) @staticmethod def unmarshal(data: bytes) -> "Packet": """ Read raw data and build the packet wrapper. Raises a ValueError whenever the data is invalid. """ pkt = Packet() pkt.magic = data[0] pkt.version = data[1] pkt.body_length = int.from_bytes(data[2:4], byteorder="big") pkt.body = [] read_bytes = 0 while read_bytes <= min(len(data) - 4, pkt.body_length): tlv_type = data[4] if not (0 <= tlv_type < len(TLV.tlv_classes())): raise ValueError(f"TLV type is not supported: {tlv_type}") tlv = TLV.tlv_classes()[tlv_type]() tlv.unmarshal(data[4:4 + pkt.body_length]) pkt.body.append(tlv) # Pad1TLV has no length read_bytes += 1 if tlv_type == 0 else tlv.length + 2 pkt.validate_data() return pkt def marshal(self) -> bytes: """ Compute the byte array data associated to the packet. """ data = self.magic.to_bytes(1, "big") data += self.version.to_bytes(1, "big") data += self.body_length.to_bytes(2, "big") data += b"".join(tlv.marshal() for tlv in self.body) return data class Hazelnut: """ A hazelnut is a connected client, with its socket. """ def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500): self.nickname = nickname try: # Resolve DNS as an IPv6 address = socket.getaddrinfo(address, None, socket.AF_INET6)[0][4][0] except socket.gaierror: # This is not a valid IPv6. Assume it can be resolved as an IPv4, and we use IPv4-mapping # to compute a valid IPv6. # See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4 address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0] self.address = IPv6Address(address) self.port = port class Squirrel(Hazelnut): """ The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts. """ def __init__(self, instance: Squinnondation, nickname: str): super().__init__(nickname, instance.bind_address, instance.bind_port) # Create UDP socket self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) # Bind the socket self.socket.bind((str(self.address), self.port)) self.squinnondation = instance self.input_buffer = "" self.history = [] self.history_pad = curses.newpad(curses.LINES - 2, curses.COLS) self.input_pad = curses.newpad(1, curses.COLS) curses.init_color(curses.COLOR_WHITE, 1000, 1000, 1000) for i in range(curses.COLOR_BLACK + 1, curses.COLOR_WHITE): curses.init_pair(i + 1, i, curses.COLOR_BLACK) self.hazelnuts = dict() self.history.append(f" *Listening on {self.address}:{self.port}*") def find_hazelnut(self, address: str, port: int) -> Hazelnut: """ Translate an address into a hazelnut, and store it in the list of the hazelnuts, ie. the neighbours. """ if (address, port) in self.hazelnuts: return self.hazelnuts[(address, port)] hazelnut = Hazelnut(address=address, port=port) self.hazelnuts[(address, port)] = hazelnut return hazelnut def send_packet(self, client: Hazelnut, pkt: Packet) -> int: """ Send a formatted packet to a client. """ return self.send_raw_data(client, pkt.marshal()) def send_raw_data(self, client: Hazelnut, data: bytes) -> int: """ Send a raw packet to a client. """ return self.socket.sendto(data, (str(client.address), client.port)) def receive_packet(self) -> Tuple[Packet, Hazelnut]: """ Receive a packet from the socket and translate it into a Python object. Warning: the process is blocking, it should be ran inside a dedicated thread. """ data, addr = self.receive_raw_data() return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1]) def receive_raw_data(self) -> Tuple[bytes, Any]: """ Receive a packet from the socket. """ return self.socket.recvfrom(1024) def print_markdown(self, pad: Any, y: int, x: int, msg: str, bold: bool = False, italic: bool = False, underline: bool = False, strike: bool = False) -> int: """ Parse a markdown-formatted text and format the text as bold, italic or text text. ***text***: bold, italic **text**: bold *text*: italic __text__: underline _text_: italic ~~text~~: strikethrough """ # Replace :emoji_name: by the good emoji if not self.squinnondation.no_emoji: msg = emoji.emojize(msg, use_aliases=True) if self.squinnondation.no_markdown: pad.addstr(y, x, msg) return len(msg) underline_match = re.match("(.*)__(.*)__(.*)", msg) if underline_match: before, text, after = underline_match.group(1), underline_match.group(2), underline_match.group(3) len_before = self.print_markdown(pad, y, x, before, bold, italic, underline) len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, not underline) len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline) return len_before + len_mid + len_after italic_match = re.match("(.*)_(.*)_(.*)", msg) if italic_match: before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3) len_before = self.print_markdown(pad, y, x, before, bold, italic, underline) len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline) len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline) return len_before + len_mid + len_after bold_italic_match = re.match("(.*)\\*\\*\\*(.*)\\*\\*\\*(.*)", msg) if bold_italic_match: before, text, after = bold_italic_match.group(1), bold_italic_match.group(2),\ bold_italic_match.group(3) len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, not italic, underline, strike) len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) return len_before + len_mid + len_after bold_match = re.match("(.*)\\*\\*(.*)\\*\\*(.*)", msg) if bold_match: before, text, after = bold_match.group(1), bold_match.group(2), bold_match.group(3) len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, italic, underline, strike) len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) return len_before + len_mid + len_after italic_match = re.match("(.*)\\*(.*)\\*(.*)", msg) if italic_match: before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3) len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline, strike) len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) return len_before + len_mid + len_after strike_match = re.match("(.*)~~(.*)~~(.*)", msg) if strike_match: before, text, after = strike_match.group(1), strike_match.group(2), strike_match.group(3) len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, underline, not strike) len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) return len_before + len_mid + len_after size = len(msg) attrs = 0 attrs |= curses.A_BOLD if bold else 0 attrs |= curses.A_ITALIC if italic else 0 attrs |= curses.A_UNDERLINE if underline else 0 if strike: msg = "".join(c + "\u0336" for c in msg) pad.addstr(y, x, msg, attrs) return size def refresh_history(self) -> None: """ Rewrite the history of the messages. """ y, x = self.squinnondation.screen.getmaxyx() if curses.is_term_resized(curses.LINES, curses.COLS): curses.resizeterm(y, x) self.history_pad.resize(curses.LINES - 2, curses.COLS - 1) self.input_pad.resize(1, curses.COLS - 1) self.history_pad.erase() for i, msg in enumerate(self.history[max(0, len(self.history) - curses.LINES + 2):]): if not re.match("<.*> .*", msg): msg = " " + msg match = re.match("<(.*)> (.*)", msg) nickname = match.group(1) msg = match.group(2) color_id = sum(ord(c) for c in nickname) % 6 + 1 self.history_pad.addstr(i, 0, "<") self.history_pad.addstr(i, 1, nickname, curses.A_BOLD | curses.color_pair(color_id + 1)) self.history_pad.addstr(i, 1 + len(nickname), "> ") self.print_markdown(self.history_pad, i, 3 + len(nickname), msg) self.history_pad.refresh(0, 0, 0, 0, curses.LINES - 2, curses.COLS) def refresh_input(self) -> None: """ Redraw input line. Must not be called while the message is not sent. """ self.input_pad.erase() color_id = sum(ord(c) for c in self.nickname) % 6 + 1 self.input_pad.addstr(0, 0, "<") self.input_pad.addstr(0, 1, self.nickname, curses.A_BOLD | curses.color_pair(color_id + 1)) self.input_pad.addstr(0, 1 + len(self.nickname), "> ") self.input_pad.addstr(0, 3 + len(self.nickname), self.input_buffer) self.input_pad.refresh(0, 0, curses.LINES - 1, 0, curses.LINES - 1, curses.COLS - 1) class Worm(Thread): """ The worm is the hazel listener. It always waits for an incoming packet, then it treats it, and continues to wait. It is in a dedicated thread. """ def __init__(self, squirrel: Squirrel, *args, **kwargs): super().__init__(*args, **kwargs) self.squirrel = squirrel def run(self) -> None: while True: try: pkt, hazelnut = self.squirrel.receive_packet() except ValueError as error: self.squirrel.history.append(" *An error occured while receiving a packet: {}*".format(error)) else: self.squirrel.history.append(pkt.body[0].data.decode('UTF-8')) self.squirrel.refresh_history() self.squirrel.refresh_input()