More abstraction on packet building

Signed-off-by: Yohann D'ANELLO <ynerant@crans.org>
This commit is contained in:
Yohann D'ANELLO 2020-12-21 16:04:10 +01:00
parent 592cbc3792
commit e8fa0ece22
Signed by: ynerant
GPG Key ID: 3A75C55819C8CF85
1 changed files with 42 additions and 14 deletions

View File

@ -167,16 +167,7 @@ class Squinnondation:
squirrel.add_message(msg)
for hazelnut in list(squirrel.hazelnuts.values()):
pkt = Packet()
pkt.magic = 95
pkt.version = 0
tlv = DataTLV()
tlv.data = msg.encode("UTF-8")
tlv.sender_id = 42
tlv.nonce = 18
tlv.length = len(tlv.data) + 1 + 1 + 8 + 4
pkt.body = [tlv]
pkt.body_length = tlv.length + 2
pkt = Packet.construct(DataTLV.construct(msg))
squirrel.send_packet(hazelnut, pkt)
@ -186,6 +177,7 @@ class TLV:
TODO: add subclasses for each type of TLV
"""
type: int
length: int
def unmarshal(self, raw_data: bytes) -> None:
"""
@ -213,6 +205,14 @@ class TLV:
It is ensured that the data is valid.
"""
@property
def tlv_length(self) -> int:
"""
Returns the total length (in bytes) of the TLV, including the type and the length.
Except for Pad1, this is 2 plus the length of the body of the TLV.
"""
return 2 + self.length
@staticmethod
def tlv_classes():
return [Pad1TLV, PadNTLV, HelloTLV, NeighbourTLV, DataTLV, AckTLV, GoAwayTLV, WarningTLV]
@ -241,6 +241,13 @@ class Pad1TLV(TLV):
squirrel.add_system_message(f"For each byte in the packet that I received, you will die today. "
"And eat cookies.")
@property
def tlv_length(self) -> int:
"""
A Pad1 has always a length of 1.
"""
return 1
class PadNTLV(TLV):
"""
@ -261,7 +268,7 @@ class PadNTLV(TLV):
"""
self.type = raw_data[0]
self.length = raw_data[1]
self.mbz = raw_data[2:2 + self.length]
self.mbz = raw_data[2:self.tlv_length]
def marshal(self) -> bytes:
"""
@ -345,7 +352,7 @@ class DataTLV(TLV):
self.length = raw_data[1]
self.sender_id = int.from_bytes(raw_data[2:10], "big")
self.nonce = int.from_bytes(raw_data[10:14], "big")
self.data = raw_data[14:2 + self.length]
self.data = raw_data[14:self.tlv_length]
def marshal(self) -> bytes:
return self.type.to_bytes(1, "big") + \
@ -361,6 +368,16 @@ class DataTLV(TLV):
"""
squirrel.add_message(self.data.decode('UTF-8'))
@staticmethod
def construct(message: str) -> "DataTLV":
tlv = DataTLV()
tlv.type = 4
tlv.sender_id = 42 # FIXME Use the good sender id
tlv.nonce = 42 # FIXME Use an incremental nonce
tlv.data = message.encode("UTF-8")
tlv.length = 12 + len(tlv.data)
return tlv
class AckTLV(TLV):
type: int = 5
@ -476,8 +493,7 @@ class Packet:
tlv = TLV.tlv_classes()[tlv_type]()
tlv.unmarshal(data[4:4 + pkt.body_length])
pkt.body.append(tlv)
# Pad1TLV has no length
read_bytes += 1 if tlv_type == 0 else tlv.length + 2
read_bytes += tlv.tlv_length
pkt.validate_data()
@ -493,6 +509,18 @@ class Packet:
data += b"".join(tlv.marshal() for tlv in self.body)
return data
@staticmethod
def construct(*tlvs: TLV) -> "Packet":
"""
Construct a new packet from the given TLVs and calculate the good lengths
"""
pkt = Packet()
pkt.magic = 95
pkt.version = 0
pkt.body = tlvs
pkt.body_length = sum(tlv.tlv_length for tlv in tlvs)
return pkt
class Hazelnut:
"""