diff --git a/squinnondation/hazel.py b/squinnondation/hazel.py index 7b173f0..59f03d2 100644 --- a/squinnondation/hazel.py +++ b/squinnondation/hazel.py @@ -11,7 +11,7 @@ import re import socket import time -from .messages import Packet, DataTLV, HelloTLV, GoAwayTLV, GoAwayType, NeighbourTLV +from .messages import Packet, DataTLV, HelloTLV, GoAwayTLV, GoAwayType, NeighbourTLV, WarningTLV class Hazelnut: @@ -21,6 +21,11 @@ class Hazelnut: def __init__(self, nickname: str = None, address: str = "localhost", port: int = 2500): self.nickname = nickname self.id = -1 + self.last_hello_time = 0 + self.last_long_hello_time = 0 + self.symmetric = False + self.active = False + self.errors = 0 try: # Resolve DNS as an IPv6 @@ -31,8 +36,50 @@ class Hazelnut: # See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4 address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0] - self.address = address # IPv6Address(address) - self.port = port + self.addresses = set() + self.addresses.add((address, port)) + + @property + def potential(self) -> bool: + return not self.active and not self.banned + + @potential.setter + def potential(self, value: bool) -> None: + self.active = not value + + @property + def main_address(self) -> Tuple[str, int]: + """ + A client can have multiple addresses. + We contact it only on one of them. + """ + return list(self.addresses)[0] + + @property + def banned(self) -> bool: + """ + If a client send more than 5 invalid packets, we don't trust it anymore. + """ + return self.errors >= 5 + + def __repr__(self): + return self.nickname or str(self.id) or str(self.main_address) + + def __str__(self): + return repr(self) + + def merge(self, other: "Hazelnut") -> "Hazelnut": + """ + Merge the hazelnut data with one other. + The symmetric and active properties are kept from the original client. + """ + self.errors = max(self.errors, other.errors) + self.last_hello_time = max(self.last_hello_time, other.last_hello_time) + self.last_long_hello_time = max(self.last_hello_time, other.last_long_hello_time) + self.addresses.update(self.addresses) + self.addresses.update(other.addresses) + self.id = self.id if self.id > 0 else other.id + return self class Squirrel(Hazelnut): @@ -49,7 +96,7 @@ class Squirrel(Hazelnut): # Create UDP socket self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) # Bind the socket - self.socket.bind((self.address, self.port)) + self.socket.bind(self.main_address) self.squinnondation = instance @@ -74,13 +121,15 @@ class Squirrel(Hazelnut): curses.init_pair(i + 1, i, curses.COLOR_BLACK) # dictionnaries of neighbours - self.potentialhazelnuts = dict() - self.activehazelnuts = dict() # of the form [hazelnut, time of last hello, - # time of last long hello, is symmetric] + self.hazelnuts = dict() self.nbNS = 0 self.minNS = 3 # minimal number of symmetric neighbours a squirrel needs to have. - self.add_system_message(f"Listening on {self.address}:{self.port}") + self.worm = Worm(self) + self.hazel_manager = HazelManager(self) + self.inondator = Inondator(self) + + self.add_system_message(f"Listening on {self.main_address[0]}:{self.main_address[1]}") self.add_system_message(f"I am {self.id}") def new_hazel(self, address: str, port: int) -> Hazelnut: @@ -90,45 +139,50 @@ class Squirrel(Hazelnut): hazelnut = Hazelnut(address=address, port=port) return hazelnut - def is_active(self, hazel: Hazelnut) -> bool: - return (hazel.address, hazel.port) in self.activehazelnuts + @property + def active_hazelnuts(self) -> set: + return set(hazelnut for hazelnut in self.hazelnuts.values() if hazelnut.active) - def is_potential(self, hazel: Hazelnut) -> bool: - return (hazel.address, hazel.port) in self.potentialhazelnuts - - def remove_from_potential(self, address: str, port: int) -> None: - self.potentialhazelnuts.pop((address, port), None) + @property + def potential_hazelnuts(self) -> set: + return set(hazelnut for hazelnut in self.hazelnuts.values() if hazelnut.potential) def find_hazelnut(self, address: str, port: int) -> Hazelnut: """ Translate an address into a hazelnut. If this hazelnut does not exist, creates a new hazelnut. """ - if (address, port) in self.activehazelnuts: - return self.activehazelnuts[(address, port)][0] + if (address, port) in self.hazelnuts: + return self.hazelnuts[(address, port)] hazelnut = Hazelnut(address=address, port=port) + self.hazelnuts[(address, port)] = hazelnut return hazelnut + def find_hazelnut_by_id(self, hazel_id: int) -> Hazelnut: + """ + Retrieve the hazelnut that is known by its id. Return None if it is unknown. + The given identifier must be positive. + """ + if hazel_id > 0: + for hazelnut in self.hazelnuts.values(): + if hazelnut.id == hazel_id: + return hazelnut + def send_packet(self, client: Hazelnut, pkt: Packet) -> int: """ Send a formatted packet to a client. """ - self.refresh_lock.acquire() if len(pkt) > 1024: # The packet is too large to be sent by the protocol. We split the packet in subpackets. return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024)) res = self.send_raw_data(client, pkt.marshal()) - self.refresh_lock.release() return res def send_raw_data(self, client: Hazelnut, data: bytes) -> int: """ Send a raw packet to a client. """ - self.refresh_lock.acquire() - res = self.socket.sendto(data, (client.address, client.port)) - self.refresh_lock.release() - return res + return self.socket.sendto(data, client.main_address) def receive_packet(self) -> Tuple[Packet, Hazelnut]: """ @@ -136,7 +190,24 @@ class Squirrel(Hazelnut): Warning: the process is blocking, it should be ran inside a dedicated thread. """ data, addr = self.receive_raw_data() - return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1]) + hazelnut = self.find_hazelnut(addr[0], addr[1]) + if hazelnut.banned: + # The client already sent errored packets + return Packet.construct(), hazelnut + try: + pkt = Packet.unmarshal(data) + except ValueError as error: + # The packet contains an error. We memorize it and warn the other user. + hazelnut.errors += 1 + self.send_packet(hazelnut, Packet.construct(WarningTLV.construct( + f"An error occured while reading your packet: {error}"))) + if hazelnut.banned: + self.send_packet(hazelnut, Packet.construct(WarningTLV.construct( + "You got banned since you sent too much errored packets."))) + raise ValueError("Client is banned since there were too many errors.", error) + raise error + else: + return pkt, hazelnut def receive_raw_data(self) -> Tuple[bytes, Any]: """ @@ -148,10 +219,6 @@ class Squirrel(Hazelnut): """ Start asynchronous threads. """ - self.worm = Worm(self) - self.hazel_manager = HazelManager(self) - self.inondator = Inondator(self) - # Kill subthreads when exitting the program self.worm.setDaemon(True) self.hazel_manager.setDaemon(True) @@ -259,8 +326,8 @@ class Squirrel(Hazelnut): self.add_message(msg) pkt = Packet.construct(DataTLV.construct(msg, self)) - for hazelnut in list(self.activehazelnuts.values()): - self.send_packet(hazelnut[0], pkt) + for hazelnut in self.active_hazelnuts: + self.send_packet(hazelnut, pkt) def handle_mouse_click(self, y: int, x: int, attr: int) -> None: """ @@ -341,47 +408,40 @@ class Squirrel(Hazelnut): def make_inundation_dict(self) -> dict: """ - Takes the activehazels dictionnary and returns a list of [hazel, date+random, 0] + Takes the active hazels dictionnary and returns a list of [hazel, date+random, 0] """ res = dict() - hazels = list(self.activehazelnuts.items()) - for key, hazel in hazels: - if hazel[3]: # Only if the neighbour is symmetric + hazels = self.active_hazelnuts + for hazel in hazels: + if hazel.symmetric: next_send = uniform(1, 2) - res[key] = [hazel[0], time.time() + next_send, 0] + res[hazel.main_address] = [hazel, time.time() + next_send, 0] return res def remove_from_inundation(self, hazel: Hazelnut, sender_id: int, nonce: int) -> None: """ Remove the sender from the list of neighbours to be inundated """ - self.refresh_lock.acquire() if (sender_id, nonce) in self.recent_messages: # If a peer is late in its acknowledgement, the absence of the previous if causes an error. - self.recent_messages[(sender_id, nonce)][2].pop((hazel.address, hazel.port), None) + for addr in hazel.addresses: + self.recent_messages[(sender_id, nonce)][2].pop(addr, None) if not self.recent_messages[(sender_id, nonce)][2]: # If dictionnary is empty, remove the message self.recent_messages.pop((sender_id, nonce), None) - self.refresh_lock.release() def clean_inundation(self) -> None: """ Remove messages which are overdue (older than 2 minutes) from the inundation dictionnary. """ - self.refresh_lock.acquire() - for key in self.recent_messages: if time.time() - self.recent_messages[key][1] > 120: self.recent_messages.pop(key) - self.refresh_lock.release() - def main_inundation(self) -> None: """ The main inundation function. """ - self.refresh_lock.acquire() - for key in self.recent_messages: k = list(self.recent_messages[key][2].keys()) for key2 in k: @@ -400,13 +460,10 @@ class Squirrel(Hazelnut): if self.recent_messages[key][2][key2][2] >= 5: # the neighbour is not reactive enough gatlv = GoAwayTLV().construct(GoAwayType.TIMEOUT, f"{self.id} No acknowledge") pkt = Packet().construct(gatlv) - self.send_packet(self.recent_messages[key][2][key2][0], pkt) - self.activehazelnuts.pop(key2) - self.potentialhazelnuts[key2] = self.recent_messages[key][2][key2][0] + hazelnut = self.recent_messages[key][2][key2][0] + self.send_packet(hazelnut, pkt) self.recent_messages[key][2].pop(key2) - self.refresh_lock.release() - def add_system_message(self, msg: str) -> None: """ Add a new system log message. @@ -606,82 +663,81 @@ class Squirrel(Hazelnut): Returns a list of hazelnuts the squirrel should contact if it does not have enough symmetric neighbours. """ - self.refresh_lock.acquire() - res = [] - lp = len(self.potentialhazelnuts) - val = list(self.potentialhazelnuts.values()) + val = list(self.potential_hazelnuts) + lp = len(val) for i in range(min(lp, max(0, self.minNS - self.nbNS))): a = randint(0, lp - 1) res.append(val[a]) - - self.refresh_lock.release() return res def send_hello(self) -> None: """ Sends a long HelloTLV to all active neighbours. """ - self.refresh_lock.acquire() - - for hazelnut in self.activehazelnuts.values(): - htlv = HelloTLV().construct(16, self, hazelnut[0]) + for hazelnut in self.active_hazelnuts: + htlv = HelloTLV().construct(16, self, hazelnut) pkt = Packet().construct(htlv) - self.send_packet(hazelnut[0], pkt) - - self.refresh_lock.release() + self.send_packet(hazelnut, pkt) def verify_activity(self) -> None: """ All neighbours that have not sent a HelloTLV in the last 2 minutes are considered not active. """ - self.refresh_lock.acquire() - - val = list(self.activehazelnuts.values()) # create a copy because the dict size will change + val = list(self.active_hazelnuts) # create a copy because the dict size will change for hazelnut in val: - if time.time() - hazelnut[1] > 2 * 60: + if time.time() - hazelnut.last_hello_time > 2 * 60: gatlv = GoAwayTLV().construct(GoAwayType.TIMEOUT, "you did not talk to me") pkt = Packet().construct(gatlv) - self.send_packet(hazelnut[0], pkt) - self.activehazelnuts.pop((hazelnut[0].address, hazelnut[0].port)) - self.potentialhazelnuts[(hazelnut[0].address, hazelnut[0].port)] = hazelnut[0] + self.send_packet(hazelnut, pkt) + hazelnut.active = False + self.update_hazelnut_table(hazelnut) - self.refresh_lock.release() + def update_hazelnut_table(self, hazelnut: Hazelnut) -> None: + """ + We insert the hazelnut into our table of clients. + If there is a collision with the address / the ID, then we merge clients into a unique one. + """ + for addr in hazelnut.addresses: + if addr in self.hazelnuts: + # Merge with the previous hazel + old_hazel = self.hazelnuts[addr] + hazelnut.merge(old_hazel) + self.hazelnuts[addr] = hazelnut + + for other_hazel in list(self.hazelnuts.values()): + if other_hazel.id == hazelnut.id > 0 and other_hazel != hazelnut: + # The hazelnut with the same id is known as a different address. We merge everything + hazelnut.merge(other_hazel) def send_neighbours(self) -> None: """ Update the number of symmetric neighbours and send all neighbours NeighbourTLV """ - self.refresh_lock.acquire() - nb_ns = 0 # could send the same to all neighbour, but it means that neighbour # A could receive a message with itself in it -> if the others do not pay attention, trouble - for key, hazelnut in self.activehazelnuts.items(): - if time.time() - hazelnut[2] <= 2 * 60: + for hazelnut in self.active_hazelnuts: + if time.time() - hazelnut.last_long_hello_time <= 2 * 60: nb_ns += 1 - self.activehazelnuts[key][3] = True - ntlv = NeighbourTLV().construct(hazelnut[0].address, hazelnut[0].port) + hazelnut.symmetric = True + ntlv = NeighbourTLV().construct(*hazelnut.main_address) pkt = Packet().construct(ntlv) - for destination in self.activehazelnuts.values(): - if destination[0].id != hazelnut[0].id: - self.send_packet(destination[0], pkt) + for destination in self.active_hazelnuts: + if destination.id != hazelnut.id: + self.send_packet(destination, pkt) else: - self.activehazelnuts[key][3] = False + hazelnut.symmetric = False self.nbNS = nb_ns - self.refresh_lock.release() - def leave(self) -> None: """ The program is exited. We send a GoAway to our neighbours, then close the program. """ - self.refresh_lock.acquire() - # Last inundation self.main_inundation() self.clean_inundation() @@ -689,8 +745,8 @@ class Squirrel(Hazelnut): # Broadcast a GoAway gatlv = GoAwayTLV().construct(GoAwayType.EXIT, "I am leaving! Good bye!") pkt = Packet.construct(gatlv) - for hazelnut in self.activehazelnuts.values(): - self.send_packet(hazelnut[0], pkt) + for hazelnut in self.active_hazelnuts: + self.send_packet(hazelnut, pkt) exit(0) @@ -711,7 +767,13 @@ class Worm(Thread): pkt, hazelnut = self.squirrel.receive_packet() except ValueError as error: self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error)) + self.squirrel.refresh_history() + self.squirrel.refresh_input() else: + if hazelnut.banned: + # Ignore banned hazelnuts + continue + for tlv in pkt.body: tlv.handle(self.squirrel, hazelnut) self.squirrel.refresh_history() @@ -746,7 +808,7 @@ class HazelManager(Thread): # Second part: send long HelloTLVs to neighbours every 30 seconds if time.time() - self.last_check > 30: - self.squirrel.add_system_message(f"I have {len(list(self.squirrel.activehazelnuts.values()))} friends") + self.squirrel.add_system_message(f"I have {len(self.squirrel.active_hazelnuts)} friends") self.squirrel.send_hello() self.last_check = time.time() diff --git a/squinnondation/messages.py b/squinnondation/messages.py index d6157a4..bf00e23 100644 --- a/squinnondation/messages.py +++ b/squinnondation/messages.py @@ -12,7 +12,6 @@ import time class TLV: """ The Tag-Length-Value contains the different type of data that can be sent. - TODO: add subclasses for each type of TLV """ type: int length: int @@ -74,7 +73,12 @@ class Pad1TLV(TLV): return self.type.to_bytes(1, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: - # TODO Add some easter eggs + if not sender.active or not sender.symmetric or not sender.id: + # It doesn't say hello, we don't listen to it + squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( + "You are not my neighbour, I don't listen to your Pad1TLV. Please say me Hello before."))) + return + squirrel.add_system_message("I received a Pad1TLV, how disapointing.") def __len__(self) -> int: @@ -119,7 +123,12 @@ class PadNTLV(TLV): + self.mbz[:self.length] def handle(self, squirrel: Any, sender: Any) -> None: - # TODO Add some easter eggs + if not sender.active or not sender.symmetric or not sender.id: + # It doesn't say hello, we don't listen to it + squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( + "You are not my neighbour, I don't listen to your PadNTLV. Please say me Hello before."))) + return + squirrel.add_system_message(f"I received {self.length} zeros.") @staticmethod @@ -159,22 +168,33 @@ class HelloTLV(TLV): def handle(self, squirrel: Any, sender: Any) -> None: time_h = time.time() - if not squirrel.is_active(sender): + + if sender.id > 0 and sender.id != self.source_id: + squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( + f"You were known as the ID {sender.id}, but you declared that you have the ID {self.source_id}."))) + squirrel.add_system_message(f"A client known as the id {sender.id} declared that it uses " + f"the id {self.source_id}.") + sender.id = self.source_id + + if not sender.active: sender.id = self.source_id # The sender we are given misses an id time_hl = time.time() else: - time_hl = squirrel.activehazelnuts[(sender.address, sender.port)][2] + time_hl = sender.last_long_hello_time if self.is_long and self.dest_id == squirrel.id: time_hl = time.time() - # Make sure the sender is not in the potential hazelnuts - squirrel.remove_from_potential(sender.address, sender.port) - # Add entry to/actualize the active hazelnuts dictionnary - squirrel.activehazelnuts[(sender.address, sender.port)] = [sender, time_h, time_hl, True] + sender.last_hello_time = time_h + sender.last_long_hello_time = time_hl + sender.symmetric = True + sender.active = True + squirrel.update_hazelnut_table(sender) squirrel.nbNS += 1 - # squirrel.add_system_message(f"Aaaawwww, {self.source_id} spoke to me and said Hello " - # + ("long" if self.is_long else "short")) + squirrel.add_system_message(f"{self.source_id} sent me a Hello " + ("long" if self.is_long else "short")) + + if not self.is_long: + squirrel.send_packet(sender, Packet.construct(HelloTLV.construct(16, squirrel, sender))) @property def is_long(self) -> bool: @@ -219,14 +239,20 @@ class NeighbourTLV(TLV): self.port.to_bytes(2, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: - if squirrel.address == str(self.ip_address) and squirrel.port == self.port: + if not sender.active or not sender.symmetric or not sender.id: + # It doesn't say hello, we don't listen to it + squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( + "You are not my neighbour, I don't listen to your NeighbourTLV. Please say me Hello before."))) + return + + if (self.ip_address, self.port) in squirrel.addresses: # This case should never happen (and in our protocol it is not possible), # but we include this test as a security measure. return - if not (str(self.ip_address), self.port) in squirrel.activehazelnuts \ - and not (str(self.ip_address), self.port) in squirrel.potentialhazelnuts: - squirrel.potentialhazelnuts[(str(self.ip_address), self.port)] = \ - squirrel.new_hazel(str(self.ip_address), self.port) + if not (str(self.ip_address), self.port) in squirrel.hazelnuts: + hazelnut = squirrel.new_hazel(str(self.ip_address), self.port) + hazelnut.potential = True + squirrel.update_hazelnut_table(hazelnut) # squirrel.add_system_message(f"New potential friend {self.ip_address}:{self.port}!") @staticmethod @@ -269,6 +295,12 @@ class DataTLV(TLV): """ A message has been sent. We log it. """ + if not sender.active or not sender.symmetric or not sender.id: + # It doesn't say hello, we don't listen to it + squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( + "You are not my neighbour, I don't listen to your DataTLV. Please say me Hello before."))) + return + msg = self.data.decode('UTF-8') # Acknowledge the packet @@ -285,13 +317,15 @@ class DataTLV(TLV): "Unable to retrieve your username. Please use the syntax 'nickname: message'"))) else: nickname = nickname_match.group(1) - if sender.nickname is None: - sender.nickname = nickname - elif sender.nickname != nickname: - squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( - "It seems that you used two different nicknames. " - f"Known nickname: {sender.nickname}, found: {nickname}"))) - sender.nickname = nickname + author = squirrel.find_hazelnut_by_id(self.sender_id) + if author: + if author.nickname is None: + author.nickname = nickname + elif author.nickname != nickname: + squirrel.send_packet(author, Packet.construct(WarningTLV.construct( + "It seems that you used two different nicknames. " + f"Known nickname: {author.nickname}, found: {nickname}"))) + author.nickname = nickname @staticmethod def construct(message: str, squirrel: Any) -> "DataTLV": @@ -330,7 +364,13 @@ class AckTLV(TLV): """ When an AckTLV is received, we know that we do not have to inundate that neighbour anymore. """ - squirrel.add_system_message("I received an AckTLV") + if not sender.active or not sender.symmetric or not sender.id: + # It doesn't say hello, we don't listen to it + squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( + "You are not my neighbour, I don't listen to your AckTLV. Please say me Hello before."))) + return + + squirrel.add_system_message(f"I received an AckTLV from {sender}") squirrel.remove_from_inundation(sender, self.sender_id, self.nonce) @staticmethod @@ -369,9 +409,15 @@ class GoAwayTLV(TLV): self.message.encode("UTF-8")[:self.length - 1] def handle(self, squirrel: Any, sender: Any) -> None: - if squirrel.is_active(sender): - squirrel.activehazelnuts.pop((sender.address, sender.port)) - squirrel.potentialhazelnuts[(sender.address, sender.port)] = sender + if not sender.active or not sender.symmetric or not sender.id: + # It doesn't say hello, we don't listen to it + squirrel.send_packet(sender, Packet.construct(WarningTLV.construct( + "You are not my neighbour, I don't listen to your GoAwayTLV. Please say me Hello before."))) + return + + if sender.active: + sender.active = False + squirrel.update_hazelnut_table(sender) squirrel.add_system_message("Some told me that he went away : " + self.message) @staticmethod @@ -446,14 +492,21 @@ class Packet: pkt.magic = data[0] pkt.version = data[1] pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder)) + if len(data) != 4 + pkt.body_length: + raise ValueError(f"Invalid packet length: " + f"declared body length is {pkt.body_length} while {len(data) - 4} bytes are avalaible") pkt.body = [] read_bytes = 0 while read_bytes < min(len(data) - 4, pkt.body_length): tlv_type = data[4 + read_bytes] if not (0 <= tlv_type < len(TLV.tlv_classes())): raise ValueError(f"TLV type is not supported: {tlv_type}") + tlv_length = data[4 + read_bytes + 1] if tlv_type > 0 else -1 + if 2 + tlv_length > pkt.body_length - read_bytes: + raise ValueError(f"TLV length is too long: requesting {tlv_length} bytes, " + f"remaining {pkt.body_length - read_bytes}") tlv = TLV.tlv_classes()[tlv_type]() - tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length]) + tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + 2 + tlv_length]) pkt.body.append(tlv) read_bytes += len(tlv) diff --git a/squinnondation/squinnondation.py b/squinnondation/squinnondation.py index 22c9999..f17da92 100644 --- a/squinnondation/squinnondation.py +++ b/squinnondation/squinnondation.py @@ -82,9 +82,5 @@ class Squinnondation: pkt = Packet().construct(htlv) squirrel.send_packet(hazelnut, pkt) - # if squirrel.port != 8082: - # hazelnut = Hazelnut(address='::1', port=8082) - # squirrel.potentialhazelnuts['::1', 8082] = hazelnut - squirrel.start_threads() squirrel.wait_for_key()