437 lines
14 KiB
Python
437 lines
14 KiB
Python
# Copyright (C) 2020 by eichhornchen, ÿnérant
|
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
import re
|
|
from typing import Any, List, Optional
|
|
from ipaddress import IPv6Address
|
|
from enum import Enum
|
|
import socket
|
|
import sys
|
|
|
|
|
|
class TLV:
|
|
"""
|
|
The Tag-Length-Value contains the different type of data that can be sent.
|
|
TODO: add subclasses for each type of TLV
|
|
"""
|
|
type: int
|
|
length: int
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
"""
|
|
Parse data and construct TLV.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def marshal(self) -> bytes:
|
|
"""
|
|
Translate the TLV into a byte array.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def validate_data(self) -> bool:
|
|
"""
|
|
Ensure that the TLV is well-formed.
|
|
Raises a ValueError if it is not the case.
|
|
TODO: Make some tests
|
|
"""
|
|
return True
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
"""
|
|
Indicates what to do when this TLV is received from a given hazel.
|
|
It is ensured that the data is valid.
|
|
"""
|
|
|
|
def __len__(self) -> int:
|
|
"""
|
|
Returns the total length (in bytes) of the TLV, including the type and the length.
|
|
Except for Pad1, this is 2 plus the length of the body of the TLV.
|
|
"""
|
|
return 2 + self.length
|
|
|
|
@staticmethod
|
|
def tlv_classes() -> list:
|
|
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
|
|
|
|
|
|
class Pad1TLV(TLV):
|
|
"""
|
|
This TLV is simply ignored.
|
|
"""
|
|
type: int = 0
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
"""
|
|
There is nothing to do. We ignore the packet.
|
|
"""
|
|
self.type = raw_data[0]
|
|
|
|
def marshal(self) -> bytes:
|
|
"""
|
|
The TLV is empty.
|
|
"""
|
|
return self.type.to_bytes(1, sys.byteorder)
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
# TODO Add some easter eggs
|
|
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
|
|
|
|
def __len__(self) -> int:
|
|
"""
|
|
A Pad1 has always a length of 1.
|
|
"""
|
|
return 1
|
|
|
|
|
|
class PadNTLV(TLV):
|
|
"""
|
|
This TLV is filled with zeros. It is ignored.
|
|
"""
|
|
type: int = 1
|
|
length: int
|
|
mbz: bytes
|
|
|
|
def validate_data(self) -> bool:
|
|
if self.mbz != int(0).to_bytes(self.length, sys.byteorder):
|
|
raise ValueError("The body of a PadN TLV is not filled with zeros.")
|
|
return True
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
"""
|
|
Store the zero-array, then ignore the packet.
|
|
"""
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.mbz = raw_data[2:len(self)]
|
|
|
|
def marshal(self) -> bytes:
|
|
"""
|
|
Construct the byte array filled by zeros.
|
|
"""
|
|
return self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \
|
|
+ self.mbz[:self.length]
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
# TODO Add some easter eggs
|
|
squirrel.add_system_message(f"I received {self.length} zeros, am I so a bag guy ? :cold_sweat:")
|
|
|
|
|
|
class HelloTLV(TLV):
|
|
type: int = 2
|
|
length: int
|
|
source_id: int
|
|
dest_id: Optional[int]
|
|
|
|
def validate_data(self) -> bool:
|
|
if self.length != 8 and self.length != 16:
|
|
raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
|
|
f"found {self.length}")
|
|
return True
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.source_id = int.from_bytes(raw_data[2:10], sys.byteorder)
|
|
if self.is_long:
|
|
self.dest_id = int.from_bytes(raw_data[10:18], sys.byteorder)
|
|
|
|
def marshal(self) -> bytes:
|
|
data = self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \
|
|
+ self.source_id.to_bytes(8, sys.byteorder)
|
|
if self.dest_id:
|
|
data += self.dest_id.to_bytes(8, sys.byteorder)
|
|
return data
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
# TODO Implement HelloTLV
|
|
squirrel.add_system_message("Aaaawwww, someone spoke to me and said me Hello smiling_face_with_"
|
|
+ (":3_hearts:" if self.is_long else "smiling_eyes:"))
|
|
|
|
@property
|
|
def is_long(self) -> bool:
|
|
return self.length == 16
|
|
|
|
|
|
class NeighbourTLV(TLV):
|
|
type: int = 3
|
|
length: int
|
|
ip_address: IPv6Address
|
|
port: int
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.ip_address = IPv6Address(raw_data[2:18])
|
|
self.port = int.from_bytes(raw_data[18:20], sys.byteorder)
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, sys.byteorder) + \
|
|
self.length.to_bytes(1, sys.byteorder) + \
|
|
self.ip_address.packed + \
|
|
self.port.to_bytes(2, sys.byteorder)
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
# TODO Implement NeighbourTLV
|
|
squirrel.add_system_message("I have a friend!")
|
|
squirrel.add_system_message(f"Welcome {self.ip_address}:{self.port}!")
|
|
|
|
|
|
class DataTLV(TLV):
|
|
type: int = 4
|
|
length: int
|
|
sender_id: int
|
|
nonce: int
|
|
data: bytes
|
|
|
|
def validate_data(self) -> bool:
|
|
if len(self.data) >= 256 - 4 - 8:
|
|
raise ValueError("The data is too long, the length is larger that one byte.")
|
|
return True
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
|
|
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
|
|
self.data = raw_data[14:len(self)]
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, sys.byteorder) + \
|
|
self.length.to_bytes(1, sys.byteorder) + \
|
|
self.sender_id.to_bytes(8, sys.byteorder) + \
|
|
socket.htonl(self.nonce).to_bytes(4, sys.byteorder) + \
|
|
self.data
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
"""
|
|
A message has been sent. We log it.
|
|
TODO: Check that the tuple (sender_id, nonce) is unique to avoid duplicates.
|
|
"""
|
|
msg = self.data.decode('UTF-8')
|
|
if not squirrel.receive_message_from(msg, self.sender_id, self.nonce):
|
|
# The message was already received
|
|
return
|
|
|
|
# Acknowledge the packet
|
|
squirrel.send_packet(sender, Packet.construct(AckTLV.construct(self.sender_id, self.nonce)))
|
|
|
|
nickname_match = re.match("(.*): (.*)", msg)
|
|
if nickname_match is None:
|
|
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
|
|
"Unable to retrieve your username. Please use the syntax 'nickname: message'")))
|
|
else:
|
|
nickname = nickname_match.group(1)
|
|
if sender.nickname is None:
|
|
sender.nickname = nickname
|
|
elif sender.nickname != nickname:
|
|
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
|
|
"It seems that you used two different nicknames. "
|
|
f"Known nickname: {sender.nickname}, found: {nickname}")))
|
|
sender.nickname = nickname
|
|
|
|
@staticmethod
|
|
def construct(message: str, squirrel: Any) -> "DataTLV":
|
|
tlv = DataTLV()
|
|
tlv.type = 4
|
|
tlv.sender_id = squirrel.id if squirrel else 0
|
|
tlv.nonce = squirrel.incr_nonce if squirrel else 0
|
|
tlv.data = message.encode("UTF-8")
|
|
tlv.length = 12 + len(tlv.data)
|
|
|
|
if squirrel:
|
|
squirrel.incr_nonce += 1
|
|
|
|
return tlv
|
|
|
|
|
|
class AckTLV(TLV):
|
|
type: int = 5
|
|
length: int
|
|
sender_id: int
|
|
nonce: int
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
|
|
self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, sys.byteorder) + \
|
|
self.length.to_bytes(1, sys.byteorder) + \
|
|
self.sender_id.to_bytes(8, sys.byteorder) + \
|
|
socket.htonl(self.nonce).to_bytes(4, sys.byteorder)
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
# TODO Implement AckTLV
|
|
squirrel.add_system_message("I received an AckTLV. I don't know what to do with it. Please implement me!")
|
|
|
|
@staticmethod
|
|
def construct(sender_id: int, nonce: int) -> "AckTLV":
|
|
tlv = AckTLV()
|
|
tlv.type = 5
|
|
tlv.length = 12
|
|
tlv.sender_id = sender_id
|
|
tlv.nonce = nonce
|
|
return tlv
|
|
|
|
|
|
class GoAwayTLV(TLV):
|
|
class GoAwayType(Enum):
|
|
UNKNOWN = 0
|
|
EXIT = 1
|
|
TIMEOUT = 2
|
|
PROTOCOL_VIOLATION = 3
|
|
|
|
type: int = 6
|
|
length: int
|
|
code: GoAwayType
|
|
message: str
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.code = GoAwayTLV.GoAwayType(raw_data[2])
|
|
self.message = raw_data[3:self.length - 1].decode("UTF-8")
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, sys.byteorder) + \
|
|
self.length.to_bytes(1, sys.byteorder) + \
|
|
self.code.value.to_bytes(1, sys.byteorder) + \
|
|
self.message.encode("UTF-8")[:self.length - 1]
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
# TODO Implement GoAwayTLV
|
|
squirrel.add_system_message("Some told me that he went away. That's not very nice :( "
|
|
"I should send him some cake.")
|
|
|
|
|
|
class WarningTLV(TLV):
|
|
type: int = 7
|
|
length: int
|
|
message: str
|
|
|
|
def unmarshal(self, raw_data: bytes) -> None:
|
|
self.type = raw_data[0]
|
|
self.length = raw_data[1]
|
|
self.message = raw_data[2:self.length + 2].decode("UTF-8")
|
|
|
|
def marshal(self) -> bytes:
|
|
return self.type.to_bytes(1, sys.byteorder) + \
|
|
self.length.to_bytes(1, sys.byteorder) + \
|
|
self.message.encode("UTF-8")[:self.length]
|
|
|
|
def handle(self, squirrel: Any, sender: Any) -> None:
|
|
squirrel.add_message(f"warning: *A client warned you: {self.message}*"
|
|
if not squirrel.squinnondation.no_markdown else
|
|
f"warning: A client warned you: {self.message}")
|
|
|
|
@staticmethod
|
|
def construct(message: str) -> "WarningTLV":
|
|
tlv = WarningTLV()
|
|
tlv.type = 7
|
|
tlv.message = message
|
|
tlv.length = len(tlv.message.encode("UTF-8"))
|
|
return tlv
|
|
|
|
|
|
class Packet:
|
|
"""
|
|
A Packet is a wrapper around the raw data that it sent and received to other clients.
|
|
"""
|
|
magic: int
|
|
version: int
|
|
body_length: int
|
|
body: List[TLV]
|
|
|
|
def validate_data(self) -> bool:
|
|
"""
|
|
Ensure that the packet is well-formed.
|
|
Raises a ValueError if the packet contains bad data.
|
|
"""
|
|
if self.magic != 95:
|
|
raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
|
|
if self.version != 0:
|
|
raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
|
|
if not (0 <= self.body_length <= 1200):
|
|
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
|
|
"found: {:d}".format(self.body_length))
|
|
return all(tlv.validate_data() for tlv in self.body)
|
|
|
|
@staticmethod
|
|
def unmarshal(data: bytes) -> "Packet":
|
|
"""
|
|
Read raw data and build the packet wrapper.
|
|
Raises a ValueError whenever the data is invalid.
|
|
"""
|
|
pkt = Packet()
|
|
pkt.magic = data[0]
|
|
pkt.version = data[1]
|
|
pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder))
|
|
pkt.body = []
|
|
read_bytes = 0
|
|
while read_bytes < min(len(data) - 4, pkt.body_length):
|
|
tlv_type = data[4 + read_bytes]
|
|
if not (0 <= tlv_type < len(TLV.tlv_classes())):
|
|
raise ValueError(f"TLV type is not supported: {tlv_type}")
|
|
tlv = TLV.tlv_classes()[tlv_type]()
|
|
tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length])
|
|
pkt.body.append(tlv)
|
|
read_bytes += len(tlv)
|
|
|
|
pkt.validate_data()
|
|
|
|
return pkt
|
|
|
|
def marshal(self) -> bytes:
|
|
"""
|
|
Compute the byte array data associated to the packet.
|
|
"""
|
|
data = self.magic.to_bytes(1, sys.byteorder)
|
|
data += self.version.to_bytes(1, sys.byteorder)
|
|
data += socket.htons(self.body_length).to_bytes(2, sys.byteorder)
|
|
data += b"".join(tlv.marshal() for tlv in self.body)
|
|
return data
|
|
|
|
def __len__(self) -> int:
|
|
"""
|
|
Calculates the length, in bytes, of the packet.
|
|
"""
|
|
return 4 + sum(len(tlv) for tlv in self.body)
|
|
|
|
def split(self, pkt_size: int) -> List["Packet"]:
|
|
"""
|
|
If the packet is too large, ie. larger that pkt_size (with pkt_size = 1024),
|
|
then we split the packet in sub-packets.
|
|
Since 1024 - 4 >> 256 + 2, that ensures that we can have at least one TLV per packet,
|
|
then we don't need to split TLVs in smaller TLVs.
|
|
"""
|
|
packets = []
|
|
|
|
current_size = 4 # Packet header length
|
|
body = []
|
|
for tlv in self.body:
|
|
if current_size + len(tlv) > pkt_size:
|
|
packets.append(Packet.construct(*body))
|
|
body.clear()
|
|
current_size = 4
|
|
body.append(tlv)
|
|
current_size += len(tlv)
|
|
|
|
if body:
|
|
packets.append(Packet.construct(*body))
|
|
|
|
return packets
|
|
|
|
@staticmethod
|
|
def construct(*tlvs: TLV) -> "Packet":
|
|
"""
|
|
Construct a new packet from the given TLVs and calculate the good lengths
|
|
"""
|
|
pkt = Packet()
|
|
pkt.magic = 95
|
|
pkt.version = 0
|
|
pkt.body = tlvs
|
|
pkt.body_length = sum(len(tlv) for tlv in tlvs)
|
|
return pkt
|