From bda4860aca7d1fe3eacd7aad9459201ce59646ce Mon Sep 17 00:00:00 2001 From: Yohann D'ANELLO Date: Mon, 21 Dec 2020 16:04:13 +0100 Subject: [PATCH] Convert host to network byte order if necessary Signed-off-by: Yohann D'ANELLO --- squinnondation/messages.py | 74 ++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/squinnondation/messages.py b/squinnondation/messages.py index 67f1f52..6c50f29 100644 --- a/squinnondation/messages.py +++ b/squinnondation/messages.py @@ -4,6 +4,7 @@ from typing import Any, List, Optional from ipaddress import IPv6Address from enum import Enum +import socket import sys @@ -53,13 +54,6 @@ class TLV: def tlv_classes() -> list: return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV] - @staticmethod - def network_order() -> str: - """ - The network byte order is always inverted as the host network byte order. - """ - return "little" if sys.byteorder == "big" else "big" - class Pad1TLV(TLV): """ @@ -77,7 +71,7 @@ class Pad1TLV(TLV): """ The TLV is empty. """ - return self.type.to_bytes(1, TLV.network_order()) + return self.type.to_bytes(1, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: # TODO Add some easter eggs @@ -100,7 +94,7 @@ class PadNTLV(TLV): mbz: bytes def validate_data(self) -> bool: - if self.mbz != int(0).to_bytes(self.length, TLV.network_order()): + if self.mbz != int(0).to_bytes(self.length, sys.byteorder): raise ValueError("The body of a PadN TLV is not filled with zeros.") return True @@ -116,7 +110,7 @@ class PadNTLV(TLV): """ Construct the byte array filled by zeros. """ - return self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ + return self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \ + self.mbz[:self.length] def handle(self, squirrel: Any, sender: Any) -> None: @@ -139,15 +133,15 @@ class HelloTLV(TLV): def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] - self.source_id = int.from_bytes(raw_data[2:10], TLV.network_order()) + self.source_id = int.from_bytes(raw_data[2:10], sys.byteorder) if self.is_long: - self.dest_id = int.from_bytes(raw_data[10:18], TLV.network_order()) + self.dest_id = int.from_bytes(raw_data[10:18], sys.byteorder) def marshal(self) -> bytes: - data = self.type.to_bytes(1, TLV.network_order()) + self.length.to_bytes(1, TLV.network_order()) \ - + self.source_id.to_bytes(8, TLV.network_order()) + data = self.type.to_bytes(1, sys.byteorder) + self.length.to_bytes(1, sys.byteorder) \ + + self.source_id.to_bytes(8, sys.byteorder) if self.dest_id: - data += self.dest_id.to_bytes(8, TLV.network_order()) + data += self.dest_id.to_bytes(8, sys.byteorder) return data def handle(self, squirrel: Any, sender: Any) -> None: @@ -170,13 +164,13 @@ class NeighbourTLV(TLV): self.type = raw_data[0] self.length = raw_data[1] self.ip_address = IPv6Address(raw_data[2:18]) - self.port = int.from_bytes(raw_data[18:20], TLV.network_order()) + self.port = int.from_bytes(raw_data[18:20], sys.byteorder) def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ + return self.type.to_bytes(1, sys.byteorder) + \ + self.length.to_bytes(1, sys.byteorder) + \ self.ip_address.packed + \ - self.port.to_bytes(2, TLV.network_order()) + self.port.to_bytes(2, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: # TODO Implement NeighbourTLV @@ -194,15 +188,15 @@ class DataTLV(TLV): def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] - self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) - self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) + self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder) + self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder)) self.data = raw_data[14:self.tlv_length] def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ - self.sender_id.to_bytes(8, TLV.network_order()) + \ - self.nonce.to_bytes(4, TLV.network_order()) + \ + return self.type.to_bytes(1, sys.byteorder) + \ + self.length.to_bytes(1, sys.byteorder) + \ + self.sender_id.to_bytes(8, sys.byteorder) + \ + socket.htonl(self.nonce).to_bytes(4, sys.byteorder) + \ self.data def handle(self, squirrel: Any, sender: Any) -> None: @@ -232,14 +226,14 @@ class AckTLV(TLV): def unmarshal(self, raw_data: bytes) -> None: self.type = raw_data[0] self.length = raw_data[1] - self.sender_id = int.from_bytes(raw_data[2:10], TLV.network_order()) - self.nonce = int.from_bytes(raw_data[10:14], TLV.network_order()) + self.sender_id = int.from_bytes(raw_data[2:10], sys.byteorder) + self.nonce = socket.ntohl(int.from_bytes(raw_data[10:14], sys.byteorder)) def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ - self.sender_id.to_bytes(8, TLV.network_order()) + \ - self.nonce.to_bytes(4, TLV.network_order()) + return self.type.to_bytes(1, sys.byteorder) + \ + self.length.to_bytes(1, sys.byteorder) + \ + self.sender_id.to_bytes(8, sys.byteorder) + \ + socket.htonl(self.nonce).to_bytes(4, sys.byteorder) def handle(self, squirrel: Any, sender: Any) -> None: # TODO Implement AckTLV @@ -265,9 +259,9 @@ class GoAwayTLV(TLV): self.message = raw_data[3:self.length - 1].decode("UTF-8") def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ - self.code.value.to_bytes(1, TLV.network_order()) + \ + return self.type.to_bytes(1, sys.byteorder) + \ + self.length.to_bytes(1, sys.byteorder) + \ + self.code.value.to_bytes(1, sys.byteorder) + \ self.message.encode("UTF-8")[:self.length - 1] def handle(self, squirrel: Any, sender: Any) -> None: @@ -287,8 +281,8 @@ class WarningTLV(TLV): self.message = raw_data[2:self.length].decode("UTF-8") def marshal(self) -> bytes: - return self.type.to_bytes(1, TLV.network_order()) + \ - self.length.to_bytes(1, TLV.network_order()) + \ + return self.type.to_bytes(1, sys.byteorder) + \ + self.length.to_bytes(1, sys.byteorder) + \ self.message.encode("UTF-8")[:self.length] def handle(self, squirrel: Any, sender: Any) -> None: @@ -327,7 +321,7 @@ class Packet: pkt = Packet() pkt.magic = data[0] pkt.version = data[1] - pkt.body_length = int.from_bytes(data[2:4], byteorder=TLV.network_order()) + pkt.body_length = socket.ntohs(int.from_bytes(data[2:4], sys.byteorder)) pkt.body = [] read_bytes = 0 while read_bytes <= min(len(data) - 4, pkt.body_length): @@ -347,9 +341,9 @@ class Packet: """ Compute the byte array data associated to the packet. """ - data = self.magic.to_bytes(1, TLV.network_order()) - data += self.version.to_bytes(1, TLV.network_order()) - data += self.body_length.to_bytes(2, TLV.network_order()) + data = self.magic.to_bytes(1, sys.byteorder) + data += self.version.to_bytes(1, sys.byteorder) + data += socket.htons(self.body_length).to_bytes(2, sys.byteorder) data += b"".join(tlv.marshal() for tlv in self.body) return data