Tests pass (more to come)

This commit is contained in:
Scott Shawcroft 2024-07-15 13:11:09 -07:00
parent c65dfaa44e
commit c2eefe2d0b
No known key found for this signature in database
GPG key ID: 0DFD512649C052DA
5 changed files with 33 additions and 28 deletions

View file

@ -11,6 +11,6 @@ repos:
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
args: [ "--fix", "--output-format=github" ]
# Run the formatter.
- id: ruff-format

View file

@ -9,6 +9,7 @@ INT_SIZE = "BHIQ"
class ElementType(enum.IntEnum):
NULL = 0b10100
STRUCTURE = 0b10101
ARRAY = 0b10110
LIST = 0b10111
@ -20,6 +21,7 @@ class TLVStructure:
self.buffer: memoryview = buffer
# These three dicts are keyed by tag.
self.tag_value_offset = {}
self.null_tags = set()
self.tag_value_length = {}
self.cached_values = {}
self._offset = 0 # Stopped at the next control octet
@ -46,7 +48,7 @@ class TLVStructure:
tag_control = control_octet >> 5
element_type = control_octet & 0x1F
print(
f"Control 0x{control_octet:x} tag_control {tag_control} element_type {element_type}"
f"Control 0x{control_octet:x} tag_control {tag_control} element_type {element_type:x}"
)
this_tag = None
@ -82,6 +84,7 @@ class TLVStructure:
length_offset = self._offset + 1 + TAG_LENGTH[tag_control]
element_category = element_type >> 2
print(f"element_category {element_category}")
if element_category == 0 or element_category == 1: # ints
value_offset = length_offset
value_length = 1 << (element_type & 0x3)
@ -97,8 +100,11 @@ class TLVStructure:
elif (
element_category == 3 or element_category == 4
): # UTF-8 String or Octet String
print(f"element_type {element_type:x}", bin(element_type))
power_of_two = element_type & 0x3
print(f"power_of_two {power_of_two}")
length_length = 1 << power_of_two
print(f"length_length {length_length}")
value_offset = length_offset + length_length
value_length = struct.unpack_from(
INT_SIZE[power_of_two], self.buffer, length_offset
@ -106,6 +112,7 @@ class TLVStructure:
elif element_type == 0b10100: # Null
value_offset = self._offset
value_length = 1
self.null_tags.add(this_tag)
else: # Container
value_offset = length_offset
value_length = 1
@ -161,7 +168,7 @@ class Member:
return obj.cached_values[self.tag]
if self.tag not in obj.tag_value_offset:
obj.scan_until(self.tag)
if self.tag not in obj.tag_value_offset:
if self.tag not in obj.tag_value_offset or self.tag in obj.null_tags:
return None
value = self.decode(
@ -175,6 +182,12 @@ class Member:
def __set__(self, obj: TLVStructure, value: Any) -> None:
obj.cached_values[self.tag] = value
def print(self, obj):
value = self.__get__(obj)
if value is None:
return "null"
return self._print(value)
class IntegerMember(Member):
def __init__(self, tag, _format, optional=False):
@ -191,8 +204,7 @@ class IntegerMember(Member):
encoded_format = self.format
return struct.unpack_from(encoded_format, buffer, offset=offset)[0]
def print(self, obj):
value = self.__get__(obj)
def _print(self, value):
unsigned = "U" if self.format.isupper() else ""
return f"{value}{unsigned}"
@ -205,8 +217,7 @@ class FloatMember(Member):
encoded_format = "<d"
return struct.unpack_from(encoded_format, buffer, offset=offset)[0]
def print(self, obj):
value = self.__get__(obj)
def _print(self, value):
return f"{value}"
@ -215,8 +226,8 @@ class BoolMember(Member):
octet = buffer[offset]
return octet & 1 == 1
def print(self, obj):
if self.__get__(obj):
def _print(self, value):
if value:
return "true"
return "false"
@ -229,8 +240,7 @@ class OctetStringMember(Member):
def decode(self, buffer, length, offset=0):
return buffer[offset : offset + length]
def print(self, obj):
value = self.__get__(obj)
def _print(self, value):
return " ".join((f"{byte:02x}" for byte in value))
@ -242,8 +252,7 @@ class UTF8StringMember(Member):
def decode(self, buffer, length, offset=0):
return buffer[offset : offset + length].decode("utf-8")
def print(self, obj):
value = self.__get__(obj)
def _print(self, value):
return f'"{value}"'
@ -255,8 +264,5 @@ class StructMember(Member):
def decode(self, buffer, length, offset=0) -> TLVStructure:
return self.substruct_class(buffer[offset : offset + length])
def print(self, obj):
value = self.__get__(obj)
if value is None:
return "null"
def _print(self, value):
return str(value)

View file

@ -11,3 +11,8 @@ dynamic = ["version", "description"]
[project.urls]
Home = "https://github.com/adafruit/circuitmatter"
[tool.pytest.ini_options]
pythonpath = [
"."
]

0
tests/__init__.py Normal file
View file

View file

@ -134,23 +134,17 @@ class TestUTF8String:
# assert bytes(s) == b"\x0c\x06Hello!"
# Octet String, 1-octet length, octets 00 01 02 03 04 10 05 00 01 02 03 04
# Octet String, 1-octet length, octets 00 01 02 03 04
# encoded: 10 05 00 01 02 03 04
class OctetStringOneOctet(tlv.TLVStructure):
s = tlv.OctetStringMember(None, 16)
class TestOctetString:
def test_octet_string_decode(self):
s = OctetStringOneOctet(
b"\x0d\x0c\x00\x01\x02\x03\x04\x10\x05\x00\x01\x02\x03\x04"
)
assert str(s) == "{\n s = 00 01 02 03 04 10 05 00 01 02 03 04\n}"
assert s.s == b"\x00\x01\x02\x03\x04\x10\x05\x00\x01\x02\x03\x04"
# def test_octet_string_encode(self):
# s = OctetString()
# s.s = b"\x00\x01\x02\x03\x04\x10\x05\x00\x01\x02\x03\x04"
# assert bytes(s) == b"\x0d\x0c\x00\x01\x02\x03\x04\x10\x05\x00\x01\x02\x03\x04"
s = OctetStringOneOctet(b"\x10\x05\x00\x01\x02\x03\x04")
assert str(s) == "{\n s = 00 01 02 03 04\n}"
assert s.s == b"\x00\x01\x02\x03\x04"
# Null