A packet can have multiple TLVs
Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
parent
a796bed259
commit
9561912ac6
|
@ -6,7 +6,7 @@ from argparse import ArgumentParser
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from ipaddress import IPv6Address
|
from ipaddress import IPv6Address
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
class Squinnondation:
|
class Squinnondation:
|
||||||
|
@ -50,13 +50,14 @@ class Squinnondation:
|
||||||
pkt = Packet()
|
pkt = Packet()
|
||||||
pkt.magic = 95
|
pkt.magic = 95
|
||||||
pkt.version = 0
|
pkt.version = 0
|
||||||
pkt.body = DataTLV()
|
tlv = DataTLV()
|
||||||
msg = f"Hello world, my name is {squirrel.nickname}!"
|
msg = f"Hello world, my name is {squirrel.nickname}!"
|
||||||
pkt.body.data = msg.encode("UTF-8")
|
tlv.data = msg.encode("UTF-8")
|
||||||
pkt.body.sender_id = 42
|
tlv.sender_id = 42
|
||||||
pkt.body.nonce = 18
|
tlv.nonce = 18
|
||||||
pkt.body.length = len(msg) + 1 + 1 + 8 + 4
|
tlv.length = len(msg) + 1 + 1 + 8 + 4
|
||||||
pkt.body_length = pkt.body.length + 2
|
pkt.body = [tlv]
|
||||||
|
pkt.body_length = tlv.length + 2
|
||||||
squirrel.send_packet(hazelnut, pkt)
|
squirrel.send_packet(hazelnut, pkt)
|
||||||
|
|
||||||
Worm(squirrel).start()
|
Worm(squirrel).start()
|
||||||
|
@ -275,7 +276,7 @@ class Packet:
|
||||||
magic: int
|
magic: int
|
||||||
version: int
|
version: int
|
||||||
body_length: int
|
body_length: int
|
||||||
body: TLV
|
body: List[TLV]
|
||||||
|
|
||||||
def validate_data(self) -> bool:
|
def validate_data(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -289,7 +290,7 @@ class Packet:
|
||||||
if not (0 <= self.body_length <= 120):
|
if not (0 <= self.body_length <= 120):
|
||||||
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
|
raise ValueError("The body length of the packet is negative or too high. It must be between 0 and 1020,"
|
||||||
"found: {:d}".format(self.body_length))
|
"found: {:d}".format(self.body_length))
|
||||||
return self.body.validate_data()
|
return all(tlv.validate_data() for tlv in self.body)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unmarshal(data: bytes) -> "Packet":
|
def unmarshal(data: bytes) -> "Packet":
|
||||||
|
@ -301,11 +302,17 @@ class Packet:
|
||||||
pkt.magic = data[0]
|
pkt.magic = data[0]
|
||||||
pkt.version = data[1]
|
pkt.version = data[1]
|
||||||
pkt.body_length = int.from_bytes(data[2:4], byteorder="big")
|
pkt.body_length = int.from_bytes(data[2:4], byteorder="big")
|
||||||
|
pkt.body = []
|
||||||
|
read_bytes = 0
|
||||||
|
while read_bytes <= min(len(data) - 4, pkt.body_length):
|
||||||
tlv_type = data[4]
|
tlv_type = data[4]
|
||||||
if not (0 <= tlv_type < len(TLV.tlv_classes())):
|
if not (0 <= tlv_type < len(TLV.tlv_classes())):
|
||||||
raise ValueError(f"TLV type is not supported: {tlv_type}")
|
raise ValueError(f"TLV type is not supported: {tlv_type}")
|
||||||
pkt.body = TLV.tlv_classes()[tlv_type]()
|
tlv = TLV.tlv_classes()[tlv_type]()
|
||||||
pkt.body.unmarshal(data[4:4 + pkt.body_length])
|
tlv.unmarshal(data[4:4 + pkt.body_length])
|
||||||
|
pkt.body.append(tlv)
|
||||||
|
# Pad1TLV has no length
|
||||||
|
read_bytes += 1 if tlv_type == 0 else tlv.length + 2
|
||||||
|
|
||||||
pkt.validate_data()
|
pkt.validate_data()
|
||||||
|
|
||||||
|
@ -318,7 +325,7 @@ class Packet:
|
||||||
data = self.magic.to_bytes(1, "big")
|
data = self.magic.to_bytes(1, "big")
|
||||||
data += self.version.to_bytes(1, "big")
|
data += self.version.to_bytes(1, "big")
|
||||||
data += self.body_length.to_bytes(2, "big")
|
data += self.body_length.to_bytes(2, "big")
|
||||||
data += self.body.marshal()
|
data += b"".join(tlv.marshal() for tlv in self.body)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@ -410,15 +417,17 @@ class Worm(Thread):
|
||||||
except ValueError as error:
|
except ValueError as error:
|
||||||
print("An error occured while receiving a packet: ", error)
|
print("An error occured while receiving a packet: ", error)
|
||||||
else:
|
else:
|
||||||
print(pkt.body.data.decode('UTF-8'))
|
print(pkt.body[0].data.decode('UTF-8'))
|
||||||
pkt = Packet()
|
pkt = Packet()
|
||||||
pkt.magic = 95
|
pkt.magic = 95
|
||||||
pkt.version = 0
|
pkt.version = 0
|
||||||
pkt.body = DataTLV()
|
pkt.body = []
|
||||||
|
tlv = DataTLV()
|
||||||
msg = f"Hello my dear hazelnut, I am {self.squirrel.nickname}!"
|
msg = f"Hello my dear hazelnut, I am {self.squirrel.nickname}!"
|
||||||
pkt.body.data = msg.encode("UTF-8")
|
tlv.data = msg.encode("UTF-8")
|
||||||
pkt.body.sender_id = 42
|
tlv.sender_id = 42
|
||||||
pkt.body.nonce = 18
|
tlv.nonce = 18
|
||||||
pkt.body.length = len(msg) + 1 + 1 + 8 + 4
|
tlv.length = len(msg) + 1 + 1 + 8 + 4
|
||||||
pkt.body_length = pkt.body.length + 2
|
pkt.body.append(tlv)
|
||||||
|
pkt.body_length = tlv.length + 2
|
||||||
self.squirrel.send_packet(hazelnut, pkt)
|
self.squirrel.send_packet(hazelnut, pkt)
|
||||||
|
|
Loading…
Reference in New Issue