Merge branch 'security' into 'master'

Security

See merge request ynerant/squinnondation!6
This commit is contained in:
ynerant 2021-01-05 22:40:54 +01:00
commit a92df73a55
3 changed files with 230 additions and 119 deletions

View File

@ -11,7 +11,7 @@ import re
import socket
import time
from .messages import Packet, DataTLV, HelloTLV, GoAwayTLV, GoAwayType, NeighbourTLV
from .messages import Packet, DataTLV, HelloTLV, GoAwayTLV, GoAwayType, NeighbourTLV, WarningTLV
class Hazelnut:
@ -21,6 +21,11 @@ class Hazelnut:
def __init__(self, nickname: str = None, address: str = "localhost", port: int = 2500):
self.nickname = nickname
self.id = -1
self.last_hello_time = 0
self.last_long_hello_time = 0
self.symmetric = False
self.active = False
self.errors = 0
try:
# Resolve DNS as an IPv6
@ -31,8 +36,50 @@ class Hazelnut:
# See https://fr.wikipedia.org/wiki/Adresse_IPv6_mappant_IPv4
address = "::ffff:" + socket.getaddrinfo(address, None, socket.AF_INET)[0][4][0]
self.address = address # IPv6Address(address)
self.port = port
self.addresses = set()
self.addresses.add((address, port))
@property
def potential(self) -> bool:
return not self.active and not self.banned
@potential.setter
def potential(self, value: bool) -> None:
self.active = not value
@property
def main_address(self) -> Tuple[str, int]:
"""
A client can have multiple addresses.
We contact it only on one of them.
"""
return list(self.addresses)[0]
@property
def banned(self) -> bool:
"""
If a client send more than 5 invalid packets, we don't trust it anymore.
"""
return self.errors >= 5
def __repr__(self):
return self.nickname or str(self.id) or str(self.main_address)
def __str__(self):
return repr(self)
def merge(self, other: "Hazelnut") -> "Hazelnut":
"""
Merge the hazelnut data with one other.
The symmetric and active properties are kept from the original client.
"""
self.errors = max(self.errors, other.errors)
self.last_hello_time = max(self.last_hello_time, other.last_hello_time)
self.last_long_hello_time = max(self.last_hello_time, other.last_long_hello_time)
self.addresses.update(self.addresses)
self.addresses.update(other.addresses)
self.id = self.id if self.id > 0 else other.id
return self
class Squirrel(Hazelnut):
@ -49,7 +96,7 @@ class Squirrel(Hazelnut):
# Create UDP socket
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Bind the socket
self.socket.bind((self.address, self.port))
self.socket.bind(self.main_address)
self.squinnondation = instance
@ -74,13 +121,15 @@ class Squirrel(Hazelnut):
curses.init_pair(i + 1, i, curses.COLOR_BLACK)
# dictionnaries of neighbours
self.potentialhazelnuts = dict()
self.activehazelnuts = dict() # of the form [hazelnut, time of last hello,
# time of last long hello, is symmetric]
self.hazelnuts = dict()
self.nbNS = 0
self.minNS = 3 # minimal number of symmetric neighbours a squirrel needs to have.
self.add_system_message(f"Listening on {self.address}:{self.port}")
self.worm = Worm(self)
self.hazel_manager = HazelManager(self)
self.inondator = Inondator(self)
self.add_system_message(f"Listening on {self.main_address[0]}:{self.main_address[1]}")
self.add_system_message(f"I am {self.id}")
def new_hazel(self, address: str, port: int) -> Hazelnut:
@ -90,45 +139,50 @@ class Squirrel(Hazelnut):
hazelnut = Hazelnut(address=address, port=port)
return hazelnut
def is_active(self, hazel: Hazelnut) -> bool:
return (hazel.address, hazel.port) in self.activehazelnuts
@property
def active_hazelnuts(self) -> set:
return set(hazelnut for hazelnut in self.hazelnuts.values() if hazelnut.active)
def is_potential(self, hazel: Hazelnut) -> bool:
return (hazel.address, hazel.port) in self.potentialhazelnuts
def remove_from_potential(self, address: str, port: int) -> None:
self.potentialhazelnuts.pop((address, port), None)
@property
def potential_hazelnuts(self) -> set:
return set(hazelnut for hazelnut in self.hazelnuts.values() if hazelnut.potential)
def find_hazelnut(self, address: str, port: int) -> Hazelnut:
"""
Translate an address into a hazelnut. If this hazelnut does not exist,
creates a new hazelnut.
"""
if (address, port) in self.activehazelnuts:
return self.activehazelnuts[(address, port)][0]
if (address, port) in self.hazelnuts:
return self.hazelnuts[(address, port)]
hazelnut = Hazelnut(address=address, port=port)
self.hazelnuts[(address, port)] = hazelnut
return hazelnut
def find_hazelnut_by_id(self, hazel_id: int) -> Hazelnut:
"""
Retrieve the hazelnut that is known by its id. Return None if it is unknown.
The given identifier must be positive.
"""
if hazel_id > 0:
for hazelnut in self.hazelnuts.values():
if hazelnut.id == hazel_id:
return hazelnut
def send_packet(self, client: Hazelnut, pkt: Packet) -> int:
"""
Send a formatted packet to a client.
"""
self.refresh_lock.acquire()
if len(pkt) > 1024:
# The packet is too large to be sent by the protocol. We split the packet in subpackets.
return sum(self.send_packet(client, subpkt) for subpkt in pkt.split(1024))
res = self.send_raw_data(client, pkt.marshal())
self.refresh_lock.release()
return res
def send_raw_data(self, client: Hazelnut, data: bytes) -> int:
"""
Send a raw packet to a client.
"""
self.refresh_lock.acquire()
res = self.socket.sendto(data, (client.address, client.port))
self.refresh_lock.release()
return res
return self.socket.sendto(data, client.main_address)
def receive_packet(self) -> Tuple[Packet, Hazelnut]:
"""
@ -136,7 +190,24 @@ class Squirrel(Hazelnut):
Warning: the process is blocking, it should be ran inside a dedicated thread.
"""
data, addr = self.receive_raw_data()
return Packet.unmarshal(data), self.find_hazelnut(addr[0], addr[1])
hazelnut = self.find_hazelnut(addr[0], addr[1])
if hazelnut.banned:
# The client already sent errored packets
return Packet.construct(), hazelnut
try:
pkt = Packet.unmarshal(data)
except ValueError as error:
# The packet contains an error. We memorize it and warn the other user.
hazelnut.errors += 1
self.send_packet(hazelnut, Packet.construct(WarningTLV.construct(
f"An error occured while reading your packet: {error}")))
if hazelnut.banned:
self.send_packet(hazelnut, Packet.construct(WarningTLV.construct(
"You got banned since you sent too much errored packets.")))
raise ValueError("Client is banned since there were too many errors.", error)
raise error
else:
return pkt, hazelnut
def receive_raw_data(self) -> Tuple[bytes, Any]:
"""
@ -148,10 +219,6 @@ class Squirrel(Hazelnut):
"""
Start asynchronous threads.
"""
self.worm = Worm(self)
self.hazel_manager = HazelManager(self)
self.inondator = Inondator(self)
# Kill subthreads when exitting the program
self.worm.setDaemon(True)
self.hazel_manager.setDaemon(True)
@ -259,8 +326,8 @@ class Squirrel(Hazelnut):
self.add_message(msg)
pkt = Packet.construct(DataTLV.construct(msg, self))
for hazelnut in list(self.activehazelnuts.values()):
self.send_packet(hazelnut[0], pkt)
for hazelnut in self.active_hazelnuts:
self.send_packet(hazelnut, pkt)
def handle_mouse_click(self, y: int, x: int, attr: int) -> None:
"""
@ -344,44 +411,37 @@ class Squirrel(Hazelnut):
Takes the active hazels dictionnary and returns a list of [hazel, date+random, 0]
"""
res = dict()
hazels = list(self.activehazelnuts.items())
for key, hazel in hazels:
if hazel[3]: # Only if the neighbour is symmetric
hazels = self.active_hazelnuts
for hazel in hazels:
if hazel.symmetric:
next_send = uniform(1, 2)
res[key] = [hazel[0], time.time() + next_send, 0]
res[hazel.main_address] = [hazel, time.time() + next_send, 0]
return res
def remove_from_inundation(self, hazel: Hazelnut, sender_id: int, nonce: int) -> None:
"""
Remove the sender from the list of neighbours to be inundated
"""
self.refresh_lock.acquire()
if (sender_id, nonce) in self.recent_messages:
# If a peer is late in its acknowledgement, the absence of the previous if causes an error.
self.recent_messages[(sender_id, nonce)][2].pop((hazel.address, hazel.port), None)
for addr in hazel.addresses:
self.recent_messages[(sender_id, nonce)][2].pop(addr, None)
if not self.recent_messages[(sender_id, nonce)][2]: # If dictionnary is empty, remove the message
self.recent_messages.pop((sender_id, nonce), None)
self.refresh_lock.release()
def clean_inundation(self) -> None:
"""
Remove messages which are overdue (older than 2 minutes) from the inundation dictionnary.
"""
self.refresh_lock.acquire()
for key in self.recent_messages:
if time.time() - self.recent_messages[key][1] > 120:
self.recent_messages.pop(key)
self.refresh_lock.release()
def main_inundation(self) -> None:
"""
The main inundation function.
"""
self.refresh_lock.acquire()
for key in self.recent_messages:
k = list(self.recent_messages[key][2].keys())
for key2 in k:
@ -400,13 +460,10 @@ class Squirrel(Hazelnut):
if self.recent_messages[key][2][key2][2] >= 5: # the neighbour is not reactive enough
gatlv = GoAwayTLV().construct(GoAwayType.TIMEOUT, f"{self.id} No acknowledge")
pkt = Packet().construct(gatlv)
self.send_packet(self.recent_messages[key][2][key2][0], pkt)
self.activehazelnuts.pop(key2)
self.potentialhazelnuts[key2] = self.recent_messages[key][2][key2][0]
hazelnut = self.recent_messages[key][2][key2][0]
self.send_packet(hazelnut, pkt)
self.recent_messages[key][2].pop(key2)
self.refresh_lock.release()
def add_system_message(self, msg: str) -> None:
"""
Add a new system log message.
@ -606,82 +663,81 @@ class Squirrel(Hazelnut):
Returns a list of hazelnuts the squirrel should contact if it does
not have enough symmetric neighbours.
"""
self.refresh_lock.acquire()
res = []
lp = len(self.potentialhazelnuts)
val = list(self.potentialhazelnuts.values())
val = list(self.potential_hazelnuts)
lp = len(val)
for i in range(min(lp, max(0, self.minNS - self.nbNS))):
a = randint(0, lp - 1)
res.append(val[a])
self.refresh_lock.release()
return res
def send_hello(self) -> None:
"""
Sends a long HelloTLV to all active neighbours.
"""
self.refresh_lock.acquire()
for hazelnut in self.activehazelnuts.values():
htlv = HelloTLV().construct(16, self, hazelnut[0])
for hazelnut in self.active_hazelnuts:
htlv = HelloTLV().construct(16, self, hazelnut)
pkt = Packet().construct(htlv)
self.send_packet(hazelnut[0], pkt)
self.refresh_lock.release()
self.send_packet(hazelnut, pkt)
def verify_activity(self) -> None:
"""
All neighbours that have not sent a HelloTLV in the last 2
minutes are considered not active.
"""
self.refresh_lock.acquire()
val = list(self.activehazelnuts.values()) # create a copy because the dict size will change
val = list(self.active_hazelnuts) # create a copy because the dict size will change
for hazelnut in val:
if time.time() - hazelnut[1] > 2 * 60:
if time.time() - hazelnut.last_hello_time > 2 * 60:
gatlv = GoAwayTLV().construct(GoAwayType.TIMEOUT, "you did not talk to me")
pkt = Packet().construct(gatlv)
self.send_packet(hazelnut[0], pkt)
self.activehazelnuts.pop((hazelnut[0].address, hazelnut[0].port))
self.potentialhazelnuts[(hazelnut[0].address, hazelnut[0].port)] = hazelnut[0]
self.send_packet(hazelnut, pkt)
hazelnut.active = False
self.update_hazelnut_table(hazelnut)
self.refresh_lock.release()
def update_hazelnut_table(self, hazelnut: Hazelnut) -> None:
"""
We insert the hazelnut into our table of clients.
If there is a collision with the address / the ID, then we merge clients into a unique one.
"""
for addr in hazelnut.addresses:
if addr in self.hazelnuts:
# Merge with the previous hazel
old_hazel = self.hazelnuts[addr]
hazelnut.merge(old_hazel)
self.hazelnuts[addr] = hazelnut
for other_hazel in list(self.hazelnuts.values()):
if other_hazel.id == hazelnut.id > 0 and other_hazel != hazelnut:
# The hazelnut with the same id is known as a different address. We merge everything
hazelnut.merge(other_hazel)
def send_neighbours(self) -> None:
"""
Update the number of symmetric neighbours and
send all neighbours NeighbourTLV
"""
self.refresh_lock.acquire()
nb_ns = 0
# could send the same to all neighbour, but it means that neighbour
# A could receive a message with itself in it -> if the others do not pay attention, trouble
for key, hazelnut in self.activehazelnuts.items():
if time.time() - hazelnut[2] <= 2 * 60:
for hazelnut in self.active_hazelnuts:
if time.time() - hazelnut.last_long_hello_time <= 2 * 60:
nb_ns += 1
self.activehazelnuts[key][3] = True
ntlv = NeighbourTLV().construct(hazelnut[0].address, hazelnut[0].port)
hazelnut.symmetric = True
ntlv = NeighbourTLV().construct(*hazelnut.main_address)
pkt = Packet().construct(ntlv)
for destination in self.activehazelnuts.values():
if destination[0].id != hazelnut[0].id:
self.send_packet(destination[0], pkt)
for destination in self.active_hazelnuts:
if destination.id != hazelnut.id:
self.send_packet(destination, pkt)
else:
self.activehazelnuts[key][3] = False
hazelnut.symmetric = False
self.nbNS = nb_ns
self.refresh_lock.release()
def leave(self) -> None:
"""
The program is exited. We send a GoAway to our neighbours, then close the program.
"""
self.refresh_lock.acquire()
# Last inundation
self.main_inundation()
self.clean_inundation()
@ -689,8 +745,8 @@ class Squirrel(Hazelnut):
# Broadcast a GoAway
gatlv = GoAwayTLV().construct(GoAwayType.EXIT, "I am leaving! Good bye!")
pkt = Packet.construct(gatlv)
for hazelnut in self.activehazelnuts.values():
self.send_packet(hazelnut[0], pkt)
for hazelnut in self.active_hazelnuts:
self.send_packet(hazelnut, pkt)
exit(0)
@ -711,7 +767,13 @@ class Worm(Thread):
pkt, hazelnut = self.squirrel.receive_packet()
except ValueError as error:
self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error))
self.squirrel.refresh_history()
self.squirrel.refresh_input()
else:
if hazelnut.banned:
# Ignore banned hazelnuts
continue
for tlv in pkt.body:
tlv.handle(self.squirrel, hazelnut)
self.squirrel.refresh_history()
@ -746,7 +808,7 @@ class HazelManager(Thread):
# Second part: send long HelloTLVs to neighbours every 30 seconds
if time.time() - self.last_check > 30:
self.squirrel.add_system_message(f"I have {len(list(self.squirrel.activehazelnuts.values()))} friends")
self.squirrel.add_system_message(f"I have {len(self.squirrel.active_hazelnuts)} friends")
self.squirrel.send_hello()
self.last_check = time.time()

View File

@ -12,7 +12,6 @@ import time
class TLV:
"""
The Tag-Length-Value contains the different type of data that can be sent.
TODO: add subclasses for each type of TLV
"""
type: int
length: int
@ -74,7 +73,12 @@ class Pad1TLV(TLV):
return self.type.to_bytes(1, sys.byteorder)
def handle(self, squirrel: Any, sender: Any) -> None:
# TODO Add some easter eggs
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your Pad1TLV. Please say me Hello before.")))
return
squirrel.add_system_message("I received a Pad1TLV, how disapointing.")
def __len__(self) -> int:
@ -119,7 +123,12 @@ class PadNTLV(TLV):
+ self.mbz[:self.length]
def handle(self, squirrel: Any, sender: Any) -> None:
# TODO Add some easter eggs
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your PadNTLV. Please say me Hello before.")))
return
squirrel.add_system_message(f"I received {self.length} zeros.")
@staticmethod
@ -159,22 +168,33 @@ class HelloTLV(TLV):
def handle(self, squirrel: Any, sender: Any) -> None:
time_h = time.time()
if not squirrel.is_active(sender):
if sender.id > 0 and sender.id != self.source_id:
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
f"You were known as the ID {sender.id}, but you declared that you have the ID {self.source_id}.")))
squirrel.add_system_message(f"A client known as the id {sender.id} declared that it uses "
f"the id {self.source_id}.")
sender.id = self.source_id
if not sender.active:
sender.id = self.source_id # The sender we are given misses an id
time_hl = time.time()
else:
time_hl = squirrel.activehazelnuts[(sender.address, sender.port)][2]
time_hl = sender.last_long_hello_time
if self.is_long and self.dest_id == squirrel.id:
time_hl = time.time()
# Make sure the sender is not in the potential hazelnuts
squirrel.remove_from_potential(sender.address, sender.port)
# Add entry to/actualize the active hazelnuts dictionnary
squirrel.activehazelnuts[(sender.address, sender.port)] = [sender, time_h, time_hl, True]
sender.last_hello_time = time_h
sender.last_long_hello_time = time_hl
sender.symmetric = True
sender.active = True
squirrel.update_hazelnut_table(sender)
squirrel.nbNS += 1
# squirrel.add_system_message(f"Aaaawwww, {self.source_id} spoke to me and said Hello "
# + ("long" if self.is_long else "short"))
squirrel.add_system_message(f"{self.source_id} sent me a Hello " + ("long" if self.is_long else "short"))
if not self.is_long:
squirrel.send_packet(sender, Packet.construct(HelloTLV.construct(16, squirrel, sender)))
@property
def is_long(self) -> bool:
@ -219,14 +239,20 @@ class NeighbourTLV(TLV):
self.port.to_bytes(2, sys.byteorder)
def handle(self, squirrel: Any, sender: Any) -> None:
if squirrel.address == str(self.ip_address) and squirrel.port == self.port:
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your NeighbourTLV. Please say me Hello before.")))
return
if (self.ip_address, self.port) in squirrel.addresses:
# This case should never happen (and in our protocol it is not possible),
# but we include this test as a security measure.
return
if not (str(self.ip_address), self.port) in squirrel.activehazelnuts \
and not (str(self.ip_address), self.port) in squirrel.potentialhazelnuts:
squirrel.potentialhazelnuts[(str(self.ip_address), self.port)] = \
squirrel.new_hazel(str(self.ip_address), self.port)
if not (str(self.ip_address), self.port) in squirrel.hazelnuts:
hazelnut = squirrel.new_hazel(str(self.ip_address), self.port)
hazelnut.potential = True
squirrel.update_hazelnut_table(hazelnut)
# squirrel.add_system_message(f"New potential friend {self.ip_address}:{self.port}!")
@staticmethod
@ -269,6 +295,12 @@ class DataTLV(TLV):
"""
A message has been sent. We log it.
"""
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your DataTLV. Please say me Hello before.")))
return
msg = self.data.decode('UTF-8')
# Acknowledge the packet
@ -285,13 +317,15 @@ class DataTLV(TLV):
"Unable to retrieve your username. Please use the syntax 'nickname: message'")))
else:
nickname = nickname_match.group(1)
if sender.nickname is None:
sender.nickname = nickname
elif sender.nickname != nickname:
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
author = squirrel.find_hazelnut_by_id(self.sender_id)
if author:
if author.nickname is None:
author.nickname = nickname
elif author.nickname != nickname:
squirrel.send_packet(author, Packet.construct(WarningTLV.construct(
"It seems that you used two different nicknames. "
f"Known nickname: {sender.nickname}, found: {nickname}")))
sender.nickname = nickname
f"Known nickname: {author.nickname}, found: {nickname}")))
author.nickname = nickname
@staticmethod
def construct(message: str, squirrel: Any) -> "DataTLV":
@ -330,7 +364,13 @@ class AckTLV(TLV):
"""
When an AckTLV is received, we know that we do not have to inundate that neighbour anymore.
"""
squirrel.add_system_message("I received an AckTLV")
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your AckTLV. Please say me Hello before.")))
return
squirrel.add_system_message(f"I received an AckTLV from {sender}")
squirrel.remove_from_inundation(sender, self.sender_id, self.nonce)
@staticmethod
@ -369,9 +409,15 @@ class GoAwayTLV(TLV):
self.message.encode("UTF-8")[:self.length - 1]
def handle(self, squirrel: Any, sender: Any) -> None:
if squirrel.is_active(sender):
squirrel.activehazelnuts.pop((sender.address, sender.port))
squirrel.potentialhazelnuts[(sender.address, sender.port)] = sender
if not sender.active or not sender.symmetric or not sender.id:
# It doesn't say hello, we don't listen to it
squirrel.send_packet(sender, Packet.construct(WarningTLV.construct(
"You are not my neighbour, I don't listen to your GoAwayTLV. Please say me Hello before.")))
return
if sender.active:
sender.active = False
squirrel.update_hazelnut_table(sender)
squirrel.add_system_message("Some told me that he went away : " + self.message)
@staticmethod
@ -446,14 +492,21 @@ class Packet:
pkt.magic = data[0]
pkt.version = data[1]
pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder))
if len(data) != 4 + pkt.body_length:
raise ValueError(f"Invalid packet length: "
f"declared body length is {pkt.body_length} while {len(data) - 4} bytes are avalaible")
pkt.body = []
read_bytes = 0
while read_bytes < min(len(data) - 4, pkt.body_length):
tlv_type = data[4 + read_bytes]
if not (0 <= tlv_type < len(TLV.tlv_classes())):
raise ValueError(f"TLV type is not supported: {tlv_type}")
tlv_length = data[4 + read_bytes + 1] if tlv_type > 0 else -1
if 2 + tlv_length > pkt.body_length - read_bytes:
raise ValueError(f"TLV length is too long: requesting {tlv_length} bytes, "
f"remaining {pkt.body_length - read_bytes}")
tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length])
tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + 2 + tlv_length])
pkt.body.append(tlv)
read_bytes += len(tlv)

View File

@ -82,9 +82,5 @@ class Squinnondation:
pkt = Packet().construct(htlv)
squirrel.send_packet(hazelnut, pkt)
# if squirrel.port != 8082:
# hazelnut = Hazelnut(address='::1', port=8082)
# squirrel.potentialhazelnuts['::1', 8082] = hazelnut
squirrel.start_threads()
squirrel.wait_for_key()