Revert "Removing the error raising in validate_data"

This reverts commit bb62669722.
This commit is contained in:
eichhornchen 2021-01-04 11:02:58 +01:00
parent bb62669722
commit 122932e289
2 changed files with 27 additions and 30 deletions

View File

@ -707,10 +707,11 @@ class Worm(Thread):
def run(self) -> None: def run(self) -> None:
while True: while True:
try:
pkt, hazelnut = self.squirrel.receive_packet() pkt, hazelnut = self.squirrel.receive_packet()
correct = pkt.validate_data() pkt.validate_data()
if not correct : except ValueError as error:
self.squirrel.add_system_message("I received a incorrect packet") self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error))
else: else:
for tlv in pkt.body: for tlv in pkt.body:
tlv.handle(self.squirrel, hazelnut) tlv.handle(self.squirrel, hazelnut)

View File

@ -74,7 +74,8 @@ class Pad1TLV(TLV):
return self.type.to_bytes(1, sys.byteorder) return self.type.to_bytes(1, sys.byteorder)
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
squirrel.add_system_message("I received a Pad1TLV, how disapointing") # TODO Add some easter eggs
squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
def __len__(self) -> int: def __len__(self) -> int:
""" """
@ -99,8 +100,7 @@ class PadNTLV(TLV):
def validate_data(self) -> bool: def validate_data(self) -> bool:
if self.mbz != int(0).to_bytes(self.length, sys.byteorder): if self.mbz != int(0).to_bytes(self.length, sys.byteorder):
return False raise ValueError("The body of a PadN TLV is not filled with zeros.")
#raise ValueError("The body of a PadN TLV is not filled with zeros.")
return True return True
def unmarshal(self, raw_data: bytes) -> None: def unmarshal(self, raw_data: bytes) -> None:
@ -119,8 +119,8 @@ class PadNTLV(TLV):
+ self.mbz[:self.length] + self.mbz[:self.length]
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
if self.validate_data(): # TODO Add some easter eggs
squirrel.add_system_message(f"I received {self.length} zeros") squirrel.add_system_message(f"I received {self.length} zeros, am I so a bad guy ? :cold_sweat:")
@staticmethod @staticmethod
def construct(length: int) -> "PadNTLV": def construct(length: int) -> "PadNTLV":
@ -139,9 +139,8 @@ class HelloTLV(TLV):
def validate_data(self) -> bool: def validate_data(self) -> bool:
if self.length != 8 and self.length != 16: if self.length != 8 and self.length != 16:
return False raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
#raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello," f"found {self.length}")
# f"found {self.length}")
return True return True
def unmarshal(self, raw_data: bytes) -> None: def unmarshal(self, raw_data: bytes) -> None:
@ -159,8 +158,6 @@ class HelloTLV(TLV):
return data return data
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
if not self.validate_data(self):
return None
time_h = time.time() time_h = time.time()
if not squirrel.is_active(sender): if not squirrel.is_active(sender):
sender.id = self.source_id # The sender we are given misses an id sender.id = self.source_id # The sender we are given misses an id
@ -424,19 +421,16 @@ class Packet:
def validate_data(self) -> bool: def validate_data(self) -> bool:
""" """
Ensure that the packet is well-formed. Ensure that the packet is well-formed.
Returns False if the packet contains bad data. Raises a ValueError if the packet contains bad data.
""" """
if self.magic != 95: if self.magic != 95:
return False raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
#raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
if self.version != 0: if self.version != 0:
return False raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
#raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
if not (0 <= self.body_length <= 1200): if not (0 <= self.body_length <= 1200):
return False raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
#raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020," "found: {:d}".format(self.body_length))
# "found: {:d}".format(self.body_length)) return all(tlv.validate_data() for tlv in self.body)
return True #all(tlv.validate_data() for tlv in self.body)
@staticmethod @staticmethod
def unmarshal(data: bytes) -> "Packet": def unmarshal(data: bytes) -> "Packet":
@ -452,12 +446,14 @@ class Packet:
read_bytes = 0 read_bytes = 0
while read_bytes < min(len(data) - 4, pkt.body_length): while read_bytes < min(len(data) - 4, pkt.body_length):
tlv_type = data[4 + read_bytes] tlv_type = data[4 + read_bytes]
if (0 <= tlv_type < len(TLV.tlv_classes())): if not (0 <= tlv_type < len(TLV.tlv_classes())):
raise ValueError(f"TLV type is not supported: {tlv_type}")
tlv = TLV.tlv_classes()[tlv_type]() tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length]) tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length])
pkt.body.append(tlv) pkt.body.append(tlv)
read_bytes += len(tlv) read_bytes += len(tlv)
# Other TLV types are ignored
pkt.validate_data()
return pkt return pkt