Adapt the byte order to the system configuration

Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
Yohann D'ANELLO 2020-12-21 16:04:11 +01:00
parent 00e24d74ee
commit 3d444f57f0
Signed by: ynerant
GPG Key ID: 3A75C55819C8CF85
1 changed files with 43 additions and 33 deletions

View File

@ -4,6 +4,7 @@
import curses import curses
import re import re
import socket import socket
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from enum import Enum from enum import Enum
from ipaddress import IPv6Address from ipaddress import IPv6Address
@ -119,9 +120,16 @@ class TLV:
return 2 + self.length return 2 + self.length
@staticmethod @staticmethod
def tlv_classes(): def tlv_classes() -> list:
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
@staticmethod
def network_order() -> str:
"""
The network byte order is always inverted as the host network byte order.
"""
return "little" if sys.byteorder == "big" else "big"
class Pad1TLV(TLV): class Pad1TLV(TLV):
""" """
@ -139,7 +147,7 @@ class Pad1TLV(TLV):
""" """
The TLV is empty. The TLV is empty.
""" """
return self.type.to_bytes(1, "big") return self.type.to_bytes(1, TLV.network_order())
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
# TODO Add some easter eggs # TODO Add some easter eggs
@ -162,7 +170,7 @@ class PadNTLV(TLV):
mbz: bytes mbz: bytes
def validate_data(self) -> bool: def validate_data(self) -> bool:
if self.mbz != int(0).to_bytes(self.length, "big"): if self.mbz != int(0).to_bytes(self.length, TLV.network_order()):
raise ValueError("The body of a PadN TLV is not filled with zeros.") raise ValueError("The body of a PadN TLV is not filled with zeros.")
return True return True
@ -178,7 +186,8 @@ class PadNTLV(TLV):
""" """
Construct the byte array filled by zeros. Construct the byte array filled by zeros.
""" """
return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length] return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \
+ self.mbz[:self.length]
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
# TODO Add some easter eggs # TODO Add some easter eggs
@ -200,14 +209,15 @@ class HelloTLV(TLV):
def unmarshal(self, raw_data: bytes) -> None: def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.source_id = int.from_bytes(raw_data[2:10], "big") self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order())
if self.is_long: if self.is_long:
self.dest_id = int.from_bytes(raw_data[10:18], "big") self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order())
def marshal(self) -> bytes: def marshal(self) -> bytes:
data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big") data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \
+ self.source_id.to_bytes(8, TLV.network_order())
if self.dest_id: if self.dest_id:
data += self.dest_id.to_bytes(8, "big") data += self.dest_id.to_bytes(8, TLV.network_order())
return data return data
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
@ -230,13 +240,13 @@ class NeighbourTLV(TLV):
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.ip_address = IPv6Address(raw_data[2:18]) self.ip_address = IPv6Address(raw_data[2:18])
self.port = int.from_bytes(raw_data[18:20], "big") self.port = int.from_bytes(raw_data[18:20], TLV.network_order())
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \ return self.type.to_bytes(1, TLV.network_order()) + \
self.length.to_bytes(1, "big") + \ self.length.to_bytes(1, TLV.network_order()) + \
self.ip_address.packed + \ self.ip_address.packed + \
self.port.to_bytes(2, "big") self.port.to_bytes(2, TLV.network_order())
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
# TODO Implement NeighbourTLV # TODO Implement NeighbourTLV
@ -254,15 +264,15 @@ class DataTLV(TLV):
def unmarshal(self, raw_data: bytes) -> None: def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], "big") self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order())
self.nonce = int.from_bytes(raw_data[10:14], "big") self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order())
self.data = raw_data[14:self.tlv_length] self.data = raw_data[14:self.tlv_length]
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \ return self.type.to_bytes(1, TLV.network_order()) + \
self.length.to_bytes(1, "big") + \ self.length.to_bytes(1, TLV.network_order()) + \
self.sender_id.to_bytes(8, "big") + \ self.sender_id.to_bytes(8, TLV.network_order()) + \
self.nonce.to_bytes(4, "big") + \ self.nonce.to_bytes(4, TLV.network_order()) + \
self.data self.data
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
@ -292,14 +302,14 @@ class AckTLV(TLV):
def unmarshal(self, raw_data: bytes) -> None: def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], "big") self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order())
self.nonce = int.from_bytes(raw_data[10:14], "big") self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order())
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \ return self.type.to_bytes(1, TLV.network_order()) + \
self.length.to_bytes(1, "big") + \ self.length.to_bytes(1, TLV.network_order()) + \
self.sender_id.to_bytes(8, "big") + \ self.sender_id.to_bytes(8, TLV.network_order()) + \
self.nonce.to_bytes(4, "big") self.nonce.to_bytes(4, TLV.network_order())
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
# TODO Implement AckTLV # TODO Implement AckTLV
@ -325,9 +335,9 @@ class GoAwayTLV(TLV):
self.message = raw_data[3:self.length - 1].decode("UTF-8") self.message = raw_data[3:self.length - 1].decode("UTF-8")
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \ return self.type.to_bytes(1, TLV.network_order()) + \
self.length.to_bytes(1, "big") + \ self.length.to_bytes(1, TLV.network_order()) + \
self.code.value.to_bytes(1, "big") + \ self.code.value.to_bytes(1, TLV.network_order()) + \
self.message.encode("UTF-8")[:self.length - 1] self.message.encode("UTF-8")[:self.length - 1]
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
@ -347,8 +357,8 @@ class WarningTLV(TLV):
self.message = raw_data[2:self.length].decode("UTF-8") self.message = raw_data[2:self.length].decode("UTF-8")
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \ return self.type.to_bytes(1, TLV.network_order()) + \
self.length.to_bytes(1, "big") + \ self.length.to_bytes(1, TLV.network_order()) + \
self.message.encode("UTF-8")[:self.length] self.message.encode("UTF-8")[:self.length]
def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None: def handle(self, squirrel: "Squirrel", sender: "Hazelnut") -> None:
@ -387,7 +397,7 @@ class Packet:
pkt = Packet() pkt = Packet()
pkt.magic = data[0] pkt.magic = data[0]
pkt.version = data[1] pkt.version = data[1]
pkt.body_length = int.from_bytes(data[2:4], byteorder="big") pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order())
pkt.body = [] pkt.body = []
read_bytes = 0 read_bytes = 0
while read_bytes <= min(len(data) - 4, pkt.body_length): while read_bytes <= min(len(data) - 4, pkt.body_length):
@ -407,9 +417,9 @@ class Packet:
""" """
Compute the byte array data associated to the packet. Compute the byte array data associated to the packet.
""" """
data = self.magic.to_bytes(1, "big") data = self.magic.to_bytes(1, TLV.network_order())
data += self.version.to_bytes(1, "big") data += self.version.to_bytes(1, TLV.network_order())
data += self.body_length.to_bytes(2, "big") data += self.body_length.to_bytes(2, TLV.network_order())
data += b"".join(tlv.marshal() for tlv in self.body) data += b"".join(tlv.marshal() for tlv in self.body)
return data return data