diff --git a/squinnondation/hazel.py b/squinnondation/hazel.py index 2e02863..8ebec48 100644 --- a/squinnondation/hazel.py +++ b/squinnondation/hazel.py @@ -707,10 +707,11 @@ class Worm(Thread): def run(self) -> None: while True: - pkt, hazelnut = self.squirrel.receive_packet() - correct = pkt.validate_data() - if not correct : - self.squirrel.add_system_message("I received a incorrect packet") + try: + pkt, hazelnut = self.squirrel.receive_packet() + pkt.validate_data() + except ValueError as error: + self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error)) else: for tlv in pkt.body: tlv.handle(self.squirrel, hazelnut) diff --git a/squinnondation/messages.py b/squinnondation/messages.py index 6503a1c..2cf0fb4 100644 --- a/squinnondation/messages.py +++ b/squinnondation/messages.py @@ -74,7 +74,8 @@ class Pad1TLV(TLV): return self.type.to_bytes(1, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: - squirrel.add_system_message("I received a Pad1TLV, how disapointing") + # TODO Add some easter eggs + squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.") def __len__(self) -> int: """ @@ -99,8 +100,7 @@ class PadNTLV(TLV): def validate_data(self) -> bool: if self.mbz != int(0).to_bytes(self.length, sys.byteorder): - return False - #raise ValueError("The body of a PadN TLV is not filled with zeros.") + raise ValueError("The body of a PadN TLV is not filled with zeros.") return True def unmarshal(self, raw_data: bytes) -> None: @@ -119,8 +119,8 @@ class PadNTLV(TLV): + self.mbz[:self.length] def handle(self, squirrel: Any, sender: Any) -> None: - if self.validate_data(): - squirrel.add_system_message(f"I received {self.length} zeros") + # TODO Add some easter eggs + squirrel.add_system_message(f"I received {self.length} zeros, am I so a bad guy ? :cold_sweat:") @staticmethod def construct(length: int) -> "PadNTLV": @@ -139,9 +139,8 @@ class HelloTLV(TLV): def validate_data(self) -> bool: if self.length != 8 and self.length != 16: - return False - #raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," - # f"found {self.length}") + raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," + f"found {self.length}") return True def unmarshal(self, raw_data: bytes) -> None: @@ -159,8 +158,6 @@ class HelloTLV(TLV): return data def handle(self, squirrel: Any, sender: Any) -> None: - if not self.validate_data(self): - return None time_h = time.time() if not squirrel.is_active(sender): sender.id = self.source_id # The sender we are given misses an id @@ -424,19 +421,16 @@ class Packet: def validate_data(self) -> bool: """ Ensure that the packet is well-formed. - Returns False if the packet contains bad data. + Raises a ValueError if the packet contains bad data. """ if self.magic != 95: - return False - #raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) + raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic)) if self.version != 0: - return False - #raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) + raise ValueError("The version of the packet is not supported: {:d}".format(self.version)) if not (0 <= self.body_length <= 1200): - return False - #raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," - # "found: {:d}".format(self.body_length)) - return True #all(tlv.validate_data() for tlv in self.body) + raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," + "found: {:d}".format(self.body_length)) + return all(tlv.validate_data() for tlv in self.body) @staticmethod def unmarshal(data: bytes) -> "Packet": @@ -452,13 +446,15 @@ class Packet: read_bytes = 0 while read_bytes < min(len(data) - 4, pkt.body_length): tlv_type = data[4 + read_bytes] - if (0 <= tlv_type < len(TLV.tlv_classes())): - tlv = TLV.tlv_classes()[tlv_type]() - tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length]) - pkt.body.append(tlv) - read_bytes += len(tlv) - # Other TLV types are ignored - + if not (0 <= tlv_type < len(TLV.tlv_classes())): + raise ValueError(f"TLV type is not supported: {tlv_type}") + tlv = TLV.tlv_classes()[tlv_type]() + tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length]) + pkt.body.append(tlv) + read_bytes += len(tlv) + + pkt.validate_data() + return pkt def marshal(self) -> bytes: