diff --git a/squinnondation/squinnondation.py b/squinnondation/squinnondation.py index 0182ba7..dec84df 100644 --- a/squinnondation/squinnondation.py +++ b/squinnondation/squinnondation.py @@ -4,6 +4,7 @@ import curses import re import socket +import sys from argparse import ArgumentParser from enum import Enum from ipaddress import IPv6Address @@ -119,9 +120,16 @@ class TLV: return 2 + self.length @staticmethod - def tlv_classes(): + def tlv_classes() -> list: return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] + @staticmethod + def network_order() -> str: + """ + The network byte order is always inverted as the host network byte order. + """ + return "little" if sys.byteorder == "big" else "big" + class Pad1TLV(TLV): """ @@ -139,7 +147,7 @@ class Pad1TLV(TLV): """ The TLV is empty. """ - return self.type.to_bytes(1, "big") + return self.type.to_bytes(1, TLV.network_order()) def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: # TODO Add some easter eggs @@ -162,7 +170,7 @@ class PadNTLV(TLV): mbz: bytes def validate_data(self) -> bool: - if self.mbz != int(0).to_bytes(self.length, "big"): + if self.mbz != int(0).to_bytes(self.length, TLV.network_order()): raise ValueError("The body of a PadN TLV is not filled with zeros.") return True @@ -178,7 +186,8 @@ class PadNTLV(TLV): """ Construct the byte array filled by zeros. """ - return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length] + return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ + + self.mbz[:self.length] def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: # TODO Add some easter eggs @@ -200,14 +209,15 @@ class HelloTLV(TLV): def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] - self.source_id = int.from_bytes(raw_data[2:10], "big") + self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order()) if self.is_long: - self.dest_id = int.from_bytes(raw_data[10:18], "big") + self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order()) def marshal(self) -> bytes: - data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big") + data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ + + self.source_id.to_bytes(8, TLV.network_order()) if self.dest_id: - data += self.dest_id.to_bytes(8, "big") + data += self.dest_id.to_bytes(8, TLV.network_order()) return data def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: @@ -230,13 +240,13 @@ class NeighbourTLV(TLV): self.type = raw_data[0] self.length = raw_data[1] self.ip_address = IPv6Address(raw_data[2:18]) - self.port = int.from_bytes(raw_data[18:20], "big") + self.port = int.from_bytes(raw_data[18:20], TLV.network_order()) def marshal(self) -> bytes: - return self.type.to_bytes(1, "big") + \ - self.length.to_bytes(1, "big") + \ + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ self.ip_address.packed + \ - self.port.to_bytes(2, "big") + self.port.to_bytes(2, TLV.network_order()) def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: # TODO Implement NeighbourTLV @@ -254,15 +264,15 @@ class DataTLV(TLV): def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] - self.sender_id = int.from_bytes(raw_data[2:10], "big") - self.nonce = int.from_bytes(raw_data[10:14], "big") + self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) + self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) self.data = raw_data[14:self.tlv_length] def marshal(self) -> bytes: - return self.type.to_bytes(1, "big") + \ - self.length.to_bytes(1, "big") + \ - self.sender_id.to_bytes(8, "big") + \ - self.nonce.to_bytes(4, "big") + \ + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ + self.sender_id.to_bytes(8, TLV.network_order()) + \ + self.nonce.to_bytes(4, TLV.network_order()) + \ self.data def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: @@ -292,14 +302,14 @@ class AckTLV(TLV): def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] - self.sender_id = int.from_bytes(raw_data[2:10], "big") - self.nonce = int.from_bytes(raw_data[10:14], "big") + self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) + self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) def marshal(self) -> bytes: - return self.type.to_bytes(1, "big") + \ - self.length.to_bytes(1, "big") + \ - self.sender_id.to_bytes(8, "big") + \ - self.nonce.to_bytes(4, "big") + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ + self.sender_id.to_bytes(8, TLV.network_order()) + \ + self.nonce.to_bytes(4, TLV.network_order()) def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: # TODO Implement AckTLV @@ -325,9 +335,9 @@ class GoAwayTLV(TLV): self.message = raw_data[3:self.length - 1].decode("UTF-8") def marshal(self) -> bytes: - return self.type.to_bytes(1, "big") + \ - self.length.to_bytes(1, "big") + \ - self.code.value.to_bytes(1, "big") + \ + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ + self.code.value.to_bytes(1, TLV.network_order()) + \ self.message.encode("UTF-8")[:self.length - 1] def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: @@ -347,8 +357,8 @@ class WarningTLV(TLV): self.message = raw_data[2:self.length].decode("UTF-8") def marshal(self) -> bytes: - return self.type.to_bytes(1, "big") + \ - self.length.to_bytes(1, "big") + \ + return self.type.to_bytes(1, TLV.network_order()) + \ + self.length.to_bytes(1, TLV.network_order()) + \ self.message.encode("UTF-8")[:self.length] def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: @@ -387,7 +397,7 @@ class Packet: pkt = Packet() pkt.magic = data[0] pkt.version = data[1] - pkt.body_length = int.from_bytes(data[2:4], byteorder="big") + pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order()) pkt.body = [] read_bytes = 0 while read_bytes <= min(len(data) - 4, pkt.body_length): @@ -407,9 +417,9 @@ class Packet: """ Compute the byte array data associated to the packet. """ - data = self.magic.to_bytes(1, "big") - data += self.version.to_bytes(1, "big") - data += self.body_length.to_bytes(2, "big") + data = self.magic.to_bytes(1, TLV.network_order()) + data += self.version.to_bytes(1, TLV.network_order()) + data += self.body_length.to_bytes(2, TLV.network_order()) data += b"".join(tlv.marshal() for tlv in self.body) return data