Store nonvolatile state in json and restore. Improve decode too

This commit is contained in:
Scott Shawcroft 2024-10-16 15:04:18 -07:00
parent cdaa606808
commit 6b9090e2f4
No known key found for this signature in database
13 changed files with 368 additions and 127 deletions

View file

@ -2,11 +2,11 @@
import binascii
import hashlib
import json
import time
from . import case
from . import interaction_model
from . import nonvolatile
from .message import Message
from .protocol import InteractionModelOpcode, ProtocolId, SecureProtocolOpcode
from . import session
@ -29,15 +29,12 @@ class CircuitMatter:
self.mdns_server = mdns_server
self.random = random_source
with open(state_filename, "r") as state_file:
self.nonvolatile = json.load(state_file)
self.nonvolatile = nonvolatile.PersistentDictionary(state_filename)
for key in ["discriminator", "salt", "iteration-count", "verifier"]:
if key not in self.nonvolatile:
raise RuntimeError(f"Missing key {key} in state file")
commission = "fabrics" not in self.nonvolatile
self.packet_buffer = memoryview(bytearray(1280))
# Define the UDP IP address and port
@ -51,6 +48,7 @@ class CircuitMatter:
# Bind the socket to the IP and port
self.socket.bind((UDP_IP, self.UDP_PORT))
print(f"Listening on UDP port {self.UDP_PORT}")
self.socket.setblocking(False)
self._endpoints = {}
@ -62,14 +60,11 @@ class CircuitMatter:
self.vendor_id = vendor_id
self.product_id = product_id
self.manager = session.SessionManager(
self.random, self.socket, self.root_node.noc
)
print(f"Listening on UDP port {self.UDP_PORT}")
if commission:
if self.root_node.fabric_count == 0:
self.start_commissioning()
def start_commissioning(self):
@ -118,6 +113,12 @@ class CircuitMatter:
device.descriptor.ServerList.append(server.CLUSTER_ID)
self.add_cluster(self._next_endpoint, server)
self.add_cluster(self._next_endpoint, device.descriptor)
if "devices" not in self.nonvolatile:
self.nonvolatile["devices"] = {}
if device.name not in self.nonvolatile["devices"]:
self.nonvolatile["devices"][device.name] = {}
device.restore(self.nonvolatile["devices"][device.name])
self._next_endpoint += 1
def process_packets(self):
@ -249,9 +250,7 @@ class CircuitMatter:
from . import pase
# This is Section 4.14.1.2
request, _ = pase.PBKDFParamRequest.decode(
message.application_payload[0], message.application_payload[1:]
)
request = pase.PBKDFParamRequest.decode(message.application_payload)
exchange.commissioning_hash = hashlib.sha256(
b"CHIP PAKE V1 Commissioning"
)
@ -287,9 +286,7 @@ class CircuitMatter:
from . import pase
print("Received PASE PAKE1")
pake1, _ = pase.PAKE1.decode(
message.application_payload[0], message.application_payload[1:]
)
pake1 = pase.PAKE1.decode(message.application_payload)
pake2 = pase.PAKE2()
verifier = binascii.a2b_base64(self.nonvolatile["verifier"])
context = exchange.commissioning_hash.digest()
@ -308,9 +305,7 @@ class CircuitMatter:
from . import pase
print("Received PASE PAKE3")
pake3, _ = pase.PAKE3.decode(
message.application_payload[0], message.application_payload[1:]
)
pake3 = pase.PAKE3.decode(message.application_payload)
if pake3.cA != exchange.cA:
del exchange.cA
del exchange.Ke
@ -341,9 +336,7 @@ class CircuitMatter:
print("PASE succeeded")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA1:
print("Received CASE Sigma1")
sigma1, _ = case.Sigma1.decode(
message.application_payload[0], message.application_payload[1:]
)
sigma1 = case.Sigma1.decode(message.application_payload)
response = self.manager.reply_to_sigma1(exchange, sigma1)
exchange.send(response)
@ -351,9 +344,7 @@ class CircuitMatter:
print("Received CASE Sigma2")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA3:
print("Received CASE Sigma3")
sigma3, _ = case.Sigma3.decode(
message.application_payload[0], message.application_payload[1:]
)
sigma3 = case.Sigma3.decode(message.application_payload)
protocol_code = self.manager.reply_to_sigma3(exchange, sigma3)
error_status = session.StatusReport()
@ -390,8 +381,8 @@ class CircuitMatter:
message.session_id
]
if protocol_opcode == InteractionModelOpcode.READ_REQUEST:
read_request, _ = interaction_model.ReadRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
read_request = interaction_model.ReadRequestMessage.decode(
message.application_payload
)
attribute_reports = []
for path in read_request.AttributeRequests:
@ -404,8 +395,8 @@ class CircuitMatter:
exchange.send(response)
elif protocol_opcode == InteractionModelOpcode.WRITE_REQUEST:
print("Received Write Request")
write_request, _ = interaction_model.WriteRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
write_request = interaction_model.WriteRequestMessage.decode(
message.application_payload
)
write_responses = []
for request in write_request.WriteRequests:
@ -421,8 +412,8 @@ class CircuitMatter:
elif protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
print("Received Invoke Request")
invoke_request, _ = interaction_model.InvokeRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
invoke_request = interaction_model.InvokeRequestMessage.decode(
message.application_payload
)
for invoke in invoke_request.InvokeRequests:
path = invoke.CommandPath
@ -460,14 +451,13 @@ class CircuitMatter:
response = interaction_model.InvokeResponseMessage()
response.SuppressResponse = False
response.InvokeResponses = invoke_responses
print("sending invoke response", response)
exchange.send(response)
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
print("Received Invoke Response")
elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST:
print("Received Subscribe Request")
subscribe_request, _ = interaction_model.SubscribeRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
subscribe_request = interaction_model.SubscribeRequestMessage.decode(
message.application_payload
)
print(subscribe_request)
attribute_reports = []
@ -484,8 +474,8 @@ class CircuitMatter:
final_response.MaxInterval = subscribe_request.MaxIntervalCeiling
exchange.queue(final_response)
elif protocol_opcode == InteractionModelOpcode.STATUS_RESPONSE:
status_response, _ = interaction_model.StatusResponseMessage.decode(
message.application_payload[0], message.application_payload[1:]
status_response = interaction_model.StatusResponseMessage.decode(
message.application_payload
)
print(
f"Received Status Response on {message.session_id}/{message.exchange_id} ack {message.acknowledged_message_counter}: {status_response.Status!r}"
@ -502,3 +492,6 @@ class CircuitMatter:
else:
print("Unknown protocol", message.protocol_id, message.protocol_opcode)
print()
self.nonvolatile.commit()
# TODO: Rollback on error?

View file

@ -3,6 +3,7 @@
import binascii
import json
import os
import pathlib
import secrets
import socket
import subprocess
@ -114,13 +115,8 @@ class MDNSServer(DummyMDNS):
subtypes=[],
instance_name="",
):
for active_service in self.active_services.values():
active_service.kill()
subtypes = [f"--subtype={subtype}" for subtype in subtypes]
txt_records = [f"{key}={value}" for key, value in txt_records.items()]
if service_type in self.active_services:
self.active_services[service_type].kill()
del self.active_services[service_type]
command = [
"avahi-publish-service",
*subtypes,
@ -130,7 +126,7 @@ class MDNSServer(DummyMDNS):
*txt_records,
]
print("running avahi", command)
self.active_services[service_type] = subprocess.Popen(command)
self.active_services[service_type + instance_name] = subprocess.Popen(command)
if self.publish_address is None:
command = [
"avahi-publish-address",
@ -226,23 +222,33 @@ class NeoPixel(on_off.OnOffLight):
def run(replay_file=None):
device_state = pathlib.Path("test_data/device_state.json")
replay_device_state = pathlib.Path("test_data/replay_device_state.json")
if replay_file:
replay_lines = []
with open(replay_file, "r") as f:
device_state_fn = f.readline().strip()
for line in f:
replay_lines.append(json.loads(line))
socketpool = ReplaySocketPool(replay_lines)
mdns_server = DummyMDNS()
random_source = ReplayRandom(replay_lines)
# Reset device state to before the captured run
device_state.write_text(pathlib.Path(device_state_fn).read_text())
else:
record_file = open("test_data/recorded_packets.jsonl", "w")
timestamp = time.strftime("%Y%m%d-%H%M%S")
record_file = open(f"test_data/recorded_packets-{timestamp}.jsonl", "w")
device_state_fn = f"test_data/device_state-{timestamp}.json"
record_file.write(f"{device_state_fn}\n")
socketpool = RecordingSocketPool(record_file)
mdns_server = MDNSServer()
random_source = RecordingRandom(record_file)
matter = cm.CircuitMatter(
socketpool, mdns_server, random_source, "test_data/device_state.json"
)
led = NeoPixel()
# Save device state before we run so replays can use it.
replay_device_state = pathlib.Path(device_state_fn)
replay_device_state.write_text(device_state.read_text())
matter = cm.CircuitMatter(socketpool, mdns_server, random_source, device_state)
led = NeoPixel("neopixel1")
matter.add_device(led)
while True:
matter.process_packets()

View file

@ -24,13 +24,15 @@ class GroupKeyMulticastPolicyEnum(Enum8):
class GroupKeySetStruct(tlv.Structure):
GroupKeySetID = tlv.IntMember(0, signed=False, octets=2)
GroupKeySecurityPolicy = tlv.EnumMember(1, GroupKeySetSecurityPolicyEnum)
EpochKey0 = tlv.OctetStringMember(2, 16)
EpochStartTime0 = tlv.IntMember(3, signed=False, octets=8)
EpochKey1 = tlv.OctetStringMember(4, 16)
EpochStartTime1 = tlv.IntMember(5, signed=False, octets=8)
EpochKey2 = tlv.OctetStringMember(6, 16)
EpochStartTime2 = tlv.IntMember(7, signed=False, octets=8)
GroupKeyMulticastPolicy = tlv.EnumMember(8, GroupKeyMulticastPolicyEnum)
EpochKey0 = tlv.OctetStringMember(2, 16, nullable=True)
EpochStartTime0 = tlv.IntMember(3, signed=False, octets=8, nullable=True)
EpochKey1 = tlv.OctetStringMember(4, 16, nullable=True)
EpochStartTime1 = tlv.IntMember(5, signed=False, octets=8, nullable=True)
EpochKey2 = tlv.OctetStringMember(6, 16, nullable=True)
EpochStartTime2 = tlv.IntMember(7, signed=False, octets=8, nullable=True)
GroupKeyMulticastPolicy = tlv.EnumMember(
8, GroupKeyMulticastPolicyEnum, nullable=True
)
class GroupKeyManagementCluster(Cluster):
@ -48,7 +50,7 @@ class GroupKeyManagementCluster(Cluster):
class KeySetWrite(tlv.Structure):
GroupKeySet = tlv.StructMember(0, GroupKeySetStruct)
group_key_map = ListAttribute(0, GroupKeyMapStruct, default=[])
group_key_map = ListAttribute(0, GroupKeyMapStruct, default=[], N_nonvolatile=True)
group_table = ListAttribute(1, GroupInfoMapStruct, default=[])
max_groups_per_fabric = NumberAttribute(2, signed=False, bits=16, default=0)
max_group_keys_per_fabric = NumberAttribute(3, signed=False, bits=16, default=1)

View file

@ -99,12 +99,20 @@ class NodeOperationalCredentialsCluster(Cluster):
class AddTrustedRootCertificate(tlv.Structure):
RootCACertificate = tlv.OctetStringMember(0, 400)
nocs = ListAttribute(0, NOCStruct, N_nonvolatile=True, C_changes_omitted=True)
fabrics = ListAttribute(1, FabricDescriptorStruct, N_nonvolatile=True)
nocs = ListAttribute(
0, NOCStruct, N_nonvolatile=True, C_changes_omitted=True, default=[]
)
fabrics = ListAttribute(1, FabricDescriptorStruct, N_nonvolatile=True, default=[])
supported_fabrics = NumberAttribute(2, signed=False, bits=8, F_fixed=True)
commissioned_fabrics = NumberAttribute(3, signed=False, bits=8, N_nonvolatile=True)
commissioned_fabrics = NumberAttribute(
3, signed=False, bits=8, N_nonvolatile=True, default=0
)
trusted_root_certificates = ListAttribute(
4, tlv.OctetStringMember(None, 400), N_nonvolatile=True, C_changes_omitted=True
4,
tlv.OctetStringMember(None, 400),
N_nonvolatile=True,
C_changes_omitted=True,
default=[],
)
# This attribute is weird because it is fabric sensitive but not marked as such.
# Cluster sets current_fabric_index for use in fabric sensitive attributes and

View file

@ -1,3 +1,4 @@
import binascii
import enum
import inspect
import random
@ -9,6 +10,8 @@ from typing import Iterable, Union
from . import interaction_model
from . import tlv
ATTRIBUTES_KEY = "a"
class Enum8(enum.IntEnum):
pass
@ -80,6 +83,7 @@ class Attribute:
self.optional = optional
self.feature = feature
self.nullable = X_nullable
self.nonvolatile = N_nonvolatile
def __get__(self, instance, cls):
v = instance._attribute_values.get(self.id, None)
@ -92,8 +96,16 @@ class Attribute:
if old_value == value:
return
instance._attribute_values[self.id] = value
if self.nonvolatile:
instance._nonvolatile[ATTRIBUTES_KEY][hex(self.id)] = self.to_json(value)
instance.data_version += 1
def to_json(self, value):
return value
def from_json(self, value):
return value
def encode(self, value) -> bytes:
if value is None and self.nullable:
return b"\x14" # No tag, NULL
@ -145,6 +157,35 @@ class EnumAttribute(NumberAttribute):
super().__init__(_id, signed=False, bits=bits, **kwargs)
class _PersistentList:
def __init__(self, wrapped_list, attribute, instance):
self._list = wrapped_list
self._instance = instance
self._attribute = attribute
def append(self, value):
self._list.append(value)
self._instance._nonvolatile[ATTRIBUTES_KEY][hex(self._attribute.id)] = (
self._attribute.to_json(self._list)
)
def __getitem__(self, index):
return self._list[index]
def __setitem__(self, index, value):
self._list[index] = value
self._dirty = True
def __iter__(self):
return iter(self._list)
def __len__(self):
return len(self._list)
def __str__(self):
return "persistent" + str(self._list)
class ListAttribute(Attribute):
def __init__(self, _id, element_type, **kwargs):
if inspect.isclass(element_type) and issubclass(element_type, enum.Enum):
@ -157,6 +198,28 @@ class ListAttribute(Attribute):
kwargs["default"] = list(kwargs["default"])
super().__init__(_id, **kwargs)
def __get__(self, instance, cls):
v = super().__get__(instance, cls)
if self.nonvolatile and v is not None and not isinstance(v, _PersistentList):
# Wrap the list in an object that tracks changes and writes them to nonvolatile.
p = _PersistentList(v, self, instance)
instance._attribute_values[self.id] = p
return p
return v
def to_json(self, value):
return [
binascii.b2a_base64(self._element_type.encode(v), newline=False).decode(
"utf-8"
)
for v in value
]
def from_json(self, value):
return [
self._element_type.decode(memoryview(binascii.a2b_base64(v))) for v in value
]
def _encode(self, value) -> bytes:
return self.tlv_type.encode(value)
@ -243,6 +306,23 @@ class Cluster:
if not field_name.startswith("_") and isinstance(descriptor, Attribute):
yield field_name, descriptor
def restore(self, nonvolatile):
self._nonvolatile = nonvolatile
if ATTRIBUTES_KEY not in nonvolatile:
nonvolatile[ATTRIBUTES_KEY] = {}
for field_name, descriptor in self._attributes():
if descriptor.nonvolatile:
print(field_name, nonvolatile[ATTRIBUTES_KEY])
if hex(descriptor.id) in nonvolatile[ATTRIBUTES_KEY]:
# Update our live value
self._attribute_values[descriptor.id] = descriptor.from_json(
nonvolatile[ATTRIBUTES_KEY][hex(descriptor.id)]
)
else:
# Store the default
nonvolatile[ATTRIBUTES_KEY][hex(descriptor.id)] = descriptor.default
def get_attribute_data(
self, session, path
) -> typing.List[interaction_model.AttributeDataIB]:

View file

@ -8,8 +8,8 @@ class OnOffLight(simple_device.SimpleDevice):
DEVICE_TYPE_ID = 0x0100
REVISION = 3
def __init__(self):
super().__init__()
def __init__(self, name):
super().__init__(name)
self._identify = Identify()
self.servers.append(self._identify)

View file

@ -2,7 +2,8 @@ from circuitmatter.clusters.system_model import binding, descriptor, user_label
class SimpleDevice:
def __init__(self):
def __init__(self, name):
self.name = name
self.servers = []
self.descriptor = descriptor.DescriptorCluster()
device_type = descriptor.DescriptorCluster.DeviceTypeStruct()
@ -19,3 +20,12 @@ class SimpleDevice:
self.user_label = user_label.UserLabelCluster()
self.servers.append(self.user_label)
def restore(self, nonvolatile):
"""Restore device state from the nonvolatile dictionary and hang onto it for any updates."""
self.nonvolatile = nonvolatile
for server in self.servers:
cluster_hex = hex(server.CLUSTER_ID)
if cluster_hex not in nonvolatile:
nonvolatile[cluster_hex] = {}
server.restore(nonvolatile[cluster_hex])

View file

@ -1,3 +1,4 @@
import binascii
import ecdsa
from ecdsa import der
import hashlib
@ -118,20 +119,58 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster):
self.pending_root_cert = None
self.pending_signing_key = None
self.nocs = []
self.fabrics = []
self.supported_fabrics = 10
self.commissioned_fabrics = 0
self.trusted_root_certificates = []
self.root_certs = []
self.compressed_fabric_ids = []
self.noc_keys = []
self.encoded_noc_keys = []
self.mdns_server = mdns_server
self.port = port
self.random = random_source
def restore(self, nonvolatile):
super().restore(nonvolatile)
if "pk" not in nonvolatile:
return
self.root_certs = []
self.compressed_fabric_ids = []
self.noc_keys = []
self.encoded_noc_keys = nonvolatile["pk"]
for i, encoded_root_cert in enumerate(self.trusted_root_certificates):
root_cert = crypto.MatterCertificate.decode(encoded_root_cert)
self.root_certs.append(root_cert)
fabric = self.fabrics[i]
fabric_id = struct.pack(">Q", fabric.FabricID)
compressed_fabric_id = crypto.KDF(
root_cert.ec_pub_key[1:], fabric_id, b"CompressedFabric", 64
)
self.compressed_fabric_ids.append(compressed_fabric_id)
signing_key = ecdsa.keys.SigningKey.from_string(
binascii.a2b_base64(self.encoded_noc_keys[i]),
curve=ecdsa.NIST256p,
hashfunc=hashlib.sha256,
)
self.noc_keys.append(signing_key)
node_id = struct.pack(">Q", fabric.NodeID).hex().upper()
compressed_fabric_id = compressed_fabric_id.hex().upper()
instance_name = f"{compressed_fabric_id}-{node_id}"
self.mdns_server.advertise_service(
"_matter",
"_tcp",
self.port,
instance_name=instance_name,
subtypes=[
f"_I{compressed_fabric_id}._sub._matter._tcp",
],
)
def certificate_chain_request(
self,
session,
@ -260,13 +299,7 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster):
self, session, args: NodeOperationalCredentialsCluster.AddNOC
) -> NodeOperationalCredentialsCluster.NOCResponse:
# Section 11.18.6.8
noc, _ = crypto.MatterCertificate.decode(
args.NOCValue[0], memoryview(args.NOCValue)[1:]
)
if args.ICACValue:
icac, _ = crypto.MatterCertificate.decode(
args.ICACValue[0], memoryview(args.ICACValue)[1:]
)
noc = crypto.MatterCertificate.decode(args.NOCValue)
response = NodeOperationalCredentialsCluster.NOCResponse()
@ -291,9 +324,7 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster):
self.nocs.append(noc_struct)
# Get the root cert public key so we can create the compressed fabric id.
root_cert, _ = crypto.MatterCertificate.decode(
self.pending_root_cert[0], memoryview(self.pending_root_cert)[1:]
)
root_cert = crypto.MatterCertificate.decode(self.pending_root_cert)
# Store the fabric
new_fabric = NodeOperationalCredentialsCluster.FabricDescriptorStruct()
@ -317,6 +348,12 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster):
self.commissioned_fabrics += 1
self.noc_keys.append(self.pending_signing_key)
self.encoded_noc_keys.append(
binascii.b2a_base64(
self.pending_signing_key.to_string(), newline=False
).decode("utf-8")
)
self._nonvolatile["pk"] = self.encoded_noc_keys
self.trusted_root_certificates.append(self.pending_root_cert)
@ -365,11 +402,29 @@ class _GroupKeyManagementCluster(GroupKeyManagementCluster):
def __init__(self):
super().__init__()
self.key_sets = []
self._encoded_key_sets = []
def restore(self, nonvolatile):
super().restore(nonvolatile)
if "gks" not in nonvolatile:
return
self._encoded_key_sets = nonvolatile["gks"]
self.key_sets = [
GroupKeySetStruct.decode(binascii.a2b_base64(v)) for v in nonvolatile["gks"]
]
def key_set_write(
self, session, args: GroupKeyManagementCluster.KeySetWrite
) -> interaction_model.StatusCode:
self.key_sets.append(args.GroupKeySet)
self._encoded_key_sets.append(
binascii.b2a_base64(args.GroupKeySet.encode(), newline=False).decode(
"utf-8"
)
)
self._nonvolatile["gks"] = self._encoded_key_sets
return interaction_model.StatusCode.SUCCESS
@ -378,7 +433,7 @@ class RootNode(simple_device.SimpleDevice):
REVISION = 2
def __init__(self, random_source, mdns_server, port, vendor_id, product_id):
super().__init__()
super().__init__("root")
basic_info = BasicInformationCluster()
basic_info.vendor_id = vendor_id
@ -421,3 +476,7 @@ class RootNode(simple_device.SimpleDevice):
self.user_label = user_label.UserLabelCluster()
self.servers.append(self.user_label)
@property
def fabric_count(self):
return self.noc.commissioned_fabrics

