WIP interaction model parsing and plumbing
This commit is contained in:
parent
7e235cd49a
commit
49d090df42
5 changed files with 535 additions and 86 deletions
|
|
@ -7,7 +7,6 @@ import hmac
|
|||
import pathlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import struct
|
||||
import time
|
||||
from ecdsa.ellipticcurve import AbstractPoint, Point, PointJacobi
|
||||
|
|
@ -16,7 +15,7 @@ from ecdsa.curves import NIST256p
|
|||
import cryptography
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESCCM
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Iterable
|
||||
|
||||
from . import tlv
|
||||
|
||||
|
|
@ -245,19 +244,19 @@ class PAKE3(tlv.TLVStructure):
|
|||
cA = tlv.OctetStringMember(1, CRYPTO_HASH_LEN_BYTES)
|
||||
|
||||
|
||||
class AttributePathIB(tlv.TLVList):
|
||||
class AttributePathIB(tlv.TLVStructure):
|
||||
"""Section 10.6.2"""
|
||||
|
||||
EnableTagCompression = tlv.BoolMember(0)
|
||||
Node = tlv.IntMember(1, signed=False, octets=8)
|
||||
Endpoint = tlv.IntMember(2, signed=False, octets=2)
|
||||
Cluster = tlv.IntMember(3, signed=False, octets=4)
|
||||
Attribute = tlv.IntMember(4, signed=False, octets=4)
|
||||
ListIndex = tlv.IntMember(5, signed=False, octets=2, nullable=True)
|
||||
WildcardPathFlags = tlv.IntMember(6, signed=False, octets=4)
|
||||
EnableTagCompression = tlv.BoolMember(0, optional=True)
|
||||
Node = tlv.IntMember(1, signed=False, octets=8, optional=True)
|
||||
Endpoint = tlv.IntMember(2, signed=False, octets=2, optional=True)
|
||||
Cluster = tlv.IntMember(3, signed=False, octets=4, optional=True)
|
||||
Attribute = tlv.IntMember(4, signed=False, octets=4, optional=True)
|
||||
ListIndex = tlv.IntMember(5, signed=False, octets=2, nullable=True, optional=True)
|
||||
WildcardPathFlags = tlv.IntMember(6, signed=False, octets=4, optional=True)
|
||||
|
||||
|
||||
class EventPathIB(tlv.TLVList):
|
||||
class EventPathIB(tlv.TLVStructure):
|
||||
"""Section 10.6.8"""
|
||||
|
||||
Node = tlv.IntMember(0, signed=False, octets=8)
|
||||
|
|
@ -274,25 +273,80 @@ class EventFilterIB(tlv.TLVStructure):
|
|||
EventMinimumInterval = tlv.IntMember(1, signed=False, octets=8)
|
||||
|
||||
|
||||
class ClusterPathIB(tlv.TLVList):
|
||||
class ClusterPathIB(tlv.TLVStructure):
|
||||
Node = tlv.IntMember(0, signed=False, octets=8)
|
||||
Endpoint = tlv.IntMember(1, signed=False, octets=2)
|
||||
Cluster = tlv.IntMember(2, signed=False, octets=4)
|
||||
|
||||
|
||||
class DataVersionFilterIB(tlv.TLVStructure):
|
||||
Path = tlv.ContainerMember(0, ClusterPathIB)
|
||||
Path = tlv.StructMember(0, ClusterPathIB)
|
||||
DataVersion = tlv.IntMember(1, signed=False, octets=4)
|
||||
|
||||
|
||||
class ReadRequestMessage(tlv.TLVStructure):
|
||||
FabricFiltered = tlv.BoolMember(3)
|
||||
class StatusIB(tlv.TLVStructure):
|
||||
Status = tlv.IntMember(0, signed=False, octets=1)
|
||||
ClusterStatus = tlv.IntMember(1, signed=False, octets=1)
|
||||
|
||||
def __init__(self):
|
||||
self.AttributeRequests = tlv.ArrayMember(0, AttributePathIB)
|
||||
self.EventRequests = tlv.ArrayMember(1, EventPathIB)
|
||||
self.EventFilters = tlv.ArrayMember(2, EventFilterIB)
|
||||
self.DataVersionFilters = tlv.ArrayMember(4, DataVersionFilterIB)
|
||||
|
||||
class AttributeDataIB(tlv.TLVStructure):
|
||||
DataVersion = tlv.IntMember(0, signed=False, octets=4)
|
||||
Path = tlv.StructMember(1, AttributePathIB)
|
||||
Data = tlv.AnythingMember(
|
||||
2
|
||||
) # This is a weird one because the TLV type can be anything.
|
||||
|
||||
|
||||
class AttributeStatusIB(tlv.TLVStructure):
|
||||
Path = tlv.StructMember(0, AttributePathIB)
|
||||
Status = tlv.StructMember(1, StatusIB)
|
||||
|
||||
|
||||
class AttributeReportIB(tlv.TLVStructure):
|
||||
AttributeStatus = tlv.StructMember(0, AttributeStatusIB)
|
||||
AttributeData = tlv.StructMember(1, AttributeDataIB)
|
||||
|
||||
|
||||
class ReadRequestMessage(tlv.TLVStructure):
|
||||
AttributeRequests = tlv.ArrayMember(0, tlv.List(AttributePathIB))
|
||||
EventRequests = tlv.ArrayMember(1, EventPathIB)
|
||||
EventFilters = tlv.ArrayMember(2, EventFilterIB)
|
||||
FabricFiltered = tlv.BoolMember(3)
|
||||
DataVersionFilters = tlv.ArrayMember(4, DataVersionFilterIB)
|
||||
|
||||
|
||||
class EventStatusIB(tlv.TLVStructure):
|
||||
Path = tlv.StructMember(0, EventPathIB)
|
||||
Status = tlv.StructMember(1, StatusIB)
|
||||
|
||||
|
||||
class EventDataIB(tlv.TLVStructure):
|
||||
Path = tlv.StructMember(0, EventPathIB)
|
||||
EventNumber = tlv.IntMember(1, signed=False, octets=8)
|
||||
PriorityLevel = tlv.IntMember(2, signed=False, octets=1)
|
||||
|
||||
# Only one of the below values
|
||||
EpochTimestamp = tlv.IntMember(3, signed=False, octets=8, optional=True)
|
||||
SystemTimestamp = tlv.IntMember(4, signed=False, octets=8, optional=True)
|
||||
DeltaEpochTimestamp = tlv.IntMember(5, signed=True, octets=8, optional=True)
|
||||
DeltaSystemTimestamp = tlv.IntMember(6, signed=True, octets=8, optional=True)
|
||||
|
||||
Data = tlv.AnythingMember(
|
||||
7
|
||||
) # This is a weird one because the TLV type can be anything.
|
||||
|
||||
|
||||
class EventReportIB(tlv.TLVStructure):
|
||||
EventStatus = tlv.StructMember(0, EventStatusIB)
|
||||
EventData = tlv.StructMember(1, EventDataIB)
|
||||
|
||||
|
||||
class ReportDataMessage(tlv.TLVStructure):
|
||||
SubscriptionId = tlv.IntMember(0, signed=False, octets=4)
|
||||
AttributeReports = tlv.ArrayMember(1, AttributeReportIB)
|
||||
EventReports = tlv.ArrayMember(2, EventReportIB)
|
||||
MoreChunkedMessages = tlv.BoolMember(3, optional=True)
|
||||
SuppressResponse = tlv.BoolMember(4, optional=True)
|
||||
|
||||
|
||||
class MessageReceptionState:
|
||||
|
|
@ -1017,8 +1071,8 @@ def Crypto_pA(w0, w1) -> bytes:
|
|||
return b""
|
||||
|
||||
|
||||
def Crypto_pB(w0: int, L: Point) -> tuple[int, AbstractPoint]:
|
||||
y = secrets.randbelow(NIST256p.order)
|
||||
def Crypto_pB(random_source, w0: int, L: Point) -> tuple[int, AbstractPoint]:
|
||||
y = random_source.randbelow(NIST256p.order)
|
||||
Y = y * NIST256p.generator + w0 * N
|
||||
return y, Y
|
||||
|
||||
|
|
@ -1109,15 +1163,169 @@ def Crypto_P2(tt, pA, pB) -> tuple[bytes, bytes, bytes]:
|
|||
return (cA, cB, Ke)
|
||||
|
||||
|
||||
class Attribute:
|
||||
def __init__(self, _id):
|
||||
self.id = _id
|
||||
|
||||
|
||||
class FeatureMap(Attribute):
|
||||
def __init__(self):
|
||||
super().__init__(0xFFFC)
|
||||
|
||||
|
||||
class NumberAttribute(Attribute):
|
||||
pass
|
||||
|
||||
|
||||
class ListAttribute(Attribute):
|
||||
pass
|
||||
|
||||
|
||||
class BoolAttribute(Attribute):
|
||||
pass
|
||||
|
||||
|
||||
class StructAttribute(Attribute):
|
||||
def __init__(self, _id, struct_type):
|
||||
self.struct_type = struct_type
|
||||
super().__init__(_id)
|
||||
|
||||
|
||||
class EnumAttribute(Attribute):
|
||||
def __init__(self, _id, enum_type):
|
||||
self.enum_type = enum_type
|
||||
super().__init__(_id)
|
||||
|
||||
|
||||
class OctetStringAttribute(Attribute):
|
||||
def __init__(self, _id, min_length, max_length):
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
super().__init__(_id)
|
||||
|
||||
|
||||
class BitmapAttribute(Attribute):
|
||||
pass
|
||||
|
||||
|
||||
class Cluster:
|
||||
feature_map = FeatureMap()
|
||||
|
||||
@classmethod
|
||||
def _attributes(cls) -> Iterable[tuple[str, Attribute]]:
|
||||
for field_name, descriptor in vars(cls).items():
|
||||
if not field_name.startswith("_") and isinstance(descriptor, Attribute):
|
||||
yield field_name, descriptor
|
||||
for field_name, descriptor in vars(Cluster).items():
|
||||
if not field_name.startswith("_") and isinstance(descriptor, Attribute):
|
||||
yield field_name, descriptor
|
||||
|
||||
def get_attribute_data(self, path) -> AttributeDataIB:
|
||||
print("get_attribute_data", path.Attribute)
|
||||
data = AttributeDataIB()
|
||||
data.Path = path
|
||||
found = False
|
||||
for field_name, descriptor in self._attributes():
|
||||
if descriptor.id != path.Attribute:
|
||||
continue
|
||||
print("read", field_name)
|
||||
data.Data = getattr(self, field_name)
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
print("not found", path.Attribute)
|
||||
return data
|
||||
|
||||
|
||||
class BasicInformationCluster(Cluster):
|
||||
CLUSTER_ID = 0x0028
|
||||
|
||||
|
||||
class GeneralCommissioningCluster(Cluster):
|
||||
CLUSTER_ID = 0x0030
|
||||
|
||||
class BasicCommissioningInfo(tlv.TLVStructure):
|
||||
FailSafeExpiryLengthSeconds = tlv.IntMember(0, signed=False, octets=2)
|
||||
MaxCumulativeFailsafeSeconds = tlv.IntMember(1, signed=False, octets=2)
|
||||
|
||||
class RegulatoryLocationType(enum.IntEnum):
|
||||
INDOOR = 0
|
||||
OUTDOOR = 1
|
||||
INDOOR_OUTDOOR = 2
|
||||
|
||||
breadcrumb = NumberAttribute(0)
|
||||
basic_commissioning_info = StructAttribute(1, BasicCommissioningInfo)
|
||||
regulatory_config = EnumAttribute(2, RegulatoryLocationType)
|
||||
location_capability = EnumAttribute(3, RegulatoryLocationType)
|
||||
support_concurrent_connection = BoolAttribute(4)
|
||||
|
||||
|
||||
class NetworkComissioningCluster(Cluster):
|
||||
CLUSTER_ID = 0x0031
|
||||
|
||||
class FeatureBitmap(enum.IntFlag):
|
||||
WIFI_NETWORK_INTERFACE = 0b001
|
||||
THREAD_NETWORK_INTERFACE = 0b010
|
||||
ETHERNET_NETWORK_INTERFACE = 0b100
|
||||
|
||||
class NetworkCommissioningStatus(enum.IntEnum):
|
||||
SUCCESS = 0
|
||||
"""Ok, no error"""
|
||||
|
||||
OUT_OF_RANGE = 1
|
||||
"""Value Outside Range"""
|
||||
|
||||
BOUNDS_EXCEEDED = 2
|
||||
"""A collection would exceed its size limit"""
|
||||
|
||||
NETWORK_ID_NOT_FOUND = 3
|
||||
"""The NetworkID is not among the collection of added networks"""
|
||||
|
||||
DUPLICATE_NETWORK_ID = 4
|
||||
"""The NetworkID is already among the collection of added networks"""
|
||||
|
||||
NETWORK_NOT_FOUND = 5
|
||||
"""Cannot find AP: SSID Not found"""
|
||||
|
||||
REGULATORY_ERROR = 6
|
||||
"""Cannot find AP: Mismatch on band/channels/regulatory domain / 2.4GHz vs 5GHz"""
|
||||
|
||||
AUTH_FAILURE = 7
|
||||
"""Cannot associate due to authentication failure"""
|
||||
|
||||
UNSUPPORTED_SECURITY = 8
|
||||
"""Cannot associate due to unsupported security mode"""
|
||||
|
||||
OTHER_CONNECTION_FAILURE = 9
|
||||
"""Other association failure"""
|
||||
|
||||
IPV6_FAILED = 10
|
||||
"""Failure to generate an IPv6 address"""
|
||||
|
||||
IP_BIND_FAILED = 11
|
||||
"""Failure to bind Wi-Fi <-> IP interfaces"""
|
||||
|
||||
UNKNOWN_ERROR = 12
|
||||
"""Unknown error"""
|
||||
|
||||
max_networks = NumberAttribute(0)
|
||||
networks = ListAttribute(1)
|
||||
scan_max_time_seconds = NumberAttribute(2)
|
||||
connect_max_time_seconds = NumberAttribute(3)
|
||||
interface_enabled = BoolAttribute(4)
|
||||
last_network_status = EnumAttribute(5, NetworkCommissioningStatus)
|
||||
last_network_id = OctetStringAttribute(6, min_length=1, max_length=32)
|
||||
last_connect_error_value = NumberAttribute(7)
|
||||
supported_wifi_bands = ListAttribute(8)
|
||||
supported_thread_features = BitmapAttribute(9)
|
||||
thread_version = NumberAttribute(10)
|
||||
|
||||
|
||||
class CircuitMatter:
|
||||
def __init__(self, socketpool, mdns_server, state_filename, record_to=None):
|
||||
def __init__(self, socketpool, mdns_server, random_source, state_filename):
|
||||
self.socketpool = socketpool
|
||||
self.mdns_server = mdns_server
|
||||
self.record_to = record_to
|
||||
if self.record_to:
|
||||
self.recorded_packets = []
|
||||
else:
|
||||
self.recorded_packets = None
|
||||
self.random = random_source
|
||||
|
||||
with open(state_filename, "r") as state_file:
|
||||
self.nonvolatile = json.load(state_file)
|
||||
|
|
@ -1150,6 +1358,11 @@ class CircuitMatter:
|
|||
if commission:
|
||||
self.start_commissioning()
|
||||
|
||||
self._endpoints = {}
|
||||
self.add_cluster(0, BasicInformationCluster())
|
||||
self.add_cluster(0, NetworkComissioningCluster())
|
||||
self.add_cluster(0, GeneralCommissioningCluster())
|
||||
|
||||
def start_commissioning(self):
|
||||
descriminator = self.nonvolatile["descriminator"]
|
||||
txt_records = {
|
||||
|
|
@ -1162,7 +1375,7 @@ class CircuitMatter:
|
|||
"T": "1",
|
||||
"VP": "65521+32769",
|
||||
}
|
||||
instance_name = os.urandom(8).hex().upper()
|
||||
instance_name = self.random.urandom(8).hex().upper()
|
||||
self.mdns_server.advertise_service(
|
||||
"_matterc",
|
||||
"_udp",
|
||||
|
|
@ -1175,6 +1388,11 @@ class CircuitMatter:
|
|||
],
|
||||
)
|
||||
|
||||
def add_cluster(self, endpoint, cluster):
|
||||
if endpoint not in self._endpoints:
|
||||
self._endpoints[endpoint] = {}
|
||||
self._endpoints[endpoint][cluster.CLUSTER_ID] = cluster
|
||||
|
||||
def process_packets(self):
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -1185,20 +1403,19 @@ class CircuitMatter:
|
|||
break
|
||||
if nbytes == 0:
|
||||
break
|
||||
if self.recorded_packets is not None:
|
||||
self.recorded_packets.append(
|
||||
(
|
||||
"receive",
|
||||
time.monotonic_ns(),
|
||||
addr,
|
||||
binascii.b2a_base64(
|
||||
self.packet_buffer[:nbytes], newline=False
|
||||
).decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
self.process_packet(addr, self.packet_buffer[:nbytes])
|
||||
|
||||
def get_report(self, cluster, path):
|
||||
report = AttributeReportIB()
|
||||
astatus = AttributeStatusIB()
|
||||
astatus.Path = path
|
||||
status = StatusIB()
|
||||
astatus.Status = status
|
||||
report.AttributeStatus = astatus
|
||||
report.AttributeData = cluster.get_attribute_data(path)
|
||||
return report
|
||||
|
||||
def process_packet(self, address, data):
|
||||
# Print the received data and the address of the sender
|
||||
# This is section 4.7.2
|
||||
|
|
@ -1207,15 +1424,12 @@ class CircuitMatter:
|
|||
message.source_ipaddress = address
|
||||
if message.secure_session:
|
||||
# Decrypt the payload
|
||||
print("decrypt message", message.session_id)
|
||||
secure_session_context = self.manager.secure_session_contexts[
|
||||
message.session_id
|
||||
]
|
||||
print(secure_session_context)
|
||||
print(message)
|
||||
print(message.payload.hex(" "))
|
||||
ok = secure_session_context.decrypt_and_verify(message)
|
||||
print("decrypt ok?", ok)
|
||||
if not ok:
|
||||
raise RuntimeError("Failed to decrypt message")
|
||||
message.parse_protocol_header()
|
||||
self.manager.mark_duplicate(message)
|
||||
|
||||
|
|
@ -1224,9 +1438,6 @@ class CircuitMatter:
|
|||
print(f"Dropping message {message.message_counter}")
|
||||
return
|
||||
|
||||
# print(f"Received packet from {address}:")
|
||||
# print(f"{data.hex(' ')}")
|
||||
# print(f"Message counter {message.message_counter}")
|
||||
protocol_id = message.protocol_id
|
||||
protocol_opcode = message.protocol_opcode
|
||||
|
||||
|
|
@ -1253,7 +1464,7 @@ class CircuitMatter:
|
|||
response.initiatorRandom = request.initiatorRandom
|
||||
|
||||
# Generate a random number
|
||||
response.responderRandom = os.urandom(32)
|
||||
response.responderRandom = self.random.urandom(32)
|
||||
session_context = self.manager.new_context()
|
||||
response.responderSessionId = session_context.local_session_id
|
||||
exchange.secure_session_context = session_context
|
||||
|
|
@ -1283,7 +1494,7 @@ class CircuitMatter:
|
|||
L = memoryview(verifier)[CRYPTO_GROUP_SIZE_BYTES:]
|
||||
L = Point.from_bytes(NIST256p.curve, L)
|
||||
w0 = int.from_bytes(w0, byteorder="big")
|
||||
y, Y = Crypto_pB(w0, L)
|
||||
y, Y = Crypto_pB(self.random, w0, L)
|
||||
# pB is Y encoded uncompressed
|
||||
# pA is X encoded uncompressed
|
||||
pake2.pB = Y.to_bytes("uncompressed")
|
||||
|
|
@ -1392,16 +1603,40 @@ class CircuitMatter:
|
|||
print("Received ICD Check-in")
|
||||
elif message.protocol_id == ProtocolId.INTERACTION_MODEL:
|
||||
print(message)
|
||||
print("application payload", message.application_payload.hex(" "))
|
||||
if protocol_opcode == InteractionModelOpcode.READ_REQUEST:
|
||||
print("Received Read Request")
|
||||
read_request = ReadRequestMessage(message.application_payload[1:-1])
|
||||
attribute_reports = []
|
||||
for attribute in read_request.AttributeRequests:
|
||||
for path in attribute:
|
||||
attribute = (
|
||||
"*" if path.Attribute is None else f"0x{path.Attribute:04x}"
|
||||
)
|
||||
print(
|
||||
f"Endpoint: {path.Endpoint}, Cluster: 0x{path.Cluster:02x}, Attribute: {attribute}"
|
||||
)
|
||||
if path.Endpoint is None:
|
||||
# Wildcard so we get it from every endpoint.
|
||||
for endpoint in self._endpoints:
|
||||
if path.Cluster in self._endpoints[endpoint]:
|
||||
cluster = self._endpoints[endpoint][path.Cluster]
|
||||
path.Endpoint = endpoint
|
||||
attribute_reports.append(
|
||||
self.get_report(cluster, path)
|
||||
)
|
||||
else:
|
||||
print(f"Cluster 0x{path.Cluster:02x} not found")
|
||||
else:
|
||||
if path.Cluster in self._endpoints[path.Endpoint]:
|
||||
cluster = self._endpoints[path.Endpoint][path.Cluster]
|
||||
attribute_reports.append(self.get_report(cluster, path))
|
||||
else:
|
||||
print(f"Cluster 0x{path.Cluster:02x} not found")
|
||||
response = ReportDataMessage()
|
||||
response.AttributeReports = attribute_reports
|
||||
print(read_request)
|
||||
if protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
|
||||
print("Received Invoke Request")
|
||||
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
|
||||
print("Received Invoke Response")
|
||||
|
||||
def __del__(self):
|
||||
if self.recorded_packets and self.record_to:
|
||||
with open(self.record_to, "w") as record_file:
|
||||
json.dump(self.recorded_packets, record_file)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,11 @@
|
|||
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import circuitmatter as cm
|
||||
|
||||
|
|
@ -36,13 +39,34 @@ class ReplaySocket:
|
|||
return len(data)
|
||||
|
||||
|
||||
class ReplayRandom:
|
||||
def __init__(self, replay_data):
|
||||
self.replay_data = replay_data
|
||||
|
||||
def urandom(self, nbytes):
|
||||
direction = None
|
||||
while direction != "urandom":
|
||||
direction, _, recorded_nbytes, data_b64 = self.replay_data.pop(0)
|
||||
if recorded_nbytes != nbytes:
|
||||
raise RuntimeError("Next replay random data is not the expected length")
|
||||
decoded = binascii.a2b_base64(data_b64)
|
||||
return decoded
|
||||
|
||||
def randbelow(self, n):
|
||||
direction = None
|
||||
while direction != "randbelow":
|
||||
direction, _, recorded_n, value = self.replay_data.pop(0)
|
||||
if recorded_n != n:
|
||||
raise RuntimeError("Next replay randbelow is not the expected length")
|
||||
return value
|
||||
|
||||
|
||||
class ReplaySocketPool:
|
||||
AF_INET6 = 0
|
||||
SOCK_DGRAM = 1
|
||||
|
||||
def __init__(self, replay_file):
|
||||
with open(replay_file, "r") as f:
|
||||
self.replay_data = json.load(f)
|
||||
def __init__(self, replay_lines):
|
||||
self.replay_data = replay_lines
|
||||
self._socket_created = False
|
||||
|
||||
def socket(self, *args, **kwargs):
|
||||
|
|
@ -99,17 +123,96 @@ class MDNSServer(DummyMDNS):
|
|||
active_service.kill()
|
||||
|
||||
|
||||
class RecordingRandom:
|
||||
def __init__(self, record_file):
|
||||
self.record_file = record_file
|
||||
|
||||
def urandom(self, nbytes):
|
||||
data = os.urandom(nbytes)
|
||||
entry = (
|
||||
"urandom",
|
||||
time.monotonic_ns(),
|
||||
nbytes,
|
||||
binascii.b2a_base64(data, newline=False).decode("utf-8"),
|
||||
)
|
||||
json.dump(entry, self.record_file)
|
||||
self.record_file.write("\n")
|
||||
return data
|
||||
|
||||
def randbelow(self, n):
|
||||
value = secrets.randbelow(n)
|
||||
entry = ("randbelow", time.monotonic_ns(), n, value)
|
||||
json.dump(entry, self.record_file)
|
||||
self.record_file.write("\n")
|
||||
return value
|
||||
|
||||
|
||||
class RecordingSocket:
|
||||
def __init__(self, record_file, socket):
|
||||
self.record_file = record_file
|
||||
self.socket = socket
|
||||
|
||||
def bind(self, address):
|
||||
self.socket.bind(address)
|
||||
|
||||
def setblocking(self, value):
|
||||
self.socket.setblocking(value)
|
||||
|
||||
def recvfrom_into(self, buffer, nbytes=None):
|
||||
nbytes, addr = self.socket.recvfrom_into(buffer, nbytes)
|
||||
entry = (
|
||||
"receive",
|
||||
time.monotonic_ns(),
|
||||
addr,
|
||||
binascii.b2a_base64(buffer[:nbytes], newline=False).decode("utf-8"),
|
||||
)
|
||||
json.dump(entry, self.record_file)
|
||||
self.record_file.write("\n")
|
||||
return nbytes, addr
|
||||
|
||||
def sendto(self, data, address):
|
||||
entry = (
|
||||
"send",
|
||||
time.monotonic_ns(),
|
||||
address,
|
||||
binascii.b2a_base64(data, newline=False).decode("utf-8"),
|
||||
)
|
||||
json.dump(entry, self.record_file)
|
||||
self.record_file.write("\n")
|
||||
return self.socket.sendto(data, address)
|
||||
|
||||
|
||||
class RecordingSocketPool:
|
||||
AF_INET6 = socket.AF_INET6
|
||||
SOCK_DGRAM = socket.SOCK_DGRAM
|
||||
|
||||
def __init__(self, record_file):
|
||||
self.record_file = record_file
|
||||
self._socket_created = False
|
||||
|
||||
def socket(self, *args, **kwargs):
|
||||
if self._socket_created:
|
||||
raise RuntimeError("Only one socket can be created")
|
||||
self._socket_created = True
|
||||
return RecordingSocket(self.record_file, socket.socket(*args, **kwargs))
|
||||
|
||||
|
||||
def run(replay_file=None):
|
||||
if replay_file:
|
||||
socketpool = ReplaySocketPool(replay_file)
|
||||
replay_lines = []
|
||||
with open(replay_file, "r") as f:
|
||||
for line in f:
|
||||
replay_lines.append(json.loads(line))
|
||||
socketpool = ReplaySocketPool(replay_lines)
|
||||
mdns_server = DummyMDNS()
|
||||
record_file = None
|
||||
random_source = ReplayRandom(replay_lines)
|
||||
else:
|
||||
socketpool = socket
|
||||
record_file = open("test_data/recorded_packets.jsonl", "w")
|
||||
socketpool = RecordingSocketPool(record_file)
|
||||
mdns_server = MDNSServer()
|
||||
record_file = "test_data/recorded_packets.json"
|
||||
random_source = RecordingRandom(record_file)
|
||||
matter = cm.CircuitMatter(
|
||||
socketpool, mdns_server, "test_data/device_state.json", record_file
|
||||
socketpool, mdns_server, random_source, "test_data/device_state.json"
|
||||
)
|
||||
while True:
|
||||
matter.process_packets()
|
||||
|
|
|
|||
|
|
@ -37,6 +37,23 @@ class ElementType(enum.IntEnum):
|
|||
END_OF_CONTAINER = 0b11000
|
||||
|
||||
|
||||
def find_container_end(buffer, start):
|
||||
nesting = 0
|
||||
end = start
|
||||
while buffer[end] != ElementType.END_OF_CONTAINER or nesting > 0:
|
||||
octet = buffer[end]
|
||||
if octet == ElementType.END_OF_CONTAINER:
|
||||
nesting -= 1
|
||||
elif (octet & 0x1F) in (
|
||||
ElementType.STRUCTURE,
|
||||
ElementType.ARRAY,
|
||||
ElementType.LIST,
|
||||
):
|
||||
nesting += 1
|
||||
end += 1
|
||||
return end + 1
|
||||
|
||||
|
||||
class TLVStructure:
|
||||
_max_length = None
|
||||
|
||||
|
|
@ -59,6 +76,8 @@ class TLVStructure:
|
|||
members = []
|
||||
for field, descriptor_class in self._members():
|
||||
value = descriptor_class.print(self)
|
||||
if not value:
|
||||
continue
|
||||
if isinstance(descriptor_class, StructMember):
|
||||
value = value.replace("\n", "\n ")
|
||||
members.append(f"{field} = {value}")
|
||||
|
|
@ -144,23 +163,8 @@ class TLVStructure:
|
|||
self.null_tags.add(this_tag)
|
||||
else: # Container
|
||||
value_offset = length_offset
|
||||
value_length = 0
|
||||
nesting = 0
|
||||
while (
|
||||
self.buffer[value_offset + value_length]
|
||||
!= ElementType.END_OF_CONTAINER
|
||||
or nesting > 0
|
||||
):
|
||||
octet = self.buffer[value_offset + value_length]
|
||||
if octet == ElementType.END_OF_CONTAINER:
|
||||
nesting -= 1
|
||||
elif (octet & 0x1F) in (
|
||||
ElementType.STRUCTURE,
|
||||
ElementType.ARRAY,
|
||||
ElementType.LIST,
|
||||
):
|
||||
nesting += 1
|
||||
value_length += 1
|
||||
end = find_container_end(self.buffer, value_offset)
|
||||
value_length = end - value_offset - 1
|
||||
|
||||
self.tag_value_offset[this_tag] = value_offset
|
||||
self.tag_value_length[this_tag] = value_length
|
||||
|
|
@ -172,6 +176,14 @@ class TLVStructure:
|
|||
else:
|
||||
self._offset = value_offset + value_length
|
||||
|
||||
if element_type in (
|
||||
ElementType.STRUCTURE,
|
||||
ElementType.ARRAY,
|
||||
ElementType.LIST,
|
||||
):
|
||||
# One more for the trailing 0x18
|
||||
self._offset += 1
|
||||
|
||||
if tag == this_tag:
|
||||
break
|
||||
|
||||
|
|
@ -304,9 +316,11 @@ class Member(ABC, Generic[_T, _OPT, _NULLABLE]):
|
|||
return new_offset
|
||||
return offset
|
||||
|
||||
def print(self, obj: TLVStructure) -> str:
|
||||
def print(self, obj: TLVStructure) -> Optional[str]:
|
||||
value = self.__get__(obj) # type: ignore # self inference issues
|
||||
if value is None:
|
||||
if self.optional:
|
||||
return None
|
||||
return "null"
|
||||
return self._print(value)
|
||||
|
||||
|
|
@ -587,15 +601,99 @@ class ArrayMember(Member[_TLVStruct, _OPT, _NULLABLE]):
|
|||
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)
|
||||
|
||||
def decode(self, buffer, length, offset=0):
|
||||
return self.substruct_class(buffer[offset : offset + length])
|
||||
entries = []
|
||||
if isinstance(self.substruct_class, List):
|
||||
i = 0
|
||||
while i < length:
|
||||
if buffer[offset + i] != ElementType.LIST:
|
||||
raise RuntimeError("Expected list start")
|
||||
start = offset + i
|
||||
end = start + 1
|
||||
while buffer[end] != ElementType.END_OF_CONTAINER:
|
||||
end += 1
|
||||
entries.append(self.substruct_class(buffer[start + 1 : end]))
|
||||
|
||||
i = (end + 1) - offset
|
||||
return entries
|
||||
|
||||
def _print(self, value):
|
||||
return str(value)
|
||||
s = ["[["]
|
||||
items = []
|
||||
for v in value:
|
||||
items.append(str(v))
|
||||
s.append(", ".join(items))
|
||||
s.append("]]")
|
||||
return "".join(s)
|
||||
|
||||
def encode_element_type(self, value):
|
||||
return ElementType.STRUCTURE
|
||||
return ElementType.ARRAY
|
||||
|
||||
def encode_value_into(self, value, buffer: bytearray, offset: int) -> int:
|
||||
offset = value.encode_into(buffer, offset)
|
||||
for v in value:
|
||||
offset = v.encode_into(buffer, offset)
|
||||
buffer[offset] = ElementType.END_OF_CONTAINER
|
||||
return offset + 1
|
||||
|
||||
|
||||
class ListIterator:
|
||||
def __init__(self, tlv_list: List):
|
||||
self.list = tlv_list
|
||||
self._offset = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._offset >= len(self.list.buffer):
|
||||
raise StopIteration
|
||||
|
||||
next_item = self.list.substruct_class(self.list.buffer)
|
||||
self._offset = len(self.list.buffer)
|
||||
return next_item
|
||||
|
||||
|
||||
class List:
|
||||
def __init__(self, substruct_class: Type[_TLVStruct], buffer=None):
|
||||
self.buffer = buffer
|
||||
self.substruct_class = substruct_class
|
||||
|
||||
def __call__(self, buffer):
|
||||
return List(self.substruct_class, buffer)
|
||||
|
||||
def _print_struct_members(self, struct):
|
||||
members = []
|
||||
for field, descriptor_class in struct._members():
|
||||
value = descriptor_class.print(struct)
|
||||
if not value:
|
||||
continue
|
||||
if isinstance(descriptor_class, StructMember):
|
||||
value = value.replace("\n ", " ")
|
||||
members.append(f"{field} = {value}")
|
||||
return ", ".join(members)
|
||||
|
||||
def __str__(self):
|
||||
items = []
|
||||
for v in self:
|
||||
items.append(self._print_struct_members(v))
|
||||
return "[[" + ", ".join(items) + "]]"
|
||||
|
||||
def __iter__(self):
|
||||
return ListIterator(self)
|
||||
|
||||
|
||||
class AnythingMember(Member):
|
||||
def __init__(self, tag):
|
||||
self.element_type = ElementType.NULL
|
||||
super().__init__(tag, optional=False, nullable=True)
|
||||
|
||||
def decode(self, buffer, length, offset=0):
|
||||
return None
|
||||
|
||||
def _print(self, value):
|
||||
return "???"
|
||||
|
||||
def encode_element_type(self, value):
|
||||
return self.element_type
|
||||
|
||||
def encode_value_into(self, value, buffer: bytearray, offset: int) -> int:
|
||||
return offset
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
[["receive", 22053600090440, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "BAAAALGoxwBQJ8u6A08zlwUgUU0AABUwASAsETJy1MZI35zvWSjBtjdwFK1FwzEpKJCepZW/hKLqLCUCzAUkAwAoBDUFJQH0ASUCLAElA6APJAQRJAULJgYAAAMBJAcBGBg="], ["receive", 22053606380662, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "BAAAALKoxwBQJ8u6A08zlwUiUU0AABUwAUEELPUQAb7V2HvHRBNMyJiVQrGlHjw9FIz+41h5wxnvYuZhg2wGT+hqfyP7RVOH4QcdAAFlIAdemfr8c3bRVKoIkRg="], ["receive", 22053621334402, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "BAAAALOoxwBQJ8u6A08zlwUkUU0AABUwASDjA9fk82ToT3rKrxDQMOK+7Pp1APSodFRzvO6UuIJzOhg="], ["receive", 22053621532375, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "AAEAAALzMgF457821uGoHhEhDzUhrzWXXEH8HBRfHNlrizMRs0CHkhx7+odFx/bY8o8JFB1FPlM8SFCCdzcEVtwm03ncTQIYw/SYTFXpkkY/wJNjEndAfBVm1eJ1wmE6QkWVlAOsXg4/Ix4aCOR8PE6xx2E17/pSaHstrRh4rw+V73o="], ["receive", 22054010148564, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "AAEAAALzMgF457821uGoHhEhDzUhrzWXXEH8HBRfHNlrizMRs0CHkhx7+odFx/bY8o8JFB1FPlM8SFCCdzcEVtwm03ncTQIYw/SYTFXpkkY/wJNjEndAfBVm1eJ1wmE6QkWVlAOsXg4/Ix4aCOR8PE6xx2E17/pSaHstrRh4rw+V73o="], ["receive", 22054413609739, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "AAEAAALzMgF457821uGoHhEhDzUhrzWXXEH8HBRfHNlrizMRs0CHkhx7+odFx/bY8o8JFB1FPlM8SFCCdzcEVtwm03ncTQIYw/SYTFXpkkY/wJNjEndAfBVm1eJ1wmE6QkWVlAOsXg4/Ix4aCOR8PE6xx2E17/pSaHstrRh4rw+V73o="], ["receive", 22054982238204, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "AAEAAALzMgF457821uGoHhEhDzUhrzWXXEH8HBRfHNlrizMRs0CHkhx7+odFx/bY8o8JFB1FPlM8SFCCdzcEVtwm03ncTQIYw/SYTFXpkkY/wJNjEndAfBVm1eJ1wmE6QkWVlAOsXg4/Ix4aCOR8PE6xx2E17/pSaHstrRh4rw+V73o="]]
|
||||
14
test_data/recorded_packets.jsonl
Normal file
14
test_data/recorded_packets.jsonl
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
["urandom", 104579896047969, 8, "C1VTrQNfuy8="]
|
||||
["receive", 104582529638004, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "BAAAAIS10Q+tnzF8Kv2/NQUgCG8AABUwASANZtmRFTwd2GhsllMTm0UyMHBkypLyNl1B1LjMXQ//2CUCv0kkAwAoBDUFJQH0ASUCLAElA6APJAQRJAULJgYAAAMBJAcBGBg="]
|
||||
["urandom", 104582529881604, 32, "Sqkt6+E937dbYLMD1Vo6WDUzBCc3A1sYebc7y1i6D3s="]
|
||||
["send", 104582530009855, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "AQAAAHmvTQqtnzF8Kv2/NQIhCG8AAIS10Q8VMAEgDWbZkRU8HdhobJZTE5tFMjBwZMqS8jZdQdS4zF0P/9gwAiBKqS3r4T3ft1tgswPVWjpYNTMEJzcDWxh5tzvLWLoPeyUDAQA1BCYBECcAADACIObgj9CEx2MyPagRHuoX1OB32N8u1aKUpNKjb4b854YkGBg="]
|
||||
["receive", 104582536515794, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "BAAAAIW10Q+tnzF8Kv2/NQUiCG8AABUwAUEEKlVOokxuEfcuW87SxBheW741rIVFlbZXHo2OXy5MV7L0Kw1E5YP909h62AI4R3dQGBz7K5mdIaeKs6/krcdIoRg="]
|
||||
["randbelow", 104582536608839, 115792089210356248762697446949407573529996955224135760342422259061068512044369, 37290483390246755562588887411941405615320357838675862131986900152226936389112]
|
||||
["send", 104582552142683, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "AQAAAHqvTQqtnzF8Kv2/NQIjCG8AAIW10Q8VMAFBBBpE90nK9kI0BzmVcsJGT98Uzbs98RclP7gXpusB33iJBxP3+L8rLS7fNaG6Hh0UaSZWqfDJMIVLrhvlQu5+1IgwAiBc8cC7YaaljZBTgf9t3jIP3xjGPL8Lpvav+cGczItQ4Rg="]
|
||||
["receive", 104582552664477, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "BAAAAIa10Q+tnzF8Kv2/NQUkCG8AABUwASCI0PP7dtDbz85Kfg7nZAKG/MfwIoXLdEsRyJkT+U1gHBg="]
|
||||
["send", 104582552742905, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "AQAAAHuvTQqtnzF8Kv2/NQJACG8AAIa10Q8AAAAAAAAAAA=="]
|
||||
["receive", 104582552923846, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "AAEAAA6nxww9xhjZmrHkONd0G9BPNoY42+U1+0H81vB65JcnFdLSR9VAXsosYDSoOtxRelT2AKoXT+l82Tz8aL+eZZZgjVQPx9OsEtdFqXBoU3YaIhhK9lfhSlMlNfmn8nzlgFS9cuWYXsxaA3YI/HpjzJk/TQAEk9dASl1yI6mp94M="]
|
||||
["receive", 104582961443495, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "AAEAAA6nxww9xhjZmrHkONd0G9BPNoY42+U1+0H81vB65JcnFdLSR9VAXsosYDSoOtxRelT2AKoXT+l82Tz8aL+eZZZgjVQPx9OsEtdFqXBoU3YaIhhK9lfhSlMlNfmn8nzlgFS9cuWYXsxaA3YI/HpjzJk/TQAEk9dASl1yI6mp94M="]
|
||||
["receive", 104583317899582, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "AAEAAA6nxww9xhjZmrHkONd0G9BPNoY42+U1+0H81vB65JcnFdLSR9VAXsosYDSoOtxRelT2AKoXT+l82Tz8aL+eZZZgjVQPx9OsEtdFqXBoU3YaIhhK9lfhSlMlNfmn8nzlgFS9cuWYXsxaA3YI/HpjzJk/TQAEk9dASl1yI6mp94M="]
|
||||
["receive", 104583886314469, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "AAEAAA6nxww9xhjZmrHkONd0G9BPNoY42+U1+0H81vB65JcnFdLSR9VAXsosYDSoOtxRelT2AKoXT+l82Tz8aL+eZZZgjVQPx9OsEtdFqXBoU3YaIhhK9lfhSlMlNfmn8nzlgFS9cuWYXsxaA3YI/HpjzJk/TQAEk9dASl1yI6mp94M="]
|
||||
["receive", 104584805340349, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 38269, 0, 0], "AAEAAA6nxww9xhjZmrHkONd0G9BPNoY42+U1+0H81vB65JcnFdLSR9VAXsosYDSoOtxRelT2AKoXT+l82Tz8aL+eZZZgjVQPx9OsEtdFqXBoU3YaIhhK9lfhSlMlNfmn8nzlgFS9cuWYXsxaA3YI/HpjzJk/TQAEk9dASl1yI6mp94M="]
|
||||
Loading…
Reference in a new issue