Split the file into more readable-sized files
This commit is contained in:
parent
63407461fe
commit
0c4ef9da5a
|
@ -0,0 +1,428 @@
|
||||||
|
# Copyright (C) 2020 by eichhornchen, ÿnérant
|
||||||
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
from ipaddress import IPv6Address
|
||||||
|
from enum import Enum
|
||||||
|
from messages import Packet, TLV, HelloTLV, NeighbourTLV, Pad1TLV, PadNTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
class Hazelnut:
|
||||||
|
"""
|
||||||
|
A hazelnut is a connected client, with its socket.
|
||||||
|
"""
|
||||||
|
def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500):
|
||||||
|
self.nickname = nickname
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Resolve DNS as an IPv6
|
||||||
|
address = socket.getaddrinfo(address, None, socket.AF_INET6)[0][4][0]
|
||||||
|
except socket.gaierror:
|
||||||
|
# This is not a valid IPv6. Assume it can be resolved as an IPv4, and we use IPv4-mapping
|
||||||
|
# to compute a valid IPv6.
|
||||||
|
# See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4
|
||||||
|
address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0]
|
||||||
|
|
||||||
|
self.address = IPv6Address(address)
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
|
||||||
|
class Squirrel(Hazelnut):
|
||||||
|
"""
|
||||||
|
The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts.
|
||||||
|
"""
|
||||||
|
def __init__(self, instance: Any, nickname: str):
|
||||||
|
super().__init__(nickname, instance.bind_address, instance.bind_port)
|
||||||
|
# Create UDP socket
|
||||||
|
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||||
|
# Bind the socket
|
||||||
|
self.socket.bind((str(self.address), self.port))
|
||||||
|
|
||||||
|
self.squinnondation = instance
|
||||||
|
|
||||||
|
self.input_buffer = ""
|
||||||
|
self.input_index = 0
|
||||||
|
self.last_line = -1
|
||||||
|
|
||||||
|
self.history = []
|
||||||
|
self.history_pad = curses.newpad(curses.LINES - 2, curses.COLS)
|
||||||
|
self.input_pad = curses.newpad(1, curses.COLS)
|
||||||
|
self.emoji_pad = curses.newpad(18, 12)
|
||||||
|
self.emoji_panel_page = -1
|
||||||
|
|
||||||
|
curses.init_color(curses.COLOR_WHITE, 1000, 1000, 1000)
|
||||||
|
for i in range(curses.COLOR_BLACK + 1, curses.COLOR_WHITE):
|
||||||
|
curses.init_pair(i + 1, i, curses.COLOR_BLACK)
|
||||||
|
|
||||||
|
self.hazelnuts = dict()
|
||||||
|
self.add_system_message(f"Listening on {self.address}:{self.port}")
|
||||||
|
|
||||||
|
def find_hazelnut(self, address: str, port: int) -> Hazelnut:
|
||||||
|
"""
|
||||||
|
Translate an address into a hazelnut, and store it in the list of the hazelnuts, ie. the neighbours.
|
||||||
|
"""
|
||||||
|
if (address, port) in self.hazelnuts:
|
||||||
|
return self.hazelnuts[(address, port)]
|
||||||
|
hazelnut = Hazelnut(address=address, port=port)
|
||||||
|
self.hazelnuts[(address, port)] = hazelnut
|
||||||
|
return hazelnut
|
||||||
|
|
||||||
|
def send_packet(self, client: Hazelnut, pkt: Packet) -> int:
|
||||||
|
"""
|
||||||
|
Send a formatted packet to a client.
|
||||||
|
"""
|
||||||
|
return self.send_raw_data(client, pkt.marshal())
|
||||||
|
|
||||||
|
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
|
||||||
|
"""
|
||||||
|
Send a raw packet to a client.
|
||||||
|
"""
|
||||||
|
return self.socket.sendto(data, (str(client.address), client.port))
|
||||||
|
|
||||||
|
def receive_packet(self) -> Tuple[Packet, Hazelnut]:
|
||||||
|
"""
|
||||||
|
Receive a packet from the socket and translate it into a Python object.
|
||||||
|
Warning: the process is blocking, it should be ran inside a dedicated thread.
|
||||||
|
"""
|
||||||
|
data, addr = self.receive_raw_data()
|
||||||
|
return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1])
|
||||||
|
|
||||||
|
def receive_raw_data(self) -> Tuple[bytes, Any]:
|
||||||
|
"""
|
||||||
|
Receive a packet from the socket.
|
||||||
|
"""
|
||||||
|
return self.socket.recvfrom(1024)
|
||||||
|
|
||||||
|
def wait_for_key(self) -> None:
|
||||||
|
"""
|
||||||
|
Infinite loop where we are waiting for a key of the user.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
self.refresh_history()
|
||||||
|
self.refresh_input()
|
||||||
|
if not self.squinnondation.no_emoji:
|
||||||
|
self.refresh_emoji_pad()
|
||||||
|
key = self.squinnondation.screen.getkey(curses.LINES - 1, 3 + len(self.nickname) + self.input_index)
|
||||||
|
|
||||||
|
if key == "KEY_MOUSE":
|
||||||
|
try:
|
||||||
|
_, x, y, _, attr = curses.getmouse()
|
||||||
|
self.handle_mouse_click(y, x, attr)
|
||||||
|
continue
|
||||||
|
except curses.error:
|
||||||
|
# This is not a valid click
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.handle_key_pressed(key)
|
||||||
|
|
||||||
|
def handle_key_pressed(self, key: str) -> None:
|
||||||
|
"""
|
||||||
|
Process the key press from the user.
|
||||||
|
"""
|
||||||
|
if key == "\x7f": # backspace
|
||||||
|
# delete character at the good position
|
||||||
|
if self.input_index:
|
||||||
|
self.input_index -= 1
|
||||||
|
self.input_buffer = self.input_buffer[:self.input_index] + self.input_buffer[self.input_index + 1:]
|
||||||
|
return
|
||||||
|
elif key == "KEY_LEFT":
|
||||||
|
# Navigate in the message to the left
|
||||||
|
self.input_index = max(0, self.input_index - 1)
|
||||||
|
return
|
||||||
|
elif key == "KEY_RIGHT":
|
||||||
|
# Navigate in the message to the right
|
||||||
|
self.input_index = min(len(self.input_buffer), self.input_index + 1)
|
||||||
|
return
|
||||||
|
elif key == "KEY_UP":
|
||||||
|
# Scroll up in the history
|
||||||
|
self.last_line = min(max(curses.LINES - 3, self.last_line - 1), len(self.history) - 1)
|
||||||
|
return
|
||||||
|
elif key == "KEY_DOWN":
|
||||||
|
# Scroll down in the history
|
||||||
|
self.last_line = min(len(self.history) - 1, self.last_line + 1)
|
||||||
|
return
|
||||||
|
elif key == "KEY_PPAGE":
|
||||||
|
# Page up in the history
|
||||||
|
self.last_line = min(max(curses.LINES - 3, self.last_line - (curses.LINES - 3)), len(self.history) - 1)
|
||||||
|
return
|
||||||
|
elif key == "KEY_NPAGE":
|
||||||
|
# Page down in the history
|
||||||
|
self.last_line = min(len(self.history) - 1, self.last_line + (curses.LINES - 3))
|
||||||
|
return
|
||||||
|
elif key == "KEY_HOME":
|
||||||
|
# Place the cursor at the beginning of the typing word
|
||||||
|
self.input_index = 0
|
||||||
|
return
|
||||||
|
elif key == "KEY_END":
|
||||||
|
# Place the cursor at the end of the typing word
|
||||||
|
self.input_index = len(self.input_buffer)
|
||||||
|
return
|
||||||
|
elif len(key) > 1:
|
||||||
|
# Unmanaged complex key
|
||||||
|
return
|
||||||
|
elif key != "\n":
|
||||||
|
# Insert the pressed key in the current message
|
||||||
|
self.input_buffer = self.input_buffer[:self.input_index] + key + self.input_buffer[self.input_index:]
|
||||||
|
self.input_index += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
# Send message to neighbours
|
||||||
|
msg = self.input_buffer
|
||||||
|
self.input_buffer = ""
|
||||||
|
self.input_index = 0
|
||||||
|
|
||||||
|
if not msg:
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = f"<{self.nickname}> {msg}"
|
||||||
|
self.add_message(msg)
|
||||||
|
|
||||||
|
pkt = Packet.construct(DataTLV.construct(msg))
|
||||||
|
for hazelnut in list(self.hazelnuts.values()):
|
||||||
|
self.send_packet(hazelnut, pkt)
|
||||||
|
|
||||||
|
def handle_mouse_click(self, y: int, x: int, attr: int) -> None:
|
||||||
|
"""
|
||||||
|
The user clicks on the screen, at coordinates (y, x).
|
||||||
|
According to the position, we can indicate what can be done.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.squinnondation.no_emoji:
|
||||||
|
if y == curses.LINES - 1 and x >= curses.COLS - 3:
|
||||||
|
# Click on the emoji, open or close the emoji pad
|
||||||
|
self.emoji_panel_page *= -1
|
||||||
|
elif self.emoji_panel_page > 0 and y == curses.LINES - 4 and x >= curses.COLS - 5:
|
||||||
|
# Open next emoji page
|
||||||
|
self.emoji_panel_page += 1
|
||||||
|
elif self.emoji_panel_page > 1 and y == curses.LINES - curses.LINES // 2 - 1 \
|
||||||
|
and x >= curses.COLS - 5:
|
||||||
|
# Open previous emoji page
|
||||||
|
self.emoji_panel_page -= 1
|
||||||
|
elif self.emoji_panel_page > 0 and y >= curses.LINES // 2 - 1 and x >= curses.COLS // 2 - 1:
|
||||||
|
pad_y, pad_x = y - (curses.LINES - curses.LINES // 2) + 1, \
|
||||||
|
(x - (curses.COLS - curses.COLS // 3) + 1) // 2
|
||||||
|
# Click on an emoji on the pad to autocomplete an emoji
|
||||||
|
self.click_on_emoji_pad(pad_y, pad_x)
|
||||||
|
|
||||||
|
def click_on_emoji_pad(self, pad_y: int, pad_x: int) -> None:
|
||||||
|
"""
|
||||||
|
The emoji pad contains the list of all available emojis.
|
||||||
|
Clicking on a emoji auto-complete the emoji in the input pad.
|
||||||
|
"""
|
||||||
|
import emoji
|
||||||
|
from emoji import unicode_codes
|
||||||
|
|
||||||
|
height, width = self.emoji_pad.getmaxyx()
|
||||||
|
height -= 1
|
||||||
|
width -= 1
|
||||||
|
|
||||||
|
emojis = list(unicode_codes.UNICODE_EMOJI)
|
||||||
|
emojis = [c for c in emojis if len(c) == 1]
|
||||||
|
size = (height - 2) * (width - 4) // 2
|
||||||
|
page = emojis[(self.emoji_panel_page - 1) * size:self.emoji_panel_page * size]
|
||||||
|
index = pad_y * (width - 4) // 2 + pad_x
|
||||||
|
char = page[index]
|
||||||
|
if char:
|
||||||
|
demojized = emoji.demojize(char)
|
||||||
|
if char != demojized:
|
||||||
|
for c in reversed(demojized):
|
||||||
|
curses.ungetch(c)
|
||||||
|
|
||||||
|
def add_message(self, msg: str) -> None:
|
||||||
|
"""
|
||||||
|
Store a new message into the history.
|
||||||
|
"""
|
||||||
|
self.history.append(msg)
|
||||||
|
if self.last_line == len(self.history) - 2:
|
||||||
|
self.last_line += 1
|
||||||
|
|
||||||
|
def add_system_message(self, msg: str) -> None:
|
||||||
|
"""
|
||||||
|
Add a new system log message.
|
||||||
|
TODO: Configure logging levels to ignore some messages.
|
||||||
|
"""
|
||||||
|
return self.add_message(f"<system> *{msg}*")
|
||||||
|
|
||||||
|
def print_markdown(self, pad: Any, y: int, x: int, msg: str,
|
||||||
|
bold: bool = False, italic: bool = False, underline: bool = False, strike: bool = False) -> int:
|
||||||
|
"""
|
||||||
|
Parse a markdown-formatted text and format the text as bold, italic or text text.
|
||||||
|
***text***: bold, italic
|
||||||
|
**text**: bold
|
||||||
|
*text*: italic
|
||||||
|
__text__: underline
|
||||||
|
_text_: italic
|
||||||
|
~~text~~: strikethrough
|
||||||
|
"""
|
||||||
|
# Replace :emoji_name: by the good emoji
|
||||||
|
if not self.squinnondation.no_emoji:
|
||||||
|
import emoji
|
||||||
|
msg = emoji.emojize(msg, use_aliases=True)
|
||||||
|
|
||||||
|
if self.squinnondation.no_markdown:
|
||||||
|
pad.addstr(y, x, msg)
|
||||||
|
return len(msg)
|
||||||
|
|
||||||
|
underline_match = re.match("(.*)__(.*)__(.*)", msg)
|
||||||
|
if underline_match:
|
||||||
|
before, text, after = underline_match.group(1), underline_match.group(2), underline_match.group(3)
|
||||||
|
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline)
|
||||||
|
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, not underline)
|
||||||
|
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline)
|
||||||
|
return len_before + len_mid + len_after
|
||||||
|
|
||||||
|
italic_match = re.match("(.*)_(.*)_(.*)", msg)
|
||||||
|
if italic_match:
|
||||||
|
before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3)
|
||||||
|
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline)
|
||||||
|
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline)
|
||||||
|
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline)
|
||||||
|
return len_before + len_mid + len_after
|
||||||
|
|
||||||
|
bold_italic_match = re.match("(.*)\\*\\*\\*(.*)\\*\\*\\*(.*)", msg)
|
||||||
|
if bold_italic_match:
|
||||||
|
before, text, after = bold_italic_match.group(1), bold_italic_match.group(2),\
|
||||||
|
bold_italic_match.group(3)
|
||||||
|
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
|
||||||
|
len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, not italic, underline, strike)
|
||||||
|
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
|
||||||
|
return len_before + len_mid + len_after
|
||||||
|
|
||||||
|
bold_match = re.match("(.*)\\*\\*(.*)\\*\\*(.*)", msg)
|
||||||
|
if bold_match:
|
||||||
|
before, text, after = bold_match.group(1), bold_match.group(2), bold_match.group(3)
|
||||||
|
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
|
||||||
|
len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, italic, underline, strike)
|
||||||
|
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
|
||||||
|
return len_before + len_mid + len_after
|
||||||
|
|
||||||
|
italic_match = re.match("(.*)\\*(.*)\\*(.*)", msg)
|
||||||
|
if italic_match:
|
||||||
|
before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3)
|
||||||
|
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
|
||||||
|
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline, strike)
|
||||||
|
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
|
||||||
|
return len_before + len_mid + len_after
|
||||||
|
|
||||||
|
strike_match = re.match("(.*)~~(.*)~~(.*)", msg)
|
||||||
|
if strike_match:
|
||||||
|
before, text, after = strike_match.group(1), strike_match.group(2), strike_match.group(3)
|
||||||
|
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
|
||||||
|
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, underline, not strike)
|
||||||
|
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
|
||||||
|
return len_before + len_mid + len_after
|
||||||
|
|
||||||
|
size = len(msg)
|
||||||
|
|
||||||
|
attrs = 0
|
||||||
|
attrs |= curses.A_BOLD if bold else 0
|
||||||
|
attrs |= curses.A_ITALIC if italic else 0
|
||||||
|
attrs |= curses.A_UNDERLINE if underline else 0
|
||||||
|
if strike:
|
||||||
|
msg = "".join(c + "\u0336" for c in msg)
|
||||||
|
pad.addstr(y, x, msg, attrs)
|
||||||
|
return size
|
||||||
|
|
||||||
|
def refresh_history(self) -> None:
|
||||||
|
"""
|
||||||
|
Rewrite the history of the messages.
|
||||||
|
"""
|
||||||
|
y, x = self.squinnondation.screen.getmaxyx()
|
||||||
|
if curses.is_term_resized(curses.LINES, curses.COLS):
|
||||||
|
curses.resizeterm(y, x)
|
||||||
|
self.history_pad.resize(curses.LINES - 2, curses.COLS - 1)
|
||||||
|
self.input_pad.resize(1, curses.COLS - 1)
|
||||||
|
|
||||||
|
self.history_pad.erase()
|
||||||
|
for i, msg in enumerate(self.history[max(0, self.last_line - curses.LINES + 3):self.last_line + 1]):
|
||||||
|
if not re.match("<.*> .*", msg):
|
||||||
|
msg = "<unknown> " + msg
|
||||||
|
match = re.match("<(.*)> (.*)", msg)
|
||||||
|
nickname = match.group(1)
|
||||||
|
msg = match.group(2)
|
||||||
|
color_id = sum(ord(c) for c in nickname) % 6 + 1
|
||||||
|
self.history_pad.addstr(i, 0, "<")
|
||||||
|
self.history_pad.addstr(i, 1, nickname, curses.A_BOLD | curses.color_pair(color_id + 1))
|
||||||
|
self.history_pad.addstr(i, 1 + len(nickname), "> ")
|
||||||
|
self.print_markdown(self.history_pad, i, 3 + len(nickname), msg)
|
||||||
|
self.history_pad.refresh(0, 0, 0, 0, curses.LINES - 2, curses.COLS)
|
||||||
|
|
||||||
|
def refresh_input(self) -> None:
|
||||||
|
"""
|
||||||
|
Redraw input line. Must not be called while the message is not sent.
|
||||||
|
"""
|
||||||
|
self.input_pad.erase()
|
||||||
|
color_id = sum(ord(c) for c in self.nickname) % 6 + 1
|
||||||
|
self.input_pad.addstr(0, 0, "<")
|
||||||
|
self.input_pad.addstr(0, 1, self.nickname, curses.A_BOLD | curses.color_pair(color_id + 1))
|
||||||
|
self.input_pad.addstr(0, 1 + len(self.nickname), "> ")
|
||||||
|
msg = self.input_buffer
|
||||||
|
if len(msg) + len(self.nickname) + 3 >= curses.COLS:
|
||||||
|
msg = ""
|
||||||
|
self.input_pad.addstr(0, 3 + len(self.nickname), self.input_buffer)
|
||||||
|
if not self.squinnondation.no_emoji:
|
||||||
|
self.input_pad.addstr(0, self.input_pad.getmaxyx()[1] - 3, "😀")
|
||||||
|
self.input_pad.refresh(0, 0, curses.LINES - 1, 0, curses.LINES - 1, curses.COLS - 1)
|
||||||
|
|
||||||
|
def refresh_emoji_pad(self) -> None:
|
||||||
|
"""
|
||||||
|
Display the emoji pad if necessary.
|
||||||
|
"""
|
||||||
|
if self.squinnondation.no_emoji:
|
||||||
|
return
|
||||||
|
|
||||||
|
from emoji import unicode_codes
|
||||||
|
|
||||||
|
self.emoji_pad.erase()
|
||||||
|
|
||||||
|
if self.emoji_panel_page > 0:
|
||||||
|
height, width = curses.LINES // 2, curses.COLS // 3
|
||||||
|
self.emoji_pad.resize(height + 1, width + 1)
|
||||||
|
self.emoji_pad.addstr(0, 0, "┏" + (width - 2) * "━" + "┓")
|
||||||
|
self.emoji_pad.addstr(0, (width - 14) // 2, " == EMOJIS == ")
|
||||||
|
for i in range(1, height):
|
||||||
|
self.emoji_pad.addstr(i, 0, "┃" + (width - 2) * " " + "┃")
|
||||||
|
self.emoji_pad.addstr(height - 1, 0, "┗" + (width - 2) * "━" + "┛")
|
||||||
|
|
||||||
|
emojis = list(unicode_codes.UNICODE_EMOJI)
|
||||||
|
emojis = [c for c in emojis if len(c) == 1]
|
||||||
|
size = (height - 2) * (width - 4) // 2
|
||||||
|
page = emojis[(self.emoji_panel_page - 1) * size:self.emoji_panel_page * size]
|
||||||
|
|
||||||
|
if self.emoji_panel_page != 1:
|
||||||
|
self.emoji_pad.addstr(1, width - 2, "⬆")
|
||||||
|
if len(page) == size:
|
||||||
|
self.emoji_pad.addstr(height - 2, width - 2, "⬇")
|
||||||
|
|
||||||
|
for i in range(height - 2):
|
||||||
|
for j in range((width - 4) // 2 + 1):
|
||||||
|
index = i * (width - 4) // 2 + j
|
||||||
|
if index < len(page):
|
||||||
|
self.emoji_pad.addstr(i + 1, 2 * j + 1, page[index])
|
||||||
|
|
||||||
|
self.emoji_pad.refresh(0, 0, curses.LINES - height - 2, curses.COLS - width - 2,
|
||||||
|
curses.LINES - 2, curses.COLS - 2)
|
||||||
|
|
||||||
|
|
||||||
|
class Worm(Thread):
|
||||||
|
"""
|
||||||
|
The worm is the hazel listener.
|
||||||
|
It always waits for an incoming packet, then it treats it, and continues to wait.
|
||||||
|
It is in a dedicated thread.
|
||||||
|
"""
|
||||||
|
def __init__(self, squirrel: Squirrel, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.squirrel = squirrel
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
pkt, hazelnut = self.squirrel.receive_packet()
|
||||||
|
pkt.validate_data()
|
||||||
|
except ValueError as error:
|
||||||
|
self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error))
|
||||||
|
else:
|
||||||
|
for tlv in pkt.body:
|
||||||
|
tlv.handle(self.squirrel, hazelnut)
|
||||||
|
self.squirrel.refresh_history()
|
||||||
|
self.squirrel.refresh_input()
|
|
@ -0,0 +1,364 @@
|
||||||
|
# Copyright (C) 2020 by eichhornchen, ÿnérant
|
||||||
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
from ipaddress import IPv6Address
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class TLV:
|
||||||
|
"""
|
||||||
|
The Tag-Length-Value contains the different type of data that can be sent.
|
||||||
|
TODO: add subclasses for each type of TLV
|
||||||
|
"""
|
||||||
|
type: int
|
||||||
|
length: int
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Parse data and construct TLV.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Translate the TLV into a byte array.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def validate_data(self) -> bool:
|
||||||
|
"""
|
||||||
|
Ensure that the TLV is well-formed.
|
||||||
|
Raises a ValueError if it is not the case.
|
||||||
|
TODO: Make some tests
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
"""
|
||||||
|
Indicates what to do when this TLV is received from a given hazel.
|
||||||
|
It is ensured that the data is valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tlv_length(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the total length (in bytes) of the TLV, including the type and the length.
|
||||||
|
Except for Pad1, this is 2 plus the length of the body of the TLV.
|
||||||
|
"""
|
||||||
|
return 2 + self.length
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tlv_classes() -> list:
|
||||||
|
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def network_order() -> str:
|
||||||
|
"""
|
||||||
|
The network byte order is always inverted as the host network byte order.
|
||||||
|
"""
|
||||||
|
return "little" if sys.byteorder == "big" else "big"
|
||||||
|
|
||||||
|
|
||||||
|
class Pad1TLV(TLV):
|
||||||
|
"""
|
||||||
|
This TLV is simply ignored.
|
||||||
|
"""
|
||||||
|
type: int = 0
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
There is nothing to do. We ignore the packet.
|
||||||
|
"""
|
||||||
|
self.type = raw_data[0]
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
"""
|
||||||
|
The TLV is empty.
|
||||||
|
"""
|
||||||
|
return self.type.to_bytes(1, TLV.network_order())
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
# TODO Add some easter eggs
|
||||||
|
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tlv_length(self) -> int:
|
||||||
|
"""
|
||||||
|
A Pad1 has always a length of 1.
|
||||||
|
"""
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class PadNTLV(TLV):
|
||||||
|
"""
|
||||||
|
This TLV is filled with zeros. It is ignored.
|
||||||
|
"""
|
||||||
|
type: int = 1
|
||||||
|
length: int
|
||||||
|
mbz: bytes
|
||||||
|
|
||||||
|
def validate_data(self) -> bool:
|
||||||
|
if self.mbz != int(0).to_bytes(self.length, TLV.network_order()):
|
||||||
|
raise ValueError("The body of a PadN TLV is not filled with zeros.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Store the zero-array, then ignore the packet.
|
||||||
|
"""
|
||||||
|
self.type = raw_data[0]
|
||||||
|
self.length = raw_data[1]
|
||||||
|
self.mbz = raw_data[2:self.tlv_length]
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Construct the byte array filled by zeros.
|
||||||
|
"""
|
||||||
|
return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \
|
||||||
|
+ self.mbz[:self.length]
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
# TODO Add some easter eggs
|
||||||
|
squirrel.add_system_message(f"I received {self.length} zeros, am I so a bag guy ? :cold_sweat:")
|
||||||
|
|
||||||
|
|
||||||
|
class HelloTLV(TLV):
|
||||||
|
type: int = 2
|
||||||
|
length: int
|
||||||
|
source_id: int
|
||||||
|
dest_id: Optional[int]
|
||||||
|
|
||||||
|
def validate_data(self) -> bool:
|
||||||
|
if self.length != 8 and self.length != 16:
|
||||||
|
raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
|
||||||
|
f"found {self.length}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
self.type = raw_data[0]
|
||||||
|
self.length = raw_data[1]
|
||||||
|
self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order())
|
||||||
|
if self.is_long:
|
||||||
|
self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order())
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \
|
||||||
|
+ self.source_id.to_bytes(8, TLV.network_order())
|
||||||
|
if self.dest_id:
|
||||||
|
data += self.dest_id.to_bytes(8, TLV.network_order())
|
||||||
|
return data
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
# TODO Implement HelloTLV
|
||||||
|
squirrel.add_system_message("Aaaawwww, someone spoke to me and said me Hello smiling_face_with_"
|
||||||
|
+ (":3_hearts:" if self.is_long else "smiling_eyes:"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long(self) -> bool:
|
||||||
|
return self.length == 16
|
||||||
|
|
||||||
|
|
||||||
|
class NeighbourTLV(TLV):
|
||||||
|
type: int = 3
|
||||||
|
length: int
|
||||||
|
ip_address: IPv6Address
|
||||||
|
port: int
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
self.type = raw_data[0]
|
||||||
|
self.length = raw_data[1]
|
||||||
|
self.ip_address = IPv6Address(raw_data[2:18])
|
||||||
|
self.port = int.from_bytes(raw_data[18:20], TLV.network_order())
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
return self.type.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.length.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.ip_address.packed + \
|
||||||
|
self.port.to_bytes(2, TLV.network_order())
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
# TODO Implement NeighbourTLV
|
||||||
|
squirrel.add_system_message("I have a friend!")
|
||||||
|
squirrel.add_system_message(f"Welcome {self.ip_address}:{self.port}!")
|
||||||
|
|
||||||
|
|
||||||
|
class DataTLV(TLV):
|
||||||
|
type: int = 4
|
||||||
|
length: int
|
||||||
|
sender_id: int
|
||||||
|
nonce: int
|
||||||
|
data: bytes
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
self.type = raw_data[0]
|
||||||
|
self.length = raw_data[1]
|
||||||
|
self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order())
|
||||||
|
self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order())
|
||||||
|
self.data = raw_data[14:self.tlv_length]
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
return self.type.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.length.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.sender_id.to_bytes(8, TLV.network_order()) + \
|
||||||
|
self.nonce.to_bytes(4, TLV.network_order()) + \
|
||||||
|
self.data
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
"""
|
||||||
|
A message has been sent. We log it.
|
||||||
|
TODO: Check that the tuple (sender_id, nonce) is unique to avoid duplicates.
|
||||||
|
"""
|
||||||
|
squirrel.add_message(self.data.decode('UTF-8'))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def construct(message: str) -> "DataTLV":
|
||||||
|
tlv = DataTLV()
|
||||||
|
tlv.type = 4
|
||||||
|
tlv.sender_id = 42 # FIXME Use the good sender id
|
||||||
|
tlv.nonce = 42 # FIXME Use an incremental nonce
|
||||||
|
tlv.data = message.encode("UTF-8")
|
||||||
|
tlv.length = 12 + len(tlv.data)
|
||||||
|
return tlv
|
||||||
|
|
||||||
|
|
||||||
|
class AckTLV(TLV):
|
||||||
|
type: int = 5
|
||||||
|
length: int
|
||||||
|
sender_id: int
|
||||||
|
nonce: int
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
self.type = raw_data[0]
|
||||||
|
self.length = raw_data[1]
|
||||||
|
self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order())
|
||||||
|
self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order())
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
return self.type.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.length.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.sender_id.to_bytes(8, TLV.network_order()) + \
|
||||||
|
self.nonce.to_bytes(4, TLV.network_order())
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
# TODO Implement AckTLV
|
||||||
|
squirrel.add_system_message("I received an AckTLV. I don't know what to do with it. Please implement me!")
|
||||||
|
|
||||||
|
|
||||||
|
class GoAwayTLV(TLV):
|
||||||
|
class GoAwayType(Enum):
|
||||||
|
UNKNOWN = 0
|
||||||
|
EXIT = 1
|
||||||
|
TIMEOUT = 2
|
||||||
|
PROTOCOL_VIOLATION = 3
|
||||||
|
|
||||||
|
type: int = 6
|
||||||
|
length: int
|
||||||
|
code: GoAwayType
|
||||||
|
message: str
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
self.type = raw_data[0]
|
||||||
|
self.length = raw_data[1]
|
||||||
|
self.code = GoAwayTLV.GoAwayType(raw_data[2])
|
||||||
|
self.message = raw_data[3:self.length - 1].decode("UTF-8")
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
return self.type.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.length.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.code.value.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.message.encode("UTF-8")[:self.length - 1]
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
# TODO Implement GoAwayTLV
|
||||||
|
squirrel.add_system_message("Some told me that he went away. That's not very nice :( "
|
||||||
|
"I should send him some cake.")
|
||||||
|
|
||||||
|
|
||||||
|
class WarningTLV(TLV):
|
||||||
|
type: int = 7
|
||||||
|
length: int
|
||||||
|
message: str
|
||||||
|
|
||||||
|
def unmarshal(self, raw_data: bytes) -> None:
|
||||||
|
self.type = raw_data[0]
|
||||||
|
self.length = raw_data[1]
|
||||||
|
self.message = raw_data[2:self.length].decode("UTF-8")
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
return self.type.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.length.to_bytes(1, TLV.network_order()) + \
|
||||||
|
self.message.encode("UTF-8")[:self.length]
|
||||||
|
|
||||||
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
||||||
|
squirrel.add_message(f"<warning> *A client warned you: {self.message}*")
|
||||||
|
|
||||||
|
|
||||||
|
class Packet:
|
||||||
|
"""
|
||||||
|
A Packet is a wrapper around the
|
||||||
|
"""
|
||||||
|
magic: int
|
||||||
|
version: int
|
||||||
|
body_length: int
|
||||||
|
body: List[TLV]
|
||||||
|
|
||||||
|
def validate_data(self) -> bool:
|
||||||
|
"""
|
||||||
|
Ensure that the packet is well-formed.
|
||||||
|
Raises a ValueError if the packet contains bad data.
|
||||||
|
"""
|
||||||
|
if self.magic != 95:
|
||||||
|
raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
|
||||||
|
if self.version != 0:
|
||||||
|
raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
|
||||||
|
if not (0 <= self.body_length <= 120):
|
||||||
|
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
|
||||||
|
"found: {:d}".format(self.body_length))
|
||||||
|
return all(tlv.validate_data() for tlv in self.body)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unmarshal(data: bytes) -> "Packet":
|
||||||
|
"""
|
||||||
|
Read raw data and build the packet wrapper.
|
||||||
|
Raises a ValueError whenever the data is invalid.
|
||||||
|
"""
|
||||||
|
pkt = Packet()
|
||||||
|
pkt.magic = data[0]
|
||||||
|
pkt.version = data[1]
|
||||||
|
pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order())
|
||||||
|
pkt.body = []
|
||||||
|
read_bytes = 0
|
||||||
|
while read_bytes <= min(len(data) - 4, pkt.body_length):
|
||||||
|
tlv_type = data[4]
|
||||||
|
if not (0 <= tlv_type < len(TLV.tlv_classes())):
|
||||||
|
raise ValueError(f"TLV type is not supported: {tlv_type}")
|
||||||
|
tlv = TLV.tlv_classes()[tlv_type]()
|
||||||
|
tlv.unmarshal(data[4:4 + pkt.body_length])
|
||||||
|
pkt.body.append(tlv)
|
||||||
|
read_bytes += tlv.tlv_length
|
||||||
|
|
||||||
|
pkt.validate_data()
|
||||||
|
|
||||||
|
return pkt
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Compute the byte array data associated to the packet.
|
||||||
|
"""
|
||||||
|
data = self.magic.to_bytes(1, TLV.network_order())
|
||||||
|
data += self.version.to_bytes(1, TLV.network_order())
|
||||||
|
data += self.body_length.to_bytes(2, TLV.network_order())
|
||||||
|
data += b"".join(tlv.marshal() for tlv in self.body)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def construct(*tlvs: TLV) -> "Packet":
|
||||||
|
"""
|
||||||
|
Construct a new packet from the given TLVs and calculate the good lengths
|
||||||
|
"""
|
||||||
|
pkt = Packet()
|
||||||
|
pkt.magic = 95
|
||||||
|
pkt.version = 0
|
||||||
|
pkt.body = tlvs
|
||||||
|
pkt.body_length = sum(tlv.tlv_length for tlv in tlvs)
|
||||||
|
return pkt
|
|
@ -12,6 +12,7 @@ from threading import Thread
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from squinnondation.term_manager import TermManager
|
from squinnondation.term_manager import TermManager
|
||||||
|
from hazel import Hazelnut, Squirrel
|
||||||
|
|
||||||
|
|
||||||
class Squinnondation:
|
class Squinnondation:
|
||||||
|
@ -83,783 +84,3 @@ class Squinnondation:
|
||||||
|
|
||||||
Worm(squirrel).start()
|
Worm(squirrel).start()
|
||||||
squirrel.wait_for_key()
|
squirrel.wait_for_key()
|
||||||
|
|
||||||
|
|
||||||
class TLV:
|
|
||||||
"""
|
|
||||||
The Tag-Length-Value contains the different type of data that can be sent.
|
|
||||||
TODO: add subclasses for each type of TLV
|
|
||||||
"""
|
|
||||||
type: int
|
|
||||||
length: int
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
"""
|
|
||||||
Parse data and construct TLV.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
"""
|
|
||||||
Translate the TLV into a byte array.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def validate_data(self) -> bool:
|
|
||||||
"""
|
|
||||||
Ensure that the TLV is well-formed.
|
|
||||||
Raises a ValueError if it is not the case.
|
|
||||||
TODO: Make some tests
|
|
||||||
"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
"""
|
|
||||||
Indicates what to do when this TLV is received from a given hazel.
|
|
||||||
It is ensured that the data is valid.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tlv_length(self) -> int:
|
|
||||||
"""
|
|
||||||
Returns the total length (in bytes) of the TLV, including the type and the length.
|
|
||||||
Except for Pad1, this is 2 plus the length of the body of the TLV.
|
|
||||||
"""
|
|
||||||
return 2 + self.length
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tlv_classes() -> list:
|
|
||||||
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def network_order() -> str:
|
|
||||||
"""
|
|
||||||
The network byte order is always inverted as the host network byte order.
|
|
||||||
"""
|
|
||||||
return "little" if sys.byteorder == "big" else "big"
|
|
||||||
|
|
||||||
|
|
||||||
class Pad1TLV(TLV):
|
|
||||||
"""
|
|
||||||
This TLV is simply ignored.
|
|
||||||
"""
|
|
||||||
type: int = 0
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
"""
|
|
||||||
There is nothing to do. We ignore the packet.
|
|
||||||
"""
|
|
||||||
self.type = raw_data[0]
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
"""
|
|
||||||
The TLV is empty.
|
|
||||||
"""
|
|
||||||
return self.type.to_bytes(1, TLV.network_order())
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
# TODO Add some easter eggs
|
|
||||||
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tlv_length(self) -> int:
|
|
||||||
"""
|
|
||||||
A Pad1 has always a length of 1.
|
|
||||||
"""
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
class PadNTLV(TLV):
|
|
||||||
"""
|
|
||||||
This TLV is filled with zeros. It is ignored.
|
|
||||||
"""
|
|
||||||
type: int = 1
|
|
||||||
length: int
|
|
||||||
mbz: bytes
|
|
||||||
|
|
||||||
def validate_data(self) -> bool:
|
|
||||||
if self.mbz != int(0).to_bytes(self.length, TLV.network_order()):
|
|
||||||
raise ValueError("The body of a PadN TLV is not filled with zeros.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
"""
|
|
||||||
Store the zero-array, then ignore the packet.
|
|
||||||
"""
|
|
||||||
self.type = raw_data[0]
|
|
||||||
self.length = raw_data[1]
|
|
||||||
self.mbz = raw_data[2:self.tlv_length]
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
"""
|
|
||||||
Construct the byte array filled by zeros.
|
|
||||||
"""
|
|
||||||
return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \
|
|
||||||
+ self.mbz[:self.length]
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
# TODO Add some easter eggs
|
|
||||||
squirrel.add_system_message(f"I received {self.length} zeros, am I so a bag guy ? :cold_sweat:")
|
|
||||||
|
|
||||||
|
|
||||||
class HelloTLV(TLV):
|
|
||||||
type: int = 2
|
|
||||||
length: int
|
|
||||||
source_id: int
|
|
||||||
dest_id: Optional[int]
|
|
||||||
|
|
||||||
def validate_data(self) -> bool:
|
|
||||||
if self.length != 8 and self.length != 16:
|
|
||||||
raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
|
|
||||||
f"found {self.length}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
self.type = raw_data[0]
|
|
||||||
self.length = raw_data[1]
|
|
||||||
self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order())
|
|
||||||
if self.is_long:
|
|
||||||
self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order())
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \
|
|
||||||
+ self.source_id.to_bytes(8, TLV.network_order())
|
|
||||||
if self.dest_id:
|
|
||||||
data += self.dest_id.to_bytes(8, TLV.network_order())
|
|
||||||
return data
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
# TODO Implement HelloTLV
|
|
||||||
squirrel.add_system_message("Aaaawwww, someone spoke to me and said me Hello smiling_face_with_"
|
|
||||||
+ (":3_hearts:" if self.is_long else "smiling_eyes:"))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_long(self) -> bool:
|
|
||||||
return self.length == 16
|
|
||||||
|
|
||||||
|
|
||||||
class NeighbourTLV(TLV):
|
|
||||||
type: int = 3
|
|
||||||
length: int
|
|
||||||
ip_address: IPv6Address
|
|
||||||
port: int
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
self.type = raw_data[0]
|
|
||||||
self.length = raw_data[1]
|
|
||||||
self.ip_address = IPv6Address(raw_data[2:18])
|
|
||||||
self.port = int.from_bytes(raw_data[18:20], TLV.network_order())
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
return self.type.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.length.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.ip_address.packed + \
|
|
||||||
self.port.to_bytes(2, TLV.network_order())
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
# TODO Implement NeighbourTLV
|
|
||||||
squirrel.add_system_message("I have a friend!")
|
|
||||||
squirrel.add_system_message(f"Welcome {self.ip_address}:{self.port}!")
|
|
||||||
|
|
||||||
|
|
||||||
class DataTLV(TLV):
|
|
||||||
type: int = 4
|
|
||||||
length: int
|
|
||||||
sender_id: int
|
|
||||||
nonce: int
|
|
||||||
data: bytes
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
self.type = raw_data[0]
|
|
||||||
self.length = raw_data[1]
|
|
||||||
self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order())
|
|
||||||
self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order())
|
|
||||||
self.data = raw_data[14:self.tlv_length]
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
return self.type.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.length.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.sender_id.to_bytes(8, TLV.network_order()) + \
|
|
||||||
self.nonce.to_bytes(4, TLV.network_order()) + \
|
|
||||||
self.data
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
"""
|
|
||||||
A message has been sent. We log it.
|
|
||||||
TODO: Check that the tuple (sender_id, nonce) is unique to avoid duplicates.
|
|
||||||
"""
|
|
||||||
squirrel.add_message(self.data.decode('UTF-8'))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def construct(message: str) -> "DataTLV":
|
|
||||||
tlv = DataTLV()
|
|
||||||
tlv.type = 4
|
|
||||||
tlv.sender_id = 42 # FIXME Use the good sender id
|
|
||||||
tlv.nonce = 42 # FIXME Use an incremental nonce
|
|
||||||
tlv.data = message.encode("UTF-8")
|
|
||||||
tlv.length = 12 + len(tlv.data)
|
|
||||||
return tlv
|
|
||||||
|
|
||||||
|
|
||||||
class AckTLV(TLV):
|
|
||||||
type: int = 5
|
|
||||||
length: int
|
|
||||||
sender_id: int
|
|
||||||
nonce: int
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
self.type = raw_data[0]
|
|
||||||
self.length = raw_data[1]
|
|
||||||
self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order())
|
|
||||||
self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order())
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
return self.type.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.length.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.sender_id.to_bytes(8, TLV.network_order()) + \
|
|
||||||
self.nonce.to_bytes(4, TLV.network_order())
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
# TODO Implement AckTLV
|
|
||||||
squirrel.add_system_message("I received an AckTLV. I don't know what to do with it. Please implement me!")
|
|
||||||
|
|
||||||
|
|
||||||
class GoAwayTLV(TLV):
|
|
||||||
class GoAwayType(Enum):
|
|
||||||
UNKNOWN = 0
|
|
||||||
EXIT = 1
|
|
||||||
TIMEOUT = 2
|
|
||||||
PROTOCOL_VIOLATION = 3
|
|
||||||
|
|
||||||
type: int = 6
|
|
||||||
length: int
|
|
||||||
code: GoAwayType
|
|
||||||
message: str
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
self.type = raw_data[0]
|
|
||||||
self.length = raw_data[1]
|
|
||||||
self.code = GoAwayTLV.GoAwayType(raw_data[2])
|
|
||||||
self.message = raw_data[3:self.length - 1].decode("UTF-8")
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
return self.type.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.length.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.code.value.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.message.encode("UTF-8")[:self.length - 1]
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
# TODO Implement GoAwayTLV
|
|
||||||
squirrel.add_system_message("Some told me that he went away. That's not very nice :( "
|
|
||||||
"I should send him some cake.")
|
|
||||||
|
|
||||||
|
|
||||||
class WarningTLV(TLV):
|
|
||||||
type: int = 7
|
|
||||||
length: int
|
|
||||||
message: str
|
|
||||||
|
|
||||||
def unmarshal(self, raw_data: bytes) -> None:
|
|
||||||
self.type = raw_data[0]
|
|
||||||
self.length = raw_data[1]
|
|
||||||
self.message = raw_data[2:self.length].decode("UTF-8")
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
return self.type.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.length.to_bytes(1, TLV.network_order()) + \
|
|
||||||
self.message.encode("UTF-8")[:self.length]
|
|
||||||
|
|
||||||
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
|
|
||||||
squirrel.add_message(f"<warning> *A client warned you: {self.message}*")
|
|
||||||
|
|
||||||
|
|
||||||
class Packet:
|
|
||||||
"""
|
|
||||||
A Packet is a wrapper around the
|
|
||||||
"""
|
|
||||||
magic: int
|
|
||||||
version: int
|
|
||||||
body_length: int
|
|
||||||
body: List[TLV]
|
|
||||||
|
|
||||||
def validate_data(self) -> bool:
|
|
||||||
"""
|
|
||||||
Ensure that the packet is well-formed.
|
|
||||||
Raises a ValueError if the packet contains bad data.
|
|
||||||
"""
|
|
||||||
if self.magic != 95:
|
|
||||||
raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
|
|
||||||
if self.version != 0:
|
|
||||||
raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
|
|
||||||
if not (0 <= self.body_length <= 120):
|
|
||||||
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
|
|
||||||
"found: {:d}".format(self.body_length))
|
|
||||||
return all(tlv.validate_data() for tlv in self.body)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def unmarshal(data: bytes) -> "Packet":
|
|
||||||
"""
|
|
||||||
Read raw data and build the packet wrapper.
|
|
||||||
Raises a ValueError whenever the data is invalid.
|
|
||||||
"""
|
|
||||||
pkt = Packet()
|
|
||||||
pkt.magic = data[0]
|
|
||||||
pkt.version = data[1]
|
|
||||||
pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order())
|
|
||||||
pkt.body = []
|
|
||||||
read_bytes = 0
|
|
||||||
while read_bytes <= min(len(data) - 4, pkt.body_length):
|
|
||||||
tlv_type = data[4]
|
|
||||||
if not (0 <= tlv_type < len(TLV.tlv_classes())):
|
|
||||||
raise ValueError(f"TLV type is not supported: {tlv_type}")
|
|
||||||
tlv = TLV.tlv_classes()[tlv_type]()
|
|
||||||
tlv.unmarshal(data[4:4 + pkt.body_length])
|
|
||||||
pkt.body.append(tlv)
|
|
||||||
read_bytes += tlv.tlv_length
|
|
||||||
|
|
||||||
pkt.validate_data()
|
|
||||||
|
|
||||||
return pkt
|
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
|
||||||
"""
|
|
||||||
Compute the byte array data associated to the packet.
|
|
||||||
"""
|
|
||||||
data = self.magic.to_bytes(1, TLV.network_order())
|
|
||||||
data += self.version.to_bytes(1, TLV.network_order())
|
|
||||||
data += self.body_length.to_bytes(2, TLV.network_order())
|
|
||||||
data += b"".join(tlv.marshal() for tlv in self.body)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def construct(*tlvs: TLV) -> "Packet":
|
|
||||||
"""
|
|
||||||
Construct a new packet from the given TLVs and calculate the good lengths
|
|
||||||
"""
|
|
||||||
pkt = Packet()
|
|
||||||
pkt.magic = 95
|
|
||||||
pkt.version = 0
|
|
||||||
pkt.body = tlvs
|
|
||||||
pkt.body_length = sum(tlv.tlv_length for tlv in tlvs)
|
|
||||||
return pkt
|
|
||||||
|
|
||||||
|
|
||||||
class Hazelnut:
|
|
||||||
"""
|
|
||||||
A hazelnut is a connected client, with its socket.
|
|
||||||
"""
|
|
||||||
def __init__(self, nickname: str = "anonymous", address: str = "localhost", port: int = 2500):
|
|
||||||
self.nickname = nickname
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Resolve DNS as an IPv6
|
|
||||||
address = socket.getaddrinfo(address, None, socket.AF_INET6)[0][4][0]
|
|
||||||
except socket.gaierror:
|
|
||||||
# This is not a valid IPv6. Assume it can be resolved as an IPv4, and we use IPv4-mapping
|
|
||||||
# to compute a valid IPv6.
|
|
||||||
# See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4
|
|
||||||
address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0]
|
|
||||||
|
|
||||||
self.address = IPv6Address(address)
|
|
||||||
self.port = port
|
|
||||||
|
|
||||||
|
|
||||||
class Squirrel(Hazelnut):
|
|
||||||
"""
|
|
||||||
The squirrel is the user of the program. It can speak with other clients, that are called hazelnuts.
|
|
||||||
"""
|
|
||||||
def __init__(self, instance: Squinnondation, nickname: str):
|
|
||||||
super().__init__(nickname, instance.bind_address, instance.bind_port)
|
|
||||||
# Create UDP socket
|
|
||||||
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
|
||||||
# Bind the socket
|
|
||||||
self.socket.bind((str(self.address), self.port))
|
|
||||||
|
|
||||||
self.squinnondation = instance
|
|
||||||
|
|
||||||
self.input_buffer = ""
|
|
||||||
self.input_index = 0
|
|
||||||
self.last_line = -1
|
|
||||||
|
|
||||||
self.history = []
|
|
||||||
self.history_pad = curses.newpad(curses.LINES - 2, curses.COLS)
|
|
||||||
self.input_pad = curses.newpad(1, curses.COLS)
|
|
||||||
self.emoji_pad = curses.newpad(18, 12)
|
|
||||||
self.emoji_panel_page = -1
|
|
||||||
|
|
||||||
curses.init_color(curses.COLOR_WHITE, 1000, 1000, 1000)
|
|
||||||
for i in range(curses.COLOR_BLACK + 1, curses.COLOR_WHITE):
|
|
||||||
curses.init_pair(i + 1, i, curses.COLOR_BLACK)
|
|
||||||
|
|
||||||
self.hazelnuts = dict()
|
|
||||||
self.add_system_message(f"Listening on {self.address}:{self.port}")
|
|
||||||
|
|
||||||
def find_hazelnut(self, address: str, port: int) -> Hazelnut:
|
|
||||||
"""
|
|
||||||
Translate an address into a hazelnut, and store it in the list of the hazelnuts, ie. the neighbours.
|
|
||||||
"""
|
|
||||||
if (address, port) in self.hazelnuts:
|
|
||||||
return self.hazelnuts[(address, port)]
|
|
||||||
hazelnut = Hazelnut(address=address, port=port)
|
|
||||||
self.hazelnuts[(address, port)] = hazelnut
|
|
||||||
return hazelnut
|
|
||||||
|
|
||||||
def send_packet(self, client: Hazelnut, pkt: Packet) -> int:
|
|
||||||
"""
|
|
||||||
Send a formatted packet to a client.
|
|
||||||
"""
|
|
||||||
return self.send_raw_data(client, pkt.marshal())
|
|
||||||
|
|
||||||
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
|
|
||||||
"""
|
|
||||||
Send a raw packet to a client.
|
|
||||||
"""
|
|
||||||
return self.socket.sendto(data, (str(client.address), client.port))
|
|
||||||
|
|
||||||
def receive_packet(self) -> Tuple[Packet, Hazelnut]:
|
|
||||||
"""
|
|
||||||
Receive a packet from the socket and translate it into a Python object.
|
|
||||||
Warning: the process is blocking, it should be ran inside a dedicated thread.
|
|
||||||
"""
|
|
||||||
data, addr = self.receive_raw_data()
|
|
||||||
return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1])
|
|
||||||
|
|
||||||
def receive_raw_data(self) -> Tuple[bytes, Any]:
|
|
||||||
"""
|
|
||||||
Receive a packet from the socket.
|
|
||||||
"""
|
|
||||||
return self.socket.recvfrom(1024)
|
|
||||||
|
|
||||||
def wait_for_key(self) -> None:
|
|
||||||
"""
|
|
||||||
Infinite loop where we are waiting for a key of the user.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
self.refresh_history()
|
|
||||||
self.refresh_input()
|
|
||||||
if not self.squinnondation.no_emoji:
|
|
||||||
self.refresh_emoji_pad()
|
|
||||||
key = self.squinnondation.screen.getkey(curses.LINES - 1, 3 + len(self.nickname) + self.input_index)
|
|
||||||
|
|
||||||
if key == "KEY_MOUSE":
|
|
||||||
try:
|
|
||||||
_, x, y, _, attr = curses.getmouse()
|
|
||||||
self.handle_mouse_click(y, x, attr)
|
|
||||||
continue
|
|
||||||
except curses.error:
|
|
||||||
# This is not a valid click
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.handle_key_pressed(key)
|
|
||||||
|
|
||||||
def handle_key_pressed(self, key: str) -> None:
|
|
||||||
"""
|
|
||||||
Process the key press from the user.
|
|
||||||
"""
|
|
||||||
if key == "\x7f": # backspace
|
|
||||||
# delete character at the good position
|
|
||||||
if self.input_index:
|
|
||||||
self.input_index -= 1
|
|
||||||
self.input_buffer = self.input_buffer[:self.input_index] + self.input_buffer[self.input_index + 1:]
|
|
||||||
return
|
|
||||||
elif key == "KEY_LEFT":
|
|
||||||
# Navigate in the message to the left
|
|
||||||
self.input_index = max(0, self.input_index - 1)
|
|
||||||
return
|
|
||||||
elif key == "KEY_RIGHT":
|
|
||||||
# Navigate in the message to the right
|
|
||||||
self.input_index = min(len(self.input_buffer), self.input_index + 1)
|
|
||||||
return
|
|
||||||
elif key == "KEY_UP":
|
|
||||||
# Scroll up in the history
|
|
||||||
self.last_line = min(max(curses.LINES - 3, self.last_line - 1), len(self.history) - 1)
|
|
||||||
return
|
|
||||||
elif key == "KEY_DOWN":
|
|
||||||
# Scroll down in the history
|
|
||||||
self.last_line = min(len(self.history) - 1, self.last_line + 1)
|
|
||||||
return
|
|
||||||
elif key == "KEY_PPAGE":
|
|
||||||
# Page up in the history
|
|
||||||
self.last_line = min(max(curses.LINES - 3, self.last_line - (curses.LINES - 3)), len(self.history) - 1)
|
|
||||||
return
|
|
||||||
elif key == "KEY_NPAGE":
|
|
||||||
# Page down in the history
|
|
||||||
self.last_line = min(len(self.history) - 1, self.last_line + (curses.LINES - 3))
|
|
||||||
return
|
|
||||||
elif key == "KEY_HOME":
|
|
||||||
# Place the cursor at the beginning of the typing word
|
|
||||||
self.input_index = 0
|
|
||||||
return
|
|
||||||
elif key == "KEY_END":
|
|
||||||
# Place the cursor at the end of the typing word
|
|
||||||
self.input_index = len(self.input_buffer)
|
|
||||||
return
|
|
||||||
elif len(key) > 1:
|
|
||||||
# Unmanaged complex key
|
|
||||||
return
|
|
||||||
elif key != "\n":
|
|
||||||
# Insert the pressed key in the current message
|
|
||||||
self.input_buffer = self.input_buffer[:self.input_index] + key + self.input_buffer[self.input_index:]
|
|
||||||
self.input_index += 1
|
|
||||||
return
|
|
||||||
|
|
||||||
# Send message to neighbours
|
|
||||||
msg = self.input_buffer
|
|
||||||
self.input_buffer = ""
|
|
||||||
self.input_index = 0
|
|
||||||
|
|
||||||
if not msg:
|
|
||||||
return
|
|
||||||
|
|
||||||
msg = f"<{self.nickname}> {msg}"
|
|
||||||
self.add_message(msg)
|
|
||||||
|
|
||||||
pkt = Packet.construct(DataTLV.construct(msg))
|
|
||||||
for hazelnut in list(self.hazelnuts.values()):
|
|
||||||
self.send_packet(hazelnut, pkt)
|
|
||||||
|
|
||||||
def handle_mouse_click(self, y: int, x: int, attr: int) -> None:
|
|
||||||
"""
|
|
||||||
The user clicks on the screen, at coordinates (y, x).
|
|
||||||
According to the position, we can indicate what can be done.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not self.squinnondation.no_emoji:
|
|
||||||
if y == curses.LINES - 1 and x >= curses.COLS - 3:
|
|
||||||
# Click on the emoji, open or close the emoji pad
|
|
||||||
self.emoji_panel_page *= -1
|
|
||||||
elif self.emoji_panel_page > 0 and y == curses.LINES - 4 and x >= curses.COLS - 5:
|
|
||||||
# Open next emoji page
|
|
||||||
self.emoji_panel_page += 1
|
|
||||||
elif self.emoji_panel_page > 1 and y == curses.LINES - curses.LINES // 2 - 1 \
|
|
||||||
and x >= curses.COLS - 5:
|
|
||||||
# Open previous emoji page
|
|
||||||
self.emoji_panel_page -= 1
|
|
||||||
elif self.emoji_panel_page > 0 and y >= curses.LINES // 2 - 1 and x >= curses.COLS // 2 - 1:
|
|
||||||
pad_y, pad_x = y - (curses.LINES - curses.LINES // 2) + 1, \
|
|
||||||
(x - (curses.COLS - curses.COLS // 3) + 1) // 2
|
|
||||||
# Click on an emoji on the pad to autocomplete an emoji
|
|
||||||
self.click_on_emoji_pad(pad_y, pad_x)
|
|
||||||
|
|
||||||
def click_on_emoji_pad(self, pad_y: int, pad_x: int) -> None:
|
|
||||||
"""
|
|
||||||
The emoji pad contains the list of all available emojis.
|
|
||||||
Clicking on a emoji auto-complete the emoji in the input pad.
|
|
||||||
"""
|
|
||||||
import emoji
|
|
||||||
from emoji import unicode_codes
|
|
||||||
|
|
||||||
height, width = self.emoji_pad.getmaxyx()
|
|
||||||
height -= 1
|
|
||||||
width -= 1
|
|
||||||
|
|
||||||
emojis = list(unicode_codes.UNICODE_EMOJI)
|
|
||||||
emojis = [c for c in emojis if len(c) == 1]
|
|
||||||
size = (height - 2) * (width - 4) // 2
|
|
||||||
page = emojis[(self.emoji_panel_page - 1) * size:self.emoji_panel_page * size]
|
|
||||||
index = pad_y * (width - 4) // 2 + pad_x
|
|
||||||
char = page[index]
|
|
||||||
if char:
|
|
||||||
demojized = emoji.demojize(char)
|
|
||||||
if char != demojized:
|
|
||||||
for c in reversed(demojized):
|
|
||||||
curses.ungetch(c)
|
|
||||||
|
|
||||||
def add_message(self, msg: str) -> None:
|
|
||||||
"""
|
|
||||||
Store a new message into the history.
|
|
||||||
"""
|
|
||||||
self.history.append(msg)
|
|
||||||
if self.last_line == len(self.history) - 2:
|
|
||||||
self.last_line += 1
|
|
||||||
|
|
||||||
def add_system_message(self, msg: str) -> None:
|
|
||||||
"""
|
|
||||||
Add a new system log message.
|
|
||||||
TODO: Configure logging levels to ignore some messages.
|
|
||||||
"""
|
|
||||||
return self.add_message(f"<system> *{msg}*")
|
|
||||||
|
|
||||||
def print_markdown(self, pad: Any, y: int, x: int, msg: str,
|
|
||||||
bold: bool = False, italic: bool = False, underline: bool = False, strike: bool = False) -> int:
|
|
||||||
"""
|
|
||||||
Parse a markdown-formatted text and format the text as bold, italic or text text.
|
|
||||||
***text***: bold, italic
|
|
||||||
**text**: bold
|
|
||||||
*text*: italic
|
|
||||||
__text__: underline
|
|
||||||
_text_: italic
|
|
||||||
~~text~~: strikethrough
|
|
||||||
"""
|
|
||||||
# Replace :emoji_name: by the good emoji
|
|
||||||
if not self.squinnondation.no_emoji:
|
|
||||||
import emoji
|
|
||||||
msg = emoji.emojize(msg, use_aliases=True)
|
|
||||||
|
|
||||||
if self.squinnondation.no_markdown:
|
|
||||||
pad.addstr(y, x, msg)
|
|
||||||
return len(msg)
|
|
||||||
|
|
||||||
underline_match = re.match("(.*)__(.*)__(.*)", msg)
|
|
||||||
if underline_match:
|
|
||||||
before, text, after = underline_match.group(1), underline_match.group(2), underline_match.group(3)
|
|
||||||
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline)
|
|
||||||
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, not underline)
|
|
||||||
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline)
|
|
||||||
return len_before + len_mid + len_after
|
|
||||||
|
|
||||||
italic_match = re.match("(.*)_(.*)_(.*)", msg)
|
|
||||||
if italic_match:
|
|
||||||
before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3)
|
|
||||||
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline)
|
|
||||||
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline)
|
|
||||||
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline)
|
|
||||||
return len_before + len_mid + len_after
|
|
||||||
|
|
||||||
bold_italic_match = re.match("(.*)\\*\\*\\*(.*)\\*\\*\\*(.*)", msg)
|
|
||||||
if bold_italic_match:
|
|
||||||
before, text, after = bold_italic_match.group(1), bold_italic_match.group(2),\
|
|
||||||
bold_italic_match.group(3)
|
|
||||||
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
|
|
||||||
len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, not italic, underline, strike)
|
|
||||||
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
|
|
||||||
return len_before + len_mid + len_after
|
|
||||||
|
|
||||||
bold_match = re.match("(.*)\\*\\*(.*)\\*\\*(.*)", msg)
|
|
||||||
if bold_match:
|
|
||||||
before, text, after = bold_match.group(1), bold_match.group(2), bold_match.group(3)
|
|
||||||
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
|
|
||||||
len_mid = self.print_markdown(pad, y, x + len_before, text, not bold, italic, underline, strike)
|
|
||||||
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
|
|
||||||
return len_before + len_mid + len_after
|
|
||||||
|
|
||||||
italic_match = re.match("(.*)\\*(.*)\\*(.*)", msg)
|
|
||||||
if italic_match:
|
|
||||||
before, text, after = italic_match.group(1), italic_match.group(2), italic_match.group(3)
|
|
||||||
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
|
|
||||||
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, not italic, underline, strike)
|
|
||||||
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
|
|
||||||
return len_before + len_mid + len_after
|
|
||||||
|
|
||||||
strike_match = re.match("(.*)~~(.*)~~(.*)", msg)
|
|
||||||
if strike_match:
|
|
||||||
before, text, after = strike_match.group(1), strike_match.group(2), strike_match.group(3)
|
|
||||||
len_before = self.print_markdown(pad, y, x, before, bold, italic, underline, strike)
|
|
||||||
len_mid = self.print_markdown(pad, y, x + len_before, text, bold, italic, underline, not strike)
|
|
||||||
len_after = self.print_markdown(pad, y, x + len_before + len_mid, after, bold, italic, underline, strike)
|
|
||||||
return len_before + len_mid + len_after
|
|
||||||
|
|
||||||
size = len(msg)
|
|
||||||
|
|
||||||
attrs = 0
|
|
||||||
attrs |= curses.A_BOLD if bold else 0
|
|
||||||
attrs |= curses.A_ITALIC if italic else 0
|
|
||||||
attrs |= curses.A_UNDERLINE if underline else 0
|
|
||||||
if strike:
|
|
||||||
msg = "".join(c + "\u0336" for c in msg)
|
|
||||||
pad.addstr(y, x, msg, attrs)
|
|
||||||
return size
|
|
||||||
|
|
||||||
def refresh_history(self) -> None:
|
|
||||||
"""
|
|
||||||
Rewrite the history of the messages.
|
|
||||||
"""
|
|
||||||
y, x = self.squinnondation.screen.getmaxyx()
|
|
||||||
if curses.is_term_resized(curses.LINES, curses.COLS):
|
|
||||||
curses.resizeterm(y, x)
|
|
||||||
self.history_pad.resize(curses.LINES - 2, curses.COLS - 1)
|
|
||||||
self.input_pad.resize(1, curses.COLS - 1)
|
|
||||||
|
|
||||||
self.history_pad.erase()
|
|
||||||
for i, msg in enumerate(self.history[max(0, self.last_line - curses.LINES + 3):self.last_line + 1]):
|
|
||||||
if not re.match("<.*> .*", msg):
|
|
||||||
msg = "<unknown> " + msg
|
|
||||||
match = re.match("<(.*)> (.*)", msg)
|
|
||||||
nickname = match.group(1)
|
|
||||||
msg = match.group(2)
|
|
||||||
color_id = sum(ord(c) for c in nickname) % 6 + 1
|
|
||||||
self.history_pad.addstr(i, 0, "<")
|
|
||||||
self.history_pad.addstr(i, 1, nickname, curses.A_BOLD | curses.color_pair(color_id + 1))
|
|
||||||
self.history_pad.addstr(i, 1 + len(nickname), "> ")
|
|
||||||
self.print_markdown(self.history_pad, i, 3 + len(nickname), msg)
|
|
||||||
self.history_pad.refresh(0, 0, 0, 0, curses.LINES - 2, curses.COLS)
|
|
||||||
|
|
||||||
def refresh_input(self) -> None:
|
|
||||||
"""
|
|
||||||
Redraw input line. Must not be called while the message is not sent.
|
|
||||||
"""
|
|
||||||
self.input_pad.erase()
|
|
||||||
color_id = sum(ord(c) for c in self.nickname) % 6 + 1
|
|
||||||
self.input_pad.addstr(0, 0, "<")
|
|
||||||
self.input_pad.addstr(0, 1, self.nickname, curses.A_BOLD | curses.color_pair(color_id + 1))
|
|
||||||
self.input_pad.addstr(0, 1 + len(self.nickname), "> ")
|
|
||||||
msg = self.input_buffer
|
|
||||||
if len(msg) + len(self.nickname) + 3 >= curses.COLS:
|
|
||||||
msg = ""
|
|
||||||
self.input_pad.addstr(0, 3 + len(self.nickname), self.input_buffer)
|
|
||||||
if not self.squinnondation.no_emoji:
|
|
||||||
self.input_pad.addstr(0, self.input_pad.getmaxyx()[1] - 3, "😀")
|
|
||||||
self.input_pad.refresh(0, 0, curses.LINES - 1, 0, curses.LINES - 1, curses.COLS - 1)
|
|
||||||
|
|
||||||
def refresh_emoji_pad(self) -> None:
|
|
||||||
"""
|
|
||||||
Display the emoji pad if necessary.
|
|
||||||
"""
|
|
||||||
if self.squinnondation.no_emoji:
|
|
||||||
return
|
|
||||||
|
|
||||||
from emoji import unicode_codes
|
|
||||||
|
|
||||||
self.emoji_pad.erase()
|
|
||||||
|
|
||||||
if self.emoji_panel_page > 0:
|
|
||||||
height, width = curses.LINES // 2, curses.COLS // 3
|
|
||||||
self.emoji_pad.resize(height + 1, width + 1)
|
|
||||||
self.emoji_pad.addstr(0, 0, "┏" + (width - 2) * "━" + "┓")
|
|
||||||
self.emoji_pad.addstr(0, (width - 14) // 2, " == EMOJIS == ")
|
|
||||||
for i in range(1, height):
|
|
||||||
self.emoji_pad.addstr(i, 0, "┃" + (width - 2) * " " + "┃")
|
|
||||||
self.emoji_pad.addstr(height - 1, 0, "┗" + (width - 2) * "━" + "┛")
|
|
||||||
|
|
||||||
emojis = list(unicode_codes.UNICODE_EMOJI)
|
|
||||||
emojis = [c for c in emojis if len(c) == 1]
|
|
||||||
size = (height - 2) * (width - 4) // 2
|
|
||||||
page = emojis[(self.emoji_panel_page - 1) * size:self.emoji_panel_page * size]
|
|
||||||
|
|
||||||
if self.emoji_panel_page != 1:
|
|
||||||
self.emoji_pad.addstr(1, width - 2, "⬆")
|
|
||||||
if len(page) == size:
|
|
||||||
self.emoji_pad.addstr(height - 2, width - 2, "⬇")
|
|
||||||
|
|
||||||
for i in range(height - 2):
|
|
||||||
for j in range((width - 4) // 2 + 1):
|
|
||||||
index = i * (width - 4) // 2 + j
|
|
||||||
if index < len(page):
|
|
||||||
self.emoji_pad.addstr(i + 1, 2 * j + 1, page[index])
|
|
||||||
|
|
||||||
self.emoji_pad.refresh(0, 0, curses.LINES - height - 2, curses.COLS - width - 2,
|
|
||||||
curses.LINES - 2, curses.COLS - 2)
|
|
||||||
|
|
||||||
|
|
||||||
class Worm(Thread):
|
|
||||||
"""
|
|
||||||
The worm is the hazel listener.
|
|
||||||
It always waits for an incoming packet, then it treats it, and continues to wait.
|
|
||||||
It is in a dedicated thread.
|
|
||||||
"""
|
|
||||||
def __init__(self, squirrel: Squirrel, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.squirrel = squirrel
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
pkt, hazelnut = self.squirrel.receive_packet()
|
|
||||||
pkt.validate_data()
|
|
||||||
except ValueError as error:
|
|
||||||
self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error))
|
|
||||||
else:
|
|
||||||
for tlv in pkt.body:
|
|
||||||
tlv.handle(self.squirrel, hazelnut)
|
|
||||||
self.squirrel.refresh_history()
|
|
||||||
self.squirrel.refresh_input()
|
|
||||||
|
|
Loading…
Reference in New Issue