View file

@ -0,0 +1,69 @@
import json
class PersistentDictionary:
"""This acts like a dictionary and is persisted when values change."""
def __init__(self, filename=None, root=None, state=None):
self.filename = filename
self.root = root
self.dirty = False
self.persisted = {}
self._state: dict
if self.root is None and filename:
self.rollback()
elif state is not None:
self._state = state
else:
raise ValueError("Provide filename or (root and state)")
def wrap(self, value):
return value
def __setitem__(self, key, value):
self._state[key] = value
if self.root:
self.root.dirty = True
else:
self.dirty = True
def __getitem__(self, key):
value = self._state[key]
if isinstance(value, dict):
if key not in self.persisted:
root = self.root if self.root else self
self.persisted[key] = PersistentDictionary(root=root, state=value)
return self.persisted[key]
return value
def __delitem__(self, key):
del self._state[key]
if self.root:
self.root.dirty = True
else:
self.dirty = True
def keys(self):
return self._state.keys()
def __iter__(self):
return iter(self._state)
def commit(self):
if not self.dirty:
print("not dirty")
return
if self.root:
print("root commit")
self.root.commit()
return
print("commit")
print(self._state)
with open(self.filename, "w") as state_file:
json.dump(self._state, state_file, indent=1)
self.dirty = False
def rollback(self):
print("rollback")
with open(self.filename, "r") as state_file:
self._state = json.load(state_file)

