diff --git a/circuitmatter/__init__.py b/circuitmatter/__init__.py index bcfeed1..25e848c 100644 --- a/circuitmatter/__init__.py +++ b/circuitmatter/__init__.py @@ -2,11 +2,11 @@ import binascii import hashlib -import json import time from . import case from . import interaction_model +from . import nonvolatile from .message import Message from .protocol import InteractionModelOpcode, ProtocolId, SecureProtocolOpcode from . import session @@ -29,15 +29,12 @@ class CircuitMatter: self.mdns_server = mdns_server self.random = random_source - with open(state_filename, "r") as state_file: - self.nonvolatile = json.load(state_file) + self.nonvolatile = nonvolatile.PersistentDictionary(state_filename) for key in ["discriminator", "salt", "iteration-count", "verifier"]: if key not in self.nonvolatile: raise RuntimeError(f"Missing key {key} in state file") - commission = "fabrics" not in self.nonvolatile - self.packet_buffer = memoryview(bytearray(1280)) # Define the UDP IP address and port @@ -51,6 +48,7 @@ class CircuitMatter: # Bind the socket to the IP and port self.socket.bind((UDP_IP, self.UDP_PORT)) + print(f"Listening on UDP port {self.UDP_PORT}") self.socket.setblocking(False) self._endpoints = {} @@ -62,14 +60,11 @@ class CircuitMatter: self.vendor_id = vendor_id self.product_id = product_id - self.manager = session.SessionManager( self.random, self.socket, self.root_node.noc ) - print(f"Listening on UDP port {self.UDP_PORT}") - - if commission: + if self.root_node.fabric_count == 0: self.start_commissioning() def start_commissioning(self): @@ -118,6 +113,12 @@ class CircuitMatter: device.descriptor.ServerList.append(server.CLUSTER_ID) self.add_cluster(self._next_endpoint, server) self.add_cluster(self._next_endpoint, device.descriptor) + + if "devices" not in self.nonvolatile: + self.nonvolatile["devices"] = {} + if device.name not in self.nonvolatile["devices"]: + self.nonvolatile["devices"][device.name] = {} + device.restore(self.nonvolatile["devices"][device.name]) self._next_endpoint += 1 def process_packets(self): @@ -249,9 +250,7 @@ class CircuitMatter: from . import pase # This is Section 4.14.1.2 - request, _ = pase.PBKDFParamRequest.decode( - message.application_payload[0], message.application_payload[1:] - ) + request = pase.PBKDFParamRequest.decode(message.application_payload) exchange.commissioning_hash = hashlib.sha256( b"CHIP PAKE V1 Commissioning" ) @@ -287,9 +286,7 @@ class CircuitMatter: from . import pase print("Received PASE PAKE1") - pake1, _ = pase.PAKE1.decode( - message.application_payload[0], message.application_payload[1:] - ) + pake1 = pase.PAKE1.decode(message.application_payload) pake2 = pase.PAKE2() verifier = binascii.a2b_base64(self.nonvolatile["verifier"]) context = exchange.commissioning_hash.digest() @@ -308,9 +305,7 @@ class CircuitMatter: from . import pase print("Received PASE PAKE3") - pake3, _ = pase.PAKE3.decode( - message.application_payload[0], message.application_payload[1:] - ) + pake3 = pase.PAKE3.decode(message.application_payload) if pake3.cA != exchange.cA: del exchange.cA del exchange.Ke @@ -341,9 +336,7 @@ class CircuitMatter: print("PASE succeeded") elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA1: print("Received CASE Sigma1") - sigma1, _ = case.Sigma1.decode( - message.application_payload[0], message.application_payload[1:] - ) + sigma1 = case.Sigma1.decode(message.application_payload) response = self.manager.reply_to_sigma1(exchange, sigma1) exchange.send(response) @@ -351,9 +344,7 @@ class CircuitMatter: print("Received CASE Sigma2") elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA3: print("Received CASE Sigma3") - sigma3, _ = case.Sigma3.decode( - message.application_payload[0], message.application_payload[1:] - ) + sigma3 = case.Sigma3.decode(message.application_payload) protocol_code = self.manager.reply_to_sigma3(exchange, sigma3) error_status = session.StatusReport() @@ -390,8 +381,8 @@ class CircuitMatter: message.session_id ] if protocol_opcode == InteractionModelOpcode.READ_REQUEST: - read_request, _ = interaction_model.ReadRequestMessage.decode( - message.application_payload[0], message.application_payload[1:] + read_request = interaction_model.ReadRequestMessage.decode( + message.application_payload ) attribute_reports = [] for path in read_request.AttributeRequests: @@ -404,8 +395,8 @@ class CircuitMatter: exchange.send(response) elif protocol_opcode == InteractionModelOpcode.WRITE_REQUEST: print("Received Write Request") - write_request, _ = interaction_model.WriteRequestMessage.decode( - message.application_payload[0], message.application_payload[1:] + write_request = interaction_model.WriteRequestMessage.decode( + message.application_payload ) write_responses = [] for request in write_request.WriteRequests: @@ -421,8 +412,8 @@ class CircuitMatter: elif protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST: print("Received Invoke Request") - invoke_request, _ = interaction_model.InvokeRequestMessage.decode( - message.application_payload[0], message.application_payload[1:] + invoke_request = interaction_model.InvokeRequestMessage.decode( + message.application_payload ) for invoke in invoke_request.InvokeRequests: path = invoke.CommandPath @@ -460,14 +451,13 @@ class CircuitMatter: response = interaction_model.InvokeResponseMessage() response.SuppressResponse = False response.InvokeResponses = invoke_responses - print("sending invoke response", response) exchange.send(response) elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE: print("Received Invoke Response") elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST: print("Received Subscribe Request") - subscribe_request, _ = interaction_model.SubscribeRequestMessage.decode( - message.application_payload[0], message.application_payload[1:] + subscribe_request = interaction_model.SubscribeRequestMessage.decode( + message.application_payload ) print(subscribe_request) attribute_reports = [] @@ -484,8 +474,8 @@ class CircuitMatter: final_response.MaxInterval = subscribe_request.MaxIntervalCeiling exchange.queue(final_response) elif protocol_opcode == InteractionModelOpcode.STATUS_RESPONSE: - status_response, _ = interaction_model.StatusResponseMessage.decode( - message.application_payload[0], message.application_payload[1:] + status_response = interaction_model.StatusResponseMessage.decode( + message.application_payload ) print( f"Received Status Response on {message.session_id}/{message.exchange_id} ack {message.acknowledged_message_counter}: {status_response.Status!r}" @@ -502,3 +492,6 @@ class CircuitMatter: else: print("Unknown protocol", message.protocol_id, message.protocol_opcode) print() + + self.nonvolatile.commit() + # TODO: Rollback on error? diff --git a/circuitmatter/__main__.py b/circuitmatter/__main__.py index f246e11..5ae0674 100644 --- a/circuitmatter/__main__.py +++ b/circuitmatter/__main__.py @@ -3,6 +3,7 @@ import binascii import json import os +import pathlib import secrets import socket import subprocess @@ -114,13 +115,8 @@ class MDNSServer(DummyMDNS): subtypes=[], instance_name="", ): - for active_service in self.active_services.values(): - active_service.kill() subtypes = [f"--subtype={subtype}" for subtype in subtypes] txt_records = [f"{key}={value}" for key, value in txt_records.items()] - if service_type in self.active_services: - self.active_services[service_type].kill() - del self.active_services[service_type] command = [ "avahi-publish-service", *subtypes, @@ -130,7 +126,7 @@ class MDNSServer(DummyMDNS): *txt_records, ] print("running avahi", command) - self.active_services[service_type] = subprocess.Popen(command) + self.active_services[service_type + instance_name] = subprocess.Popen(command) if self.publish_address is None: command = [ "avahi-publish-address", @@ -226,23 +222,33 @@ class NeoPixel(on_off.OnOffLight): def run(replay_file=None): + device_state = pathlib.Path("test_data/device_state.json") + replay_device_state = pathlib.Path("test_data/replay_device_state.json") if replay_file: replay_lines = [] with open(replay_file, "r") as f: + device_state_fn = f.readline().strip() for line in f: replay_lines.append(json.loads(line)) socketpool = ReplaySocketPool(replay_lines) mdns_server = DummyMDNS() random_source = ReplayRandom(replay_lines) + # Reset device state to before the captured run + device_state.write_text(pathlib.Path(device_state_fn).read_text()) else: - record_file = open("test_data/recorded_packets.jsonl", "w") + timestamp = time.strftime("%Y%m%d-%H%M%S") + record_file = open(f"test_data/recorded_packets-{timestamp}.jsonl", "w") + device_state_fn = f"test_data/device_state-{timestamp}.json" + record_file.write(f"{device_state_fn}\n") socketpool = RecordingSocketPool(record_file) mdns_server = MDNSServer() random_source = RecordingRandom(record_file) - matter = cm.CircuitMatter( - socketpool, mdns_server, random_source, "test_data/device_state.json" - ) - led = NeoPixel() + # Save device state before we run so replays can use it. + replay_device_state = pathlib.Path(device_state_fn) + replay_device_state.write_text(device_state.read_text()) + + matter = cm.CircuitMatter(socketpool, mdns_server, random_source, device_state) + led = NeoPixel("neopixel1") matter.add_device(led) while True: matter.process_packets() diff --git a/circuitmatter/clusters/device_management/group_key_management.py b/circuitmatter/clusters/device_management/group_key_management.py index af6c8d8..473d819 100644 --- a/circuitmatter/clusters/device_management/group_key_management.py +++ b/circuitmatter/clusters/device_management/group_key_management.py @@ -24,13 +24,15 @@ class GroupKeyMulticastPolicyEnum(Enum8): class GroupKeySetStruct(tlv.Structure): GroupKeySetID = tlv.IntMember(0, signed=False, octets=2) GroupKeySecurityPolicy = tlv.EnumMember(1, GroupKeySetSecurityPolicyEnum) - EpochKey0 = tlv.OctetStringMember(2, 16) - EpochStartTime0 = tlv.IntMember(3, signed=False, octets=8) - EpochKey1 = tlv.OctetStringMember(4, 16) - EpochStartTime1 = tlv.IntMember(5, signed=False, octets=8) - EpochKey2 = tlv.OctetStringMember(6, 16) - EpochStartTime2 = tlv.IntMember(7, signed=False, octets=8) - GroupKeyMulticastPolicy = tlv.EnumMember(8, GroupKeyMulticastPolicyEnum) + EpochKey0 = tlv.OctetStringMember(2, 16, nullable=True) + EpochStartTime0 = tlv.IntMember(3, signed=False, octets=8, nullable=True) + EpochKey1 = tlv.OctetStringMember(4, 16, nullable=True) + EpochStartTime1 = tlv.IntMember(5, signed=False, octets=8, nullable=True) + EpochKey2 = tlv.OctetStringMember(6, 16, nullable=True) + EpochStartTime2 = tlv.IntMember(7, signed=False, octets=8, nullable=True) + GroupKeyMulticastPolicy = tlv.EnumMember( + 8, GroupKeyMulticastPolicyEnum, nullable=True + ) class GroupKeyManagementCluster(Cluster): @@ -48,7 +50,7 @@ class GroupKeyManagementCluster(Cluster): class KeySetWrite(tlv.Structure): GroupKeySet = tlv.StructMember(0, GroupKeySetStruct) - group_key_map = ListAttribute(0, GroupKeyMapStruct, default=[]) + group_key_map = ListAttribute(0, GroupKeyMapStruct, default=[], N_nonvolatile=True) group_table = ListAttribute(1, GroupInfoMapStruct, default=[]) max_groups_per_fabric = NumberAttribute(2, signed=False, bits=16, default=0) max_group_keys_per_fabric = NumberAttribute(3, signed=False, bits=16, default=1) diff --git a/circuitmatter/clusters/device_management/node_operational_credentials.py b/circuitmatter/clusters/device_management/node_operational_credentials.py index 97c480d..f32816d 100644 --- a/circuitmatter/clusters/device_management/node_operational_credentials.py +++ b/circuitmatter/clusters/device_management/node_operational_credentials.py @@ -99,12 +99,20 @@ class NodeOperationalCredentialsCluster(Cluster): class AddTrustedRootCertificate(tlv.Structure): RootCACertificate = tlv.OctetStringMember(0, 400) - nocs = ListAttribute(0, NOCStruct, N_nonvolatile=True, C_changes_omitted=True) - fabrics = ListAttribute(1, FabricDescriptorStruct, N_nonvolatile=True) + nocs = ListAttribute( + 0, NOCStruct, N_nonvolatile=True, C_changes_omitted=True, default=[] + ) + fabrics = ListAttribute(1, FabricDescriptorStruct, N_nonvolatile=True, default=[]) supported_fabrics = NumberAttribute(2, signed=False, bits=8, F_fixed=True) - commissioned_fabrics = NumberAttribute(3, signed=False, bits=8, N_nonvolatile=True) + commissioned_fabrics = NumberAttribute( + 3, signed=False, bits=8, N_nonvolatile=True, default=0 + ) trusted_root_certificates = ListAttribute( - 4, tlv.OctetStringMember(None, 400), N_nonvolatile=True, C_changes_omitted=True + 4, + tlv.OctetStringMember(None, 400), + N_nonvolatile=True, + C_changes_omitted=True, + default=[], ) # This attribute is weird because it is fabric sensitive but not marked as such. # Cluster sets current_fabric_index for use in fabric sensitive attributes and diff --git a/circuitmatter/data_model.py b/circuitmatter/data_model.py index e36b639..ad9ac3e 100644 --- a/circuitmatter/data_model.py +++ b/circuitmatter/data_model.py @@ -1,3 +1,4 @@ +import binascii import enum import inspect import random @@ -9,6 +10,8 @@ from typing import Iterable, Union from . import interaction_model from . import tlv +ATTRIBUTES_KEY = "a" + class Enum8(enum.IntEnum): pass @@ -80,6 +83,7 @@ class Attribute: self.optional = optional self.feature = feature self.nullable = X_nullable + self.nonvolatile = N_nonvolatile def __get__(self, instance, cls): v = instance._attribute_values.get(self.id, None) @@ -92,8 +96,16 @@ class Attribute: if old_value == value: return instance._attribute_values[self.id] = value + if self.nonvolatile: + instance._nonvolatile[ATTRIBUTES_KEY][hex(self.id)] = self.to_json(value) instance.data_version += 1 + def to_json(self, value): + return value + + def from_json(self, value): + return value + def encode(self, value) -> bytes: if value is None and self.nullable: return b"\x14" # No tag, NULL @@ -145,6 +157,35 @@ class EnumAttribute(NumberAttribute): super().__init__(_id, signed=False, bits=bits, **kwargs) +class _PersistentList: + def __init__(self, wrapped_list, attribute, instance): + self._list = wrapped_list + self._instance = instance + self._attribute = attribute + + def append(self, value): + self._list.append(value) + self._instance._nonvolatile[ATTRIBUTES_KEY][hex(self._attribute.id)] = ( + self._attribute.to_json(self._list) + ) + + def __getitem__(self, index): + return self._list[index] + + def __setitem__(self, index, value): + self._list[index] = value + self._dirty = True + + def __iter__(self): + return iter(self._list) + + def __len__(self): + return len(self._list) + + def __str__(self): + return "persistent" + str(self._list) + + class ListAttribute(Attribute): def __init__(self, _id, element_type, **kwargs): if inspect.isclass(element_type) and issubclass(element_type, enum.Enum): @@ -157,6 +198,28 @@ class ListAttribute(Attribute): kwargs["default"] = list(kwargs["default"]) super().__init__(_id, **kwargs) + def __get__(self, instance, cls): + v = super().__get__(instance, cls) + if self.nonvolatile and v is not None and not isinstance(v, _PersistentList): + # Wrap the list in an object that tracks changes and writes them to nonvolatile. + p = _PersistentList(v, self, instance) + instance._attribute_values[self.id] = p + return p + return v + + def to_json(self, value): + return [ + binascii.b2a_base64(self._element_type.encode(v), newline=False).decode( + "utf-8" + ) + for v in value + ] + + def from_json(self, value): + return [ + self._element_type.decode(memoryview(binascii.a2b_base64(v))) for v in value + ] + def _encode(self, value) -> bytes: return self.tlv_type.encode(value) @@ -243,6 +306,23 @@ class Cluster: if not field_name.startswith("_") and isinstance(descriptor, Attribute): yield field_name, descriptor + def restore(self, nonvolatile): + self._nonvolatile = nonvolatile + + if ATTRIBUTES_KEY not in nonvolatile: + nonvolatile[ATTRIBUTES_KEY] = {} + for field_name, descriptor in self._attributes(): + if descriptor.nonvolatile: + print(field_name, nonvolatile[ATTRIBUTES_KEY]) + if hex(descriptor.id) in nonvolatile[ATTRIBUTES_KEY]: + # Update our live value + self._attribute_values[descriptor.id] = descriptor.from_json( + nonvolatile[ATTRIBUTES_KEY][hex(descriptor.id)] + ) + else: + # Store the default + nonvolatile[ATTRIBUTES_KEY][hex(descriptor.id)] = descriptor.default + def get_attribute_data( self, session, path ) -> typing.List[interaction_model.AttributeDataIB]: diff --git a/circuitmatter/device_types/lighting/on_off.py b/circuitmatter/device_types/lighting/on_off.py index 6769936..0f31370 100644 --- a/circuitmatter/device_types/lighting/on_off.py +++ b/circuitmatter/device_types/lighting/on_off.py @@ -8,8 +8,8 @@ class OnOffLight(simple_device.SimpleDevice): DEVICE_TYPE_ID = 0x0100 REVISION = 3 - def __init__(self): - super().__init__() + def __init__(self, name): + super().__init__(name) self._identify = Identify() self.servers.append(self._identify) diff --git a/circuitmatter/device_types/simple_device.py b/circuitmatter/device_types/simple_device.py index 05fb590..f710007 100644 --- a/circuitmatter/device_types/simple_device.py +++ b/circuitmatter/device_types/simple_device.py @@ -2,7 +2,8 @@ from circuitmatter.clusters.system_model import binding, descriptor, user_label class SimpleDevice: - def __init__(self): + def __init__(self, name): + self.name = name self.servers = [] self.descriptor = descriptor.DescriptorCluster() device_type = descriptor.DescriptorCluster.DeviceTypeStruct() @@ -19,3 +20,12 @@ class SimpleDevice: self.user_label = user_label.UserLabelCluster() self.servers.append(self.user_label) + + def restore(self, nonvolatile): + """Restore device state from the nonvolatile dictionary and hang onto it for any updates.""" + self.nonvolatile = nonvolatile + for server in self.servers: + cluster_hex = hex(server.CLUSTER_ID) + if cluster_hex not in nonvolatile: + nonvolatile[cluster_hex] = {} + server.restore(nonvolatile[cluster_hex]) diff --git a/circuitmatter/device_types/utility/root_node.py b/circuitmatter/device_types/utility/root_node.py index fbe5ca2..6e6af29 100644 --- a/circuitmatter/device_types/utility/root_node.py +++ b/circuitmatter/device_types/utility/root_node.py @@ -1,3 +1,4 @@ +import binascii import ecdsa from ecdsa import der import hashlib @@ -118,20 +119,58 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster): self.pending_root_cert = None self.pending_signing_key = None - self.nocs = [] - self.fabrics = [] self.supported_fabrics = 10 - self.commissioned_fabrics = 0 - self.trusted_root_certificates = [] self.root_certs = [] self.compressed_fabric_ids = [] self.noc_keys = [] + self.encoded_noc_keys = [] self.mdns_server = mdns_server self.port = port self.random = random_source + def restore(self, nonvolatile): + super().restore(nonvolatile) + + if "pk" not in nonvolatile: + return + + self.root_certs = [] + self.compressed_fabric_ids = [] + self.noc_keys = [] + self.encoded_noc_keys = nonvolatile["pk"] + + for i, encoded_root_cert in enumerate(self.trusted_root_certificates): + root_cert = crypto.MatterCertificate.decode(encoded_root_cert) + + self.root_certs.append(root_cert) + fabric = self.fabrics[i] + fabric_id = struct.pack(">Q", fabric.FabricID) + compressed_fabric_id = crypto.KDF( + root_cert.ec_pub_key[1:], fabric_id, b"CompressedFabric", 64 + ) + self.compressed_fabric_ids.append(compressed_fabric_id) + signing_key = ecdsa.keys.SigningKey.from_string( + binascii.a2b_base64(self.encoded_noc_keys[i]), + curve=ecdsa.NIST256p, + hashfunc=hashlib.sha256, + ) + self.noc_keys.append(signing_key) + + node_id = struct.pack(">Q", fabric.NodeID).hex().upper() + compressed_fabric_id = compressed_fabric_id.hex().upper() + instance_name = f"{compressed_fabric_id}-{node_id}" + self.mdns_server.advertise_service( + "_matter", + "_tcp", + self.port, + instance_name=instance_name, + subtypes=[ + f"_I{compressed_fabric_id}._sub._matter._tcp", + ], + ) + def certificate_chain_request( self, session, @@ -260,13 +299,7 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster): self, session, args: NodeOperationalCredentialsCluster.AddNOC ) -> NodeOperationalCredentialsCluster.NOCResponse: # Section 11.18.6.8 - noc, _ = crypto.MatterCertificate.decode( - args.NOCValue[0], memoryview(args.NOCValue)[1:] - ) - if args.ICACValue: - icac, _ = crypto.MatterCertificate.decode( - args.ICACValue[0], memoryview(args.ICACValue)[1:] - ) + noc = crypto.MatterCertificate.decode(args.NOCValue) response = NodeOperationalCredentialsCluster.NOCResponse() @@ -291,9 +324,7 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster): self.nocs.append(noc_struct) # Get the root cert public key so we can create the compressed fabric id. - root_cert, _ = crypto.MatterCertificate.decode( - self.pending_root_cert[0], memoryview(self.pending_root_cert)[1:] - ) + root_cert = crypto.MatterCertificate.decode(self.pending_root_cert) # Store the fabric new_fabric = NodeOperationalCredentialsCluster.FabricDescriptorStruct() @@ -317,6 +348,12 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster): self.commissioned_fabrics += 1 self.noc_keys.append(self.pending_signing_key) + self.encoded_noc_keys.append( + binascii.b2a_base64( + self.pending_signing_key.to_string(), newline=False + ).decode("utf-8") + ) + self._nonvolatile["pk"] = self.encoded_noc_keys self.trusted_root_certificates.append(self.pending_root_cert) @@ -365,11 +402,29 @@ class _GroupKeyManagementCluster(GroupKeyManagementCluster): def __init__(self): super().__init__() self.key_sets = [] + self._encoded_key_sets = [] + + def restore(self, nonvolatile): + super().restore(nonvolatile) + + if "gks" not in nonvolatile: + return + + self._encoded_key_sets = nonvolatile["gks"] + self.key_sets = [ + GroupKeySetStruct.decode(binascii.a2b_base64(v)) for v in nonvolatile["gks"] + ] def key_set_write( self, session, args: GroupKeyManagementCluster.KeySetWrite ) -> interaction_model.StatusCode: self.key_sets.append(args.GroupKeySet) + self._encoded_key_sets.append( + binascii.b2a_base64(args.GroupKeySet.encode(), newline=False).decode( + "utf-8" + ) + ) + self._nonvolatile["gks"] = self._encoded_key_sets return interaction_model.StatusCode.SUCCESS @@ -378,7 +433,7 @@ class RootNode(simple_device.SimpleDevice): REVISION = 2 def __init__(self, random_source, mdns_server, port, vendor_id, product_id): - super().__init__() + super().__init__("root") basic_info = BasicInformationCluster() basic_info.vendor_id = vendor_id @@ -421,3 +476,7 @@ class RootNode(simple_device.SimpleDevice): self.user_label = user_label.UserLabelCluster() self.servers.append(self.user_label) + + @property + def fabric_count(self): + return self.noc.commissioned_fabrics diff --git a/circuitmatter/nonvolatile.py b/circuitmatter/nonvolatile.py new file mode 100644 index 0000000..52b93c9 --- /dev/null +++ b/circuitmatter/nonvolatile.py @@ -0,0 +1,69 @@ +import json + + +class PersistentDictionary: + """This acts like a dictionary and is persisted when values change.""" + + def __init__(self, filename=None, root=None, state=None): + self.filename = filename + self.root = root + self.dirty = False + self.persisted = {} + self._state: dict + if self.root is None and filename: + self.rollback() + elif state is not None: + self._state = state + else: + raise ValueError("Provide filename or (root and state)") + + def wrap(self, value): + return value + + def __setitem__(self, key, value): + self._state[key] = value + if self.root: + self.root.dirty = True + else: + self.dirty = True + + def __getitem__(self, key): + value = self._state[key] + if isinstance(value, dict): + if key not in self.persisted: + root = self.root if self.root else self + self.persisted[key] = PersistentDictionary(root=root, state=value) + return self.persisted[key] + return value + + def __delitem__(self, key): + del self._state[key] + if self.root: + self.root.dirty = True + else: + self.dirty = True + + def keys(self): + return self._state.keys() + + def __iter__(self): + return iter(self._state) + + def commit(self): + if not self.dirty: + print("not dirty") + return + if self.root: + print("root commit") + self.root.commit() + return + print("commit") + print(self._state) + with open(self.filename, "w") as state_file: + json.dump(self._state, state_file, indent=1) + self.dirty = False + + def rollback(self): + print("rollback") + with open(self.filename, "r") as state_file: + self._state = json.load(state_file) diff --git a/circuitmatter/session.py b/circuitmatter/session.py index 7751e31..0e05adb 100644 --- a/circuitmatter/session.py +++ b/circuitmatter/session.py @@ -629,7 +629,7 @@ class SessionManager: decrypted = s3k_cipher.decrypt(b"NCASE_Sigma3N", sigma3.encrypted3, b"") except cryptography.exceptions.InvalidTag: return SecureChannelProtocolCode.INVALID_PARAMETER - sigma3_tbe, _ = case.Sigma3TbeData.decode(decrypted[0], decrypted[1:]) + sigma3_tbe = case.Sigma3TbeData.decode(decrypted) # TODO: Implement checks 4a-4d. INVALID_PARAMETER if they fail. @@ -640,9 +640,7 @@ class SessionManager: secure_session_context = exchange.secure_session_context peer_noc = sigma3_tbe.initiatorNOC - peer_noc, _ = crypto.MatterCertificate.decode( - peer_noc[0], memoryview(peer_noc)[1:] - ) + peer_noc = crypto.MatterCertificate.decode(peer_noc) secure_session_context.peer_node_id = peer_noc.subject.matter_node_id exchange.transcript_hash.update(sigma3.encode()) diff --git a/circuitmatter/tlv.py b/circuitmatter/tlv.py index 1471168..124eb6a 100644 --- a/circuitmatter/tlv.py +++ b/circuitmatter/tlv.py @@ -93,7 +93,7 @@ def decode_element(control_octet, buffer, offset, depth): value = None offset = offset else: - result = member_class.decode(control_octet, buffer, offset, depth) + result = member_class.decode_member(control_octet, buffer, offset, depth) value, offset = result return value, offset @@ -165,7 +165,15 @@ class Structure(Container): return offset + 1 @classmethod - def decode(cls, control_octet, buffer, offset=0, depth=0) -> tuple[dict, int]: + def decode(cls, buffer: memoryview, offset=0) -> Structure: + control_octet = buffer[offset] + values, offset = cls.decode_member(control_octet, buffer, offset + 1) + return values + + @classmethod + def decode_member( + cls, control_octet, buffer, offset=0, depth=0 + ) -> tuple[dict, int]: values = {} buffer = memoryview(buffer) while offset < len(buffer) and buffer[offset] != ElementType.END_OF_CONTAINER: @@ -351,8 +359,12 @@ class Member(ABC, Generic[_T, _OPT, _NULLABLE]): return new_offset return offset + def decode(self, buffer: memoryview, offset: int = 0) -> _T: + "Return the decoded value at `offset` in `buffer`" + return self.decode_member(buffer[offset], buffer, offset + 1)[0] + @abstractmethod - def decode( + def decode_member( self, control_octet: int, buffer: memoryview, offset: int = 0 ) -> (_T, int): "Return the decoded value at `offset` in `buffer`. `offset` is after the tag (but before any length)" @@ -455,7 +467,7 @@ class NumberMember(Member[_NT, _OPT, _NULLABLE], Generic[_NT, _OPT, _NULLABLE]): super().__set__(obj, value) # type: ignore # self inference issues @staticmethod - def decode(control_octet, buffer, offset=0, depth=0) -> tuple[_NT, int]: + def decode_member(control_octet, buffer, offset=0, depth=0) -> tuple[_NT, int]: element_type = control_octet & 0x1F element_category = element_type >> 2 if element_category == 0 or element_category == 1: @@ -595,7 +607,7 @@ class BoolMember(Member[bool, _OPT, _NULLABLE]): max_value_length = 0 @staticmethod - def decode(control_octet, buffer, offset=0, depth=0): + def decode_member(control_octet, buffer, offset=0, depth=0): return (control_octet & 1 == 1, offset) def print(self, value): @@ -682,7 +694,7 @@ class OctetStringMember(StringMember[bytes, _OPT, _NULLABLE]): _base_element_type: ElementType = ElementType.OCTET_STRING @staticmethod - def decode(control_octet, buffer, offset=0, depth=0): + def decode_member(control_octet, buffer, offset=0, depth=0): length, offset = StringMember.parse_length(control_octet, buffer, offset) return (buffer[offset : offset + length].tobytes(), offset + length) @@ -691,7 +703,7 @@ class UTF8StringMember(StringMember[str, _OPT, _NULLABLE]): _base_element_type = ElementType.UTF8_STRING @staticmethod - def decode(control_octet, buffer, offset=0, depth=0): + def decode_member(control_octet, buffer, offset=0, depth=0): length, offset = StringMember.parse_length(control_octet, buffer, offset) return ( buffer[offset : offset + length].tobytes().decode("utf-8"), @@ -723,8 +735,8 @@ class StructMember(Member[_TLVStruct, _OPT, _NULLABLE]): super().__init__(tag, optional=optional, nullable=nullable, **kwargs) @staticmethod - def decode(control_octet, buffer, offset=0, depth=0): - value, offset = Structure.decode(control_octet, buffer, offset, depth) + def decode_member(control_octet, buffer, offset=0, depth=0): + value, offset = Structure.decode_member(control_octet, buffer, offset, depth) return value, offset + 1 def print(self, value): @@ -765,7 +777,7 @@ class ArrayMember(Member[_TLVStruct, _OPT, _NULLABLE]): super().__init__(tag, optional=optional, nullable=nullable, **kwargs) @staticmethod - def decode(control_octet, buffer, offset=0, depth=0): + def decode_member(control_octet, buffer, offset=0, depth=0): entries = [] while buffer[offset] != ElementType.END_OF_CONTAINER: control_octet = buffer[offset] @@ -931,7 +943,7 @@ class ListMember(Member): super().__init__(tag, optional=optional, nullable=nullable, **kwargs) @staticmethod - def decode(control_octet, buffer, offset=0, depth=0): + def decode_member(control_octet, buffer, offset=0, depth=0): raw_list = [] while buffer[offset] != ElementType.END_OF_CONTAINER: control_octet = buffer[offset] @@ -962,7 +974,7 @@ class ListMember(Member): class AnythingMember(Member): """Stores a TLV encoded value.""" - def decode(self, control_octet, buffer, offset=0): + def decode_member(self, control_octet, buffer, offset=0): return None def print(self, value): diff --git a/test_data/device_state-uncommissioned.json b/test_data/device_state-uncommissioned.json new file mode 100644 index 0000000..8e79d9c --- /dev/null +++ b/test_data/device_state-uncommissioned.json @@ -0,0 +1,7 @@ +{ + "discriminator": 3840, + "passcode": 67202583, + "iteration-count": 10000, + "salt": "5uCP0ITHYzI9qBEe6hfU4HfY3y7VopSk0qNvhvznhiQ=", + "verifier": "0xGqxJFBr/ViQt3lv1Yw5F0GcPBAtFFvXB+EcIIjH5cEsjkPZHDQyFWjA6Ide+2gafYnZgIy6gJBgdJOlD8htAZKe0i6nIhT/ADsBWH4CvZcl37n/ofEEECWSEBV4vy/0A==" +} diff --git a/tests/test_tlv.py b/tests/test_tlv.py index 74132d5..8423d27 100644 --- a/tests/test_tlv.py +++ b/tests/test_tlv.py @@ -24,12 +24,12 @@ class Bool(tlv.Structure): class TestBool: def test_bool_false_decode(self): - s, _ = Bool.decode(0x15, b"\x28\x00\x18") + s = Bool.decode(b"\x15\x28\x00\x18") assert str(s) == "{\n b = false\n}" assert s.b is False def test_bool_true_decode(self): - s, _ = Bool.decode(0x15, b"\x29\x00\x18") + s = Bool.decode(b"\x15\x29\x00\x18") assert str(s) == "{\n b = true\n}" assert s.b is True @@ -72,29 +72,27 @@ class SignedIntEightOctet(tlv.Structure): # 03 00 90 2f 50 09 00 00 00 class TestSignedInt: def test_signed_int_42_decode(self): - s, _ = SignedIntOneOctet.decode(0x15, b"\x20\x00\x2a") + s = SignedIntOneOctet.decode(b"\x15\x20\x00\x2a") assert str(s) == "{\n i = 42\n}" assert s.i == 42 def test_signed_int_negative_17_decode(self): - s, _ = SignedIntOneOctet.decode(0x15, b"\x20\x00\xef") + s = SignedIntOneOctet.decode(b"\x15\x20\x00\xef") assert str(s) == "{\n i = -17\n}" assert s.i == -17 def test_signed_int_42_two_octet_decode(self): - s, _ = SignedIntTwoOctet.decode(0x15, b"\x21\x00\x2a\x00") + s = SignedIntTwoOctet.decode(b"\x15\x21\x00\x2a\x00") assert str(s) == "{\n i = 42\n}" assert s.i == 42 def test_signed_int_negative_170000_decode(self): - s, _ = SignedIntFourOctet.decode(0x15, b"\x22\x00\xf0\x67\xfd\xff") + s = SignedIntFourOctet.decode(b"\x15\x22\x00\xf0\x67\xfd\xff") assert str(s) == "{\n i = -170000\n}" assert s.i == -170000 def test_signed_int_40000000000_decode(self): - s, _ = SignedIntEightOctet.decode( - 0x15, b"\x23\x00\x00\x90\x2f\x50\x09\x00\x00\x00" - ) + s = SignedIntEightOctet.decode(b"\x15\x23\x00\x00\x90\x2f\x50\x09\x00\x00\x00") assert str(s) == "{\n i = 40000000000\n}" assert s.i == 40000000000 @@ -158,7 +156,7 @@ class UnsignedIntOneOctet(tlv.Structure): # 04 2a class TestUnsignedInt: def test_unsigned_int_42_decode(self): - s, _ = UnsignedIntOneOctet.decode(0x15, b"\x24\x00\x2a\x18") + s = UnsignedIntOneOctet.decode(b"\x15\x24\x00\x2a\x18") assert str(s) == "{\n i = 42U\n}" assert s.i == 42 @@ -197,7 +195,7 @@ class TestUnsignedInt: s.i = v buffer = s.encode().tobytes() - s2, _ = UnsignedIntOneOctet.decode(0x15, buffer[1:]) + s2 = UnsignedIntOneOctet.decode(buffer) assert s2.i == s.i assert str(s2) == str(s) @@ -228,12 +226,12 @@ class UTF8StringOneOctet(tlv.Structure): class TestUTF8String: def test_utf8_string_hello_decode(self): - s, _ = UTF8StringOneOctet.decode(0x15, b"\x2c\x00\x06Hello!") + s = UTF8StringOneOctet.decode(b"\x15\x2c\x00\x06Hello!") assert str(s) == '{\n s = "Hello!"\n}' assert s.s == "Hello!" def test_utf8_string_tschs_decode(self): - s, _ = UTF8StringOneOctet.decode(0x15, b"\x2c\x00\x07Tsch\xc3\xbcs") + s = UTF8StringOneOctet.decode(b"\x15\x2c\x00\x07Tsch\xc3\xbcs") assert str(s) == '{\n s = "Tschüs"\n}' assert s.s == "Tschüs" @@ -254,7 +252,7 @@ class TestUTF8String: s.s = v buffer = s.encode().tobytes() - s2, _ = UTF8StringOneOctet.decode(0x15, buffer[1:]) + s2 = UTF8StringOneOctet.decode(buffer) assert s2.s == s.s assert str(s2) == str(s) @@ -268,7 +266,7 @@ class OctetStringOneOctet(tlv.Structure): class TestOctetString: def test_octet_string_decode(self): - s, _ = OctetStringOneOctet.decode(0x15, b"\x30\x00\x05\x00\x01\x02\x03\x04\x18") + s = OctetStringOneOctet.decode(b"\x15\x30\x00\x05\x00\x01\x02\x03\x04\x18") assert str(s) == "{\n s = 00 01 02 03 04\n}" assert s.s == b"\x00\x01\x02\x03\x04" @@ -283,7 +281,7 @@ class TestOctetString: s.s = v buffer = s.encode().tobytes() - s2, _ = OctetStringOneOctet.decode(0x15, buffer[1:]) + s2 = OctetStringOneOctet.decode(buffer) assert s2.s == s.s assert str(s2) == str(s) @@ -304,7 +302,7 @@ class NotNull(tlv.Structure): class TestNull: def test_null_decode(self): - s, _ = Null.decode(0x15, b"\x34\x00\x18") + s = Null.decode(b"\x15\x34\x00\x18") assert str(s) == "{\n n = null\n}" assert s.n is None @@ -352,28 +350,28 @@ class FloatDouble(tlv.Structure): class TestFloatSingle: def test_precision_float_0_0_decode(self): - s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x00\x00\x18") + s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x00\x00\x18") assert str(s) == "{\n f = 0.0\n}" assert s.f == 0.0 def test_precision_float_1_3_decode(self): - s, _ = FloatSingle.decode(0x15, b"\x2a\x00\xab\xaa\xaa\x3e\x18") + s = FloatSingle.decode(b"\x15\x2a\x00\xab\xaa\xaa\x3e\x18") # assert str(s) == "{\n f = 0.3333333432674408\n}" f = s.f assert math.isclose(f, 1.0 / 3.0, rel_tol=1e-06) def test_precision_float_17_9_decode(self): - s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x33\x33\x8f\x41\x18") + s = FloatSingle.decode(b"\x15\x2a\x00\x33\x33\x8f\x41\x18") assert str(s) == "{\n f = 17.899999618530273\n}" assert math.isclose(s.f, 17.9, rel_tol=1e-06) def test_precision_float_infinity_decode(self): - s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x80\x7f\x18") + s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x80\x7f\x18") assert str(s) == "{\n f = inf\n}" assert math.isinf(s.f) def test_precision_float_negative_infinity_decode(self): - s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x80\xff\x18") + s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x80\xff\x18") assert str(s) == "{\n f = -inf\n}" assert math.isinf(s.f) @@ -408,7 +406,7 @@ class TestFloatSingle: s.f = v buffer = s.encode().tobytes() - s2, _ = FloatDouble.decode(0x15, buffer[1:]) + s2 = FloatDouble.decode(buffer) assert ( (math.isnan(s.f) and math.isnan(s2.f)) @@ -431,7 +429,7 @@ class TestFloatSingle: buffer = s.encode().tobytes() print("Buffer", buffer.hex(" ")) - s2, _ = FloatSingle.decode(0x15, buffer[1:]) + s2 = FloatSingle.decode(buffer) assert (math.isnan(s.f) and math.isnan(s2.f)) or math.isclose( s2.f, s.f, rel_tol=1e-7, abs_tol=1e-9 @@ -440,28 +438,28 @@ class TestFloatSingle: class TestFloatDouble: def test_precision_float_0_0_decode(self): - s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\x00\x00") + s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\x00\x00") assert str(s) == "{\n f = 0.0\n}" assert s.f == 0.0 def test_precision_float_1_3_decode(self): - s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x55\x55\x55\x55\x55\x55\xd5\x3f") + s = FloatDouble.decode(b"\x15\x2b\x00\x55\x55\x55\x55\x55\x55\xd5\x3f") # assert str(s) == "{\n f = 0.3333333333333333\n}" f = s.f assert math.isclose(f, 1.0 / 3.0, rel_tol=1e-06) def test_precision_float_17_9_decode(self): - s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x66\x66\x66\x66\x66\xe6\x31\x40") + s = FloatDouble.decode(b"\x15\x2b\x00\x66\x66\x66\x66\x66\xe6\x31\x40") assert str(s) == "{\n f = 17.9\n}" assert math.isclose(s.f, 17.9, rel_tol=1e-06) def test_precision_float_infinity_decode(self): - s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\x7f") + s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\x7f") assert str(s) == "{\n f = inf\n}" assert math.isinf(s.f) def test_precision_float_negative_infinity_decode(self): - s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\xff") + s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\xff") assert str(s) == "{\n f = -inf\n}" assert math.isinf(s.f) @@ -506,7 +504,7 @@ class TestFloatDouble: s.f = v buffer = s.encode().tobytes() - s2, _ = FloatDouble.decode(0x15, buffer[1:]) + s2 = FloatDouble.decode(buffer) assert ( (math.isnan(s.f) and math.isnan(s2.f)) @@ -527,7 +525,7 @@ class OuterStruct(tlv.Structure): class TestStruct: def test_inner_struct_decode(self): - s, _ = OuterStruct.decode(0x15, b"\x35\x00\x20\x00\x2a\x20\x01\xef\x18\x18") + s = OuterStruct.decode(b"\x15\x35\x00\x20\x00\x2a\x20\x01\xef\x18\x18") assert_type(s, OuterStruct) assert_type(s.s, InnerStruct) assert_type(s.s.a, Optional[int]) @@ -536,7 +534,7 @@ class TestStruct: assert s.s.b == -17 def test_inner_struct_decode_empty(self): - s, _ = OuterStruct.decode(0x15, b"\x35\x00\x18\x18") + s = OuterStruct.decode(b"\x15\x35\x00\x18\x18") assert str(s) == "{\n s = {\n \n }\n}" assert s.s.a is None assert s.s.b is None @@ -562,9 +560,8 @@ class FullyQualified(tlv.Structure): class TestFullyQualifiedTags: def test_decode(self): - s, _ = FullyQualified.decode( - 0x15, - b"\xc2\xda\x0a\x00\x0f\x23\x01\x2a\x00\x00\x00\xe2\xda\x0a\x00\x0f\x45\x23\x01\x00\xef\xff\xff\xff\x18", + s = FullyQualified.decode( + b"\x15\xc2\xda\x0a\x00\x0f\x23\x01\x2a\x00\x00\x00\xe2\xda\x0a\x00\x0f\x45\x23\x01\x00\xef\xff\xff\xff\x18", ) assert_type(s, FullyQualified) assert_type(s.a, Optional[int])