squinnondation/squinnondation/messages.py

576 lines
20 KiB
Python

# Copyright (C) 2020 by eichhornchen, ÿnérant
# SPDX-License-Identifier: GPL-3.0-or-later
import re
from typing import Any, List, Optional
from ipaddress import IPv6Address
from enum import Enum
import socket
import sys
import time
class TLV:
"""
The Tag-Length-Value contains the different type of data that can be sent.
"""
type: int
length: int
def unmarshal(self, raw_data: bytes) -> None:
"""
Parse data and construct TLV.
"""
raise NotImplementedError
def marshal(self) -> bytes:
"""
Translate the TLV into a byte array.
"""
raise NotImplementedError
def validate_data(self) -> bool:
"""
Ensure that the TLV is well-formed.
Raises a ValueError if it is not the case.
"""
return True
def handle(self, user: Any, sender: Any) -> None:
"""
Indicates what to do when this TLV is received from a given peer.
It is ensured that the data is valid.
"""
def __len__(self) -> int:
"""
Returns the total length (in bytes) of the TLV, including the type and the length.
Except for Pad1, this is 2 plus the length of the body of the TLV.
"""
return 2 + self.length
@staticmethod
def tlv_classes() -> list:
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
class Pad1TLV(TLV):
"""
This TLV is simply ignored.
"""
type: int = 0
def unmarshal(self, raw_data: bytes) -> None:
"""
There is nothing to do. We ignore the packet.
"""
self.type = raw_data[0]
def marshal(self) -> bytes:
"""
The TLV is empty.
"""
return self.type.to_bytes(1, sys.byteorder)
def handle(self, user: Any, sender: Any) -> None:
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
user.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I won't listen to your Pad1TLV. Please say Hello to me before.")))
return
user.add_system_message("I received a Pad1TLV, how disapointing.")
def __len__(self) -> int:
"""
A Pad1 has always a length of 1.
"""
return 1
@staticmethod
def construct() -> "Pad1TLV":
tlv = Pad1TLV()
tlv.type = 0
return tlv
class PadNTLV(TLV):
"""
This TLV is filled with zeros. It is ignored.
"""
type: int = 1
length: int
mbz: bytes
def validate_data(self) -> bool:
if self.mbz != int(0).to_bytes(self.length, sys.byteorder):
raise ValueError("The body of a PadN TLV is not filled with zeros.")
return True
def unmarshal(self, raw_data: bytes) -> None:
"""
Store the zero-array, then ignore the packet.
"""
self.type = raw_data[0]
self.length = raw_data[1]
self.mbz = raw_data[2:len(self)]
def marshal(self) -> bytes:
"""
Construct the byte array filled by zeros.
"""
return self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \
+ self.mbz[:self.length]
def handle(self, user: Any, sender: Any) -> None:
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
user.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I won't listen to your PadNTLV. Please say Hello to me before.")))
return
user.add_system_message(f"I received {self.length} zeros.")
@staticmethod
def construct(length: int) -> "PadNTLV":
tlv = PadNTLV()
tlv.type = 1
tlv.length = length
tlv.mbz = b'0' * length
return tlv
class HelloTLV(TLV):
type: int = 2
length: int
source_id: int
dest_id: Optional[int]
def validate_data(self) -> bool:
if self.length != 8 and self.length != 16:
raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
f"found {self.length}")
return True
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.source_id = int.from_bytes(raw_data[2:10], sys.byteorder)
if self.is_long:
self.dest_id = int.from_bytes(raw_data[10:18], sys.byteorder)
def marshal(self) -> bytes:
data = self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \
+ self.source_id.to_bytes(8, sys.byteorder)
if self.dest_id:
data += self.dest_id.to_bytes(8, sys.byteorder)
return data
def handle(self, user: Any, sender: Any) -> None:
time_h = time.time()
if sender.id > 0 and sender.id != self.source_id:
user.send_packet(sender, Packet.construct(WarningTLV.construct(
f"You were known as the ID {sender.id}, but you declared that you have the ID {self.source_id}.")))
user.add_system_message(f"A client known as the id {sender.id} declared that it uses "
f"the id {self.source_id}.")
sender.id = self.source_id
if not sender.active:
sender.id = self.source_id # The sender we are given misses an id
time_hl = time.time()
else:
time_hl = sender.last_long_hello_time
if self.is_long and self.dest_id == user.id:
time_hl = time.time()
# Add entry to/actualize the active peers dictionnary
sender.last_hello_time = time_h
sender.last_long_hello_time = time_hl
sender.symmetric = True
sender.active = True
user.update_peer_table(sender)
user.nbNS += 1
user.add_system_message(f"{self.source_id} sent me a Hello " + ("long" if self.is_long else "short"))
if not self.is_long:
user.send_packet(sender, Packet.construct(HelloTLV.construct(16, user, sender)))
@property
def is_long(self) -> bool:
return self.length == 16
@staticmethod
def construct(length: int, user: Any, destination: Any = None) -> "HelloTLV":
tlv = HelloTLV()
tlv.type = 2
tlv.source_id = user.id if user else 0
if (destination is None) or destination.id == -1 or length == 8:
tlv.length = 8
tlv.dest_id = None # if the destination id is not known, or
# if the destination was not precised, send a short hello
else:
tlv.length = 16
tlv.dest_id = destination.id
return tlv
class NeighbourTLV(TLV):
type: int = 3
length: int
ip_address: IPv6Address
port: int
def validate_data(self) -> bool:
if not (1 <= self.port <= 65535):
raise ValueError(f"Invalid port received in NeighbourTLV: {self.port}")
return True
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.ip_address = IPv6Address(raw_data[2:18])
self.port = int.from_bytes(raw_data[18:20], sys.byteorder)
def marshal(self) -> bytes:
return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, sys.byteorder) + \
self.ip_address.packed + \
self.port.to_bytes(2, sys.byteorder)
def handle(self, user: Any, sender: Any) -> None:
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
user.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I won't listen to your NeighbourTLV. Please say Hello to me before.")))
return
if (self.ip_address, self.port) in user.addresses:
# This case should never happen (and in our protocol it is not possible),
# but we include this test as a security measure.
return
if not (str(self.ip_address), self.port) in user.neighbours:
peer = user.new_peer(str(self.ip_address), self.port)
peer.potential = True
user.update_peer_table(peer)
# user.add_system_message(f"New potential friend {self.ip_address}:{self.port}!")
@staticmethod
def construct(address: str, port: int) -> "NeighbourTLV":
tlv = NeighbourTLV()
tlv.type = 3
tlv.length = 18
tlv.ip_address = IPv6Address(address)
tlv.port = port
return tlv
class DataTLV(TLV):
type: int = 4
length: int
sender_id: int
nonce: int
data: bytes
def validate_data(self) -> bool:
if len(self.data) >= 256 - 4 - 8:
raise ValueError("The data is too long, the length is larger that one byte.")
return True
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
self.data = raw_data[14:len(self)]
if self.data[-1] == 0:
self.data = self.data[:-1]
self.length -= 1
def marshal(self) -> bytes:
return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, sys.byteorder) + \
self.sender_id.to_bytes(8, sys.byteorder) + \
socket.htonl(self.nonce).to_bytes(4, sys.byteorder) + \
self.data
def handle(self, user: Any, sender: Any) -> None:
"""
A message has been sent. We log it.
"""
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
user.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I won't listen to your DataTLV. Please say Hello to me before.")))
return
if 0 in self.data:
user.send_packet(user.find_peer_by_id(self.sender_id) or sender, Packet.construct(WarningTLV.construct(
f"The length of your DataTLV mismatches. You told me that the length is {len(self.data)} "
f"while a zero was found at index {self.data.index(0)}.")))
self.data = self.data[:self.data.index(0)]
msg = self.data.decode('UTF-8')
# Acknowledge the packet
user.send_packet(sender, Packet.construct(AckTLV.construct(self.sender_id, self.nonce)))
if not user.receive_message_from(self, msg, self.sender_id, self.nonce, sender):
# The message was already received, do not print it on screen
user.add_system_message(f"I was inundated a message which I already knew {self.sender_id, self.nonce}")
return
nickname_match = re.match("(.*): (.*)", msg)
if nickname_match is None:
user.send_packet(sender, Packet.construct(WarningTLV.construct(
"Unable to retrieve your username. Please use the syntax 'nickname: message'")))
else:
nickname = nickname_match.group(1)
author = user.find_peer_by_id(self.sender_id)
if author:
if author.nickname is None:
author.nickname = nickname
elif author.nickname != nickname:
user.send_packet(author, Packet.construct(WarningTLV.construct(
"It seems that you used two different nicknames. "
f"Known nickname: {author.nickname}, found: {nickname}")))
author.nickname = nickname
@staticmethod
def construct(message: str, user: Any) -> "DataTLV":
tlv = DataTLV()
tlv.type = 4
tlv.sender_id = user.id if user else 0
tlv.nonce = user.incr_nonce if user else 0
tlv.data = message.encode("UTF-8")
tlv.length = 12 + len(tlv.data)
if user:
user.incr_nonce += 1
return tlv
class AckTLV(TLV):
type: int = 5
length: int
sender_id: int
nonce: int
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
def marshal(self) -> bytes:
return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, sys.byteorder) + \
self.sender_id.to_bytes(8, sys.byteorder) + \
socket.htonl(self.nonce).to_bytes(4, sys.byteorder)
def handle(self, user: Any, sender: Any) -> None:
"""
When an AckTLV is received, we know that we do not have to inundate that neighbour anymore.
"""
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
user.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I won't listen to your AckTLV. Please say Hello to me before.")))
return
user.add_system_message(f"I received an AckTLV from {sender}")
user.remove_from_inundation(sender, self.sender_id, self.nonce)
@staticmethod
def construct(sender_id: int, nonce: int) -> "AckTLV":
tlv = AckTLV()
tlv.type = 5
tlv.length = 12
tlv.sender_id = sender_id
tlv.nonce = nonce
return tlv
class GoAwayType(Enum):
UNKNOWN = 0
EXIT = 1
TIMEOUT = 2
PROTOCOL_VIOLATION = 3
class GoAwayTLV(TLV):
type: int = 6
length: int
code: GoAwayType
message: str
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.code = GoAwayType(raw_data[2])
self.message = raw_data[3:2 + self.length].decode("UTF-8")
def marshal(self) -> bytes:
return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, sys.byteorder) + \
self.code.value.to_bytes(1, sys.byteorder) + \
self.message.encode("UTF-8")[:self.length - 1]
def handle(self, user: Any, sender: Any) -> None:
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
user.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I won't listen to your GoAwayTLV. Please say Hello to me before.")))
return
if sender.active:
sender.active = False
user.update_peer_table(sender)
user.add_system_message("Someone told me that he went away : " + self.message)
@staticmethod
def construct(ga_type: GoAwayType, message: str) -> "GoAwayTLV":
tlv = GoAwayTLV()
tlv.type = 6
tlv.code = ga_type
tlv.message = message
tlv.length = 1 + len(tlv.message.encode("UTF-8"))
return tlv
class WarningTLV(TLV):
type: int = 7
length: int
message: str
def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0]
self.length = raw_data[1]
self.message = raw_data[2:self.length + 2].decode("UTF-8")
def marshal(self) -> bytes:
return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, sys.byteorder) + \
self.message.encode("UTF-8")[:self.length]
def handle(self, user: Any, sender: Any) -> None:
user.add_message(f"warning: *A client warned you: {self.message}*"
if not user.squinnondation.no_markdown else
"warning: A client warned you: {self.message}")
@staticmethod
def construct(message: str) -> "WarningTLV":
tlv = WarningTLV()
tlv.type = 7
tlv.message = message
tlv.length = len(tlv.message.encode("UTF-8"))
return tlv
class Packet:
"""
A Packet is a wrapper around the raw data that it sent and received to other clients.
"""
magic: int
version: int
body_length: int
body: List[TLV]
def validate_data(self) -> bool:
"""
Ensure that the packet is well-formed.
Raises a ValueError if the packet contains bad data.
"""
if self.magic != 95:
raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
if self.version != 0:
raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
if not (0 <= self.body_length <= 1200):
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
"found: {:d}".format(self.body_length))
return all(tlv.validate_data() for tlv in self.body)
@staticmethod
def unmarshal(data: bytes) -> "Packet":
"""
Read raw data and build the packet wrapper.
Raises a ValueError whenever the data is invalid.
"""
pkt = Packet()
pkt.magic = data[0]
pkt.version = data[1]
pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder))
if len(data) != 4 + pkt.body_length:
raise ValueError(f"Invalid packet length: "
f"declared body length is {pkt.body_length} while {len(data) - 4} bytes are avalaible")
pkt.body = []
read_bytes = 0
while read_bytes < min(len(data) - 4, pkt.body_length):
tlv_type = data[4 + read_bytes]
if not (0 <= tlv_type < len(TLV.tlv_classes())):
raise ValueError(f"TLV type is not supported: {tlv_type}")
tlv_length = data[4 + read_bytes + 1] if tlv_type > 0 else -1
if 2 + tlv_length > pkt.body_length - read_bytes:
raise ValueError(f"TLV length is too long: requesting {tlv_length} bytes, "
f"remaining {pkt.body_length - read_bytes}")
tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + 2 + tlv_length])
pkt.body.append(tlv)
read_bytes += len(tlv)
pkt.validate_data()
return pkt
def marshal(self) -> bytes:
"""
Compute the byte array data associated to the packet.
"""
data = self.magic.to_bytes(1, sys.byteorder)
data += self.version.to_bytes(1, sys.byteorder)
data += socket.htons(self.body_length).to_bytes(2, sys.byteorder)
data += b"".join(tlv.marshal() for tlv in self.body)
return data
def __len__(self) -> int:
"""
Calculates the length, in bytes, of the packet.
"""
return 4 + sum(len(tlv) for tlv in self.body)
def split(self, pkt_size: int) -> List["Packet"]:
"""
If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024),
then we split the packet in sub-packets.
Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet,
then we don't need to split TLVs in smaller TLVs.
"""
packets = []
current_size = 4 # Packet header length
body = []
for tlv in self.body:
if current_size + len(tlv) > pkt_size:
packets.append(Packet.construct(*body))
body.clear()
current_size = 4
body.append(tlv)
current_size += len(tlv)
if body:
packets.append(Packet.construct(*body))
return packets
@staticmethod
def construct(*tlvs: TLV) -> "Packet":
"""
Construct a new packet from the given TLVs and calculate the good lengths
"""
pkt = Packet()
pkt.magic = 95
pkt.version = 0
pkt.body = tlvs
pkt.body_length = sum(len(tlv) for tlv in tlvs)
return pkt