View file

@ -629,7 +629,7 @@ class SessionManager:
decrypted = s3k_cipher.decrypt(b"NCASE_Sigma3N", sigma3.encrypted3, b"")
except cryptography.exceptions.InvalidTag:
return SecureChannelProtocolCode.INVALID_PARAMETER
sigma3_tbe, _ = case.Sigma3TbeData.decode(decrypted[0], decrypted[1:])
sigma3_tbe = case.Sigma3TbeData.decode(decrypted)
# TODO: Implement checks 4a-4d. INVALID_PARAMETER if they fail.
@ -640,9 +640,7 @@ class SessionManager:
secure_session_context = exchange.secure_session_context
peer_noc = sigma3_tbe.initiatorNOC
peer_noc, _ = crypto.MatterCertificate.decode(
peer_noc[0], memoryview(peer_noc)[1:]
)
peer_noc = crypto.MatterCertificate.decode(peer_noc)
secure_session_context.peer_node_id = peer_noc.subject.matter_node_id
exchange.transcript_hash.update(sigma3.encode())

View file

@ -93,7 +93,7 @@ def decode_element(control_octet, buffer, offset, depth):
value = None
offset = offset
else:
result = member_class.decode(control_octet, buffer, offset, depth)
result = member_class.decode_member(control_octet, buffer, offset, depth)
value, offset = result
return value, offset
@ -165,7 +165,15 @@ class Structure(Container):
return offset + 1
@classmethod
def decode(cls, control_octet, buffer, offset=0, depth=0) -> tuple[dict, int]:
def decode(cls, buffer: memoryview, offset=0) -> Structure:
control_octet = buffer[offset]
values, offset = cls.decode_member(control_octet, buffer, offset + 1)
return values
@classmethod
def decode_member(
cls, control_octet, buffer, offset=0, depth=0
) -> tuple[dict, int]:
values = {}
buffer = memoryview(buffer)
while offset < len(buffer) and buffer[offset] != ElementType.END_OF_CONTAINER:
@ -351,8 +359,12 @@ class Member(ABC, Generic[_T, _OPT, _NULLABLE]):
return new_offset
return offset
def decode(self, buffer: memoryview, offset: int = 0) -> _T:
"Return the decoded value at `offset` in `buffer`"
return self.decode_member(buffer[offset], buffer, offset + 1)[0]
@abstractmethod
def decode(
def decode_member(
self, control_octet: int, buffer: memoryview, offset: int = 0
) -> (_T, int):
"Return the decoded value at `offset` in `buffer`. `offset` is after the tag (but before any length)"
@ -455,7 +467,7 @@ class NumberMember(Member[_NT, _OPT, _NULLABLE], Generic[_NT, _OPT, _NULLABLE]):
super().__set__(obj, value) # type: ignore # self inference issues
@staticmethod
def decode(control_octet, buffer, offset=0, depth=0) -> tuple[_NT, int]:
def decode_member(control_octet, buffer, offset=0, depth=0) -> tuple[_NT, int]:
element_type = control_octet & 0x1F
element_category = element_type >> 2
if element_category == 0 or element_category == 1:
@ -595,7 +607,7 @@ class BoolMember(Member[bool, _OPT, _NULLABLE]):
max_value_length = 0
@staticmethod
def decode(control_octet, buffer, offset=0, depth=0):
def decode_member(control_octet, buffer, offset=0, depth=0):
return (control_octet & 1 == 1, offset)
def print(self, value):
@ -682,7 +694,7 @@ class OctetStringMember(StringMember[bytes, _OPT, _NULLABLE]):
_base_element_type: ElementType = ElementType.OCTET_STRING
@staticmethod
def decode(control_octet, buffer, offset=0, depth=0):
def decode_member(control_octet, buffer, offset=0, depth=0):
length, offset = StringMember.parse_length(control_octet, buffer, offset)
return (buffer[offset : offset + length].tobytes(), offset + length)
@ -691,7 +703,7 @@ class UTF8StringMember(StringMember[str, _OPT, _NULLABLE]):
_base_element_type = ElementType.UTF8_STRING
@staticmethod
def decode(control_octet, buffer, offset=0, depth=0):
def decode_member(control_octet, buffer, offset=0, depth=0):
length, offset = StringMember.parse_length(control_octet, buffer, offset)
return (
buffer[offset : offset + length].tobytes().decode("utf-8"),
@ -723,8 +735,8 @@ class StructMember(Member[_TLVStruct, _OPT, _NULLABLE]):
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)
@staticmethod
def decode(control_octet, buffer, offset=0, depth=0):
value, offset = Structure.decode(control_octet, buffer, offset, depth)
def decode_member(control_octet, buffer, offset=0, depth=0):
value, offset = Structure.decode_member(control_octet, buffer, offset, depth)
return value, offset + 1
def print(self, value):
@ -765,7 +777,7 @@ class ArrayMember(Member[_TLVStruct, _OPT, _NULLABLE]):
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)
@staticmethod
def decode(control_octet, buffer, offset=0, depth=0):
def decode_member(control_octet, buffer, offset=0, depth=0):
entries = []
while buffer[offset] != ElementType.END_OF_CONTAINER:
control_octet = buffer[offset]
@ -931,7 +943,7 @@ class ListMember(Member):
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)
@staticmethod
def decode(control_octet, buffer, offset=0, depth=0):
def decode_member(control_octet, buffer, offset=0, depth=0):
raw_list = []
while buffer[offset] != ElementType.END_OF_CONTAINER:
control_octet = buffer[offset]
@ -962,7 +974,7 @@ class ListMember(Member):
class AnythingMember(Member):
"""Stores a TLV encoded value."""
def decode(self, control_octet, buffer, offset=0):
def decode_member(self, control_octet, buffer, offset=0):
return None
def print(self, value):

