diff --git a/squinnondation/hazel.py b/squinnondation/hazel.py index 8ebec48..2e02863 100644 --- a/squinnondation/hazel.py +++ b/squinnondation/hazel.py @@ -707,11 +707,10 @@ class Worm(Thread): def run(self) -> None: while True: - try: - pkt, hazelnut = self.squirrel.receive_packet() - pkt.validate_data() - except ValueError as error: - self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error)) + pkt, hazelnut = self.squirrel.receive_packet() + correct = pkt.validate_data() + if not correct : + self.squirrel.add_system_message("I received a incorrect packet") else: for tlv in pkt.body: tlv.handle(self.squirrel, hazelnut) diff --git a/squinnondation/messages.py b/squinnondation/messages.py index 2cf0fb4..6503a1c 100644 --- a/squinnondation/messages.py +++ b/squinnondation/messages.py @@ -74,8 +74,7 @@ class Pad1TLV(TLV): return self.type.to_bytes(1, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: - # TODO Add some easter eggs - squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.") + squirrel.add_system_message("I received a Pad1TLV, how disapointing") def __len__(self) -> int: """ @@ -100,7 +99,8 @@ class PadNTLV(TLV): def validate_data(self) -> bool: if self.mbz != int(0).to_bytes(self.length, sys.byteorder): - raise ValueError("The body of a PadN TLV is not filled with zeros.") + return False + #raise ValueError("The body of a PadN TLV is not filled with zeros.") return True def unmarshal(self, raw_data: bytes) -> None: @@ -119,8 +119,8 @@ class PadNTLV(TLV): + self.mbz[:self.length] def handle(self, squirrel: Any, sender: Any) -> None: - # TODO Add some easter eggs - squirrel.add_system_message(f"I received {self.length} zeros, am I so a bad guy ? :cold_sweat:") + if self.validate_data(): + squirrel.add_system_message(f"I received {self.length} zeros") @staticmethod def construct(length: int) -> "PadNTLV": @@ -139,8 +139,9 @@ class HelloTLV(TLV): def validate_data(self) -> bool: if self.length != 8 and self.length != 16: - raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," - f"found {self.length}") + return False + #raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," + # f"found {self.length}") return True def unmarshal(self, raw_data: bytes) -> None: @@ -158,6 +159,8 @@ class HelloTLV(TLV): return data def handle(self, squirrel: Any, sender: Any) -> None: + if not self.validate_data(self): + return None time_h = time.time() if not squirrel.is_active(sender): sender.id = self.source_id # The sender we are given misses an id @@ -421,16 +424,19 @@ class Packet: def validate_data(self) -> bool: """ Ensure that the packet is well-formed. - Raises a ValueError if the packet contains bad data. + Returns False if the packet contains bad data. """ if self.magic != 95: - raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) + return False + #raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) if self.version != 0: - raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) + return False + #raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) if not (0 <= self.body_length <= 1200): - raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," - "found: {:d}".format(self.body_length)) - return all(tlv.validate_data() for tlv in self.body) + return False + #raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," + # "found: {:d}".format(self.body_length)) + return True #all(tlv.validate_data() for tlv in self.body) @staticmethod def unmarshal(data: bytes) -> "Packet": @@ -446,15 +452,13 @@ class Packet: read_bytes = 0 while read_bytes < min(len(data) - 4, pkt.body_length): tlv_type = data[4 + read_bytes] - if not (0 <= tlv_type < len(TLV.tlv_classes())): - raise ValueError(f"TLV type is not supported: {tlv_type}") - tlv = TLV.tlv_classes()[tlv_type]() - tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length]) - pkt.body.append(tlv) - read_bytes += len(tlv) - - pkt.validate_data() - + if (0 <= tlv_type < len(TLV.tlv_classes())): + tlv = TLV.tlv_classes()[tlv_type]() + tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length]) + pkt.body.append(tlv) + read_bytes += len(tlv) + # Other TLV types are ignored + return pkt def marshal(self) -> bytes: