Convert host to network byte order if necessary

Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
Yohann D'ANELLO 2020-12-21 16:04:13 +01:00
parent 833c56755a
commit bda4860aca
Signed by: ynerant
GPG Key ID: 3A75C55819C8CF85
1 changed files with 34 additions and 40 deletions

View File

@ -4,6 +4,7 @@
from typing import Any, List, Optional from typing import Any, List, Optional
from ipaddress import IPv6Address from ipaddress import IPv6Address
from enum import Enum from enum import Enum
import socket
import sys import sys
@ -53,13 +54,6 @@ class TLV:
def tlv_classes() -> list: def tlv_classes() -> list:
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
@staticmethod
def network_order() -> str:
"""
The network byte order is always inverted as the host network byte order.
"""
return "little" if sys.byteorder == "big" else "big"
class Pad1TLV(TLV): class Pad1TLV(TLV):
""" """
@ -77,7 +71,7 @@ class Pad1TLV(TLV):
""" """
The TLV is empty. The TLV is empty.
""" """
return self.type.to_bytes(1, TLV.network_order()) return self.type.to_bytes(1, sys.byteorder)
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
# TODO Add some easter eggs # TODO Add some easter eggs
@ -100,7 +94,7 @@ class PadNTLV(TLV):
mbz: bytes mbz: bytes
def validate_data(self) -> bool: def validate_data(self) -> bool:
if self.mbz != int(0).to_bytes(self.length, TLV.network_order()): if self.mbz != int(0).to_bytes(self.length, sys.byteorder):
raise ValueError("The body of a PadN TLV is not filled with zeros.") raise ValueError("The body of a PadN TLV is not filled with zeros.")
return True return True
@ -116,7 +110,7 @@ class PadNTLV(TLV):
""" """
Construct the byte array filled by zeros. Construct the byte array filled by zeros.
""" """
return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ return self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \
+ self.mbz[:self.length] + self.mbz[:self.length]
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
@ -139,15 +133,15 @@ class HelloTLV(TLV):
def unmarshal(self, raw_data: bytes) -> None: def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order()) self.source_id = int.from_bytes(raw_data[2:10], sys.byteorder)
if self.is_long: if self.is_long:
self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order()) self.dest_id = int.from_bytes(raw_data[10:18], sys.byteorder)
def marshal(self) -> bytes: def marshal(self) -> bytes:
data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ data = self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \
+ self.source_id.to_bytes(8, TLV.network_order()) + self.source_id.to_bytes(8, sys.byteorder)
if self.dest_id: if self.dest_id:
data += self.dest_id.to_bytes(8, TLV.network_order()) data += self.dest_id.to_bytes(8, sys.byteorder)
return data return data
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
@ -170,13 +164,13 @@ class NeighbourTLV(TLV):
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.ip_address = IPv6Address(raw_data[2:18]) self.ip_address = IPv6Address(raw_data[2:18])
self.port = int.from_bytes(raw_data[18:20], TLV.network_order()) self.port = int.from_bytes(raw_data[18:20], sys.byteorder)
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, TLV.network_order()) + \ return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, TLV.network_order()) + \ self.length.to_bytes(1, sys.byteorder) + \
self.ip_address.packed + \ self.ip_address.packed + \
self.port.to_bytes(2, TLV.network_order()) self.port.to_bytes(2, sys.byteorder)
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
# TODO Implement NeighbourTLV # TODO Implement NeighbourTLV
@ -194,15 +188,15 @@ class DataTLV(TLV):
def unmarshal(self, raw_data: bytes) -> None: def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
self.data = raw_data[14:self.tlv_length] self.data = raw_data[14:self.tlv_length]
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, TLV.network_order()) + \ return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, TLV.network_order()) + \ self.length.to_bytes(1, sys.byteorder) + \
self.sender_id.to_bytes(8, TLV.network_order()) + \ self.sender_id.to_bytes(8, sys.byteorder) + \
self.nonce.to_bytes(4, TLV.network_order()) + \ socket.htonl(self.nonce).to_bytes(4, sys.byteorder) + \
self.data self.data
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
@ -232,14 +226,14 @@ class AckTLV(TLV):
def unmarshal(self, raw_data: bytes) -> None: def unmarshal(self, raw_data: bytes) -> None:
self.type = raw_data[0] self.type = raw_data[0]
self.length = raw_data[1] self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder)
self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder))
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, TLV.network_order()) + \ return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, TLV.network_order()) + \ self.length.to_bytes(1, sys.byteorder) + \
self.sender_id.to_bytes(8, TLV.network_order()) + \ self.sender_id.to_bytes(8, sys.byteorder) + \
self.nonce.to_bytes(4, TLV.network_order()) socket.htonl(self.nonce).to_bytes(4, sys.byteorder)
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
# TODO Implement AckTLV # TODO Implement AckTLV
@ -265,9 +259,9 @@ class GoAwayTLV(TLV):
self.message = raw_data[3:self.length - 1].decode("UTF-8") self.message = raw_data[3:self.length - 1].decode("UTF-8")
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, TLV.network_order()) + \ return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, TLV.network_order()) + \ self.length.to_bytes(1, sys.byteorder) + \
self.code.value.to_bytes(1, TLV.network_order()) + \ self.code.value.to_bytes(1, sys.byteorder) + \
self.message.encode("UTF-8")[:self.length - 1] self.message.encode("UTF-8")[:self.length - 1]
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
@ -287,8 +281,8 @@ class WarningTLV(TLV):
self.message = raw_data[2:self.length].decode("UTF-8") self.message = raw_data[2:self.length].decode("UTF-8")
def marshal(self) -> bytes: def marshal(self) -> bytes:
return self.type.to_bytes(1, TLV.network_order()) + \ return self.type.to_bytes(1, sys.byteorder) + \
self.length.to_bytes(1, TLV.network_order()) + \ self.length.to_bytes(1, sys.byteorder) + \
self.message.encode("UTF-8")[:self.length] self.message.encode("UTF-8")[:self.length]
def handle(self, squirrel: Any, sender: Any) -> None: def handle(self, squirrel: Any, sender: Any) -> None:
@ -327,7 +321,7 @@ class Packet:
pkt = Packet() pkt = Packet()
pkt.magic = data[0] pkt.magic = data[0]
pkt.version = data[1] pkt.version = data[1]
pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order()) pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder))
pkt.body = [] pkt.body = []
read_bytes = 0 read_bytes = 0
while read_bytes <= min(len(data) - 4, pkt.body_length): while read_bytes <= min(len(data) - 4, pkt.body_length):
@ -347,9 +341,9 @@ class Packet:
""" """
Compute the byte array data associated to the packet. Compute the byte array data associated to the packet.
""" """
data = self.magic.to_bytes(1, TLV.network_order()) data = self.magic.to_bytes(1, sys.byteorder)
data += self.version.to_bytes(1, TLV.network_order()) data += self.version.to_bytes(1, sys.byteorder)
data += self.body_length.to_bytes(2, TLV.network_order()) data += socket.htons(self.body_length).to_bytes(2, sys.byteorder)
data += b"".join(tlv.marshal() for tlv in self.body) data += b"".join(tlv.marshal() for tlv in self.body)
return data return data