From 0c4ef9da5a73dee28d07bbd3e43920142b917ea7 Mon Sep 17 00:00:00 2001 From: eichhornchen Date: Wed, 16 Dec 2020 17:51:01 +0100 Subject: [PATCH] Split the file into more readable-sized files --- squinnondation/hazel.py | 428 +++++++++++++++++ squinnondation/messages.py | 364 ++++++++++++++ squinnondation/squinnondation.py | 781 +------------------------------ 3 files changed, 793 insertions(+), 780 deletions(-) create mode 100644 squinnondation/hazel.py create mode 100644 squinnondation/messages.py diff --git a/squinnondation/hazel.py b/squinnondation/hazel.py new file mode 100644 index 0000000..ea6843f --- /dev/null +++ b/squinnondation/hazel.py @@ -0,0 +1,428 @@ +# Copyright (C) 2020 by eichhornchen, ΓΏnΓ©rant +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, List, Optional, Tuple +from ipaddress import IPv6Address +from enum import Enum +from messages import Packet, TLV, HelloTLV, NeighbourTLV, Pad1TLV, PadNTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV +from threading import Thread + +class Hazelnut: + """ + A hazelnut is a connected client, with its socket. + """ + def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500): + self.nickname = nickname + + try: + # Resolve DNS as an IPv6 + address = socket.getaddrinfo(address, None, socket.AF_INET6)[0][4][0] + except socket.gaierror: + # This is not a valid IPv6. Assume it can be resolved as an IPv4, and we use IPv4-mapping + # to compute a valid IPv6. + # See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4 + address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0] + + self.address = IPv6Address(address) + self.port = port + + +class Squirrel(Hazelnut): + """ + The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts. + """ + def __init__(self, instance: Any, nickname: str): + super().__init__(nickname, instance.bind_address, instance.bind_port) + # Create UDP socket + self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Bind the socket + self.socket.bind((str(self.address), self.port)) + + self.squinnondation = instance + + self.input_buffer = "" + self.input_index = 0 + self.last_line = -1 + + self.history = [] + self.history_pad = curses.newpad(curses.LINES - 2, curses.COLS) + self.input_pad = curses.newpad(1, curses.COLS) + self.emoji_pad = curses.newpad(18, 12) + self.emoji_panel_page = -1 + + curses.init_color(curses.COLOR_WHITE, 1000, 1000, 1000) + for i in range(curses.COLOR_BLACK + 1, curses.COLOR_WHITE): + curses.init_pair(i + 1, i, curses.COLOR_BLACK) + + self.hazelnuts = dict() + self.add_system_message(f"Listening on {self.address}:{self.port}") + + def find_hazelnut(self, address: str, port: int) -> Hazelnut: + """ + Translate an address into a hazelnut, and store it in the list of the hazelnuts, ie. the neighbours. + """ + if (address, port) in self.hazelnuts: + return self.hazelnuts[(address, port)] + hazelnut = Hazelnut(address=address, port=port) + self.hazelnuts[(address, port)] = hazelnut + return hazelnut + + def send_packet(self, client: Hazelnut, pkt: Packet) -> int: + """ + Send a formatted packet to a client. + """ + return self.send_raw_data(client, pkt.marshal()) + + def send_raw_data(self, client: Hazelnut, data: bytes) -> int: + """ + Send a raw packet to a client. + """ + return self.socket.sendto(data, (str(client.address), client.port)) + + def receive_packet(self) -> Tuple[Packet, Hazelnut]: + """ + Receive a packet from the socket and translate it into a Python object. + Warning: the process is blocking, it should be ran inside a dedicated thread. + """ + data, addr = self.receive_raw_data() + return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1]) + + def receive_raw_data(self) -> Tuple[bytes, Any]: + """ + Receive a packet from the socket. + """ + return self.socket.recvfrom(1024) + + def wait_for_key(self) -> None: + """ + Infinite loop where we are waiting for a key of the user. + """ + while True: + self.refresh_history() + self.refresh_input() + if not self.squinnondation.no_emoji: + self.refresh_emoji_pad() + key = self.squinnondation.screen.getkey(curses.LINES - 1, 3 + len(self.nickname) + self.input_index) + + if key == "KEY_MOUSE": + try: + _, x, y, _, attr = curses.getmouse() + self.handle_mouse_click(y, x, attr) + continue + except curses.error: + # This is not a valid click + continue + + self.handle_key_pressed(key) + + def handle_key_pressed(self, key: str) -> None: + """ + Process the key press from the user. + """ + if key == "\x7f": # backspace + # delete character at the good position + if self.input_index: + self.input_index -= 1 + self.input_buffer = self.input_buffer[:self.input_index] + self.input_buffer[self.input_index + 1:] + return + elif key == "KEY_LEFT": + # Navigate in the message to the left + self.input_index = max(0, self.input_index - 1) + return + elif key == "KEY_RIGHT": + # Navigate in the message to the right + self.input_index = min(len(self.input_buffer), self.input_index + 1) + return + elif key == "KEY_UP": + # Scroll up in the history + self.last_line = min(max(curses.LINES - 3, self.last_line - 1), len(self.history) - 1) + return + elif key == "KEY_DOWN": + # Scroll down in the history + self.last_line = min(len(self.history) - 1, self.last_line + 1) + return + elif key == "KEY_PPAGE": + # Page up in the history + self.last_line = min(max(curses.LINES - 3, self.last_line - (curses.LINES - 3)), len(self.history) - 1) + return + elif key == "KEY_NPAGE": + # Page down in the history + self.last_line = min(len(self.history) - 1, self.last_line + (curses.LINES - 3)) + return + elif key == "KEY_HOME": + # Place the cursor at the beginning of the typing word + self.input_index = 0 + return + elif key == "KEY_END": + # Place the cursor at the end of the typing word + self.input_index = len(self.input_buffer) + return + elif len(key) > 1: + # Unmanaged complex key + return + elif key != "\n": + # Insert the pressed key in the current message + self.input_buffer = self.input_buffer[:self.input_index] + key + self.input_buffer[self.input_index:] + self.input_index += 1 + return + + # Send message to neighbours + msg = self.input_buffer + self.input_buffer = "" + self.input_index = 0 + + if not msg: + return + + msg = f"<{self.nickname}> {msg}" + self.add_message(msg) + + pkt = Packet.construct(DataTLV.construct(msg)) + for hazelnut in list(self.hazelnuts.values()): + self.send_packet(hazelnut, pkt) + + def handle_mouse_click(self, y: int, x: int, attr: int) -> None: + """ + The user clicks on the screen, at coordinates (y, x). + According to the position, we can indicate what can be done. + """ + + if not self.squinnondation.no_emoji: + if y == curses.LINES - 1 and x >= curses.COLS - 3: + # Click on the emoji, open or close the emoji pad + self.emoji_panel_page *= -1 + elif self.emoji_panel_page > 0 and y == curses.LINES - 4 and x >= curses.COLS - 5: + # Open next emoji page + self.emoji_panel_page += 1 + elif self.emoji_panel_page > 1 and y == curses.LINES - curses.LINES // 2 - 1 \ + and x >= curses.COLS - 5: + # Open previous emoji page + self.emoji_panel_page -= 1 + elif self.emoji_panel_page > 0 and y >= curses.LINES // 2 - 1 and x >= curses.COLS // 2 - 1: + pad_y, pad_x = y - (curses.LINES - curses.LINES // 2) + 1, \ + (x - (curses.COLS - curses.COLS // 3) + 1) // 2 + # Click on an emoji on the pad to autocomplete an emoji + self.click_on_emoji_pad(pad_y, pad_x) + + def click_on_emoji_pad(self, pad_y: int, pad_x: int) -> None: + """ + The emoji pad contains the list of all available emojis. + Clicking on a emoji auto-complete the emoji in the input pad. + """ + import emoji + from emoji import unicode_codes + + height, width = self.emoji_pad.getmaxyx() + height -= 1 + width -= 1 + + emojis = list(unicode_codes.UNICODE_EMOJI) + emojis = [c for c in emojis if len(c) == 1] + size = (height - 2) * (width - 4) // 2 + page = emojis[(self.emoji_panel_page - 1) * size:self.emoji_panel_page * size] + index = pad_y * (width - 4) // 2 + pad_x + char = page[index] + if char: + demojized = emoji.demojize(char) + if char != demojized: + for c in reversed(demojized): + curses.ungetch(c) + + def add_message(self, msg: str) -> None: + """ + Store a new message into the history. + """ + self.history.append(msg) + if self.last_line == len(self.history) - 2: + self.last_line += 1 + + def add_system_message(self, msg: str) -> None: + """ + Add a new system log message. + TODO: Configure logging levels to ignore some messages. + """ + return self.add_message(f" *{msg}*") + + def print_markdown(self, pad: Any, y: int, x: int, msg: str, + bold: bool = False, italic: bool = False, underline: bool = False, strike: bool = False) -> int: + """ + Parse a markdown-formatted text and format the text as bold, italic or text text. + ***text***: bold, italic + **text**: bold + *text*: italic + __text__: underline + _text_: italic + ~~text~~: strikethrough + """ + # Replace :emoji_name: by the good emoji + if not self.squinnondation.no_emoji: + import emoji + msg = emoji.emojize(msg, use_aliases=True) + + if self.squinnondation.no_markdown: + pad.addstr(y, x, msg) + return len(msg) + + underline_match = re.match("(.*)__(.*)__(.*)", msg) + if underline_match: + before, text, after = underline_match.group(1), underline_match.group(2), underline_match.group(3) + len_before = self.print_markdown(pad, y, x, before, bold, italic, underline) + len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, not underline) + len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline) + return len_before + len_mid + len_after + + italic_match = re.match("(.*)_(.*)_(.*)", msg) + if italic_match: + before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3) + len_before = self.print_markdown(pad, y, x, before, bold, italic, underline) + len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline) + len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline) + return len_before + len_mid + len_after + + bold_italic_match = re.match("(.*)\\*\\*\\*(.*)\\*\\*\\*(.*)", msg) + if bold_italic_match: + before, text, after = bold_italic_match.group(1), bold_italic_match.group(2),\ + bold_italic_match.group(3) + len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) + len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, not italic, underline, strike) + len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) + return len_before + len_mid + len_after + + bold_match = re.match("(.*)\\*\\*(.*)\\*\\*(.*)", msg) + if bold_match: + before, text, after = bold_match.group(1), bold_match.group(2), bold_match.group(3) + len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) + len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, italic, underline, strike) + len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) + return len_before + len_mid + len_after + + italic_match = re.match("(.*)\\*(.*)\\*(.*)", msg) + if italic_match: + before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3) + len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) + len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline, strike) + len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) + return len_before + len_mid + len_after + + strike_match = re.match("(.*)~~(.*)~~(.*)", msg) + if strike_match: + before, text, after = strike_match.group(1), strike_match.group(2), strike_match.group(3) + len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) + len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, underline, not strike) + len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) + return len_before + len_mid + len_after + + size = len(msg) + + attrs = 0 + attrs |= curses.A_BOLD if bold else 0 + attrs |= curses.A_ITALIC if italic else 0 + attrs |= curses.A_UNDERLINE if underline else 0 + if strike: + msg = "".join(c + "\u0336" for c in msg) + pad.addstr(y, x, msg, attrs) + return size + + def refresh_history(self) -> None: + """ + Rewrite the history of the messages. + """ + y, x = self.squinnondation.screen.getmaxyx() + if curses.is_term_resized(curses.LINES, curses.COLS): + curses.resizeterm(y, x) + self.history_pad.resize(curses.LINES - 2, curses.COLS - 1) + self.input_pad.resize(1, curses.COLS - 1) + + self.history_pad.erase() + for i, msg in enumerate(self.history[max(0, self.last_line - curses.LINES + 3):self.last_line + 1]): + if not re.match("<.*> .*", msg): + msg = " " + msg + match = re.match("<(.*)> (.*)", msg) + nickname = match.group(1) + msg = match.group(2) + color_id = sum(ord(c) for c in nickname) % 6 + 1 + self.history_pad.addstr(i, 0, "<") + self.history_pad.addstr(i, 1, nickname, curses.A_BOLD | curses.color_pair(color_id + 1)) + self.history_pad.addstr(i, 1 + len(nickname), "> ") + self.print_markdown(self.history_pad, i, 3 + len(nickname), msg) + self.history_pad.refresh(0, 0, 0, 0, curses.LINES - 2, curses.COLS) + + def refresh_input(self) -> None: + """ + Redraw input line. Must not be called while the message is not sent. + """ + self.input_pad.erase() + color_id = sum(ord(c) for c in self.nickname) % 6 + 1 + self.input_pad.addstr(0, 0, "<") + self.input_pad.addstr(0, 1, self.nickname, curses.A_BOLD | curses.color_pair(color_id + 1)) + self.input_pad.addstr(0, 1 + len(self.nickname), "> ") + msg = self.input_buffer + if len(msg) + len(self.nickname) + 3 >= curses.COLS: + msg = "" + self.input_pad.addstr(0, 3 + len(self.nickname), self.input_buffer) + if not self.squinnondation.no_emoji: + self.input_pad.addstr(0, self.input_pad.getmaxyx()[1] - 3, "πŸ˜€") + self.input_pad.refresh(0, 0, curses.LINES - 1, 0, curses.LINES - 1, curses.COLS - 1) + + def refresh_emoji_pad(self) -> None: + """ + Display the emoji pad if necessary. + """ + if self.squinnondation.no_emoji: + return + + from emoji import unicode_codes + + self.emoji_pad.erase() + + if self.emoji_panel_page > 0: + height, width = curses.LINES // 2, curses.COLS // 3 + self.emoji_pad.resize(height + 1, width + 1) + self.emoji_pad.addstr(0, 0, "┏" + (width - 2) * "━" + "β”“") + self.emoji_pad.addstr(0, (width - 14) // 2, " == EMOJIS == ") + for i in range(1, height): + self.emoji_pad.addstr(i, 0, "┃" + (width - 2) * " " + "┃") + self.emoji_pad.addstr(height - 1, 0, "β”—" + (width - 2) * "━" + "β”›") + + emojis = list(unicode_codes.UNICODE_EMOJI) + emojis = [c for c in emojis if len(c) == 1] + size = (height - 2) * (width - 4) // 2 + page = emojis[(self.emoji_panel_page - 1) * size:self.emoji_panel_page * size] + + if self.emoji_panel_page != 1: + self.emoji_pad.addstr(1, width - 2, "⬆") + if len(page) == size: + self.emoji_pad.addstr(height - 2, width - 2, "⬇") + + for i in range(height - 2): + for j in range((width - 4) // 2 + 1): + index = i * (width - 4) // 2 + j + if index < len(page): + self.emoji_pad.addstr(i + 1, 2 * j + 1, page[index]) + + self.emoji_pad.refresh(0, 0, curses.LINES - height - 2, curses.COLS - width - 2, + curses.LINES - 2, curses.COLS - 2) + + +class Worm(Thread): + """ + The worm is the hazel listener. + It always waits for an incoming packet, then it treats it, and continues to wait. + It is in a dedicated thread. + """ + def __init__(self, squirrel: Squirrel, *args, **kwargs): + super().__init__(*args, **kwargs) + self.squirrel = squirrel + + def run(self) -> None: + while True: + try: + pkt, hazelnut = self.squirrel.receive_packet() + pkt.validate_data() + except ValueError as error: + self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error)) + else: + for tlv in pkt.body: + tlv.handle(self.squirrel, hazelnut) + self.squirrel.refresh_history() + self.squirrel.refresh_input() diff --git a/squinnondation/messages.py b/squinnondation/messages.py new file mode 100644 index 0000000..9fabdc2 --- /dev/null +++ b/squinnondation/messages.py @@ -0,0 +1,364 @@ +# Copyright (C) 2020 by eichhornchen, ΓΏnΓ©rant +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, List, Optional, Tuple +from ipaddress import IPv6Address +from enum import Enum + +class TLV: + """ + The Tag-Length-Value contains the different type of data that can be sent. + TODO: add subclasses for each type of TLV + """ + type: int + length: int + + def unmarshal(self, raw_data: bytes) -> None: + """ + Parse data and construct TLV. + """ + raise NotImplementedError + + def marshal(self) -> bytes: + """ + Translate the TLV into a byte array. + """ + raise NotImplementedError + + def validate_data(self) -> bool: + """ + Ensure that the TLV is well-formed. + Raises a ValueError if it is not the case. + TODO: Make some tests + """ + return True + + def handle(self, squirrel: Any, sender: Any) -> None: + """ + Indicates what to do when this TLV is received from a given hazel. + It is ensured that the data is valid. + """ + + @property + def tlv_length(self) -> int: + """ + Returns the total length (in bytes) of the TLV, including the type and the length. + Except for Pad1, this is 2 plus the length of the body of the TLV. + """ + return 2 + self.length + + @staticmethod + def tlv_classes() -> list: + return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] + + @staticmethod + def network_order() -> str: + """ + The network byte order is always inverted as the host network byte order. + """ + return "little" if sys.byteorder == "big" else "big" + + +class Pad1TLV(TLV): + """ + This TLV is simply ignored. + """ + type: int = 0 + + def unmarshal(self, raw_data: bytes) -> None: + """ + There is nothing to do. We ignore the packet. + """ + self.type = raw_data[0] + + def marshal(self) -> bytes: + """ + The TLV is empty. + """ + return self.type.to_bytes(1, TLV.network_order()) + + def handle(self, squirrel: Any, sender: Any) -> None: + # TODO Add some easter eggs + squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.") + + @property + def tlv_length(self) -> int: + """ + A Pad1 has always a length of 1. + """ + return 1 + + +class PadNTLV(TLV): + """ + This TLV is filled with zeros. It is ignored. + """ + type: int = 1 + length: int + mbz: bytes + + def validate_data(self) -> bool: + if self.mbz != int(0).to_bytes(self.length, TLV.network_order()): + raise ValueError("The body of a PadN TLV is not filled with zeros.") + return True + + def unmarshal(self, raw_data: bytes) -> None: + """ + Store the zero-array, then ignore the packet. + """ + self.type = raw_data[0] + self.length = raw_data[1] + self.mbz = raw_data[2:self.tlv_length] + + def marshal(self) -> bytes: + """ + Construct the byte array filled by zeros. + """ + return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ + + self.mbz[:self.length] + + def handle(self, squirrel: Any, sender: Any) -> None: + # TODO Add some easter eggs + squirrel.add_system_message(f"I received {self.length} zeros, am I so a bag guy ? :cold_sweat:") + + +class HelloTLV(TLV): + type: int = 2 + length: int + source_id: int + dest_id: Optional[int] + + def validate_data(self) -> bool: + if self.length != 8 and self.length != 16: + raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," + f"found {self.length}") + return True + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order()) + if self.is_long: + self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order()) + + def marshal(self) -> bytes: + data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ + + self.source_id.to_bytes(8, TLV.network_order()) + if self.dest_id: + data += self.dest_id.to_bytes(8, TLV.network_order()) + return data + + def handle(self, squirrel: Any, sender: Any) -> None: + # TODO Implement HelloTLV + squirrel.add_system_message("Aaaawwww, someone spoke to me and said me Hello smiling_face_with_" + + (":3_hearts:" if self.is_long else "smiling_eyes:")) + + @property + def is_long(self) -> bool: + return self.length == 16 + + +class NeighbourTLV(TLV): + type: int = 3 + length: int + ip_address: IPv6Address + port: int + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.ip_address = IPv6Address(raw_data[2:18]) + self.port = int.from_bytes(raw_data[18:20], TLV.network_order()) + + def marshal(self) -> bytes: + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ + self.ip_address.packed + \ + self.port.to_bytes(2, TLV.network_order()) + + def handle(self, squirrel: Any, sender: Any) -> None: + # TODO Implement NeighbourTLV + squirrel.add_system_message("I have a friend!") + squirrel.add_system_message(f"Welcome {self.ip_address}:{self.port}!") + + +class DataTLV(TLV): + type: int = 4 + length: int + sender_id: int + nonce: int + data: bytes + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) + self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) + self.data = raw_data[14:self.tlv_length] + + def marshal(self) -> bytes: + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ + self.sender_id.to_bytes(8, TLV.network_order()) + \ + self.nonce.to_bytes(4, TLV.network_order()) + \ + self.data + + def handle(self, squirrel: Any, sender: Any) -> None: + """ + A message has been sent. We log it. + TODO: Check that the tuple (sender_id, nonce) is unique to avoid duplicates. + """ + squirrel.add_message(self.data.decode('UTF-8')) + + @staticmethod + def construct(message: str) -> "DataTLV": + tlv = DataTLV() + tlv.type = 4 + tlv.sender_id = 42 # FIXME Use the good sender id + tlv.nonce = 42 # FIXME Use an incremental nonce + tlv.data = message.encode("UTF-8") + tlv.length = 12 + len(tlv.data) + return tlv + + +class AckTLV(TLV): + type: int = 5 + length: int + sender_id: int + nonce: int + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) + self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) + + def marshal(self) -> bytes: + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ + self.sender_id.to_bytes(8, TLV.network_order()) + \ + self.nonce.to_bytes(4, TLV.network_order()) + + def handle(self, squirrel: Any, sender: Any) -> None: + # TODO Implement AckTLV + squirrel.add_system_message("I received an AckTLV. I don't know what to do with it. Please implement me!") + + +class GoAwayTLV(TLV): + class GoAwayType(Enum): + UNKNOWN = 0 + EXIT = 1 + TIMEOUT = 2 + PROTOCOL_VIOLATION = 3 + + type: int = 6 + length: int + code: GoAwayType + message: str + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.code = GoAwayTLV.GoAwayType(raw_data[2]) + self.message = raw_data[3:self.length - 1].decode("UTF-8") + + def marshal(self) -> bytes: + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ + self.code.value.to_bytes(1, TLV.network_order()) + \ + self.message.encode("UTF-8")[:self.length - 1] + + def handle(self, squirrel: Any, sender: Any) -> None: + # TODO Implement GoAwayTLV + squirrel.add_system_message("Some told me that he went away. That's not very nice :( " + "I should send him some cake.") + + +class WarningTLV(TLV): + type: int = 7 + length: int + message: str + + def unmarshal(self, raw_data: bytes) -> None: + self.type = raw_data[0] + self.length = raw_data[1] + self.message = raw_data[2:self.length].decode("UTF-8") + + def marshal(self) -> bytes: + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ + self.message.encode("UTF-8")[:self.length] + + def handle(self, squirrel: Any, sender: Any) -> None: + squirrel.add_message(f" *A client warned you: {self.message}*") + + +class Packet: + """ + A Packet is a wrapper around the + """ + magic: int + version: int + body_length: int + body: List[TLV] + + def validate_data(self) -> bool: + """ + Ensure that the packet is well-formed. + Raises a ValueError if the packet contains bad data. + """ + if self.magic != 95: + raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) + if self.version != 0: + raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) + if not (0 <= self.body_length <= 120): + raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," + "found: {:d}".format(self.body_length)) + return all(tlv.validate_data() for tlv in self.body) + + @staticmethod + def unmarshal(data: bytes) -> "Packet": + """ + Read raw data and build the packet wrapper. + Raises a ValueError whenever the data is invalid. + """ + pkt = Packet() + pkt.magic = data[0] + pkt.version = data[1] + pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order()) + pkt.body = [] + read_bytes = 0 + while read_bytes <= min(len(data) - 4, pkt.body_length): + tlv_type = data[4] + if not (0 <= tlv_type < len(TLV.tlv_classes())): + raise ValueError(f"TLV type is not supported: {tlv_type}") + tlv = TLV.tlv_classes()[tlv_type]() + tlv.unmarshal(data[4:4 + pkt.body_length]) + pkt.body.append(tlv) + read_bytes += tlv.tlv_length + + pkt.validate_data() + + return pkt + + def marshal(self) -> bytes: + """ + Compute the byte array data associated to the packet. + """ + data = self.magic.to_bytes(1, TLV.network_order()) + data += self.version.to_bytes(1, TLV.network_order()) + data += self.body_length.to_bytes(2, TLV.network_order()) + data += b"".join(tlv.marshal() for tlv in self.body) + return data + + @staticmethod + def construct(*tlvs: TLV) -> "Packet": + """ + Construct a new packet from the given TLVs and calculate the good lengths + """ + pkt = Packet() + pkt.magic = 95 + pkt.version = 0 + pkt.body = tlvs + pkt.body_length = sum(tlv.tlv_length for tlv in tlvs) + return pkt diff --git a/squinnondation/squinnondation.py b/squinnondation/squinnondation.py index 52457f7..1539fd4 100644 --- a/squinnondation/squinnondation.py +++ b/squinnondation/squinnondation.py @@ -12,6 +12,7 @@ from threading import Thread from typing import Any, List, Optional, Tuple from squinnondation.term_manager import TermManager +from hazel import Hazelnut, Squirrel class Squinnondation: @@ -83,783 +84,3 @@ class Squinnondation: Worm(squirrel).start() squirrel.wait_for_key() - - -class TLV: - """ - The Tag-Length-Value contains the different type of data that can be sent. - TODO: add subclasses for each type of TLV - """ - type: int - length: int - - def unmarshal(self, raw_data: bytes) -> None: - """ - Parse data and construct TLV. - """ - raise NotImplementedError - - def marshal(self) -> bytes: - """ - Translate the TLV into a byte array. - """ - raise NotImplementedError - - def validate_data(self) -> bool: - """ - Ensure that the TLV is well-formed. - Raises a ValueError if it is not the case. - TODO: Make some tests - """ - return True - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - """ - Indicates what to do when this TLV is received from a given hazel. - It is ensured that the data is valid. - """ - - @property - def tlv_length(self) -> int: - """ - Returns the total length (in bytes) of the TLV, including the type and the length. - Except for Pad1, this is 2 plus the length of the body of the TLV. - """ - return 2 + self.length - - @staticmethod - def tlv_classes() -> list: - return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] - - @staticmethod - def network_order() -> str: - """ - The network byte order is always inverted as the host network byte order. - """ - return "little" if sys.byteorder == "big" else "big" - - -class Pad1TLV(TLV): - """ - This TLV is simply ignored. - """ - type: int = 0 - - def unmarshal(self, raw_data: bytes) -> None: - """ - There is nothing to do. We ignore the packet. - """ - self.type = raw_data[0] - - def marshal(self) -> bytes: - """ - The TLV is empty. - """ - return self.type.to_bytes(1, TLV.network_order()) - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - # TODO Add some easter eggs - squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.") - - @property - def tlv_length(self) -> int: - """ - A Pad1 has always a length of 1. - """ - return 1 - - -class PadNTLV(TLV): - """ - This TLV is filled with zeros. It is ignored. - """ - type: int = 1 - length: int - mbz: bytes - - def validate_data(self) -> bool: - if self.mbz != int(0).to_bytes(self.length, TLV.network_order()): - raise ValueError("The body of a PadN TLV is not filled with zeros.") - return True - - def unmarshal(self, raw_data: bytes) -> None: - """ - Store the zero-array, then ignore the packet. - """ - self.type = raw_data[0] - self.length = raw_data[1] - self.mbz = raw_data[2:self.tlv_length] - - def marshal(self) -> bytes: - """ - Construct the byte array filled by zeros. - """ - return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ - + self.mbz[:self.length] - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - # TODO Add some easter eggs - squirrel.add_system_message(f"I received {self.length} zeros, am I so a bag guy ? :cold_sweat:") - - -class HelloTLV(TLV): - type: int = 2 - length: int - source_id: int - dest_id: Optional[int] - - def validate_data(self) -> bool: - if self.length != 8 and self.length != 16: - raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," - f"found {self.length}") - return True - - def unmarshal(self, raw_data: bytes) -> None: - self.type = raw_data[0] - self.length = raw_data[1] - self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order()) - if self.is_long: - self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order()) - - def marshal(self) -> bytes: - data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ - + self.source_id.to_bytes(8, TLV.network_order()) - if self.dest_id: - data += self.dest_id.to_bytes(8, TLV.network_order()) - return data - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - # TODO Implement HelloTLV - squirrel.add_system_message("Aaaawwww, someone spoke to me and said me Hello smiling_face_with_" - + (":3_hearts:" if self.is_long else "smiling_eyes:")) - - @property - def is_long(self) -> bool: - return self.length == 16 - - -class NeighbourTLV(TLV): - type: int = 3 - length: int - ip_address: IPv6Address - port: int - - def unmarshal(self, raw_data: bytes) -> None: - self.type = raw_data[0] - self.length = raw_data[1] - self.ip_address = IPv6Address(raw_data[2:18]) - self.port = int.from_bytes(raw_data[18:20], TLV.network_order()) - - def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ - self.ip_address.packed + \ - self.port.to_bytes(2, TLV.network_order()) - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - # TODO Implement NeighbourTLV - squirrel.add_system_message("I have a friend!") - squirrel.add_system_message(f"Welcome {self.ip_address}:{self.port}!") - - -class DataTLV(TLV): - type: int = 4 - length: int - sender_id: int - nonce: int - data: bytes - - def unmarshal(self, raw_data: bytes) -> None: - self.type = raw_data[0] - self.length = raw_data[1] - self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) - self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) - self.data = raw_data[14:self.tlv_length] - - def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ - self.sender_id.to_bytes(8, TLV.network_order()) + \ - self.nonce.to_bytes(4, TLV.network_order()) + \ - self.data - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - """ - A message has been sent. We log it. - TODO: Check that the tuple (sender_id, nonce) is unique to avoid duplicates. - """ - squirrel.add_message(self.data.decode('UTF-8')) - - @staticmethod - def construct(message: str) -> "DataTLV": - tlv = DataTLV() - tlv.type = 4 - tlv.sender_id = 42 # FIXME Use the good sender id - tlv.nonce = 42 # FIXME Use an incremental nonce - tlv.data = message.encode("UTF-8") - tlv.length = 12 + len(tlv.data) - return tlv - - -class AckTLV(TLV): - type: int = 5 - length: int - sender_id: int - nonce: int - - def unmarshal(self, raw_data: bytes) -> None: - self.type = raw_data[0] - self.length = raw_data[1] - self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) - self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) - - def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ - self.sender_id.to_bytes(8, TLV.network_order()) + \ - self.nonce.to_bytes(4, TLV.network_order()) - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - # TODO Implement AckTLV - squirrel.add_system_message("I received an AckTLV. I don't know what to do with it. Please implement me!") - - -class GoAwayTLV(TLV): - class GoAwayType(Enum): - UNKNOWN = 0 - EXIT = 1 - TIMEOUT = 2 - PROTOCOL_VIOLATION = 3 - - type: int = 6 - length: int - code: GoAwayType - message: str - - def unmarshal(self, raw_data: bytes) -> None: - self.type = raw_data[0] - self.length = raw_data[1] - self.code = GoAwayTLV.GoAwayType(raw_data[2]) - self.message = raw_data[3:self.length - 1].decode("UTF-8") - - def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ - self.code.value.to_bytes(1, TLV.network_order()) + \ - self.message.encode("UTF-8")[:self.length - 1] - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - # TODO Implement GoAwayTLV - squirrel.add_system_message("Some told me that he went away. That's not very nice :( " - "I should send him some cake.") - - -class WarningTLV(TLV): - type: int = 7 - length: int - message: str - - def unmarshal(self, raw_data: bytes) -> None: - self.type = raw_data[0] - self.length = raw_data[1] - self.message = raw_data[2:self.length].decode("UTF-8") - - def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ - self.message.encode("UTF-8")[:self.length] - - def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: - squirrel.add_message(f" *A client warned you: {self.message}*") - - -class Packet: - """ - A Packet is a wrapper around the - """ - magic: int - version: int - body_length: int - body: List[TLV] - - def validate_data(self) -> bool: - """ - Ensure that the packet is well-formed. - Raises a ValueError if the packet contains bad data. - """ - if self.magic != 95: - raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) - if self.version != 0: - raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) - if not (0 <= self.body_length <= 120): - raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," - "found: {:d}".format(self.body_length)) - return all(tlv.validate_data() for tlv in self.body) - - @staticmethod - def unmarshal(data: bytes) -> "Packet": - """ - Read raw data and build the packet wrapper. - Raises a ValueError whenever the data is invalid. - """ - pkt = Packet() - pkt.magic = data[0] - pkt.version = data[1] - pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order()) - pkt.body = [] - read_bytes = 0 - while read_bytes <= min(len(data) - 4, pkt.body_length): - tlv_type = data[4] - if not (0 <= tlv_type < len(TLV.tlv_classes())): - raise ValueError(f"TLV type is not supported: {tlv_type}") - tlv = TLV.tlv_classes()[tlv_type]() - tlv.unmarshal(data[4:4 + pkt.body_length]) - pkt.body.append(tlv) - read_bytes += tlv.tlv_length - - pkt.validate_data() - - return pkt - - def marshal(self) -> bytes: - """ - Compute the byte array data associated to the packet. - """ - data = self.magic.to_bytes(1, TLV.network_order()) - data += self.version.to_bytes(1, TLV.network_order()) - data += self.body_length.to_bytes(2, TLV.network_order()) - data += b"".join(tlv.marshal() for tlv in self.body) - return data - - @staticmethod - def construct(*tlvs: TLV) -> "Packet": - """ - Construct a new packet from the given TLVs and calculate the good lengths - """ - pkt = Packet() - pkt.magic = 95 - pkt.version = 0 - pkt.body = tlvs - pkt.body_length = sum(tlv.tlv_length for tlv in tlvs) - return pkt - - -class Hazelnut: - """ - A hazelnut is a connected client, with its socket. - """ - def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500): - self.nickname = nickname - - try: - # Resolve DNS as an IPv6 - address = socket.getaddrinfo(address, None, socket.AF_INET6)[0][4][0] - except socket.gaierror: - # This is not a valid IPv6. Assume it can be resolved as an IPv4, and we use IPv4-mapping - # to compute a valid IPv6. - # See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4 - address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0] - - self.address = IPv6Address(address) - self.port = port - - -class Squirrel(Hazelnut): - """ - The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts. - """ - def __init__(self, instance: Squinnondation, nickname: str): - super().__init__(nickname, instance.bind_address, instance.bind_port) - # Create UDP socket - self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - # Bind the socket - self.socket.bind((str(self.address), self.port)) - - self.squinnondation = instance - - self.input_buffer = "" - self.input_index = 0 - self.last_line = -1 - - self.history = [] - self.history_pad = curses.newpad(curses.LINES - 2, curses.COLS) - self.input_pad = curses.newpad(1, curses.COLS) - self.emoji_pad = curses.newpad(18, 12) - self.emoji_panel_page = -1 - - curses.init_color(curses.COLOR_WHITE, 1000, 1000, 1000) - for i in range(curses.COLOR_BLACK + 1, curses.COLOR_WHITE): - curses.init_pair(i + 1, i, curses.COLOR_BLACK) - - self.hazelnuts = dict() - self.add_system_message(f"Listening on {self.address}:{self.port}") - - def find_hazelnut(self, address: str, port: int) -> Hazelnut: - """ - Translate an address into a hazelnut, and store it in the list of the hazelnuts, ie. the neighbours. - """ - if (address, port) in self.hazelnuts: - return self.hazelnuts[(address, port)] - hazelnut = Hazelnut(address=address, port=port) - self.hazelnuts[(address, port)] = hazelnut - return hazelnut - - def send_packet(self, client: Hazelnut, pkt: Packet) -> int: - """ - Send a formatted packet to a client. - """ - return self.send_raw_data(client, pkt.marshal()) - - def send_raw_data(self, client: Hazelnut, data: bytes) -> int: - """ - Send a raw packet to a client. - """ - return self.socket.sendto(data, (str(client.address), client.port)) - - def receive_packet(self) -> Tuple[Packet, Hazelnut]: - """ - Receive a packet from the socket and translate it into a Python object. - Warning: the process is blocking, it should be ran inside a dedicated thread. - """ - data, addr = self.receive_raw_data() - return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1]) - - def receive_raw_data(self) -> Tuple[bytes, Any]: - """ - Receive a packet from the socket. - """ - return self.socket.recvfrom(1024) - - def wait_for_key(self) -> None: - """ - Infinite loop where we are waiting for a key of the user. - """ - while True: - self.refresh_history() - self.refresh_input() - if not self.squinnondation.no_emoji: - self.refresh_emoji_pad() - key = self.squinnondation.screen.getkey(curses.LINES - 1, 3 + len(self.nickname) + self.input_index) - - if key == "KEY_MOUSE": - try: - _, x, y, _, attr = curses.getmouse() - self.handle_mouse_click(y, x, attr) - continue - except curses.error: - # This is not a valid click - continue - - self.handle_key_pressed(key) - - def handle_key_pressed(self, key: str) -> None: - """ - Process the key press from the user. - """ - if key == "\x7f": # backspace - # delete character at the good position - if self.input_index: - self.input_index -= 1 - self.input_buffer = self.input_buffer[:self.input_index] + self.input_buffer[self.input_index + 1:] - return - elif key == "KEY_LEFT": - # Navigate in the message to the left - self.input_index = max(0, self.input_index - 1) - return - elif key == "KEY_RIGHT": - # Navigate in the message to the right - self.input_index = min(len(self.input_buffer), self.input_index + 1) - return - elif key == "KEY_UP": - # Scroll up in the history - self.last_line = min(max(curses.LINES - 3, self.last_line - 1), len(self.history) - 1) - return - elif key == "KEY_DOWN": - # Scroll down in the history - self.last_line = min(len(self.history) - 1, self.last_line + 1) - return - elif key == "KEY_PPAGE": - # Page up in the history - self.last_line = min(max(curses.LINES - 3, self.last_line - (curses.LINES - 3)), len(self.history) - 1) - return - elif key == "KEY_NPAGE": - # Page down in the history - self.last_line = min(len(self.history) - 1, self.last_line + (curses.LINES - 3)) - return - elif key == "KEY_HOME": - # Place the cursor at the beginning of the typing word - self.input_index = 0 - return - elif key == "KEY_END": - # Place the cursor at the end of the typing word - self.input_index = len(self.input_buffer) - return - elif len(key) > 1: - # Unmanaged complex key - return - elif key != "\n": - # Insert the pressed key in the current message - self.input_buffer = self.input_buffer[:self.input_index] + key + self.input_buffer[self.input_index:] - self.input_index += 1 - return - - # Send message to neighbours - msg = self.input_buffer - self.input_buffer = "" - self.input_index = 0 - - if not msg: - return - - msg = f"<{self.nickname}> {msg}" - self.add_message(msg) - - pkt = Packet.construct(DataTLV.construct(msg)) - for hazelnut in list(self.hazelnuts.values()): - self.send_packet(hazelnut, pkt) - - def handle_mouse_click(self, y: int, x: int, attr: int) -> None: - """ - The user clicks on the screen, at coordinates (y, x). - According to the position, we can indicate what can be done. - """ - - if not self.squinnondation.no_emoji: - if y == curses.LINES - 1 and x >= curses.COLS - 3: - # Click on the emoji, open or close the emoji pad - self.emoji_panel_page *= -1 - elif self.emoji_panel_page > 0 and y == curses.LINES - 4 and x >= curses.COLS - 5: - # Open next emoji page - self.emoji_panel_page += 1 - elif self.emoji_panel_page > 1 and y == curses.LINES - curses.LINES // 2 - 1 \ - and x >= curses.COLS - 5: - # Open previous emoji page - self.emoji_panel_page -= 1 - elif self.emoji_panel_page > 0 and y >= curses.LINES // 2 - 1 and x >= curses.COLS // 2 - 1: - pad_y, pad_x = y - (curses.LINES - curses.LINES // 2) + 1, \ - (x - (curses.COLS - curses.COLS // 3) + 1) // 2 - # Click on an emoji on the pad to autocomplete an emoji - self.click_on_emoji_pad(pad_y, pad_x) - - def click_on_emoji_pad(self, pad_y: int, pad_x: int) -> None: - """ - The emoji pad contains the list of all available emojis. - Clicking on a emoji auto-complete the emoji in the input pad. - """ - import emoji - from emoji import unicode_codes - - height, width = self.emoji_pad.getmaxyx() - height -= 1 - width -= 1 - - emojis = list(unicode_codes.UNICODE_EMOJI) - emojis = [c for c in emojis if len(c) == 1] - size = (height - 2) * (width - 4) // 2 - page = emojis[(self.emoji_panel_page - 1) * size:self.emoji_panel_page * size] - index = pad_y * (width - 4) // 2 + pad_x - char = page[index] - if char: - demojized = emoji.demojize(char) - if char != demojized: - for c in reversed(demojized): - curses.ungetch(c) - - def add_message(self, msg: str) -> None: - """ - Store a new message into the history. - """ - self.history.append(msg) - if self.last_line == len(self.history) - 2: - self.last_line += 1 - - def add_system_message(self, msg: str) -> None: - """ - Add a new system log message. - TODO: Configure logging levels to ignore some messages. - """ - return self.add_message(f" *{msg}*") - - def print_markdown(self, pad: Any, y: int, x: int, msg: str, - bold: bool = False, italic: bool = False, underline: bool = False, strike: bool = False) -> int: - """ - Parse a markdown-formatted text and format the text as bold, italic or text text. - ***text***: bold, italic - **text**: bold - *text*: italic - __text__: underline - _text_: italic - ~~text~~: strikethrough - """ - # Replace :emoji_name: by the good emoji - if not self.squinnondation.no_emoji: - import emoji - msg = emoji.emojize(msg, use_aliases=True) - - if self.squinnondation.no_markdown: - pad.addstr(y, x, msg) - return len(msg) - - underline_match = re.match("(.*)__(.*)__(.*)", msg) - if underline_match: - before, text, after = underline_match.group(1), underline_match.group(2), underline_match.group(3) - len_before = self.print_markdown(pad, y, x, before, bold, italic, underline) - len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, not underline) - len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline) - return len_before + len_mid + len_after - - italic_match = re.match("(.*)_(.*)_(.*)", msg) - if italic_match: - before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3) - len_before = self.print_markdown(pad, y, x, before, bold, italic, underline) - len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline) - len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline) - return len_before + len_mid + len_after - - bold_italic_match = re.match("(.*)\\*\\*\\*(.*)\\*\\*\\*(.*)", msg) - if bold_italic_match: - before, text, after = bold_italic_match.group(1), bold_italic_match.group(2),\ - bold_italic_match.group(3) - len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) - len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, not italic, underline, strike) - len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) - return len_before + len_mid + len_after - - bold_match = re.match("(.*)\\*\\*(.*)\\*\\*(.*)", msg) - if bold_match: - before, text, after = bold_match.group(1), bold_match.group(2), bold_match.group(3) - len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) - len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, italic, underline, strike) - len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) - return len_before + len_mid + len_after - - italic_match = re.match("(.*)\\*(.*)\\*(.*)", msg) - if italic_match: - before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3) - len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) - len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline, strike) - len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) - return len_before + len_mid + len_after - - strike_match = re.match("(.*)~~(.*)~~(.*)", msg) - if strike_match: - before, text, after = strike_match.group(1), strike_match.group(2), strike_match.group(3) - len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike) - len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, underline, not strike) - len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike) - return len_before + len_mid + len_after - - size = len(msg) - - attrs = 0 - attrs |= curses.A_BOLD if bold else 0 - attrs |= curses.A_ITALIC if italic else 0 - attrs |= curses.A_UNDERLINE if underline else 0 - if strike: - msg = "".join(c + "\u0336" for c in msg) - pad.addstr(y, x, msg, attrs) - return size - - def refresh_history(self) -> None: - """ - Rewrite the history of the messages. - """ - y, x = self.squinnondation.screen.getmaxyx() - if curses.is_term_resized(curses.LINES, curses.COLS): - curses.resizeterm(y, x) - self.history_pad.resize(curses.LINES - 2, curses.COLS - 1) - self.input_pad.resize(1, curses.COLS - 1) - - self.history_pad.erase() - for i, msg in enumerate(self.history[max(0, self.last_line - curses.LINES + 3):self.last_line + 1]): - if not re.match("<.*> .*", msg): - msg = " " + msg - match = re.match("<(.*)> (.*)", msg) - nickname = match.group(1) - msg = match.group(2) - color_id = sum(ord(c) for c in nickname) % 6 + 1 - self.history_pad.addstr(i, 0, "<") - self.history_pad.addstr(i, 1, nickname, curses.A_BOLD | curses.color_pair(color_id + 1)) - self.history_pad.addstr(i, 1 + len(nickname), "> ") - self.print_markdown(self.history_pad, i, 3 + len(nickname), msg) - self.history_pad.refresh(0, 0, 0, 0, curses.LINES - 2, curses.COLS) - - def refresh_input(self) -> None: - """ - Redraw input line. Must not be called while the message is not sent. - """ - self.input_pad.erase() - color_id = sum(ord(c) for c in self.nickname) % 6 + 1 - self.input_pad.addstr(0, 0, "<") - self.input_pad.addstr(0, 1, self.nickname, curses.A_BOLD | curses.color_pair(color_id + 1)) - self.input_pad.addstr(0, 1 + len(self.nickname), "> ") - msg = self.input_buffer - if len(msg) + len(self.nickname) + 3 >= curses.COLS: - msg = "" - self.input_pad.addstr(0, 3 + len(self.nickname), self.input_buffer) - if not self.squinnondation.no_emoji: - self.input_pad.addstr(0, self.input_pad.getmaxyx()[1] - 3, "πŸ˜€") - self.input_pad.refresh(0, 0, curses.LINES - 1, 0, curses.LINES - 1, curses.COLS - 1) - - def refresh_emoji_pad(self) -> None: - """ - Display the emoji pad if necessary. - """ - if self.squinnondation.no_emoji: - return - - from emoji import unicode_codes - - self.emoji_pad.erase() - - if self.emoji_panel_page > 0: - height, width = curses.LINES // 2, curses.COLS // 3 - self.emoji_pad.resize(height + 1, width + 1) - self.emoji_pad.addstr(0, 0, "┏" + (width - 2) * "━" + "β”“") - self.emoji_pad.addstr(0, (width - 14) // 2, " == EMOJIS == ") - for i in range(1, height): - self.emoji_pad.addstr(i, 0, "┃" + (width - 2) * " " + "┃") - self.emoji_pad.addstr(height - 1, 0, "β”—" + (width - 2) * "━" + "β”›") - - emojis = list(unicode_codes.UNICODE_EMOJI) - emojis = [c for c in emojis if len(c) == 1] - size = (height - 2) * (width - 4) // 2 - page = emojis[(self.emoji_panel_page - 1) * size:self.emoji_panel_page * size] - - if self.emoji_panel_page != 1: - self.emoji_pad.addstr(1, width - 2, "⬆") - if len(page) == size: - self.emoji_pad.addstr(height - 2, width - 2, "⬇") - - for i in range(height - 2): - for j in range((width - 4) // 2 + 1): - index = i * (width - 4) // 2 + j - if index < len(page): - self.emoji_pad.addstr(i + 1, 2 * j + 1, page[index]) - - self.emoji_pad.refresh(0, 0, curses.LINES - height - 2, curses.COLS - width - 2, - curses.LINES - 2, curses.COLS - 2) - - -class Worm(Thread): - """ - The worm is the hazel listener. - It always waits for an incoming packet, then it treats it, and continues to wait. - It is in a dedicated thread. - """ - def __init__(self, squirrel: Squirrel, *args, **kwargs): - super().__init__(*args, **kwargs) - self.squirrel = squirrel - - def run(self) -> None: - while True: - try: - pkt, hazelnut = self.squirrel.receive_packet() - pkt.validate_data() - except ValueError as error: - self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error)) - else: - for tlv in pkt.body: - tlv.handle(self.squirrel, hazelnut) - self.squirrel.refresh_history() - self.squirrel.refresh_input()