Merge branch 'security' into 'master'

Security

See merge request ynerant/squinnondation!6
This commit is contained in:
ynerant 2021-01-05 22:40:54 +01:00
commit a92df73a55
3 changed files with 230 additions and 119 deletions

View File

@ -11,7 +11,7 @@ import re
import socket import socket
import time import time
from .messages import Packet, DataTLV, HelloTLV, GoAwayTLV, GoAwayType, NeighbourTLV from .messages import Packet, DataTLV, HelloTLV, GoAwayTLV, GoAwayType, NeighbourTLV, WarningTLV
class Hazelnut: class Hazelnut:
@ -21,6 +21,11 @@ class Hazelnut:
def __init__(self, nickname: str = None, address: str = "localhost", port: int = 2500): def __init__(self, nickname: str = None, address: str = "localhost", port: int = 2500):
self.nickname = nickname self.nickname = nickname
self.id = -1 self.id = -1
self.last_hello_time = 0
self.last_long_hello_time = 0
self.symmetric = False
self.active = False
self.errors = 0
try: try:
# Resolve DNS as an IPv6 # Resolve DNS as an IPv6
@ -31,8 +36,50 @@ class Hazelnut:
# See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4 # See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4
address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0] address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0]
self.address = address # IPv6Address(address) self.addresses = set()
self.port = port self.addresses.add((address, port))
@property
def potential(self) -> bool:
return not self.active and not self.banned
@potential.setter
def potential(self, value: bool) -> None:
self.active = not value
@property
def main_address(self) -> Tuple[str, int]:
"""
A client can have multiple addresses.
We contact it only on one of them.
"""
return list(self.addresses)[0]
@property
def banned(self) -> bool:
"""
If a client send more than 5 invalid packets, we don't trust it anymore.
"""
return self.errors >= 5
def __repr__(self):
return self.nickname or str(self.id) or str(self.main_address)
def __str__(self):
return repr(self)
def merge(self, other: "Hazelnut") -> "Hazelnut":
"""
Merge the hazelnut data with one other.
The symmetric and active properties are kept from the original client.
"""
self.errors = max(self.errors, other.errors)
self.last_hello_time = max(self.last_hello_time, other.last_hello_time)
self.last_long_hello_time = max(self.last_hello_time, other.last_long_hello_time)
self.addresses.update(self.addresses)
self.addresses.update(other.addresses)
self.id = self.id if self.id > 0 else other.id
return self
class Squirrel(Hazelnut): class Squirrel(Hazelnut):
@ -49,7 +96,7 @@ class Squirrel(Hazelnut):
# Create UDP socket # Create UDP socket
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Bind the socket # Bind the socket
self.socket.bind((self.address, self.port)) self.socket.bind(self.main_address)
self.squinnondation = instance self.squinnondation = instance
@ -74,13 +121,15 @@ class Squirrel(Hazelnut):
curses.init_pair(i + 1, i, curses.COLOR_BLACK) curses.init_pair(i + 1, i, curses.COLOR_BLACK)
# dictionnaries of neighbours # dictionnaries of neighbours
self.potentialhazelnuts = dict() self.hazelnuts = dict()
self.activehazelnuts = dict() # of the form [hazelnut, time of last hello,
# time of last long hello, is symmetric]
self.nbNS = 0 self.nbNS = 0
self.minNS = 3 # minimal number of symmetric neighbours a squirrel needs to have. self.minNS = 3 # minimal number of symmetric neighbours a squirrel needs to have.
self.add_system_message(f"Listening on {self.address}:{self.port}") self.worm = Worm(self)
self.hazel_manager = HazelManager(self)
self.inondator = Inondator(self)
self.add_system_message(f"Listening on {self.main_address[0]}:{self.main_address[1]}")
self.add_system_message(f"I am {self.id}") self.add_system_message(f"I am {self.id}")
def new_hazel(self, address: str, port: int) -> Hazelnut: def new_hazel(self, address: str, port: int) -> Hazelnut:
@ -90,45 +139,50 @@ class Squirrel(Hazelnut):
hazelnut = Hazelnut(address=address, port=port) hazelnut = Hazelnut(address=address, port=port)
return hazelnut return hazelnut
def is_active(self, hazel: Hazelnut) -> bool: @property
return (hazel.address, hazel.port) in self.activehazelnuts def active_hazelnuts(self) -> set:
return set(hazelnut for hazelnut in self.hazelnuts.values() if hazelnut.active)
def is_potential(self, hazel: Hazelnut) -> bool: @property
return (hazel.address, hazel.port) in self.potentialhazelnuts def potential_hazelnuts(self) -> set:
return set(hazelnut for hazelnut in self.hazelnuts.values() if hazelnut.potential)
def remove_from_potential(self, address: str, port: int) -> None:
self.potentialhazelnuts.pop((address, port), None)
def find_hazelnut(self, address: str, port: int) -> Hazelnut: def find_hazelnut(self, address: str, port: int) -> Hazelnut:
""" """
Translate an address into a hazelnut. If this hazelnut does not exist, Translate an address into a hazelnut. If this hazelnut does not exist,
creates a new hazelnut. creates a new hazelnut.
""" """
if (address, port) in self.activehazelnuts: if (address, port) in self.hazelnuts:
return self.activehazelnuts[(address, port)][0] return self.hazelnuts[(address, port)]
hazelnut = Hazelnut(address=address, port=port) hazelnut = Hazelnut(address=address, port=port)
self.hazelnuts[(address, port)] = hazelnut
return hazelnut
def find_hazelnut_by_id(self, hazel_id: int) -> Hazelnut:
"""
Retrieve the hazelnut that is known by its id. Return None if it is unknown.
The given identifier must be positive.
"""
if hazel_id > 0:
for hazelnut in self.hazelnuts.values():
if hazelnut.id == hazel_id:
return hazelnut return hazelnut
def send_packet(self, client: Hazelnut, pkt: Packet) -> int: def send_packet(self, client: Hazelnut, pkt: Packet) -> int:
""" """
Send a formatted packet to a client. Send a formatted packet to a client.
""" """
self.refresh_lock.acquire()
if len(pkt) > 1024: if len(pkt) > 1024:
# The packet is too large to be sent by the protocol. We split the packet in subpackets. # The packet is too large to be sent by the protocol. We split the packet in subpackets.
return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024)) return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024))
res = self.send_raw_data(client, pkt.marshal()) res = self.send_raw_data(client, pkt.marshal())
self.refresh_lock.release()
return res return res
def send_raw_data(self, client: Hazelnut, data: bytes) -> int: def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
""" """
Send a raw packet to a client. Send a raw packet to a client.
""" """
self.refresh_lock.acquire() return self.socket.sendto(data, client.main_address)
res = self.socket.sendto(data, (client.address, client.port))
self.refresh_lock.release()
return res
def receive_packet(self) -> Tuple[Packet, Hazelnut]: def receive_packet(self) -> Tuple[Packet, Hazelnut]:
""" """
@ -136,7 +190,24 @@ class Squirrel(Hazelnut):
Warning: the process is blocking, it should be ran inside a dedicated thread. Warning: the process is blocking, it should be ran inside a dedicated thread.
""" """
data, addr = self.receive_raw_data() data, addr = self.receive_raw_data()
return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1]) hazelnut = self.find_hazelnut(addr[0], addr[1])
if hazelnut.banned:
# The client already sent errored packets
return Packet.construct(), hazelnut
try:
pkt = Packet.unmarshal(data)
except ValueError as error:
# The packet contains an error. We memorize it and warn the other user.
hazelnut.errors += 1
self.send_packet(hazelnut, Packet.construct(WarningTLV.construct(
f"An error occured while reading your packet: {error}")))
if hazelnut.banned:
self.send_packet(hazelnut, Packet.construct(WarningTLV.construct(
"You got banned since you sent too much errored packets.")))
raise ValueError("Client is banned since there were too many errors.", error)
raise error
else:
return pkt, hazelnut
def receive_raw_data(self) -> Tuple[bytes, Any]: def receive_raw_data(self) -> Tuple[bytes, Any]:
""" """
@ -148,10 +219,6 @@ class Squirrel(Hazelnut):
""" """
Start asynchronous threads. Start asynchronous threads.
""" """
self.worm = Worm(self)
self.hazel_manager = HazelManager(self)
self.inondator = Inondator(self)
# Kill subthreads when exitting the program # Kill subthreads when exitting the program
self.worm.setDaemon(True) self.worm.setDaemon(True)
self.hazel_manager.setDaemon(True) self.hazel_manager.setDaemon(True)
@ -259,8 +326,8 @@ class Squirrel(Hazelnut):
self.add_message(msg) self.add_message(msg)
pkt = Packet.construct(DataTLV.construct(msg, self)) pkt = Packet.construct(DataTLV.construct(msg, self))
for hazelnut in list(self.activehazelnuts.values()): for hazelnut in self.active_hazelnuts:
self.send_packet(hazelnut[0], pkt) self.send_packet(hazelnut, pkt)
def handle_mouse_click(self, y: int, x: int, attr: int) -> None: def handle_mouse_click(self, y: int, x: int, attr: int) -> None:
""" """
@ -344,44 +411,37 @@ class Squirrel(Hazelnut):
Takes the active hazels dictionnary and returns a list of [hazel, date+random, 0] Takes the active hazels dictionnary and returns a list of [hazel, date+random, 0]
""" """
res = dict() res = dict()
hazels = list(self.activehazelnuts.items()) hazels = self.active_hazelnuts
for key, hazel in hazels: for hazel in hazels:
if hazel[3]: # Only if the neighbour is symmetric if hazel.symmetric:
next_send = uniform(1, 2) next_send = uniform(1, 2)
res[key] = [hazel[0], time.time() + next_send, 0] res[hazel.main_address] = [hazel, time.time() + next_send, 0]
return res return res
def remove_from_inundation(self, hazel: Hazelnut, sender_id: int, nonce: int) -> None: def remove_from_inundation(self, hazel: Hazelnut, sender_id: int, nonce: int) -> None:
""" """
Remove the sender from the list of neighbours to be inundated Remove the sender from the list of neighbours to be inundated
""" """
self.refresh_lock.acquire()
if (sender_id, nonce) in self.recent_messages: if (sender_id, nonce) in self.recent_messages:
# If a peer is late in its acknowledgement, the absence of the previous if causes an error. # If a peer is late in its acknowledgement, the absence of the previous if causes an error.
self.recent_messages[(sender_id, nonce)][2].pop((hazel.address, hazel.port), None) for addr in hazel.addresses:
self.recent_messages[(sender_id, nonce)][2].pop(addr, None)
if not self.recent_messages[(sender_id, nonce)][2]: # If dictionnary is empty, remove the message if not self.recent_messages[(sender_id, nonce)][2]: # If dictionnary is empty, remove the message
self.recent_messages.pop((sender_id, nonce), None) self.recent_messages.pop((sender_id, nonce), None)
self.refresh_lock.release()
def clean_inundation(self) -> None: def clean_inundation(self) -> None:
""" """
Remove messages which are overdue (older than 2 minutes) from the inundation dictionnary. Remove messages which are overdue (older than 2 minutes) from the inundation dictionnary.
""" """
self.refresh_lock.acquire()
for key in self.recent_messages: for key in self.recent_messages:
if time.time() - self.recent_messages[key][1] > 120: if time.time() - self.recent_messages[key][1] > 120:
self.recent_messages.pop(key) self.recent_messages.pop(key)
self.refresh_lock.release()
def main_inundation(self) -> None: def main_inundation(self) -> None:
""" """
The main inundation function. The main inundation function.
""" """
self.refresh_lock.acquire()
for key in self.recent_messages: for key in self.recent_messages:
k = list(self.recent_messages[key][2].keys()) k = list(self.recent_messages[key][2].keys())
for key2 in k: for key2 in k:
@ -400,13 +460,10 @@ class Squirrel(Hazelnut):
if self.recent_messages[key][2][key2][2] >= 5: # the neighbour is not reactive enough if self.recent_messages[key][2][key2][2] >= 5: # the neighbour is not reactive enough
gatlv = GoAwayTLV().construct(GoAwayType.TIMEOUT, f"{self.id} No acknowledge") gatlv = GoAwayTLV().construct(GoAwayType.TIMEOUT, f"{self.id} No acknowledge")
pkt = Packet().construct(gatlv) pkt = Packet().construct(gatlv)
self.send_packet(self.recent_messages[key][2][key2][0], pkt) hazelnut = self.recent_messages[key][2][key2][0]
self.activehazelnuts.pop(key2) self.send_packet(hazelnut, pkt)
self.potentialhazelnuts[key2] = self.recent_messages[key][2][key2][0]
self.recent_messages[key][2].pop(key2) self.recent_messages[key][2].pop(key2)
self.refresh_lock.release()
def add_system_message(self, msg: str) -> None: def add_system_message(self, msg: str) -> None:
""" """
Add a new system log message. Add a new system log message.
@ -606,82 +663,81 @@ class Squirrel(Hazelnut):
Returns a list of hazelnuts the squirrel should contact if it does Returns a list of hazelnuts the squirrel should contact if it does
not have enough symmetric neighbours. not have enough symmetric neighbours.
""" """
self.refresh_lock.acquire()
res = [] res = []
lp = len(self.potentialhazelnuts) val = list(self.potential_hazelnuts)
val = list(self.potentialhazelnuts.values()) lp = len(val)
for i in range(min(lp, max(0, self.minNS - self.nbNS))): for i in range(min(lp, max(0, self.minNS - self.nbNS))):
a = randint(0, lp - 1) a = randint(0, lp - 1)
res.append(val[a]) res.append(val[a])
self.refresh_lock.release()
return res return res
def send_hello(self) -> None: def send_hello(self) -> None:
""" """
Sends a long HelloTLV to all active neighbours. Sends a long HelloTLV to all active neighbours.
""" """
self.refresh_lock.acquire() for hazelnut in self.active_hazelnuts:
htlv = HelloTLV().construct(16, self, hazelnut)
for hazelnut in self.activehazelnuts.values():
htlv = HelloTLV().construct(16, self, hazelnut[0])
pkt = Packet().construct(htlv) pkt = Packet().construct(htlv)
self.send_packet(hazelnut[0], pkt) self.send_packet(hazelnut, pkt)
self.refresh_lock.release()
def verify_activity(self) -> None: def verify_activity(self) -> None:
""" """
All neighbours that have not sent a HelloTLV in the last 2 All neighbours that have not sent a HelloTLV in the last 2
minutes are considered not active. minutes are considered not active.
""" """
self.refresh_lock.acquire() val = list(self.active_hazelnuts) # create a copy because the dict size will change
val = list(self.activehazelnuts.values()) # create a copy because the dict size will change
for hazelnut in val: for hazelnut in val:
if time.time() - hazelnut[1] > 2 * 60: if time.time() - hazelnut.last_hello_time > 2 * 60:
gatlv = GoAwayTLV().construct(GoAwayType.TIMEOUT, "you did not talk to me") gatlv = GoAwayTLV().construct(GoAwayType.TIMEOUT, "you did not talk to me")
pkt = Packet().construct(gatlv) pkt = Packet().construct(gatlv)
self.send_packet(hazelnut[0], pkt) self.send_packet(hazelnut, pkt)
self.activehazelnuts.pop((hazelnut[0].address, hazelnut[0].port)) hazelnut.active = False
self.potentialhazelnuts[(hazelnut[0].address, hazelnut[0].port)] = hazelnut[0] self.update_hazelnut_table(hazelnut)
self.refresh_lock.release() def update_hazelnut_table(self, hazelnut: Hazelnut) -> None:
"""
We insert the hazelnut into our table of clients.
If there is a collision with the address / the ID, then we merge clients into a unique one.
"""
for addr in hazelnut.addresses:
if addr in self.hazelnuts:
# Merge with the previous hazel
old_hazel = self.hazelnuts[addr]
hazelnut.merge(old_hazel)
self.hazelnuts[addr] = hazelnut
for other_hazel in list(self.hazelnuts.values()):
if other_hazel.id == hazelnut.id > 0 and other_hazel != hazelnut:
# The hazelnut with the same id is known as a different address. We merge everything
hazelnut.merge(other_hazel)
def send_neighbours(self) -> None: def send_neighbours(self) -> None:
""" """
Update the number of symmetric neighbours and Update the number of symmetric neighbours and
send all neighbours NeighbourTLV send all neighbours NeighbourTLV
""" """
self.refresh_lock.acquire()
nb_ns = 0 nb_ns = 0
# could send the same to all neighbour, but it means that neighbour # could send the same to all neighbour, but it means that neighbour
# A could receive a message with itself in it -> if the others do not pay attention, trouble # A could receive a message with itself in it -> if the others do not pay attention, trouble
for key, hazelnut in self.activehazelnuts.items(): for hazelnut in self.active_hazelnuts:
if time.time() - hazelnut[2] <= 2 * 60: if time.time() - hazelnut.last_long_hello_time <= 2 * 60:
nb_ns += 1 nb_ns += 1
self.activehazelnuts[key][3] = True hazelnut.symmetric = True
ntlv = NeighbourTLV().construct(hazelnut[0].address, hazelnut[0].port) ntlv = NeighbourTLV().construct(*hazelnut.main_address)
pkt = Packet().construct(ntlv) pkt = Packet().construct(ntlv)
for destination in self.activehazelnuts.values(): for destination in self.active_hazelnuts:
if destination[0].id != hazelnut[0].id: if destination.id != hazelnut.id:
self.send_packet(destination[0], pkt) self.send_packet(destination, pkt)
else: else:
self.activehazelnuts[key][3] = False hazelnut.symmetric = False
self.nbNS = nb_ns self.nbNS = nb_ns
self.refresh_lock.release()
def leave(self) -> None: def leave(self) -> None:
""" """
The program is exited. We send a GoAway to our neighbours, then close the program. The program is exited. We send a GoAway to our neighbours, then close the program.
""" """
self.refresh_lock.acquire()
# Last inundation # Last inundation
self.main_inundation() self.main_inundation()
self.clean_inundation() self.clean_inundation()
@ -689,8 +745,8 @@ class Squirrel(Hazelnut):
# Broadcast a GoAway # Broadcast a GoAway
gatlv = GoAwayTLV().construct(GoAwayType.EXIT, "I am leaving! Good bye!") gatlv = GoAwayTLV().construct(GoAwayType.EXIT, "I am leaving! Good bye!")
pkt = Packet.construct(gatlv) pkt = Packet.construct(gatlv)
for hazelnut in self.activehazelnuts.values(): for hazelnut in self.active_hazelnuts:
self.send_packet(hazelnut[0], pkt) self.send_packet(hazelnut, pkt)
exit(0) exit(0)
@ -711,7 +767,13 @@ class Worm(Thread):
pkt, hazelnut = self.squirrel.receive_packet() pkt, hazelnut = self.squirrel.receive_packet()
except ValueError as error: except ValueError as error:
self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error)) self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error))
self.squirrel.refresh_history()
self.squirrel.refresh_input()
else: else:
if hazelnut.banned:
# Ignore banned hazelnuts
continue
for tlv in pkt.body: for tlv in pkt.body:
tlv.handle(self.squirrel, hazelnut) tlv.handle(self.squirrel, hazelnut)
self.squirrel.refresh_history() self.squirrel.refresh_history()
@ -746,7 +808,7 @@ class HazelManager(Thread):
# Second part: send long HelloTLVs to neighbours every 30 seconds # Second part: send long HelloTLVs to neighbours every 30 seconds
if time.time() - self.last_check > 30: if time.time() - self.last_check > 30:
self.squirrel.add_system_message(f"I have {len(list(self.squirrel.activehazelnuts.values()))} friends") self.squirrel.add_system_message(f"I have {len(self.squirrel.active_hazelnuts)} friends")
self.squirrel.send_hello() self.squirrel.send_hello()
self.last_check = time.time() self.last_check = time.time()