View file

@ -0,0 +1,7 @@
{
"discriminator": 3840,
"passcode": 67202583,
"iteration-count": 10000,
"salt": "5uCP0ITHYzI9qBEe6hfU4HfY3y7VopSk0qNvhvznhiQ=",
"verifier": "0xGqxJFBr/ViQt3lv1Yw5F0GcPBAtFFvXB+EcIIjH5cEsjkPZHDQyFWjA6Ide+2gafYnZgIy6gJBgdJOlD8htAZKe0i6nIhT/ADsBWH4CvZcl37n/ofEEECWSEBV4vy/0A=="
}

View file

@ -24,12 +24,12 @@ class Bool(tlv.Structure):
class TestBool:
def test_bool_false_decode(self):
s, _ = Bool.decode(0x15, b"\x28\x00\x18")
s = Bool.decode(b"\x15\x28\x00\x18")
assert str(s) == "{\n b = false\n}"
assert s.b is False
def test_bool_true_decode(self):
s, _ = Bool.decode(0x15, b"\x29\x00\x18")
s = Bool.decode(b"\x15\x29\x00\x18")
assert str(s) == "{\n b = true\n}"
assert s.b is True
@ -72,29 +72,27 @@ class SignedIntEightOctet(tlv.Structure):
# 03 00 90 2f 50 09 00 00 00
class TestSignedInt:
def test_signed_int_42_decode(self):
s, _ = SignedIntOneOctet.decode(0x15, b"\x20\x00\x2a")
s = SignedIntOneOctet.decode(b"\x15\x20\x00\x2a")
assert str(s) == "{\n i = 42\n}"
assert s.i == 42
def test_signed_int_negative_17_decode(self):
s, _ = SignedIntOneOctet.decode(0x15, b"\x20\x00\xef")
s = SignedIntOneOctet.decode(b"\x15\x20\x00\xef")
assert str(s) == "{\n i = -17\n}"
assert s.i == -17
def test_signed_int_42_two_octet_decode(self):
s, _ = SignedIntTwoOctet.decode(0x15, b"\x21\x00\x2a\x00")
s = SignedIntTwoOctet.decode(b"\x15\x21\x00\x2a\x00")
assert str(s) == "{\n i = 42\n}"
assert s.i == 42
def test_signed_int_negative_170000_decode(self):
s, _ = SignedIntFourOctet.decode(0x15, b"\x22\x00\xf0\x67\xfd\xff")
s = SignedIntFourOctet.decode(b"\x15\x22\x00\xf0\x67\xfd\xff")
assert str(s) == "{\n i = -170000\n}"
assert s.i == -170000
def test_signed_int_40000000000_decode(self):
s, _ = SignedIntEightOctet.decode(
0x15, b"\x23\x00\x00\x90\x2f\x50\x09\x00\x00\x00"
)
s = SignedIntEightOctet.decode(b"\x15\x23\x00\x00\x90\x2f\x50\x09\x00\x00\x00")
assert str(s) == "{\n i = 40000000000\n}"
assert s.i == 40000000000
@ -158,7 +156,7 @@ class UnsignedIntOneOctet(tlv.Structure):
# 04 2a
class TestUnsignedInt:
def test_unsigned_int_42_decode(self):
s, _ = UnsignedIntOneOctet.decode(0x15, b"\x24\x00\x2a\x18")
s = UnsignedIntOneOctet.decode(b"\x15\x24\x00\x2a\x18")
assert str(s) == "{\n i = 42U\n}"
assert s.i == 42
@ -197,7 +195,7 @@ class TestUnsignedInt:
s.i = v
buffer = s.encode().tobytes()
s2, _ = UnsignedIntOneOctet.decode(0x15, buffer[1:])
s2 = UnsignedIntOneOctet.decode(buffer)
assert s2.i == s.i
assert str(s2) == str(s)
@ -228,12 +226,12 @@ class UTF8StringOneOctet(tlv.Structure):
class TestUTF8String:
def test_utf8_string_hello_decode(self):
s, _ = UTF8StringOneOctet.decode(0x15, b"\x2c\x00\x06Hello!")
s = UTF8StringOneOctet.decode(b"\x15\x2c\x00\x06Hello!")
assert str(s) == '{\n s = "Hello!"\n}'
assert s.s == "Hello!"
def test_utf8_string_tschs_decode(self):
s, _ = UTF8StringOneOctet.decode(0x15, b"\x2c\x00\x07Tsch\xc3\xbcs")
s = UTF8StringOneOctet.decode(b"\x15\x2c\x00\x07Tsch\xc3\xbcs")
assert str(s) == '{\n s = "Tschüs"\n}'
assert s.s == "Tschüs"
@ -254,7 +252,7 @@ class TestUTF8String:
s.s = v
buffer = s.encode().tobytes()
s2, _ = UTF8StringOneOctet.decode(0x15, buffer[1:])
s2 = UTF8StringOneOctet.decode(buffer)
assert s2.s == s.s
assert str(s2) == str(s)
@ -268,7 +266,7 @@ class OctetStringOneOctet(tlv.Structure):
class TestOctetString:
def test_octet_string_decode(self):
s, _ = OctetStringOneOctet.decode(0x15, b"\x30\x00\x05\x00\x01\x02\x03\x04\x18")
s = OctetStringOneOctet.decode(b"\x15\x30\x00\x05\x00\x01\x02\x03\x04\x18")
assert str(s) == "{\n s = 00 01 02 03 04\n}"
assert s.s == b"\x00\x01\x02\x03\x04"
@ -283,7 +281,7 @@ class TestOctetString:
s.s = v
buffer = s.encode().tobytes()
s2, _ = OctetStringOneOctet.decode(0x15, buffer[1:])
s2 = OctetStringOneOctet.decode(buffer)
assert s2.s == s.s
assert str(s2) == str(s)
@ -304,7 +302,7 @@ class NotNull(tlv.Structure):
class TestNull:
def test_null_decode(self):
s, _ = Null.decode(0x15, b"\x34\x00\x18")
s = Null.decode(b"\x15\x34\x00\x18")
assert str(s) == "{\n n = null\n}"
assert s.n is None
@ -352,28 +350,28 @@ class FloatDouble(tlv.Structure):
class TestFloatSingle:
def test_precision_float_0_0_decode(self):
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x00\x00\x18")
s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x00\x00\x18")
assert str(s) == "{\n f = 0.0\n}"
assert s.f == 0.0
def test_precision_float_1_3_decode(self):
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\xab\xaa\xaa\x3e\x18")
s = FloatSingle.decode(b"\x15\x2a\x00\xab\xaa\xaa\x3e\x18")
# assert str(s) == "{\n f = 0.3333333432674408\n}"
f = s.f
assert math.isclose(f, 1.0 / 3.0, rel_tol=1e-06)
def test_precision_float_17_9_decode(self):
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x33\x33\x8f\x41\x18")
s = FloatSingle.decode(b"\x15\x2a\x00\x33\x33\x8f\x41\x18")
assert str(s) == "{\n f = 17.899999618530273\n}"
assert math.isclose(s.f, 17.9, rel_tol=1e-06)
def test_precision_float_infinity_decode(self):
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x80\x7f\x18")
s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x80\x7f\x18")
assert str(s) == "{\n f = inf\n}"
assert math.isinf(s.f)
def test_precision_float_negative_infinity_decode(self):
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x80\xff\x18")
s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x80\xff\x18")
assert str(s) == "{\n f = -inf\n}"
assert math.isinf(s.f)
@ -408,7 +406,7 @@ class TestFloatSingle:
s.f = v
buffer = s.encode().tobytes()
s2, _ = FloatDouble.decode(0x15, buffer[1:])
s2 = FloatDouble.decode(buffer)
assert (
(math.isnan(s.f) and math.isnan(s2.f))
@ -431,7 +429,7 @@ class TestFloatSingle:
buffer = s.encode().tobytes()
print("Buffer", buffer.hex(" "))
s2, _ = FloatSingle.decode(0x15, buffer[1:])
s2 = FloatSingle.decode(buffer)
assert (math.isnan(s.f) and math.isnan(s2.f)) or math.isclose(
s2.f, s.f, rel_tol=1e-7, abs_tol=1e-9
@ -440,28 +438,28 @@ class TestFloatSingle:
class TestFloatDouble:
def test_precision_float_0_0_decode(self):
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\x00\x00")
s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\x00\x00")
assert str(s) == "{\n f = 0.0\n}"
assert s.f == 0.0
def test_precision_float_1_3_decode(self):
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x55\x55\x55\x55\x55\x55\xd5\x3f")
s = FloatDouble.decode(b"\x15\x2b\x00\x55\x55\x55\x55\x55\x55\xd5\x3f")
# assert str(s) == "{\n f = 0.3333333333333333\n}"
f = s.f
assert math.isclose(f, 1.0 / 3.0, rel_tol=1e-06)
def test_precision_float_17_9_decode(self):
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x66\x66\x66\x66\x66\xe6\x31\x40")
s = FloatDouble.decode(b"\x15\x2b\x00\x66\x66\x66\x66\x66\xe6\x31\x40")
assert str(s) == "{\n f = 17.9\n}"
assert math.isclose(s.f, 17.9, rel_tol=1e-06)
def test_precision_float_infinity_decode(self):
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\x7f")
s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\x7f")
assert str(s) == "{\n f = inf\n}"
assert math.isinf(s.f)
def test_precision_float_negative_infinity_decode(self):
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\xff")
s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\xff")
assert str(s) == "{\n f = -inf\n}"
assert math.isinf(s.f)
@ -506,7 +504,7 @@ class TestFloatDouble:
s.f = v
buffer = s.encode().tobytes()
s2, _ = FloatDouble.decode(0x15, buffer[1:])
s2 = FloatDouble.decode(buffer)
assert (
(math.isnan(s.f) and math.isnan(s2.f))
@ -527,7 +525,7 @@ class OuterStruct(tlv.Structure):
class TestStruct:
def test_inner_struct_decode(self):
s, _ = OuterStruct.decode(0x15, b"\x35\x00\x20\x00\x2a\x20\x01\xef\x18\x18")
s = OuterStruct.decode(b"\x15\x35\x00\x20\x00\x2a\x20\x01\xef\x18\x18")
assert_type(s, OuterStruct)
assert_type(s.s, InnerStruct)
assert_type(s.s.a, Optional[int])
@ -536,7 +534,7 @@ class TestStruct:
assert s.s.b == -17
def test_inner_struct_decode_empty(self):
s, _ = OuterStruct.decode(0x15, b"\x35\x00\x18\x18")
s = OuterStruct.decode(b"\x15\x35\x00\x18\x18")
assert str(s) == "{\n s = {\n \n }\n}"
assert s.s.a is None
assert s.s.b is None
@ -562,9 +560,8 @@ class FullyQualified(tlv.Structure):
class TestFullyQualifiedTags:
def test_decode(self):
s, _ = FullyQualified.decode(
0x15,
b"\xc2\xda\x0a\x00\x0f\x23\x01\x2a\x00\x00\x00\xe2\xda\x0a\x00\x0f\x45\x23\x01\x00\xef\xff\xff\xff\x18",
s = FullyQualified.decode(
b"\x15\xc2\xda\x0a\x00\x0f\x23\x01\x2a\x00\x00\x00\xe2\xda\x0a\x00\x0f\x45\x23\x01\x00\xef\xff\xff\xff\x18",
)
assert_type(s, FullyQualified)
assert_type(s.a, Optional[int])