Removing the error raising in validate_data
This commit is contained in:
		@@ -707,11 +707,10 @@ class Worm(Thread):
 | 
			
		||||
 | 
			
		||||
    def run(self) -> None:
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                pkt, hazelnut = self.squirrel.receive_packet()
 | 
			
		||||
                pkt.validate_data()
 | 
			
		||||
            except ValueError as error:
 | 
			
		||||
                self.squirrel.add_system_message("An error occurred while receiving a packet: {}".format(error))
 | 
			
		||||
            pkt, hazelnut = self.squirrel.receive_packet()
 | 
			
		||||
            correct = pkt.validate_data()
 | 
			
		||||
            if not correct :
 | 
			
		||||
                self.squirrel.add_system_message("I received a incorrect packet")
 | 
			
		||||
            else:
 | 
			
		||||
                for tlv in pkt.body:
 | 
			
		||||
                    tlv.handle(self.squirrel, hazelnut)
 | 
			
		||||
 
 | 
			
		||||
@@ -74,8 +74,7 @@ class Pad1TLV(TLV):
 | 
			
		||||
        return self.type.to_bytes(1, sys.byteorder)
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
        # TODO Add some easter eggs
 | 
			
		||||
        squirrel.add_system_message("For each byte in the packet that I received, you will die today. And eat cookies.")
 | 
			
		||||
        squirrel.add_system_message("I received a Pad1TLV, how disapointing")
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        """
 | 
			
		||||
@@ -100,7 +99,8 @@ class PadNTLV(TLV):
 | 
			
		||||
 | 
			
		||||
    def validate_data(self) -> bool:
 | 
			
		||||
        if self.mbz != int(0).to_bytes(self.length, sys.byteorder):
 | 
			
		||||
            raise ValueError("The body of a PadN TLV is not filled with zeros.")
 | 
			
		||||
            return False
 | 
			
		||||
            #raise ValueError("The body of a PadN TLV is not filled with zeros.")
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def unmarshal(self, raw_data: bytes) -> None:
 | 
			
		||||
@@ -119,8 +119,8 @@ class PadNTLV(TLV):
 | 
			
		||||
            + self.mbz[:self.length]
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
        # TODO Add some easter eggs
 | 
			
		||||
        squirrel.add_system_message(f"I received {self.length} zeros, am I so a bad guy ? :cold_sweat:")
 | 
			
		||||
        if self.validate_data():
 | 
			
		||||
            squirrel.add_system_message(f"I received {self.length} zeros")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def construct(length: int) -> "PadNTLV":
 | 
			
		||||
@@ -139,8 +139,9 @@ class HelloTLV(TLV):
 | 
			
		||||
 | 
			
		||||
    def validate_data(self) -> bool:
 | 
			
		||||
        if self.length != 8 and self.length != 16:
 | 
			
		||||
            raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
 | 
			
		||||
                             f"found {self.length}")
 | 
			
		||||
            return False
 | 
			
		||||
            #raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
 | 
			
		||||
            #                 f"found {self.length}")
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def unmarshal(self, raw_data: bytes) -> None:
 | 
			
		||||
@@ -158,6 +159,8 @@ class HelloTLV(TLV):
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
        if not self.validate_data(self):
 | 
			
		||||
            return None
 | 
			
		||||
        time_h = time.time()
 | 
			
		||||
        if not squirrel.is_active(sender):
 | 
			
		||||
            sender.id = self.source_id  # The sender we are given misses an id
 | 
			
		||||
@@ -421,16 +424,19 @@ class Packet:
 | 
			
		||||
    def validate_data(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Ensure that the packet is well-formed.
 | 
			
		||||
        Raises a ValueError if the packet contains bad data.
 | 
			
		||||
        Returns False if the packet contains bad data.
 | 
			
		||||
        """
 | 
			
		||||
        if self.magic != 95:
 | 
			
		||||
            raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
 | 
			
		||||
            return False
 | 
			
		||||
            #raise ValueError("The magic code of the packet must be 95, found: {:d}".format(self.magic))
 | 
			
		||||
        if self.version != 0:
 | 
			
		||||
            raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
 | 
			
		||||
            return False
 | 
			
		||||
            #raise ValueError("The version of the packet is not supported: {:d}".format(self.version))
 | 
			
		||||
        if not (0 <= self.body_length <= 1200):
 | 
			
		||||
            raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
 | 
			
		||||
                             "found: {:d}".format(self.body_length))
 | 
			
		||||
        return all(tlv.validate_data() for tlv in self.body)
 | 
			
		||||
            return False
 | 
			
		||||
            #raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
 | 
			
		||||
            #                 "found: {:d}".format(self.body_length))
 | 
			
		||||
        return True #all(tlv.validate_data() for tlv in self.body)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def unmarshal(data: bytes) -> "Packet":
 | 
			
		||||
@@ -446,15 +452,13 @@ class Packet:
 | 
			
		||||
        read_bytes = 0
 | 
			
		||||
        while read_bytes < min(len(data) - 4, pkt.body_length):
 | 
			
		||||
            tlv_type = data[4 + read_bytes]
 | 
			
		||||
            if not (0 <= tlv_type < len(TLV.tlv_classes())):
 | 
			
		||||
                raise ValueError(f"TLV type is not supported: {tlv_type}")
 | 
			
		||||
            tlv = TLV.tlv_classes()[tlv_type]()
 | 
			
		||||
            tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length])
 | 
			
		||||
            pkt.body.append(tlv)
 | 
			
		||||
            read_bytes += len(tlv)
 | 
			
		||||
 | 
			
		||||
        pkt.validate_data()
 | 
			
		||||
 | 
			
		||||
            if (0 <= tlv_type < len(TLV.tlv_classes())):
 | 
			
		||||
                tlv = TLV.tlv_classes()[tlv_type]()
 | 
			
		||||
                tlv.unmarshal(data[4 + read_bytes:4 + read_bytes + pkt.body_length])
 | 
			
		||||
                pkt.body.append(tlv)
 | 
			
		||||
                read_bytes += len(tlv)
 | 
			
		||||
            # Other TLV types are ignored
 | 
			
		||||
        
 | 
			
		||||
        return pkt
 | 
			
		||||
 | 
			
		||||
    def marshal(self) -> bytes:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user