View File

@ -12,7 +12,6 @@ import time
class TLV: class TLV:
""" """
The Tag-Length-Value contains the different type of data that can be sent. The Tag-Length-Value contains the different type of data that can be sent.
TODO: add subclasses for each type of TLV
""" """
type: int type: int
length: int length: int
@ -74,7 +73,12 @@ class Pad1TLV(TLV):
return self.type.to_bytes(1, sys.byteorder) return self.type.to_bytes(1, sys.byteorder)
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
# TODO Add some easter eggs if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your Pad1TLV. Please say me Hello before.")))
return
squirrel.add_system_message("I received a Pad1TLV, how disapointing.") squirrel.add_system_message("I received a Pad1TLV, how disapointing.")
def __len__(self) -> int: def __len__(self) -> int:
@ -119,7 +123,12 @@ class PadNTLV(TLV):
+ self.mbz[:self.length] + self.mbz[:self.length]
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
# TODO Add some easter eggs if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your PadNTLV. Please say me Hello before.")))
return
squirrel.add_system_message(f"I received {self.length} zeros.") squirrel.add_system_message(f"I received {self.length} zeros.")
@staticmethod @staticmethod
@ -159,22 +168,33 @@ class HelloTLV(TLV):
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
time_h = time.time() time_h = time.time()
if not squirrel.is_active(sender):
if sender.id > 0 and sender.id != self.source_id:
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
f"You were known as the ID {sender.id}, but you declared that you have the ID {self.source_id}.")))
squirrel.add_system_message(f"A client known as the id {sender.id} declared that it uses "
f"the id {self.source_id}.")
sender.id = self.source_id
if not sender.active:
sender.id = self.source_id # The sender we are given misses an id sender.id = self.source_id # The sender we are given misses an id
time_hl = time.time() time_hl = time.time()
else: else:
time_hl = squirrel.activehazelnuts[(sender.address, sender.port)][2] time_hl = sender.last_long_hello_time
if self.is_long and self.dest_id == squirrel.id: if self.is_long and self.dest_id == squirrel.id:
time_hl = time.time() time_hl = time.time()
# Make sure the sender is not in the potential hazelnuts
squirrel.remove_from_potential(sender.address, sender.port)
# Add entry to/actualize the active hazelnuts dictionnary # Add entry to/actualize the active hazelnuts dictionnary
squirrel.activehazelnuts[(sender.address, sender.port)] = [sender, time_h, time_hl, True] sender.last_hello_time = time_h
sender.last_long_hello_time = time_hl
sender.symmetric = True
sender.active = True
squirrel.update_hazelnut_table(sender)
squirrel.nbNS += 1 squirrel.nbNS += 1
# squirrel.add_system_message(f"Aaaawwww, {self.source_id} spoke to me and said Hello " squirrel.add_system_message(f"{self.source_id} sent me a Hello " + ("long" if self.is_long else "short"))
# + ("long" if self.is_long else "short"))
if not self.is_long:
squirrel.send_packet(sender, Packet.construct(HelloTLV.construct(16, squirrel, sender)))
@property @property
def is_long(self) -> bool: def is_long(self) -> bool:
@ -219,14 +239,20 @@ class NeighbourTLV(TLV):
self.port.to_bytes(2, sys.byteorder) self.port.to_bytes(2, sys.byteorder)
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
if squirrel.address == str(self.ip_address) and squirrel.port == self.port: if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your NeighbourTLV. Please say me Hello before.")))
return
if (self.ip_address, self.port) in squirrel.addresses:
# This case should never happen (and in our protocol it is not possible), # This case should never happen (and in our protocol it is not possible),
# but we include this test as a security measure. # but we include this test as a security measure.
return return
if not (str(self.ip_address), self.port) in squirrel.activehazelnuts \ if not (str(self.ip_address), self.port) in squirrel.hazelnuts:
and not (str(self.ip_address), self.port) in squirrel.potentialhazelnuts: hazelnut = squirrel.new_hazel(str(self.ip_address), self.port)
squirrel.potentialhazelnuts[(str(self.ip_address), self.port)] = \ hazelnut.potential = True
squirrel.new_hazel(str(self.ip_address), self.port) squirrel.update_hazelnut_table(hazelnut)
# squirrel.add_system_message(f"New potential friend {self.ip_address}:{self.port}!") # squirrel.add_system_message(f"New potential friend {self.ip_address}:{self.port}!")
@staticmethod @staticmethod
@ -269,6 +295,12 @@ class DataTLV(TLV):
""" """
A message has been sent. We log it. A message has been sent. We log it.
""" """
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your DataTLV. Please say me Hello before.")))
return
msg = self.data.decode('UTF-8') msg = self.data.decode('UTF-8')
# Acknowledge the packet # Acknowledge the packet
@ -285,13 +317,15 @@ class DataTLV(TLV):
"Unable to retrieve your username. Please use the syntax 'nickname: message'"))) "Unable to retrieve your username. Please use the syntax 'nickname: message'")))
else: else:
nickname = nickname_match.group(1) nickname = nickname_match.group(1)
if sender.nickname is None: author = squirrel.find_hazelnut_by_id(self.sender_id)
sender.nickname = nickname if author:
elif sender.nickname != nickname: if author.nickname is None:
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( author.nickname = nickname
elif author.nickname != nickname:
squirrel.send_packet(author, Packet.construct(WarningTLV.construct(
"It seems that you used two different nicknames. " "It seems that you used two different nicknames. "
f"Known nickname: {sender.nickname}, found: {nickname}"))) f"Known nickname: {author.nickname}, found: {nickname}")))
sender.nickname = nickname author.nickname = nickname
@staticmethod @staticmethod
def construct(message: str, squirrel: Any) -> "DataTLV": def construct(message: str, squirrel: Any) -> "DataTLV":
@ -330,7 +364,13 @@ class AckTLV(TLV):
""" """
When an AckTLV is received, we know that we do not have to inundate that neighbour anymore. When an AckTLV is received, we know that we do not have to inundate that neighbour anymore.
""" """
squirrel.add_system_message("I received an AckTLV") if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your AckTLV. Please say me Hello before.")))
return
squirrel.add_system_message(f"I received an AckTLV from {sender}")
squirrel.remove_from_inundation(sender, self.sender_id, self.nonce) squirrel.remove_from_inundation(sender, self.sender_id, self.nonce)
@staticmethod @staticmethod
@ -369,9 +409,15 @@ class GoAwayTLV(TLV):
self.message.encode("UTF-8")[:self.length - 1] self.message.encode("UTF-8")[:self.length - 1]
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
if squirrel.is_active(sender): if not sender.active or not sender.symmetric or not sender.id:
squirrel.activehazelnuts.pop((sender.address, sender.port)) # It doesn't say hello, we don't listen to it
squirrel.potentialhazelnuts[(sender.address, sender.port)] = sender squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your GoAwayTLV. Please say me Hello before.")))
return
if sender.active:
sender.active = False
squirrel.update_hazelnut_table(sender)
squirrel.add_system_message("Some told me that he went away : " + self.message) squirrel.add_system_message("Some told me that he went away : " + self.message)
@staticmethod @staticmethod
@ -446,14 +492,21 @@ class Packet:
pkt.magic = data[0] pkt.magic = data[0]
pkt.version = data[1] pkt.version = data[1]
pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder)) pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder))
if len(data) != 4 + pkt.body_length:
raise ValueError(f"Invalid packet length: "
f"declared body length is {pkt.body_length} while {len(data) - 4} bytes are avalaible")
pkt.body = [] pkt.body = []
read_bytes = 0 read_bytes = 0
while read_bytes < min(len(data) - 4, pkt.body_length): while read_bytes < min(len(data) - 4, pkt.body_length):
tlv_type = data[4 + read_bytes] tlv_type = data[4 + read_bytes]
if not (0 <= tlv_type < len(TLV.tlv_classes())): if not (0 <= tlv_type < len(TLV.tlv_classes())):
raise ValueError(f"TLV type is not supported: {tlv_type}") raise ValueError(f"TLV type is not supported: {tlv_type}")
tlv_length = data[4 + read_bytes + 1] if tlv_type > 0 else -1
if 2 + tlv_length > pkt.body_length - read_bytes:
raise ValueError(f"TLV length is too long: requesting {tlv_length} bytes, "
f"remaining {pkt.body_length - read_bytes}")
tlv = TLV.tlv_classes()[tlv_type]() tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length]) tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + 2 + tlv_length])
pkt.body.append(tlv) pkt.body.append(tlv)
read_bytes += len(tlv) read_bytes += len(tlv)

View File

@ -82,9 +82,5 @@ class Squinnondation:
pkt = Packet().construct(htlv) pkt = Packet().construct(htlv)
squirrel.send_packet(hazelnut, pkt) squirrel.send_packet(hazelnut, pkt)
# if squirrel.port != 8082:
# hazelnut = Hazelnut(address='::1', port=8082)
# squirrel.potentialhazelnuts['::1', 8082] = hazelnut
squirrel.start_threads() squirrel.start_threads()
squirrel.wait_for_key() squirrel.wait_for_key()