Construct TLVs
Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
parent
1caf06bf24
commit
abbcbbc3b1
|
@ -3,7 +3,8 @@
|
|||
|
||||
import socket
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any, Tuple
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
|
||||
class Squinnondation:
|
||||
|
@ -47,15 +48,18 @@ class Squinnondation:
|
|||
pkt = Packet()
|
||||
pkt.magic = 95
|
||||
pkt.version = 0
|
||||
pkt.body = TLV()
|
||||
pkt.body = DataTLV()
|
||||
msg = f"Hello world, my name is {squirrel.nickname}!"
|
||||
pkt.body.raw_data = msg.encode("UTF-8")
|
||||
pkt.body_length = len(pkt.body.raw_data)
|
||||
pkt.body.data = msg.encode("UTF-8")
|
||||
pkt.body.sender_id = 42
|
||||
pkt.body.nonce = 18
|
||||
pkt.body.length = len(msg) + 1 + 1 + 8 + 4
|
||||
pkt.body_length = pkt.body.length + 2
|
||||
squirrel.send_packet(hazelnut, pkt)
|
||||
|
||||
while True:
|
||||
pkt, addr = squirrel.receive_packet()
|
||||
print(f"received message: {pkt.body.raw_data.decode('UTF-8')}")
|
||||
print(f"received message: {pkt.body.data.decode('UTF-8')}")
|
||||
|
||||
|
||||
class TLV:
|
||||
|
@ -64,7 +68,18 @@ class TLV:
|
|||
TODO: add subclasses for each type of TLV
|
||||
"""
|
||||
type: int
|
||||
raw_data: bytes
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
"""
|
||||
Parse data and construct TLV.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
"""
|
||||
Translate the TLV into a byte array.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_data(self) -> bool:
|
||||
"""
|
||||
|
@ -74,6 +89,184 @@ class TLV:
|
|||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def tlv_classes():
|
||||
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
|
||||
|
||||
|
||||
class Pad1TLV(TLV):
|
||||
"""
|
||||
This TLV is simply ignored.
|
||||
"""
|
||||
type: int = 0
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
"""
|
||||
There is nothing to do. We ignore the packet.
|
||||
"""
|
||||
self.type = raw_data[0]
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
"""
|
||||
The TLV is empty.
|
||||
"""
|
||||
return self.type.to_bytes(1, "big")
|
||||
|
||||
|
||||
class PadNTLV(TLV):
|
||||
"""
|
||||
This TLV is filled with zeros. It is ignored.
|
||||
"""
|
||||
type: int = 1
|
||||
length: int
|
||||
mbz: bytes
|
||||
|
||||
def validate_data(self) -> bool:
|
||||
if self.mbz != int(0).to_bytes(self.length, "big"):
|
||||
raise ValueError("The body of a PadN TLV is not filled with zeros.")
|
||||
return True
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
"""
|
||||
Store the zero-array, then ignore the packet.
|
||||
"""
|
||||
self.type = raw_data[0]
|
||||
self.length = raw_data[1]
|
||||
self.mbz = raw_data[2:2 + self.length]
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
"""
|
||||
Construct the byte array filled by zeros.
|
||||
"""
|
||||
return self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.mbz[:self.length]
|
||||
|
||||
|
||||
class HelloTLV(TLV):
|
||||
type: int = 2
|
||||
length: int
|
||||
source_id: int
|
||||
dest_id: Optional[int]
|
||||
|
||||
def validate_data(self) -> bool:
|
||||
if self.length != 8 and self.length != 16:
|
||||
raise ValueError("The length of a Hello TLV must be 8 for a short Hello, or 16 for a long Hello,"
|
||||
f"found {self.length}")
|
||||
return True
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
self.type = raw_data[0]
|
||||
self.length = raw_data[1]
|
||||
self.source_id = int.from_bytes(raw_data[2:10], "big")
|
||||
if self.length == 16:
|
||||
self.dest_id = int.from_bytes(raw_data[10:18], "big")
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
data = self.type.to_bytes(1, "big") + self.length.to_bytes(1, "big") + self.source_id.to_bytes(8, "big")
|
||||
if self.dest_id:
|
||||
data += self.dest_id.to_bytes(8, "big")
|
||||
return data
|
||||
|
||||
|
||||
class NeighbourTLV(TLV):
|
||||
type: int = 3
|
||||
length: int
|
||||
ip_address: int
|
||||
port: int
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
self.type = raw_data[0]
|
||||
self.length = raw_data[1]
|
||||
self.ip_address = raw_data[2:18]
|
||||
self.port = int.from_bytes(raw_data[18:20], "big")
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
return self.type.to_bytes(1, "big") + \
|
||||
self.length.to_bytes(1, "big") + \
|
||||
self.ip_address.to_bytes(16, "big") + \
|
||||
self.port.to_bytes(2, "big")
|
||||
|
||||
|
||||
class DataTLV(TLV):
|
||||
type: int = 4
|
||||
length: int
|
||||
sender_id: int
|
||||
nonce: int
|
||||
data: bytes
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
self.type = raw_data[0]
|
||||
self.length = raw_data[1]
|
||||
self.sender_id = int.from_bytes(raw_data[2:10], "big")
|
||||
self.nonce = int.from_bytes(raw_data[10:14], "big")
|
||||
self.data = raw_data[14:2 + self.length]
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
return self.type.to_bytes(1, "big") + \
|
||||
self.length.to_bytes(1, "big") + \
|
||||
self.sender_id.to_bytes(8, "big") + \
|
||||
self.nonce.to_bytes(4, "big") + \
|
||||
self.data
|
||||
|
||||
|
||||
class AckTLV(TLV):
|
||||
type: int = 5
|
||||
length: int
|
||||
sender_id: int
|
||||
nonce: int
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
self.type = raw_data[0]
|
||||
self.length = raw_data[1]
|
||||
self.sender_id = int.from_bytes(raw_data[2:10], "big")
|
||||
self.nonce = int.from_bytes(raw_data[10:14], "big")
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
return self.type.to_bytes(1, "big") + \
|
||||
self.length.to_bytes(1, "big") + \
|
||||
self.sender_id.to_bytes(8, "big") + \
|
||||
self.nonce.to_bytes(4, "big")
|
||||
|
||||
|
||||
class GoAwayTLV(TLV):
|
||||
class GoAwayType(Enum):
|
||||
UNKNOWN = 0
|
||||
EXIT = 1
|
||||
TIMEOUT = 2
|
||||
PROTOCOL_VIOLATION = 3
|
||||
|
||||
type: int = 6
|
||||
length: int
|
||||
code: GoAwayType
|
||||
message: str
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
self.type = raw_data[0]
|
||||
self.length = raw_data[1]
|
||||
self.code = GoAwayTLV.GoAwayType(raw_data[2])
|
||||
self.message = raw_data[3:self.length - 1].decode("UTF-8")
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
return self.type.to_bytes(1, "big") + \
|
||||
self.length.to_bytes(1, "big") + \
|
||||
self.code.value.to_bytes(1, "big") + \
|
||||
self.message.encode("UTF-8")[:self.length - 1]
|
||||
|
||||
|
||||
class WarningTLV(TLV):
|
||||
type: int = 7
|
||||
length: int
|
||||
message: str
|
||||
|
||||
def unmarshal(self, raw_data: bytes) -> None:
|
||||
self.type = raw_data[0]
|
||||
self.length = raw_data[1]
|
||||
self.message = raw_data[2:self.length].decode("UTF-8")
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
return self.type.to_bytes(1, "big") + \
|
||||
self.length.to_bytes(1, "big") + \
|
||||
self.message.encode("UTF-8")[:self.length]
|
||||
|
||||
|
||||
class Packet:
|
||||
"""
|
||||
|
@ -108,8 +301,11 @@ class Packet:
|
|||
pkt.magic = data[0]
|
||||
pkt.version = data[1]
|
||||
pkt.body_length = int.from_bytes(data[2:4], byteorder="big")
|
||||
pkt.body = TLV()
|
||||
pkt.body.raw_data = data[4:4+pkt.body_length]
|
||||
tlv_type = data[4]
|
||||
if not (0 <= tlv_type < len(TLV.tlv_classes())):
|
||||
raise ValueError(f"TLV type is not supported: {tlv_type}")
|
||||
pkt.body = TLV.tlv_classes()[tlv_type]()
|
||||
pkt.body.unmarshal(data[4:4+pkt.body_length])
|
||||
|
||||
pkt.validate_data()
|
||||
|
||||
|
@ -122,7 +318,7 @@ class Packet:
|
|||
data = self.magic.to_bytes(1, "big")
|
||||
data += self.version.to_bytes(1, "big")
|
||||
data += self.body_length.to_bytes(2, "big")
|
||||
data += self.body.raw_data
|
||||
data += self.body.marshal()
|
||||
return data
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue