Connect back up to real networking, record packets and enable replay

This commit is contained in:
Scott Shawcroft 2024-07-18 14:39:21 -07:00
parent e59044735f
commit 6078dbe887
No known key found for this signature in database
GPG key ID: 0DFD512649C052DA
5 changed files with 378 additions and 152 deletions

View file

@ -1,8 +1,10 @@
"""Pure Python implementation of the Matter IOT protocol."""
import binascii
import enum
import pathlib
import json
import os
import struct
import time
@ -10,21 +12,6 @@ from . import tlv
__version__ = "0.0.0"
# descriminator = 3840
# avahi = subprocess.Popen(["avahi-publish-service", "-v", f"--subtype=_L{descriminator}._sub._matterc._udp", "--subtype=_CM._sub._matterc._udp", "FA93546B21F5FB54", "_matterc._udp", "5540", "PI=", "PH=33", "CM=1", f"D={descriminator}", "CRI=3000", "CRA=4000", "T=1", "VP=65521+32769"])
# # Define the UDP IP address and port
# UDP_IP = "::" # Listen on all available network interfaces
# UDP_PORT = 5540
# # Create the UDP socket
# sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# # Bind the socket to the IP and port
# sock.bind((UDP_IP, UDP_PORT))
# print(f"Listening on UDP port {UDP_PORT}")
# Section 4.11.2
MSG_COUNTER_WINDOW_SIZE = 32
MSG_COUNTER_SYNC_REQ_JITTER_MS = 500
@ -243,11 +230,54 @@ class MessageReceptionState:
return False
class Exchange:
def __init__(self, initiator: bool, exchange_id: int, protocols):
self.initiator = initiator
self.exchange_id = exchange_id
self.protocols = protocols
self.pending_acknowledgement = None
self.next_retransmission_time = None
self.pending_retransmission = None
def send(self, message):
pass
def receive(self, message) -> bool:
"""Process the message and return if the packet should be dropped."""
if message.protocol_id not in self.protocols:
# Drop messages that don't match the protocols we're waiting for.
return True
# Section 4.10.5.2.1
if message.exchange_flags & ExchangeFlags.A:
if message.acknowledged_message_counter is None:
# Drop messages that are missing an acknowledgement counter.
return True
if self.pending_acknowledgement is None:
# Drop messages that are not waiting for an acknowledgement.
return True
if message.acknowledged_message_counter != self.pending_acknowledgement:
# Drop messages that have the wrong acknowledgement counter.
return True
self.pending_acknowledgement = None
self.pending_retransmission = None
self.next_retransmission_time = None
# Section 4.10.5.2.2
# if message.exchange_flags & ExchangeFlags.R:
# if message
if message.duplicate:
return True
return False
class UnsecuredSessionContext:
def __init__(self, initiator, ephemeral_initiator_node_id):
self.initiator = initiator
self.ephemeral_initiator_node_id = ephemeral_initiator_node_id
self.message_reception_state = None
self.exchanges = {}
class SecureSessionContext:
@ -283,6 +313,7 @@ class SecureSessionContext:
self.session_idle_interval = None
self.session_active_interval = None
self.session_active_threshold = None
self.exchanges = {}
@property
def peer_active(self):
@ -295,6 +326,7 @@ class Message:
self.flags, self.session_id, self.security_flags, self.message_counter = (
struct.unpack_from("<BHBI", buffer)
)
self.security_flags = SecurityFlags(self.security_flags)
offset = 8
self.source_node_id = None
if self.flags & (1 << 2):
@ -303,17 +335,18 @@ class Message:
if (self.flags >> 4) != 0:
raise RuntimeError("Incorrect version")
self.secure_session = self.security_flags & 0x3 != 0 or self.session_id != 0
self.secure_session = not (
not (self.security_flags & SecurityFlags.GROUP) and self.session_id == 0
)
if not self.secure_session:
self.payload = memoryview(buffer)[offset:]
context = UnsecuredSessionContext(False, self.source_node_id)
self.unsecured_session_context[self.source_node_id] = context
else:
self.payload = None
def _parse_protocol_header(self):
self.duplicate = None
def parse_protocol_header(self):
self.exchange_flags, self.protocol_opcode, self.exchange_id = (
struct.unpack_from("<BBH", self.payload)
)
@ -329,7 +362,7 @@ class Message:
protocol_id = struct.unpack_from("<H", self.payload, decrypted_offset)[0]
decrypted_offset += 2
self.protocol_id = ProtocolId(protocol_id)
self.protocol_opcode = PROTOCOL_OPCODES[protocol_id](self.protocol_opcode)
self.protocol_opcode = PROTOCOL_OPCODES[self.protocol_id](self.protocol_opcode)
self.acknowledged_message_counter = None
if self.exchange_flags & ExchangeFlags.A:
@ -338,6 +371,8 @@ class Message:
)[0]
decrypted_offset += 4
self.application_payload = self.payload[decrypted_offset:]
def reply(self, payload, protocol_id=None, protocol_opcode=None) -> memoryview:
reply = bytearray(1280)
offset = 0
@ -375,12 +410,11 @@ class SessionManager:
def _increment(self, value):
return (value + 1) % 0xFFFFFFFF
def counter_ok(self, message):
"""Implements 4.6.7"""
def get_session(self, message):
if message.secure_session:
if message.security_flags & SecurityFlags.GROUP:
if message.source_node_id is None:
return False
return None
# TODO: Get MRS for source node id and message type
else:
session_context = self.secure_session_contexts[message.session_id]
@ -393,6 +427,11 @@ class SessionManager:
)
)
session_context = self.unsecured_session_context[message.source_node_id]
return session_context
def mark_duplicate(self, message):
"""Implements 4.6.7"""
session_context = self.get_session(message)
if session_context.message_reception_state is None:
session_context.message_reception_state = MessageReceptionState(
@ -400,9 +439,10 @@ class SessionManager:
rollover=False,
encrypted=message.secure_session,
)
return True
message.duplicate = False
return
return session_context.message_reception_state.process_counter(
message.duplicate = session_context.message_reception_state.process_counter(
message.message_counter
)
@ -443,3 +483,207 @@ class SessionManager:
self.secure_session_contexts[session_id] = SecureSessionContext(session_id)
return self.secure_session_contexts[session_id]
def process_exchange(self, message):
session = self.get_session(message)
if session is None:
return None
# Step 1 of 4.12.5.2
if (
message.exchange_flags & (ExchangeFlags.R | ExchangeFlags.A)
and not message.security_flags & SecurityFlags.C
and message.security_flags & SecurityFlags.GROUP
):
# Drop illegal combination of flags.
return None
if message.exchange_id not in session.exchanges:
# Section 4.10.5.2
initiator = message.exchange_flags & ExchangeFlags.I
if initiator and not message.duplicate:
session.exchanges[message.exchange_id] = Exchange(
not initiator, message.exchange_id, [message.protocol_id]
)
# Drop because the message isn't from an initiator.
elif message.exchange_flags & ExchangeFlags.R:
# Send a bare acknowledgement back.
raise NotImplementedError("Send a bare acknowledgement back")
return None
else:
# Just drop it.
return None
exchange = session.exchanges[message.exchange_id]
if exchange.receive(message):
# If we want to drop the message, then return None.
return None
return exchange
class CircuitMatter:
def __init__(self, socketpool, mdns_server, state_filename, record_to=None):
self.socketpool = socketpool
self.mdns_server = mdns_server
self.avahi = None
self.record_to = record_to
if self.record_to:
self.recorded_packets = []
else:
self.recorded_packets = None
self.manager = SessionManager()
with open(state_filename, "r") as state_file:
self.nonvolatile = json.load(state_file)
for key in ["descriminator", "salt", "iteration-count"]:
if key not in self.nonvolatile:
raise RuntimeError(f"Missing key {key} in state file")
commission = "fabrics" not in self.nonvolatile
self.packet_buffer = memoryview(bytearray(1280))
# Define the UDP IP address and port
UDP_IP = "::" # Listen on all available network interfaces
self.UDP_PORT = 5540
# Create the UDP socket
self.socket = self.socketpool.socket(
self.socketpool.AF_INET6, self.socketpool.SOCK_DGRAM
)
# Bind the socket to the IP and port
self.socket.bind((UDP_IP, self.UDP_PORT))
self.socket.setblocking(False)
print(f"Listening on UDP port {self.UDP_PORT}")
if commission:
self.start_commissioning()
def start_commissioning(self):
descriminator = self.nonvolatile["descriminator"]
txt_records = {
"PI": "",
"PH": "33",
"CM": "1",
"D": str(descriminator),
"CRI": "3000",
"CRA": "4000",
"T": "1",
"VP": "65521+32769",
}
self.mdns_server.advertise_service(
"_matterc",
"_udp",
self.UDP_PORT,
txt_records=txt_records,
instance_name="FA93546B21F5FB54",
subtypes=[
f"_L{descriminator}._sub._matterc._udp",
"_CM._sub._matterc._udp",
],
)
def process_packets(self):
while True:
try:
nbytes, addr = self.socket.recvfrom_into(
self.packet_buffer, len(self.packet_buffer)
)
except BlockingIOError:
break
if nbytes == 0:
break
if self.recorded_packets is not None:
self.recorded_packets.append(
(
"receive",
time.monotonic_ns(),
addr,
binascii.b2a_base64(
self.packet_buffer[:nbytes], newline=False
).decode("utf-8"),
)
)
self.process_packet(addr, self.packet_buffer[:nbytes])
def process_packet(self, address, data):
# Print the received data and the address of the sender
# This is section 4.7.2
message = Message(data)
if message.secure_session:
# Decrypt the payload
pass
message.parse_protocol_header()
self.manager.mark_duplicate(message)
exchange = self.manager.process_exchange(message)
if exchange is None:
print(f"Dropping message {message.message_counter}")
return
print(f"Received packet from {address}:")
print(f"{data.hex(' ')}")
print(f"Message counter {message.message_counter}")
protocol_id = message.protocol_id
protocol_opcode = message.protocol_opcode
if protocol_id == ProtocolId.SECURE_CHANNEL:
if protocol_opcode == SecureProtocolOpcode.MSG_COUNTER_SYNC_REQ:
print("Received Message Counter Synchronization Request")
elif protocol_opcode == SecureProtocolOpcode.MSG_COUNTER_SYNC_RSP:
print("Received Message Counter Synchronization Response")
elif protocol_opcode == SecureProtocolOpcode.PBKDF_PARAM_REQUEST:
print("Received PBKDF Parameter Request")
# This is Section 4.14.1.2
request = PBKDFParamRequest(message.application_payload[1:-1])
if request.passcodeId == 0:
pass
# Send back failure
# response = StatusReport()
# response.GeneralCode
print(request)
response = PBKDFParamResponse()
response.initiatorRandom = request.initiatorRandom
# Generate a random number
response.responderRandom = os.urandom(32)
session_context = self.manager.new_context()
response.responderSessionId = session_context.local_session_id
session_context.peer_session_id = request.initiatorSessionId
if not request.hasPBKDFParameters:
params = Crypto_PBKDFParameterSet()
params.iterations = self.nonvolatile["iteration-count"]
params.salt = binascii.a2b_base64(self.nonvolatile["salt"])
response.pbkdf_parameters = params
print(response)
elif protocol_opcode == SecureProtocolOpcode.PBKDF_PARAM_RESPONSE:
print("Received PBKDF Parameter Response")
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE1:
print("Received PASE PAKE1")
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE2:
print("Received PASE PAKE2")
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE3:
print("Received PASE PAKE3")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA1:
print("Received CASE Sigma1")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA2:
print("Received CASE Sigma2")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA3:
print("Received CASE Sigma3")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA2_RESUME:
print("Received CASE Sigma2 Resume")
elif protocol_opcode == SecureProtocolOpcode.STATUS_REPORT:
print("Received Status Report")
elif protocol_opcode == SecureProtocolOpcode.ICD_CHECK_IN:
print("Received ICD Check-in")
def __del__(self):
if self.avahi:
self.avahi.kill()
if self.recorded_packets and self.record_to:
with open(self.record_to, "w") as record_file:
json.dump(self.recorded_packets, record_file)

View file

@ -1,127 +1,120 @@
"""Pure Python implementation of the Matter IOT protocol."""
import os
import binascii
import json
import socket
import subprocess
import circuitmatter as cm
# descriminator = 3840
# avahi = subprocess.Popen(["avahi-publish-service", "-v", f"--subtype=_L{descriminator}._sub._matterc._udp", "--subtype=_CM._sub._matterc._udp", "FA93546B21F5FB54", "_matterc._udp", "5540", "PI=", "PH=33", "CM=1", f"D={descriminator}", "CRI=3000", "CRA=4000", "T=1", "VP=65521+32769"])
class ReplaySocket:
def __init__(self, replay_data):
self.replay_data = replay_data
# # Define the UDP IP address and port
# UDP_IP = "::" # Listen on all available network interfaces
# UDP_PORT = 5540
def bind(self, address):
print("bind to", address)
# # Create the UDP socket
# sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
def setblocking(self, value):
print("setblocking", value)
# # Bind the socket to the IP and port
# sock.bind((UDP_IP, UDP_PORT))
# print(f"Listening on UDP port {UDP_PORT}")
unsecured_session_context = {}
secure_session_contexts = ["reserved"]
# while True:
# # Receive data from the socket (1280 is the minimum ipv6 MTU and the max UDP matter packet size.)
# data, addr = sock.recvfrom(1280)
data = b"\x04\x00\x00\x00\x0b\x06\xb7\t)\xad\x07\xd9\xae\xa1\xee\xa0\x05 j\x15\x00\x00\x150\x01 \x97\x064#\x1c\xd1E7H\x0b|\xc2G\xa7\xc38\xe9\xce3\x11\xb2@M\x86\xd7\xb5{)\xaa`\xddb%\x02\xc2\x86$\x03\x00(\x045\x05%\x01\xf4\x01%\x02,\x01%\x03\xa0\x0f$\x04\x11$\x05\x0b&\x06\x00\x00\x03\x01$\x07\x01\x18\x18"
addr = None
def recvfrom_into(self, buffer, nbytes=None):
if nbytes is None:
nbytes = len(buffer)
direction = "send"
while direction == "send":
direction, _, address, data_b64 = self.replay_data.pop(0)
decoded = binascii.a2b_base64(data_b64)
if len(decoded) > nbytes:
raise RuntimeError("Next replay packet is larger than buffer to read into")
buffer[: len(decoded)] = decoded
return len(decoded), address
# pathlib.Path("data.bin").write_bytes(data)
class ReplaySocketPool:
AF_INET6 = 0
SOCK_DGRAM = 1
bookmarks = []
def __init__(self, replay_file):
with open(replay_file, "r") as f:
self.replay_data = json.load(f)
self._socket_created = False
def socket(self, *args, **kwargs):
if self._socket_created:
raise RuntimeError("Only one socket can be created")
self._socket_created = True
return ReplaySocket(self.replay_data)
def add_bookmark(start, length, name, color=0x0000FF):
bookmarks.append(
{
"color": 0x4F000000 | color,
"comment": "\n",
"id": len(bookmarks),
"locked": True,
"name": name,
"region": {"address": start, "size": length},
}
class DummyMDNS:
def advertise_service(
self,
service_type,
protocol,
port,
txt_records=[],
subtypes=[],
instance_name="",
):
print(f"Advertise service {service_type} {protocol} {port} {txt_records}")
class MDNSServer(DummyMDNS):
def __init__(self):
self.active_services = {}
def advertise_service(
self,
service_type,
protocol,
port,
txt_records={},
subtypes=[],
instance_name="",
):
subtypes = [f"--subtype={subtype}" for subtype in subtypes]
txt_records = [f"{key}={value}" for key, value in txt_records.items()]
if service_type in self.active_services:
self.active_services[service_type].kill()
del self.active_services[service_type]
self.active_services[service_type] = subprocess.Popen(
[
"avahi-publish-service",
*subtypes,
instance_name,
f"{service_type}.{protocol}",
str(port),
*txt_records,
]
)
def __del__(self):
for active_service in self.active_services.values():
active_service.kill()
def run(replay_file=None):
if replay_file:
socketpool = ReplaySocketPool(replay_file)
mdns_server = DummyMDNS()
record_file = None
else:
socketpool = socket
mdns_server = MDNSServer()
record_file = "test_data/recorded_packets.json"
matter = cm.CircuitMatter(
socketpool, mdns_server, "test_data/device_state.json", record_file
)
# Write every time in case we crash
# pathlib.Path("parsed.hexbm").write_text(json.dumps({"bookmarks": bookmarks}))
def run():
manager = cm.SessionManager()
# Print the received data and the address of the sender
# This is section 4.7.2
print(f"Received packet from {addr}: {data}")
message = cm.Message(data)
if message.secure_session:
# Decrypt the payload
pass
if not manager.counter_ok(message):
print("Dropping message due to counter error")
return
# if not manager.rmp_ok(message):
# print("Dropping message due to RMP")
# continue
protocol_id = message.protocol_id
protocol_opcode = message.protocol_opcode
if protocol_id == cm.ProtocolId.SECURE_CHANNEL:
if protocol_opcode == cm.SecureProtocolOpcode.MSG_COUNTER_SYNC_REQ:
print("Received Message Counter Synchronization Request")
elif protocol_opcode == cm.SecureProtocolOpcode.MSG_COUNTER_SYNC_RSP:
print("Received Message Counter Synchronization Response")
elif protocol_opcode == cm.SecureProtocolOpcode.PBKDF_PARAM_REQUEST:
print("Received PBKDF Parameter Request")
# This is Section 4.14.1.2
request = cm.PBKDFParamRequest(message.payload)
if request.passcodeID == 0:
pass
# Send back failure
# response = StatusReport()
# response.GeneralCode
print(request)
response = cm.PBKDFParamResponse()
response.initiatorRandom = request.initiatorRandom
# Generate a random number
response.responderRandom = os.urandom(32)
session_context = manager.new_context(response.responderSessionId)
session_context.peer_session_id = request.initiatorSessionId
if not request.hasPBKDFParameters:
params = cm.Crypto_PBKDFParameterSet()
params.iterations = 1000
params.salt = b"\x00" * 32
response.pbkdf_parameters = params
print(response)
elif protocol_opcode == cm.SecureProtocolOpcode.PBKDF_PARAM_RESPONSE:
print("Received PBKDF Parameter Response")
elif protocol_opcode == cm.SecureProtocolOpcode.PASE_PAKE1:
print("Received PASE PAKE1")
elif protocol_opcode == cm.SecureProtocolOpcode.PASE_PAKE2:
print("Received PASE PAKE2")
elif protocol_opcode == cm.SecureProtocolOpcode.PASE_PAKE3:
print("Received PASE PAKE3")
elif protocol_opcode == cm.SecureProtocolOpcode.CASE_SIGMA1:
print("Received CASE Sigma1")
elif protocol_opcode == cm.SecureProtocolOpcode.CASE_SIGMA2:
print("Received CASE Sigma2")
elif protocol_opcode == cm.SecureProtocolOpcode.CASE_SIGMA3:
print("Received CASE Sigma3")
elif protocol_opcode == cm.SecureProtocolOpcode.CASE_SIGMA2_RESUME:
print("Received CASE Sigma2 Resume")
elif protocol_opcode == cm.SecureProtocolOpcode.STATUS_REPORT:
print("Received Status Report")
elif protocol_opcode == cm.SecureProtocolOpcode.ICD_CHECK_IN:
print("Received ICD Check-in")
# avahi.kill()
while True:
matter.process_packets()
if __name__ == "__main__":
run()
import sys
print(sys.argv)
replay_file = None
if len(sys.argv) > 1:
replay_file = sys.argv[1]
run(replay_file=replay_file)

View file

@ -74,21 +74,15 @@ class TLVStructure:
def scan_until(self, tag):
if self.buffer is None:
return
print(bytes(self.buffer[self._offset :]))
print(f"Looking for {tag}")
while self._offset < len(self.buffer):
control_octet = self.buffer[self._offset]
tag_control = control_octet >> 5
element_type = control_octet & 0x1F
print(
f"Control 0x{control_octet:x} tag_control {tag_control} element_type {element_type:x}"
)
this_tag = None
if tag_control == 0: # Anonymous
this_tag = None
elif tag_control == 1: # Context specific
print("context specific tag")
this_tag = self.buffer[self._offset + 1]
else:
vendor_id = None
@ -113,11 +107,9 @@ class TLVStructure:
this_tag = (vendor_id, profile_number, tag_number)
else:
this_tag = tag_number
print(f"found tag {this_tag}")
length_offset = self._offset + 1 + TAG_LENGTH[tag_control]
element_category = element_type >> 2
print(f"element_category {element_category}")
if element_category == 0 or element_category == 1: # ints
value_offset = length_offset
value_length = 1 << (element_type & 0x3)
@ -133,11 +125,8 @@ class TLVStructure:
elif (
element_category == 3 or element_category == 4
): # UTF-8 String or Octet String
print(f"element_type {element_type:x}", bin(element_type))
power_of_two = element_type & 0x3
print(f"power_of_two {power_of_two}")
length_length = 1 << power_of_two
print(f"length_length {length_length}")
value_offset = length_offset + length_length
value_length = struct.unpack_from(
INT_SIZE[power_of_two], self.buffer, length_offset
@ -238,9 +227,7 @@ class Member:
buffer[offset] = self.tag
offset += 1
if value is not None:
print("enconding value into", offset)
new_offset = self.encode_value_into(value, buffer, offset)
print("new offset", new_offset)
return new_offset
return offset
@ -262,13 +249,10 @@ class NumberMember(Member):
ElementType.SIGNED_INT if self.signed else ElementType.UNSIGNED_INT
)
self._element_type |= int(math.log(self.max_value_length, 2))
print(f"{self._element_type:x}")
else:
print("float")
self._element_type = ElementType.FLOAT
if self.max_value_length == 8:
self._element_type |= 1
print(f"{self._element_type:x}")
super().__init__(tag, optional)
def __set__(self, obj, value):
@ -303,7 +287,6 @@ class NumberMember(Member):
def encode_element_type(self, value):
# We don't adjust our encoding based on value size. We always use the bytes needed for the
# format.
print("encode", self._element_type)
return self._element_type
def encode_value_into(self, value, buffer, offset) -> int:

View file

@ -0,0 +1,5 @@
{
"descriminator": 2207,
"iteration-count": 10000,
"salt": "5uCP0ITHYzI9qBEe6hfU4HfY3y7VopSk0qNvhvznhiQ="
}

File diff suppressed because one or more lines are too long