Convert host to network byte order if necessary
Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
		@@ -4,6 +4,7 @@
 | 
			
		||||
from typing import Any, List, Optional
 | 
			
		||||
from ipaddress import IPv6Address
 | 
			
		||||
from enum import Enum
 | 
			
		||||
import socket
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -53,13 +54,6 @@ class TLV:
 | 
			
		||||
    def tlv_classes() -> list:
 | 
			
		||||
        return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def network_order() -> str:
 | 
			
		||||
        """
 | 
			
		||||
        The network byte order is always inverted as the host network byte order.
 | 
			
		||||
        """
 | 
			
		||||
        return "little" if sys.byteorder == "big" else "big"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Pad1TLV(TLV):
 | 
			
		||||
    """
 | 
			
		||||
@@ -77,7 +71,7 @@ class Pad1TLV(TLV):
 | 
			
		||||
        """
 | 
			
		||||
        The TLV is empty.
 | 
			
		||||
        """
 | 
			
		||||
        return self.type.to_bytes(1, TLV.network_order())
 | 
			
		||||
        return self.type.to_bytes(1, sys.byteorder)
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
        # TODO Add some easter eggs
 | 
			
		||||
@@ -100,7 +94,7 @@ class PadNTLV(TLV):
 | 
			
		||||
    mbz: bytes
 | 
			
		||||
 | 
			
		||||
    def validate_data(self) -> bool:
 | 
			
		||||
        if self.mbz != int(0).to_bytes(self.length, TLV.network_order()):
 | 
			
		||||
        if self.mbz != int(0).to_bytes(self.length, sys.byteorder):
 | 
			
		||||
            raise ValueError("The body of a PadN TLV is not filled with zeros.")
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
@@ -116,7 +110,7 @@ class PadNTLV(TLV):
 | 
			
		||||
        """
 | 
			
		||||
        Construct the byte array filled by zeros.
 | 
			
		||||
        """
 | 
			
		||||
        return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \
 | 
			
		||||
        return self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \
 | 
			
		||||
            + self.mbz[:self.length]
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
@@ -139,15 +133,15 @@ class HelloTLV(TLV):
 | 
			
		||||
    def unmarshal(self, raw_data: bytes) -> None:
 | 
			
		||||
        self.type = raw_data[0]
 | 
			
		||||
        self.length = raw_data[1]
 | 
			
		||||
        self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order())
 | 
			
		||||
        self.source_id = int.from_bytes(raw_data[2:10], sys.byteorder)
 | 
			
		||||
        if self.is_long:
 | 
			
		||||
            self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order())
 | 
			
		||||
            self.dest_id = int.from_bytes(raw_data[10:18], sys.byteorder)
 | 
			
		||||
 | 
			
		||||
    def marshal(self) -> bytes:
 | 
			
		||||
        data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \
 | 
			
		||||
            + self.source_id.to_bytes(8, TLV.network_order())
 | 
			
		||||
        data = self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \
 | 
			
		||||
            + self.source_id.to_bytes(8, sys.byteorder)
 | 
			
		||||
        if self.dest_id:
 | 
			
		||||
            data += self.dest_id.to_bytes(8, TLV.network_order())
 | 
			
		||||
            data += self.dest_id.to_bytes(8, sys.byteorder)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
@@ -170,13 +164,13 @@ class NeighbourTLV(TLV):
 | 
			
		||||
        self.type = raw_data[0]
 | 
			
		||||
        self.length = raw_data[1]
 | 
			
		||||
        self.ip_address = IPv6Address(raw_data[2:18])
 | 
			
		||||
        self.port = int.from_bytes(raw_data[18:20], TLV.network_order())
 | 
			
		||||
        self.port = int.from_bytes(raw_data[18:20], sys.byteorder)
 | 
			
		||||
 | 
			
		||||
    def marshal(self) -> bytes:
 | 
			
		||||
        return self.type.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
            self.length.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
        return self.type.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.length.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.ip_address.packed + \
 | 
			
		||||
            self.port.to_bytes(2, TLV.network_order())
 | 
			
		||||
            self.port.to_bytes(2, sys.byteorder)
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
        # TODO Implement NeighbourTLV
 | 
			
		||||
@@ -194,15 +188,15 @@ class DataTLV(TLV):
 | 
			
		||||
    def unmarshal(self, raw_data: bytes) -> None:
 | 
			
		||||
        self.type = raw_data[0]
 | 
			
		||||
        self.length = raw_data[1]
 | 
			
		||||
        self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order())
 | 
			
		||||
        self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order())
 | 
			
		||||
        self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
 | 
			
		||||
        self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
 | 
			
		||||
        self.data = raw_data[14:self.tlv_length]
 | 
			
		||||
 | 
			
		||||
    def marshal(self) -> bytes:
 | 
			
		||||
        return self.type.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
            self.length.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
            self.sender_id.to_bytes(8, TLV.network_order()) + \
 | 
			
		||||
            self.nonce.to_bytes(4, TLV.network_order()) + \
 | 
			
		||||
        return self.type.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.length.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.sender_id.to_bytes(8, sys.byteorder) + \
 | 
			
		||||
            socket.htonl(self.nonce).to_bytes(4, sys.byteorder) + \
 | 
			
		||||
            self.data
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
@@ -232,14 +226,14 @@ class AckTLV(TLV):
 | 
			
		||||
    def unmarshal(self, raw_data: bytes) -> None:
 | 
			
		||||
        self.type = raw_data[0]
 | 
			
		||||
        self.length = raw_data[1]
 | 
			
		||||
        self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order())
 | 
			
		||||
        self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order())
 | 
			
		||||
        self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
 | 
			
		||||
        self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
 | 
			
		||||
 | 
			
		||||
    def marshal(self) -> bytes:
 | 
			
		||||
        return self.type.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
            self.length.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
            self.sender_id.to_bytes(8, TLV.network_order()) + \
 | 
			
		||||
            self.nonce.to_bytes(4, TLV.network_order())
 | 
			
		||||
        return self.type.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.length.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.sender_id.to_bytes(8, sys.byteorder) + \
 | 
			
		||||
            socket.htonl(self.nonce).to_bytes(4, sys.byteorder)
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
        # TODO Implement AckTLV
 | 
			
		||||
@@ -265,9 +259,9 @@ class GoAwayTLV(TLV):
 | 
			
		||||
        self.message = raw_data[3:self.length - 1].decode("UTF-8")
 | 
			
		||||
 | 
			
		||||
    def marshal(self) -> bytes:
 | 
			
		||||
        return self.type.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
            self.length.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
            self.code.value.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
        return self.type.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.length.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.code.value.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.message.encode("UTF-8")[:self.length - 1]
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
@@ -287,8 +281,8 @@ class WarningTLV(TLV):
 | 
			
		||||
        self.message = raw_data[2:self.length].decode("UTF-8")
 | 
			
		||||
 | 
			
		||||
    def marshal(self) -> bytes:
 | 
			
		||||
        return self.type.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
            self.length.to_bytes(1, TLV.network_order()) + \
 | 
			
		||||
        return self.type.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.length.to_bytes(1, sys.byteorder) + \
 | 
			
		||||
            self.message.encode("UTF-8")[:self.length]
 | 
			
		||||
 | 
			
		||||
    def handle(self, squirrel: Any, sender: Any) -> None:
 | 
			
		||||
@@ -327,7 +321,7 @@ class Packet:
 | 
			
		||||
        pkt = Packet()
 | 
			
		||||
        pkt.magic = data[0]
 | 
			
		||||
        pkt.version = data[1]
 | 
			
		||||
        pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order())
 | 
			
		||||
        pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder))
 | 
			
		||||
        pkt.body = []
 | 
			
		||||
        read_bytes = 0
 | 
			
		||||
        while read_bytes <= min(len(data) - 4, pkt.body_length):
 | 
			
		||||
@@ -347,9 +341,9 @@ class Packet:
 | 
			
		||||
        """
 | 
			
		||||
        Compute the byte array data associated to the packet.
 | 
			
		||||
        """
 | 
			
		||||
        data = self.magic.to_bytes(1, TLV.network_order())
 | 
			
		||||
        data += self.version.to_bytes(1, TLV.network_order())
 | 
			
		||||
        data += self.body_length.to_bytes(2, TLV.network_order())
 | 
			
		||||
        data = self.magic.to_bytes(1, sys.byteorder)
 | 
			
		||||
        data += self.version.to_bytes(1, sys.byteorder)
 | 
			
		||||
        data += socket.htons(self.body_length).to_bytes(2, sys.byteorder)
 | 
			
		||||
        data += b"".join(tlv.marshal() for tlv in self.body)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user