Store nonvolatile state in json and restore. Improve decode too
This commit is contained in:
parent
cdaa606808
commit
6b9090e2f4
13 changed files with 368 additions and 127 deletions
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
|
||||
from . import case
|
||||
from . import interaction_model
|
||||
from . import nonvolatile
|
||||
from .message import Message
|
||||
from .protocol import InteractionModelOpcode, ProtocolId, SecureProtocolOpcode
|
||||
from . import session
|
||||
|
|
@ -29,15 +29,12 @@ class CircuitMatter:
|
|||
self.mdns_server = mdns_server
|
||||
self.random = random_source
|
||||
|
||||
with open(state_filename, "r") as state_file:
|
||||
self.nonvolatile = json.load(state_file)
|
||||
self.nonvolatile = nonvolatile.PersistentDictionary(state_filename)
|
||||
|
||||
for key in ["discriminator", "salt", "iteration-count", "verifier"]:
|
||||
if key not in self.nonvolatile:
|
||||
raise RuntimeError(f"Missing key {key} in state file")
|
||||
|
||||
commission = "fabrics" not in self.nonvolatile
|
||||
|
||||
self.packet_buffer = memoryview(bytearray(1280))
|
||||
|
||||
# Define the UDP IP address and port
|
||||
|
|
@ -51,6 +48,7 @@ class CircuitMatter:
|
|||
|
||||
# Bind the socket to the IP and port
|
||||
self.socket.bind((UDP_IP, self.UDP_PORT))
|
||||
print(f"Listening on UDP port {self.UDP_PORT}")
|
||||
self.socket.setblocking(False)
|
||||
|
||||
self._endpoints = {}
|
||||
|
|
@ -62,14 +60,11 @@ class CircuitMatter:
|
|||
|
||||
self.vendor_id = vendor_id
|
||||
self.product_id = product_id
|
||||
|
||||
self.manager = session.SessionManager(
|
||||
self.random, self.socket, self.root_node.noc
|
||||
)
|
||||
|
||||
print(f"Listening on UDP port {self.UDP_PORT}")
|
||||
|
||||
if commission:
|
||||
if self.root_node.fabric_count == 0:
|
||||
self.start_commissioning()
|
||||
|
||||
def start_commissioning(self):
|
||||
|
|
@ -118,6 +113,12 @@ class CircuitMatter:
|
|||
device.descriptor.ServerList.append(server.CLUSTER_ID)
|
||||
self.add_cluster(self._next_endpoint, server)
|
||||
self.add_cluster(self._next_endpoint, device.descriptor)
|
||||
|
||||
if "devices" not in self.nonvolatile:
|
||||
self.nonvolatile["devices"] = {}
|
||||
if device.name not in self.nonvolatile["devices"]:
|
||||
self.nonvolatile["devices"][device.name] = {}
|
||||
device.restore(self.nonvolatile["devices"][device.name])
|
||||
self._next_endpoint += 1
|
||||
|
||||
def process_packets(self):
|
||||
|
|
@ -249,9 +250,7 @@ class CircuitMatter:
|
|||
from . import pase
|
||||
|
||||
# This is Section 4.14.1.2
|
||||
request, _ = pase.PBKDFParamRequest.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
)
|
||||
request = pase.PBKDFParamRequest.decode(message.application_payload)
|
||||
exchange.commissioning_hash = hashlib.sha256(
|
||||
b"CHIP PAKE V1 Commissioning"
|
||||
)
|
||||
|
|
@ -287,9 +286,7 @@ class CircuitMatter:
|
|||
from . import pase
|
||||
|
||||
print("Received PASE PAKE1")
|
||||
pake1, _ = pase.PAKE1.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
)
|
||||
pake1 = pase.PAKE1.decode(message.application_payload)
|
||||
pake2 = pase.PAKE2()
|
||||
verifier = binascii.a2b_base64(self.nonvolatile["verifier"])
|
||||
context = exchange.commissioning_hash.digest()
|
||||
|
|
@ -308,9 +305,7 @@ class CircuitMatter:
|
|||
from . import pase
|
||||
|
||||
print("Received PASE PAKE3")
|
||||
pake3, _ = pase.PAKE3.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
)
|
||||
pake3 = pase.PAKE3.decode(message.application_payload)
|
||||
if pake3.cA != exchange.cA:
|
||||
del exchange.cA
|
||||
del exchange.Ke
|
||||
|
|
@ -341,9 +336,7 @@ class CircuitMatter:
|
|||
print("PASE succeeded")
|
||||
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA1:
|
||||
print("Received CASE Sigma1")
|
||||
sigma1, _ = case.Sigma1.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
)
|
||||
sigma1 = case.Sigma1.decode(message.application_payload)
|
||||
response = self.manager.reply_to_sigma1(exchange, sigma1)
|
||||
|
||||
exchange.send(response)
|
||||
|
|
@ -351,9 +344,7 @@ class CircuitMatter:
|
|||
print("Received CASE Sigma2")
|
||||
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA3:
|
||||
print("Received CASE Sigma3")
|
||||
sigma3, _ = case.Sigma3.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
)
|
||||
sigma3 = case.Sigma3.decode(message.application_payload)
|
||||
protocol_code = self.manager.reply_to_sigma3(exchange, sigma3)
|
||||
|
||||
error_status = session.StatusReport()
|
||||
|
|
@ -390,8 +381,8 @@ class CircuitMatter:
|
|||
message.session_id
|
||||
]
|
||||
if protocol_opcode == InteractionModelOpcode.READ_REQUEST:
|
||||
read_request, _ = interaction_model.ReadRequestMessage.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
read_request = interaction_model.ReadRequestMessage.decode(
|
||||
message.application_payload
|
||||
)
|
||||
attribute_reports = []
|
||||
for path in read_request.AttributeRequests:
|
||||
|
|
@ -404,8 +395,8 @@ class CircuitMatter:
|
|||
exchange.send(response)
|
||||
elif protocol_opcode == InteractionModelOpcode.WRITE_REQUEST:
|
||||
print("Received Write Request")
|
||||
write_request, _ = interaction_model.WriteRequestMessage.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
write_request = interaction_model.WriteRequestMessage.decode(
|
||||
message.application_payload
|
||||
)
|
||||
write_responses = []
|
||||
for request in write_request.WriteRequests:
|
||||
|
|
@ -421,8 +412,8 @@ class CircuitMatter:
|
|||
|
||||
elif protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
|
||||
print("Received Invoke Request")
|
||||
invoke_request, _ = interaction_model.InvokeRequestMessage.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
invoke_request = interaction_model.InvokeRequestMessage.decode(
|
||||
message.application_payload
|
||||
)
|
||||
for invoke in invoke_request.InvokeRequests:
|
||||
path = invoke.CommandPath
|
||||
|
|
@ -460,14 +451,13 @@ class CircuitMatter:
|
|||
response = interaction_model.InvokeResponseMessage()
|
||||
response.SuppressResponse = False
|
||||
response.InvokeResponses = invoke_responses
|
||||
print("sending invoke response", response)
|
||||
exchange.send(response)
|
||||
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
|
||||
print("Received Invoke Response")
|
||||
elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST:
|
||||
print("Received Subscribe Request")
|
||||
subscribe_request, _ = interaction_model.SubscribeRequestMessage.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
subscribe_request = interaction_model.SubscribeRequestMessage.decode(
|
||||
message.application_payload
|
||||
)
|
||||
print(subscribe_request)
|
||||
attribute_reports = []
|
||||
|
|
@ -484,8 +474,8 @@ class CircuitMatter:
|
|||
final_response.MaxInterval = subscribe_request.MaxIntervalCeiling
|
||||
exchange.queue(final_response)
|
||||
elif protocol_opcode == InteractionModelOpcode.STATUS_RESPONSE:
|
||||
status_response, _ = interaction_model.StatusResponseMessage.decode(
|
||||
message.application_payload[0], message.application_payload[1:]
|
||||
status_response = interaction_model.StatusResponseMessage.decode(
|
||||
message.application_payload
|
||||
)
|
||||
print(
|
||||
f"Received Status Response on {message.session_id}/{message.exchange_id} ack {message.acknowledged_message_counter}: {status_response.Status!r}"
|
||||
|
|
@ -502,3 +492,6 @@ class CircuitMatter:
|
|||
else:
|
||||
print("Unknown protocol", message.protocol_id, message.protocol_opcode)
|
||||
print()
|
||||
|
||||
self.nonvolatile.commit()
|
||||
# TODO: Rollback on error?
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import binascii
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import secrets
|
||||
import socket
|
||||
import subprocess
|
||||
|
|
@ -114,13 +115,8 @@ class MDNSServer(DummyMDNS):
|
|||
subtypes=[],
|
||||
instance_name="",
|
||||
):
|
||||
for active_service in self.active_services.values():
|
||||
active_service.kill()
|
||||
subtypes = [f"--subtype={subtype}" for subtype in subtypes]
|
||||
txt_records = [f"{key}={value}" for key, value in txt_records.items()]
|
||||
if service_type in self.active_services:
|
||||
self.active_services[service_type].kill()
|
||||
del self.active_services[service_type]
|
||||
command = [
|
||||
"avahi-publish-service",
|
||||
*subtypes,
|
||||
|
|
@ -130,7 +126,7 @@ class MDNSServer(DummyMDNS):
|
|||
*txt_records,
|
||||
]
|
||||
print("running avahi", command)
|
||||
self.active_services[service_type] = subprocess.Popen(command)
|
||||
self.active_services[service_type + instance_name] = subprocess.Popen(command)
|
||||
if self.publish_address is None:
|
||||
command = [
|
||||
"avahi-publish-address",
|
||||
|
|
@ -226,23 +222,33 @@ class NeoPixel(on_off.OnOffLight):
|
|||
|
||||
|
||||
def run(replay_file=None):
|
||||
device_state = pathlib.Path("test_data/device_state.json")
|
||||
replay_device_state = pathlib.Path("test_data/replay_device_state.json")
|
||||
if replay_file:
|
||||
replay_lines = []
|
||||
with open(replay_file, "r") as f:
|
||||
device_state_fn = f.readline().strip()
|
||||
for line in f:
|
||||
replay_lines.append(json.loads(line))
|
||||
socketpool = ReplaySocketPool(replay_lines)
|
||||
mdns_server = DummyMDNS()
|
||||
random_source = ReplayRandom(replay_lines)
|
||||
# Reset device state to before the captured run
|
||||
device_state.write_text(pathlib.Path(device_state_fn).read_text())
|
||||
else:
|
||||
record_file = open("test_data/recorded_packets.jsonl", "w")
|
||||
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
record_file = open(f"test_data/recorded_packets-{timestamp}.jsonl", "w")
|
||||
device_state_fn = f"test_data/device_state-{timestamp}.json"
|
||||
record_file.write(f"{device_state_fn}\n")
|
||||
socketpool = RecordingSocketPool(record_file)
|
||||
mdns_server = MDNSServer()
|
||||
random_source = RecordingRandom(record_file)
|
||||
matter = cm.CircuitMatter(
|
||||
socketpool, mdns_server, random_source, "test_data/device_state.json"
|
||||
)
|
||||
led = NeoPixel()
|
||||
# Save device state before we run so replays can use it.
|
||||
replay_device_state = pathlib.Path(device_state_fn)
|
||||
replay_device_state.write_text(device_state.read_text())
|
||||
|
||||
matter = cm.CircuitMatter(socketpool, mdns_server, random_source, device_state)
|
||||
led = NeoPixel("neopixel1")
|
||||
matter.add_device(led)
|
||||
while True:
|
||||
matter.process_packets()
|
||||
|
|
|
|||
|
|
@ -24,13 +24,15 @@ class GroupKeyMulticastPolicyEnum(Enum8):
|
|||
class GroupKeySetStruct(tlv.Structure):
|
||||
GroupKeySetID = tlv.IntMember(0, signed=False, octets=2)
|
||||
GroupKeySecurityPolicy = tlv.EnumMember(1, GroupKeySetSecurityPolicyEnum)
|
||||
EpochKey0 = tlv.OctetStringMember(2, 16)
|
||||
EpochStartTime0 = tlv.IntMember(3, signed=False, octets=8)
|
||||
EpochKey1 = tlv.OctetStringMember(4, 16)
|
||||
EpochStartTime1 = tlv.IntMember(5, signed=False, octets=8)
|
||||
EpochKey2 = tlv.OctetStringMember(6, 16)
|
||||
EpochStartTime2 = tlv.IntMember(7, signed=False, octets=8)
|
||||
GroupKeyMulticastPolicy = tlv.EnumMember(8, GroupKeyMulticastPolicyEnum)
|
||||
EpochKey0 = tlv.OctetStringMember(2, 16, nullable=True)
|
||||
EpochStartTime0 = tlv.IntMember(3, signed=False, octets=8, nullable=True)
|
||||
EpochKey1 = tlv.OctetStringMember(4, 16, nullable=True)
|
||||
EpochStartTime1 = tlv.IntMember(5, signed=False, octets=8, nullable=True)
|
||||
EpochKey2 = tlv.OctetStringMember(6, 16, nullable=True)
|
||||
EpochStartTime2 = tlv.IntMember(7, signed=False, octets=8, nullable=True)
|
||||
GroupKeyMulticastPolicy = tlv.EnumMember(
|
||||
8, GroupKeyMulticastPolicyEnum, nullable=True
|
||||
)
|
||||
|
||||
|
||||
class GroupKeyManagementCluster(Cluster):
|
||||
|
|
@ -48,7 +50,7 @@ class GroupKeyManagementCluster(Cluster):
|
|||
class KeySetWrite(tlv.Structure):
|
||||
GroupKeySet = tlv.StructMember(0, GroupKeySetStruct)
|
||||
|
||||
group_key_map = ListAttribute(0, GroupKeyMapStruct, default=[])
|
||||
group_key_map = ListAttribute(0, GroupKeyMapStruct, default=[], N_nonvolatile=True)
|
||||
group_table = ListAttribute(1, GroupInfoMapStruct, default=[])
|
||||
max_groups_per_fabric = NumberAttribute(2, signed=False, bits=16, default=0)
|
||||
max_group_keys_per_fabric = NumberAttribute(3, signed=False, bits=16, default=1)
|
||||
|
|
|
|||
|
|
@ -99,12 +99,20 @@ class NodeOperationalCredentialsCluster(Cluster):
|
|||
class AddTrustedRootCertificate(tlv.Structure):
|
||||
RootCACertificate = tlv.OctetStringMember(0, 400)
|
||||
|
||||
nocs = ListAttribute(0, NOCStruct, N_nonvolatile=True, C_changes_omitted=True)
|
||||
fabrics = ListAttribute(1, FabricDescriptorStruct, N_nonvolatile=True)
|
||||
nocs = ListAttribute(
|
||||
0, NOCStruct, N_nonvolatile=True, C_changes_omitted=True, default=[]
|
||||
)
|
||||
fabrics = ListAttribute(1, FabricDescriptorStruct, N_nonvolatile=True, default=[])
|
||||
supported_fabrics = NumberAttribute(2, signed=False, bits=8, F_fixed=True)
|
||||
commissioned_fabrics = NumberAttribute(3, signed=False, bits=8, N_nonvolatile=True)
|
||||
commissioned_fabrics = NumberAttribute(
|
||||
3, signed=False, bits=8, N_nonvolatile=True, default=0
|
||||
)
|
||||
trusted_root_certificates = ListAttribute(
|
||||
4, tlv.OctetStringMember(None, 400), N_nonvolatile=True, C_changes_omitted=True
|
||||
4,
|
||||
tlv.OctetStringMember(None, 400),
|
||||
N_nonvolatile=True,
|
||||
C_changes_omitted=True,
|
||||
default=[],
|
||||
)
|
||||
# This attribute is weird because it is fabric sensitive but not marked as such.
|
||||
# Cluster sets current_fabric_index for use in fabric sensitive attributes and
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import binascii
|
||||
import enum
|
||||
import inspect
|
||||
import random
|
||||
|
|
@ -9,6 +10,8 @@ from typing import Iterable, Union
|
|||
from . import interaction_model
|
||||
from . import tlv
|
||||
|
||||
ATTRIBUTES_KEY = "a"
|
||||
|
||||
|
||||
class Enum8(enum.IntEnum):
|
||||
pass
|
||||
|
|
@ -80,6 +83,7 @@ class Attribute:
|
|||
self.optional = optional
|
||||
self.feature = feature
|
||||
self.nullable = X_nullable
|
||||
self.nonvolatile = N_nonvolatile
|
||||
|
||||
def __get__(self, instance, cls):
|
||||
v = instance._attribute_values.get(self.id, None)
|
||||
|
|
@ -92,8 +96,16 @@ class Attribute:
|
|||
if old_value == value:
|
||||
return
|
||||
instance._attribute_values[self.id] = value
|
||||
if self.nonvolatile:
|
||||
instance._nonvolatile[ATTRIBUTES_KEY][hex(self.id)] = self.to_json(value)
|
||||
instance.data_version += 1
|
||||
|
||||
def to_json(self, value):
|
||||
return value
|
||||
|
||||
def from_json(self, value):
|
||||
return value
|
||||
|
||||
def encode(self, value) -> bytes:
|
||||
if value is None and self.nullable:
|
||||
return b"\x14" # No tag, NULL
|
||||
|
|
@ -145,6 +157,35 @@ class EnumAttribute(NumberAttribute):
|
|||
super().__init__(_id, signed=False, bits=bits, **kwargs)
|
||||
|
||||
|
||||
class _PersistentList:
|
||||
def __init__(self, wrapped_list, attribute, instance):
|
||||
self._list = wrapped_list
|
||||
self._instance = instance
|
||||
self._attribute = attribute
|
||||
|
||||
def append(self, value):
|
||||
self._list.append(value)
|
||||
self._instance._nonvolatile[ATTRIBUTES_KEY][hex(self._attribute.id)] = (
|
||||
self._attribute.to_json(self._list)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._list[index]
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
self._list[index] = value
|
||||
self._dirty = True
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._list)
|
||||
|
||||
def __str__(self):
|
||||
return "persistent" + str(self._list)
|
||||
|
||||
|
||||
class ListAttribute(Attribute):
|
||||
def __init__(self, _id, element_type, **kwargs):
|
||||
if inspect.isclass(element_type) and issubclass(element_type, enum.Enum):
|
||||
|
|
@ -157,6 +198,28 @@ class ListAttribute(Attribute):
|
|||
kwargs["default"] = list(kwargs["default"])
|
||||
super().__init__(_id, **kwargs)
|
||||
|
||||
def __get__(self, instance, cls):
|
||||
v = super().__get__(instance, cls)
|
||||
if self.nonvolatile and v is not None and not isinstance(v, _PersistentList):
|
||||
# Wrap the list in an object that tracks changes and writes them to nonvolatile.
|
||||
p = _PersistentList(v, self, instance)
|
||||
instance._attribute_values[self.id] = p
|
||||
return p
|
||||
return v
|
||||
|
||||
def to_json(self, value):
|
||||
return [
|
||||
binascii.b2a_base64(self._element_type.encode(v), newline=False).decode(
|
||||
"utf-8"
|
||||
)
|
||||
for v in value
|
||||
]
|
||||
|
||||
def from_json(self, value):
|
||||
return [
|
||||
self._element_type.decode(memoryview(binascii.a2b_base64(v))) for v in value
|
||||
]
|
||||
|
||||
def _encode(self, value) -> bytes:
|
||||
return self.tlv_type.encode(value)
|
||||
|
||||
|
|
@ -243,6 +306,23 @@ class Cluster:
|
|||
if not field_name.startswith("_") and isinstance(descriptor, Attribute):
|
||||
yield field_name, descriptor
|
||||
|
||||
def restore(self, nonvolatile):
|
||||
self._nonvolatile = nonvolatile
|
||||
|
||||
if ATTRIBUTES_KEY not in nonvolatile:
|
||||
nonvolatile[ATTRIBUTES_KEY] = {}
|
||||
for field_name, descriptor in self._attributes():
|
||||
if descriptor.nonvolatile:
|
||||
print(field_name, nonvolatile[ATTRIBUTES_KEY])
|
||||
if hex(descriptor.id) in nonvolatile[ATTRIBUTES_KEY]:
|
||||
# Update our live value
|
||||
self._attribute_values[descriptor.id] = descriptor.from_json(
|
||||
nonvolatile[ATTRIBUTES_KEY][hex(descriptor.id)]
|
||||
)
|
||||
else:
|
||||
# Store the default
|
||||
nonvolatile[ATTRIBUTES_KEY][hex(descriptor.id)] = descriptor.default
|
||||
|
||||
def get_attribute_data(
|
||||
self, session, path
|
||||
) -> typing.List[interaction_model.AttributeDataIB]:
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ class OnOffLight(simple_device.SimpleDevice):
|
|||
DEVICE_TYPE_ID = 0x0100
|
||||
REVISION = 3
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
self._identify = Identify()
|
||||
self.servers.append(self._identify)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from circuitmatter.clusters.system_model import binding, descriptor, user_label
|
|||
|
||||
|
||||
class SimpleDevice:
|
||||
def __init__(self):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.servers = []
|
||||
self.descriptor = descriptor.DescriptorCluster()
|
||||
device_type = descriptor.DescriptorCluster.DeviceTypeStruct()
|
||||
|
|
@ -19,3 +20,12 @@ class SimpleDevice:
|
|||
|
||||
self.user_label = user_label.UserLabelCluster()
|
||||
self.servers.append(self.user_label)
|
||||
|
||||
def restore(self, nonvolatile):
|
||||
"""Restore device state from the nonvolatile dictionary and hang onto it for any updates."""
|
||||
self.nonvolatile = nonvolatile
|
||||
for server in self.servers:
|
||||
cluster_hex = hex(server.CLUSTER_ID)
|
||||
if cluster_hex not in nonvolatile:
|
||||
nonvolatile[cluster_hex] = {}
|
||||
server.restore(nonvolatile[cluster_hex])
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import binascii
|
||||
import ecdsa
|
||||
from ecdsa import der
|
||||
import hashlib
|
||||
|
|
@ -118,20 +119,58 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster):
|
|||
self.pending_root_cert = None
|
||||
self.pending_signing_key = None
|
||||
|
||||
self.nocs = []
|
||||
self.fabrics = []
|
||||
self.supported_fabrics = 10
|
||||
self.commissioned_fabrics = 0
|
||||
self.trusted_root_certificates = []
|
||||
|
||||
self.root_certs = []
|
||||
self.compressed_fabric_ids = []
|
||||
self.noc_keys = []
|
||||
self.encoded_noc_keys = []
|
||||
|
||||
self.mdns_server = mdns_server
|
||||
self.port = port
|
||||
self.random = random_source
|
||||
|
||||
def restore(self, nonvolatile):
|
||||
super().restore(nonvolatile)
|
||||
|
||||
if "pk" not in nonvolatile:
|
||||
return
|
||||
|
||||
self.root_certs = []
|
||||
self.compressed_fabric_ids = []
|
||||
self.noc_keys = []
|
||||
self.encoded_noc_keys = nonvolatile["pk"]
|
||||
|
||||
for i, encoded_root_cert in enumerate(self.trusted_root_certificates):
|
||||
root_cert = crypto.MatterCertificate.decode(encoded_root_cert)
|
||||
|
||||
self.root_certs.append(root_cert)
|
||||
fabric = self.fabrics[i]
|
||||
fabric_id = struct.pack(">Q", fabric.FabricID)
|
||||
compressed_fabric_id = crypto.KDF(
|
||||
root_cert.ec_pub_key[1:], fabric_id, b"CompressedFabric", 64
|
||||
)
|
||||
self.compressed_fabric_ids.append(compressed_fabric_id)
|
||||
signing_key = ecdsa.keys.SigningKey.from_string(
|
||||
binascii.a2b_base64(self.encoded_noc_keys[i]),
|
||||
curve=ecdsa.NIST256p,
|
||||
hashfunc=hashlib.sha256,
|
||||
)
|
||||
self.noc_keys.append(signing_key)
|
||||
|
||||
node_id = struct.pack(">Q", fabric.NodeID).hex().upper()
|
||||
compressed_fabric_id = compressed_fabric_id.hex().upper()
|
||||
instance_name = f"{compressed_fabric_id}-{node_id}"
|
||||
self.mdns_server.advertise_service(
|
||||
"_matter",
|
||||
"_tcp",
|
||||
self.port,
|
||||
instance_name=instance_name,
|
||||
subtypes=[
|
||||
f"_I{compressed_fabric_id}._sub._matter._tcp",
|
||||
],
|
||||
)
|
||||
|
||||
def certificate_chain_request(
|
||||
self,
|
||||
session,
|
||||
|
|
@ -260,13 +299,7 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster):
|
|||
self, session, args: NodeOperationalCredentialsCluster.AddNOC
|
||||
) -> NodeOperationalCredentialsCluster.NOCResponse:
|
||||
# Section 11.18.6.8
|
||||
noc, _ = crypto.MatterCertificate.decode(
|
||||
args.NOCValue[0], memoryview(args.NOCValue)[1:]
|
||||
)
|
||||
if args.ICACValue:
|
||||
icac, _ = crypto.MatterCertificate.decode(
|
||||
args.ICACValue[0], memoryview(args.ICACValue)[1:]
|
||||
)
|
||||
noc = crypto.MatterCertificate.decode(args.NOCValue)
|
||||
|
||||
response = NodeOperationalCredentialsCluster.NOCResponse()
|
||||
|
||||
|
|
@ -291,9 +324,7 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster):
|
|||
self.nocs.append(noc_struct)
|
||||
|
||||
# Get the root cert public key so we can create the compressed fabric id.
|
||||
root_cert, _ = crypto.MatterCertificate.decode(
|
||||
self.pending_root_cert[0], memoryview(self.pending_root_cert)[1:]
|
||||
)
|
||||
root_cert = crypto.MatterCertificate.decode(self.pending_root_cert)
|
||||
|
||||
# Store the fabric
|
||||
new_fabric = NodeOperationalCredentialsCluster.FabricDescriptorStruct()
|
||||
|
|
@ -317,6 +348,12 @@ class _NodeOperationalCredentialsCluster(NodeOperationalCredentialsCluster):
|
|||
self.commissioned_fabrics += 1
|
||||
|
||||
self.noc_keys.append(self.pending_signing_key)
|
||||
self.encoded_noc_keys.append(
|
||||
binascii.b2a_base64(
|
||||
self.pending_signing_key.to_string(), newline=False
|
||||
).decode("utf-8")
|
||||
)
|
||||
self._nonvolatile["pk"] = self.encoded_noc_keys
|
||||
|
||||
self.trusted_root_certificates.append(self.pending_root_cert)
|
||||
|
||||
|
|
@ -365,11 +402,29 @@ class _GroupKeyManagementCluster(GroupKeyManagementCluster):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.key_sets = []
|
||||
self._encoded_key_sets = []
|
||||
|
||||
def restore(self, nonvolatile):
|
||||
super().restore(nonvolatile)
|
||||
|
||||
if "gks" not in nonvolatile:
|
||||
return
|
||||
|
||||
self._encoded_key_sets = nonvolatile["gks"]
|
||||
self.key_sets = [
|
||||
GroupKeySetStruct.decode(binascii.a2b_base64(v)) for v in nonvolatile["gks"]
|
||||
]
|
||||
|
||||
def key_set_write(
|
||||
self, session, args: GroupKeyManagementCluster.KeySetWrite
|
||||
) -> interaction_model.StatusCode:
|
||||
self.key_sets.append(args.GroupKeySet)
|
||||
self._encoded_key_sets.append(
|
||||
binascii.b2a_base64(args.GroupKeySet.encode(), newline=False).decode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
self._nonvolatile["gks"] = self._encoded_key_sets
|
||||
return interaction_model.StatusCode.SUCCESS
|
||||
|
||||
|
||||
|
|
@ -378,7 +433,7 @@ class RootNode(simple_device.SimpleDevice):
|
|||
REVISION = 2
|
||||
|
||||
def __init__(self, random_source, mdns_server, port, vendor_id, product_id):
|
||||
super().__init__()
|
||||
super().__init__("root")
|
||||
|
||||
basic_info = BasicInformationCluster()
|
||||
basic_info.vendor_id = vendor_id
|
||||
|
|
@ -421,3 +476,7 @@ class RootNode(simple_device.SimpleDevice):
|
|||
|
||||
self.user_label = user_label.UserLabelCluster()
|
||||
self.servers.append(self.user_label)
|
||||
|
||||
@property
|
||||
def fabric_count(self):
|
||||
return self.noc.commissioned_fabrics
|
||||
|
|
|
|||
69
circuitmatter/nonvolatile.py
Normal file
69
circuitmatter/nonvolatile.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import json
|
||||
|
||||
|
||||
class PersistentDictionary:
|
||||
"""This acts like a dictionary and is persisted when values change."""
|
||||
|
||||
def __init__(self, filename=None, root=None, state=None):
|
||||
self.filename = filename
|
||||
self.root = root
|
||||
self.dirty = False
|
||||
self.persisted = {}
|
||||
self._state: dict
|
||||
if self.root is None and filename:
|
||||
self.rollback()
|
||||
elif state is not None:
|
||||
self._state = state
|
||||
else:
|
||||
raise ValueError("Provide filename or (root and state)")
|
||||
|
||||
def wrap(self, value):
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._state[key] = value
|
||||
if self.root:
|
||||
self.root.dirty = True
|
||||
else:
|
||||
self.dirty = True
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = self._state[key]
|
||||
if isinstance(value, dict):
|
||||
if key not in self.persisted:
|
||||
root = self.root if self.root else self
|
||||
self.persisted[key] = PersistentDictionary(root=root, state=value)
|
||||
return self.persisted[key]
|
||||
return value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._state[key]
|
||||
if self.root:
|
||||
self.root.dirty = True
|
||||
else:
|
||||
self.dirty = True
|
||||
|
||||
def keys(self):
|
||||
return self._state.keys()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._state)
|
||||
|
||||
def commit(self):
|
||||
if not self.dirty:
|
||||
print("not dirty")
|
||||
return
|
||||
if self.root:
|
||||
print("root commit")
|
||||
self.root.commit()
|
||||
return
|
||||
print("commit")
|
||||
print(self._state)
|
||||
with open(self.filename, "w") as state_file:
|
||||
json.dump(self._state, state_file, indent=1)
|
||||
self.dirty = False
|
||||
|
||||
def rollback(self):
|
||||
print("rollback")
|
||||
with open(self.filename, "r") as state_file:
|
||||
self._state = json.load(state_file)
|
||||
|
|
@ -629,7 +629,7 @@ class SessionManager:
|
|||
decrypted = s3k_cipher.decrypt(b"NCASE_Sigma3N", sigma3.encrypted3, b"")
|
||||
except cryptography.exceptions.InvalidTag:
|
||||
return SecureChannelProtocolCode.INVALID_PARAMETER
|
||||
sigma3_tbe, _ = case.Sigma3TbeData.decode(decrypted[0], decrypted[1:])
|
||||
sigma3_tbe = case.Sigma3TbeData.decode(decrypted)
|
||||
|
||||
# TODO: Implement checks 4a-4d. INVALID_PARAMETER if they fail.
|
||||
|
||||
|
|
@ -640,9 +640,7 @@ class SessionManager:
|
|||
secure_session_context = exchange.secure_session_context
|
||||
|
||||
peer_noc = sigma3_tbe.initiatorNOC
|
||||
peer_noc, _ = crypto.MatterCertificate.decode(
|
||||
peer_noc[0], memoryview(peer_noc)[1:]
|
||||
)
|
||||
peer_noc = crypto.MatterCertificate.decode(peer_noc)
|
||||
secure_session_context.peer_node_id = peer_noc.subject.matter_node_id
|
||||
|
||||
exchange.transcript_hash.update(sigma3.encode())
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def decode_element(control_octet, buffer, offset, depth):
|
|||
value = None
|
||||
offset = offset
|
||||
else:
|
||||
result = member_class.decode(control_octet, buffer, offset, depth)
|
||||
result = member_class.decode_member(control_octet, buffer, offset, depth)
|
||||
value, offset = result
|
||||
return value, offset
|
||||
|
||||
|
|
@ -165,7 +165,15 @@ class Structure(Container):
|
|||
return offset + 1
|
||||
|
||||
@classmethod
|
||||
def decode(cls, control_octet, buffer, offset=0, depth=0) -> tuple[dict, int]:
|
||||
def decode(cls, buffer: memoryview, offset=0) -> Structure:
|
||||
control_octet = buffer[offset]
|
||||
values, offset = cls.decode_member(control_octet, buffer, offset + 1)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def decode_member(
|
||||
cls, control_octet, buffer, offset=0, depth=0
|
||||
) -> tuple[dict, int]:
|
||||
values = {}
|
||||
buffer = memoryview(buffer)
|
||||
while offset < len(buffer) and buffer[offset] != ElementType.END_OF_CONTAINER:
|
||||
|
|
@ -351,8 +359,12 @@ class Member(ABC, Generic[_T, _OPT, _NULLABLE]):
|
|||
return new_offset
|
||||
return offset
|
||||
|
||||
def decode(self, buffer: memoryview, offset: int = 0) -> _T:
|
||||
"Return the decoded value at `offset` in `buffer`"
|
||||
return self.decode_member(buffer[offset], buffer, offset + 1)[0]
|
||||
|
||||
@abstractmethod
|
||||
def decode(
|
||||
def decode_member(
|
||||
self, control_octet: int, buffer: memoryview, offset: int = 0
|
||||
) -> (_T, int):
|
||||
"Return the decoded value at `offset` in `buffer`. `offset` is after the tag (but before any length)"
|
||||
|
|
@ -455,7 +467,7 @@ class NumberMember(Member[_NT, _OPT, _NULLABLE], Generic[_NT, _OPT, _NULLABLE]):
|
|||
super().__set__(obj, value) # type: ignore # self inference issues
|
||||
|
||||
@staticmethod
|
||||
def decode(control_octet, buffer, offset=0, depth=0) -> tuple[_NT, int]:
|
||||
def decode_member(control_octet, buffer, offset=0, depth=0) -> tuple[_NT, int]:
|
||||
element_type = control_octet & 0x1F
|
||||
element_category = element_type >> 2
|
||||
if element_category == 0 or element_category == 1:
|
||||
|
|
@ -595,7 +607,7 @@ class BoolMember(Member[bool, _OPT, _NULLABLE]):
|
|||
max_value_length = 0
|
||||
|
||||
@staticmethod
|
||||
def decode(control_octet, buffer, offset=0, depth=0):
|
||||
def decode_member(control_octet, buffer, offset=0, depth=0):
|
||||
return (control_octet & 1 == 1, offset)
|
||||
|
||||
def print(self, value):
|
||||
|
|
@ -682,7 +694,7 @@ class OctetStringMember(StringMember[bytes, _OPT, _NULLABLE]):
|
|||
_base_element_type: ElementType = ElementType.OCTET_STRING
|
||||
|
||||
@staticmethod
|
||||
def decode(control_octet, buffer, offset=0, depth=0):
|
||||
def decode_member(control_octet, buffer, offset=0, depth=0):
|
||||
length, offset = StringMember.parse_length(control_octet, buffer, offset)
|
||||
return (buffer[offset : offset + length].tobytes(), offset + length)
|
||||
|
||||
|
|
@ -691,7 +703,7 @@ class UTF8StringMember(StringMember[str, _OPT, _NULLABLE]):
|
|||
_base_element_type = ElementType.UTF8_STRING
|
||||
|
||||
@staticmethod
|
||||
def decode(control_octet, buffer, offset=0, depth=0):
|
||||
def decode_member(control_octet, buffer, offset=0, depth=0):
|
||||
length, offset = StringMember.parse_length(control_octet, buffer, offset)
|
||||
return (
|
||||
buffer[offset : offset + length].tobytes().decode("utf-8"),
|
||||
|
|
@ -723,8 +735,8 @@ class StructMember(Member[_TLVStruct, _OPT, _NULLABLE]):
|
|||
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def decode(control_octet, buffer, offset=0, depth=0):
|
||||
value, offset = Structure.decode(control_octet, buffer, offset, depth)
|
||||
def decode_member(control_octet, buffer, offset=0, depth=0):
|
||||
value, offset = Structure.decode_member(control_octet, buffer, offset, depth)
|
||||
return value, offset + 1
|
||||
|
||||
def print(self, value):
|
||||
|
|
@ -765,7 +777,7 @@ class ArrayMember(Member[_TLVStruct, _OPT, _NULLABLE]):
|
|||
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def decode(control_octet, buffer, offset=0, depth=0):
|
||||
def decode_member(control_octet, buffer, offset=0, depth=0):
|
||||
entries = []
|
||||
while buffer[offset] != ElementType.END_OF_CONTAINER:
|
||||
control_octet = buffer[offset]
|
||||
|
|
@ -931,7 +943,7 @@ class ListMember(Member):
|
|||
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def decode(control_octet, buffer, offset=0, depth=0):
|
||||
def decode_member(control_octet, buffer, offset=0, depth=0):
|
||||
raw_list = []
|
||||
while buffer[offset] != ElementType.END_OF_CONTAINER:
|
||||
control_octet = buffer[offset]
|
||||
|
|
@ -962,7 +974,7 @@ class ListMember(Member):
|
|||
class AnythingMember(Member):
|
||||
"""Stores a TLV encoded value."""
|
||||
|
||||
def decode(self, control_octet, buffer, offset=0):
|
||||
def decode_member(self, control_octet, buffer, offset=0):
|
||||
return None
|
||||
|
||||
def print(self, value):
|
||||
|
|
|
|||
7
test_data/device_state-uncommissioned.json
Normal file
7
test_data/device_state-uncommissioned.json
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"discriminator": 3840,
|
||||
"passcode": 67202583,
|
||||
"iteration-count": 10000,
|
||||
"salt": "5uCP0ITHYzI9qBEe6hfU4HfY3y7VopSk0qNvhvznhiQ=",
|
||||
"verifier": "0xGqxJFBr/ViQt3lv1Yw5F0GcPBAtFFvXB+EcIIjH5cEsjkPZHDQyFWjA6Ide+2gafYnZgIy6gJBgdJOlD8htAZKe0i6nIhT/ADsBWH4CvZcl37n/ofEEECWSEBV4vy/0A=="
|
||||
}
|
||||
|
|
@ -24,12 +24,12 @@ class Bool(tlv.Structure):
|
|||
|
||||
class TestBool:
|
||||
def test_bool_false_decode(self):
|
||||
s, _ = Bool.decode(0x15, b"\x28\x00\x18")
|
||||
s = Bool.decode(b"\x15\x28\x00\x18")
|
||||
assert str(s) == "{\n b = false\n}"
|
||||
assert s.b is False
|
||||
|
||||
def test_bool_true_decode(self):
|
||||
s, _ = Bool.decode(0x15, b"\x29\x00\x18")
|
||||
s = Bool.decode(b"\x15\x29\x00\x18")
|
||||
assert str(s) == "{\n b = true\n}"
|
||||
assert s.b is True
|
||||
|
||||
|
|
@ -72,29 +72,27 @@ class SignedIntEightOctet(tlv.Structure):
|
|||
# 03 00 90 2f 50 09 00 00 00
|
||||
class TestSignedInt:
|
||||
def test_signed_int_42_decode(self):
|
||||
s, _ = SignedIntOneOctet.decode(0x15, b"\x20\x00\x2a")
|
||||
s = SignedIntOneOctet.decode(b"\x15\x20\x00\x2a")
|
||||
assert str(s) == "{\n i = 42\n}"
|
||||
assert s.i == 42
|
||||
|
||||
def test_signed_int_negative_17_decode(self):
|
||||
s, _ = SignedIntOneOctet.decode(0x15, b"\x20\x00\xef")
|
||||
s = SignedIntOneOctet.decode(b"\x15\x20\x00\xef")
|
||||
assert str(s) == "{\n i = -17\n}"
|
||||
assert s.i == -17
|
||||
|
||||
def test_signed_int_42_two_octet_decode(self):
|
||||
s, _ = SignedIntTwoOctet.decode(0x15, b"\x21\x00\x2a\x00")
|
||||
s = SignedIntTwoOctet.decode(b"\x15\x21\x00\x2a\x00")
|
||||
assert str(s) == "{\n i = 42\n}"
|
||||
assert s.i == 42
|
||||
|
||||
def test_signed_int_negative_170000_decode(self):
|
||||
s, _ = SignedIntFourOctet.decode(0x15, b"\x22\x00\xf0\x67\xfd\xff")
|
||||
s = SignedIntFourOctet.decode(b"\x15\x22\x00\xf0\x67\xfd\xff")
|
||||
assert str(s) == "{\n i = -170000\n}"
|
||||
assert s.i == -170000
|
||||
|
||||
def test_signed_int_40000000000_decode(self):
|
||||
s, _ = SignedIntEightOctet.decode(
|
||||
0x15, b"\x23\x00\x00\x90\x2f\x50\x09\x00\x00\x00"
|
||||
)
|
||||
s = SignedIntEightOctet.decode(b"\x15\x23\x00\x00\x90\x2f\x50\x09\x00\x00\x00")
|
||||
assert str(s) == "{\n i = 40000000000\n}"
|
||||
assert s.i == 40000000000
|
||||
|
||||
|
|
@ -158,7 +156,7 @@ class UnsignedIntOneOctet(tlv.Structure):
|
|||
# 04 2a
|
||||
class TestUnsignedInt:
|
||||
def test_unsigned_int_42_decode(self):
|
||||
s, _ = UnsignedIntOneOctet.decode(0x15, b"\x24\x00\x2a\x18")
|
||||
s = UnsignedIntOneOctet.decode(b"\x15\x24\x00\x2a\x18")
|
||||
assert str(s) == "{\n i = 42U\n}"
|
||||
assert s.i == 42
|
||||
|
||||
|
|
@ -197,7 +195,7 @@ class TestUnsignedInt:
|
|||
s.i = v
|
||||
buffer = s.encode().tobytes()
|
||||
|
||||
s2, _ = UnsignedIntOneOctet.decode(0x15, buffer[1:])
|
||||
s2 = UnsignedIntOneOctet.decode(buffer)
|
||||
|
||||
assert s2.i == s.i
|
||||
assert str(s2) == str(s)
|
||||
|
|
@ -228,12 +226,12 @@ class UTF8StringOneOctet(tlv.Structure):
|
|||
|
||||
class TestUTF8String:
|
||||
def test_utf8_string_hello_decode(self):
|
||||
s, _ = UTF8StringOneOctet.decode(0x15, b"\x2c\x00\x06Hello!")
|
||||
s = UTF8StringOneOctet.decode(b"\x15\x2c\x00\x06Hello!")
|
||||
assert str(s) == '{\n s = "Hello!"\n}'
|
||||
assert s.s == "Hello!"
|
||||
|
||||
def test_utf8_string_tschs_decode(self):
|
||||
s, _ = UTF8StringOneOctet.decode(0x15, b"\x2c\x00\x07Tsch\xc3\xbcs")
|
||||
s = UTF8StringOneOctet.decode(b"\x15\x2c\x00\x07Tsch\xc3\xbcs")
|
||||
assert str(s) == '{\n s = "Tschüs"\n}'
|
||||
assert s.s == "Tschüs"
|
||||
|
||||
|
|
@ -254,7 +252,7 @@ class TestUTF8String:
|
|||
s.s = v
|
||||
buffer = s.encode().tobytes()
|
||||
|
||||
s2, _ = UTF8StringOneOctet.decode(0x15, buffer[1:])
|
||||
s2 = UTF8StringOneOctet.decode(buffer)
|
||||
|
||||
assert s2.s == s.s
|
||||
assert str(s2) == str(s)
|
||||
|
|
@ -268,7 +266,7 @@ class OctetStringOneOctet(tlv.Structure):
|
|||
|
||||
class TestOctetString:
|
||||
def test_octet_string_decode(self):
|
||||
s, _ = OctetStringOneOctet.decode(0x15, b"\x30\x00\x05\x00\x01\x02\x03\x04\x18")
|
||||
s = OctetStringOneOctet.decode(b"\x15\x30\x00\x05\x00\x01\x02\x03\x04\x18")
|
||||
assert str(s) == "{\n s = 00 01 02 03 04\n}"
|
||||
assert s.s == b"\x00\x01\x02\x03\x04"
|
||||
|
||||
|
|
@ -283,7 +281,7 @@ class TestOctetString:
|
|||
s.s = v
|
||||
buffer = s.encode().tobytes()
|
||||
|
||||
s2, _ = OctetStringOneOctet.decode(0x15, buffer[1:])
|
||||
s2 = OctetStringOneOctet.decode(buffer)
|
||||
|
||||
assert s2.s == s.s
|
||||
assert str(s2) == str(s)
|
||||
|
|
@ -304,7 +302,7 @@ class NotNull(tlv.Structure):
|
|||
|
||||
class TestNull:
|
||||
def test_null_decode(self):
|
||||
s, _ = Null.decode(0x15, b"\x34\x00\x18")
|
||||
s = Null.decode(b"\x15\x34\x00\x18")
|
||||
assert str(s) == "{\n n = null\n}"
|
||||
assert s.n is None
|
||||
|
||||
|
|
@ -352,28 +350,28 @@ class FloatDouble(tlv.Structure):
|
|||
|
||||
class TestFloatSingle:
|
||||
def test_precision_float_0_0_decode(self):
|
||||
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x00\x00\x18")
|
||||
s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x00\x00\x18")
|
||||
assert str(s) == "{\n f = 0.0\n}"
|
||||
assert s.f == 0.0
|
||||
|
||||
def test_precision_float_1_3_decode(self):
|
||||
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\xab\xaa\xaa\x3e\x18")
|
||||
s = FloatSingle.decode(b"\x15\x2a\x00\xab\xaa\xaa\x3e\x18")
|
||||
# assert str(s) == "{\n f = 0.3333333432674408\n}"
|
||||
f = s.f
|
||||
assert math.isclose(f, 1.0 / 3.0, rel_tol=1e-06)
|
||||
|
||||
def test_precision_float_17_9_decode(self):
|
||||
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x33\x33\x8f\x41\x18")
|
||||
s = FloatSingle.decode(b"\x15\x2a\x00\x33\x33\x8f\x41\x18")
|
||||
assert str(s) == "{\n f = 17.899999618530273\n}"
|
||||
assert math.isclose(s.f, 17.9, rel_tol=1e-06)
|
||||
|
||||
def test_precision_float_infinity_decode(self):
|
||||
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x80\x7f\x18")
|
||||
s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x80\x7f\x18")
|
||||
assert str(s) == "{\n f = inf\n}"
|
||||
assert math.isinf(s.f)
|
||||
|
||||
def test_precision_float_negative_infinity_decode(self):
|
||||
s, _ = FloatSingle.decode(0x15, b"\x2a\x00\x00\x00\x80\xff\x18")
|
||||
s = FloatSingle.decode(b"\x15\x2a\x00\x00\x00\x80\xff\x18")
|
||||
assert str(s) == "{\n f = -inf\n}"
|
||||
assert math.isinf(s.f)
|
||||
|
||||
|
|
@ -408,7 +406,7 @@ class TestFloatSingle:
|
|||
s.f = v
|
||||
buffer = s.encode().tobytes()
|
||||
|
||||
s2, _ = FloatDouble.decode(0x15, buffer[1:])
|
||||
s2 = FloatDouble.decode(buffer)
|
||||
|
||||
assert (
|
||||
(math.isnan(s.f) and math.isnan(s2.f))
|
||||
|
|
@ -431,7 +429,7 @@ class TestFloatSingle:
|
|||
buffer = s.encode().tobytes()
|
||||
print("Buffer", buffer.hex(" "))
|
||||
|
||||
s2, _ = FloatSingle.decode(0x15, buffer[1:])
|
||||
s2 = FloatSingle.decode(buffer)
|
||||
|
||||
assert (math.isnan(s.f) and math.isnan(s2.f)) or math.isclose(
|
||||
s2.f, s.f, rel_tol=1e-7, abs_tol=1e-9
|
||||
|
|
@ -440,28 +438,28 @@ class TestFloatSingle:
|
|||
|
||||
class TestFloatDouble:
|
||||
def test_precision_float_0_0_decode(self):
|
||||
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\x00\x00")
|
||||
s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\x00\x00")
|
||||
assert str(s) == "{\n f = 0.0\n}"
|
||||
assert s.f == 0.0
|
||||
|
||||
def test_precision_float_1_3_decode(self):
|
||||
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x55\x55\x55\x55\x55\x55\xd5\x3f")
|
||||
s = FloatDouble.decode(b"\x15\x2b\x00\x55\x55\x55\x55\x55\x55\xd5\x3f")
|
||||
# assert str(s) == "{\n f = 0.3333333333333333\n}"
|
||||
f = s.f
|
||||
assert math.isclose(f, 1.0 / 3.0, rel_tol=1e-06)
|
||||
|
||||
def test_precision_float_17_9_decode(self):
|
||||
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x66\x66\x66\x66\x66\xe6\x31\x40")
|
||||
s = FloatDouble.decode(b"\x15\x2b\x00\x66\x66\x66\x66\x66\xe6\x31\x40")
|
||||
assert str(s) == "{\n f = 17.9\n}"
|
||||
assert math.isclose(s.f, 17.9, rel_tol=1e-06)
|
||||
|
||||
def test_precision_float_infinity_decode(self):
|
||||
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\x7f")
|
||||
s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\x7f")
|
||||
assert str(s) == "{\n f = inf\n}"
|
||||
assert math.isinf(s.f)
|
||||
|
||||
def test_precision_float_negative_infinity_decode(self):
|
||||
s, _ = FloatDouble.decode(0x15, b"\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\xff")
|
||||
s = FloatDouble.decode(b"\x15\x2b\x00\x00\x00\x00\x00\x00\x00\xf0\xff")
|
||||
assert str(s) == "{\n f = -inf\n}"
|
||||
assert math.isinf(s.f)
|
||||
|
||||
|
|
@ -506,7 +504,7 @@ class TestFloatDouble:
|
|||
s.f = v
|
||||
buffer = s.encode().tobytes()
|
||||
|
||||
s2, _ = FloatDouble.decode(0x15, buffer[1:])
|
||||
s2 = FloatDouble.decode(buffer)
|
||||
|
||||
assert (
|
||||
(math.isnan(s.f) and math.isnan(s2.f))
|
||||
|
|
@ -527,7 +525,7 @@ class OuterStruct(tlv.Structure):
|
|||
|
||||
class TestStruct:
|
||||
def test_inner_struct_decode(self):
|
||||
s, _ = OuterStruct.decode(0x15, b"\x35\x00\x20\x00\x2a\x20\x01\xef\x18\x18")
|
||||
s = OuterStruct.decode(b"\x15\x35\x00\x20\x00\x2a\x20\x01\xef\x18\x18")
|
||||
assert_type(s, OuterStruct)
|
||||
assert_type(s.s, InnerStruct)
|
||||
assert_type(s.s.a, Optional[int])
|
||||
|
|
@ -536,7 +534,7 @@ class TestStruct:
|
|||
assert s.s.b == -17
|
||||
|
||||
def test_inner_struct_decode_empty(self):
|
||||
s, _ = OuterStruct.decode(0x15, b"\x35\x00\x18\x18")
|
||||
s = OuterStruct.decode(b"\x15\x35\x00\x18\x18")
|
||||
assert str(s) == "{\n s = {\n \n }\n}"
|
||||
assert s.s.a is None
|
||||
assert s.s.b is None
|
||||
|
|
@ -562,9 +560,8 @@ class FullyQualified(tlv.Structure):
|
|||
|
||||
class TestFullyQualifiedTags:
|
||||
def test_decode(self):
|
||||
s, _ = FullyQualified.decode(
|
||||
0x15,
|
||||
b"\xc2\xda\x0a\x00\x0f\x23\x01\x2a\x00\x00\x00\xe2\xda\x0a\x00\x0f\x45\x23\x01\x00\xef\xff\xff\xff\x18",
|
||||
s = FullyQualified.decode(
|
||||
b"\x15\xc2\xda\x0a\x00\x0f\x23\x01\x2a\x00\x00\x00\xe2\xda\x0a\x00\x0f\x45\x23\x01\x00\xef\xff\xff\xff\x18",
|
||||
)
|
||||
assert_type(s, FullyQualified)
|
||||
assert_type(s.a, Optional[int])
|
||||
|
|
|
|||
Loading…
Reference in a new issue