Add message reception state for counter verification

This commit is contained in:
Scott Shawcroft 2024-07-17 11:23:03 -07:00
parent 6f28636869
commit 9d2687b784
No known key found for this signature in database
GPG key ID: 0DFD512649C052DA
3 changed files with 422 additions and 66 deletions

View file

@ -1,6 +1,10 @@
"""Pure Python implementation of the Matter IOT protocol."""
import enum
import pathlib
import json
import struct
import time
from . import tlv
@ -21,7 +25,10 @@ __version__ = "0.0.0"
# print(f"Listening on UDP port {UDP_PORT}")
unsecured_session_context = {}
# Section 4.11.2
MSG_COUNTER_WINDOW_SIZE = 32
MSG_COUNTER_SYNC_REQ_JITTER_MS = 500
MSG_COUNTER_SYNC_TIMEOUT_MS = 400
class ProtocolId(enum.Enum):
@ -36,6 +43,8 @@ class SecurityFlags(enum.Flag):
P = 1 << 7
C = 1 << 6
MX = 1 << 5
# This is actually 2 bits but the top bit is reserved and always zero.
GROUP = 1 << 0
class ExchangeFlags(enum.Flag):
@ -173,3 +182,264 @@ class PBKDFParamResponse(tlv.TLVStructure):
responderSessionId = tlv.NumberMember(3, "<H")
pbkdf_parameters = tlv.StructMember(4, Crypto_PBKDFParameterSet)
responderSessionParams = tlv.StructMember(5, SessionParameterStruct, optional=True)
class MessageReceptionState:
def __init__(self, starting_value, rollover=True, encrypted=False):
"""Implements 4.6.5.1"""
self.message_counter = starting_value
self.window_bitmap = (1 << MSG_COUNTER_WINDOW_SIZE) - 1
self.mask = self.window_bitmap
self.encrypted = encrypted
self.rollover = rollover
def process_counter(self, counter) -> bool:
"""Returns True if the counter number is a duplicate"""
# Process the current window first. Behavior outside the window varies.
if counter == self.message_counter:
return True
if self.message_counter <= MSG_COUNTER_WINDOW_SIZE < counter:
# Window wraps
bit_position = 0xFFFFFFFF - counter + self.message_counter
else:
bit_position = self.message_counter - counter - 1
if 0 <= bit_position < MSG_COUNTER_WINDOW_SIZE:
if self.window_bitmap & (1 << bit_position) != 0:
# This is a duplicate message
return True
self.window_bitmap |= 1 << bit_position
return False
new_start = (self.message_counter + 1) & self.mask # Inclusive
new_end = (
self.message_counter - MSG_COUNTER_WINDOW_SIZE
) & self.mask # Exclusive
if not self.rollover:
new_end = (1 << MSG_COUNTER_WINDOW_SIZE) - 1
elif self.encrypted:
new_end = (
self.message_counter + (1 << (MSG_COUNTER_WINDOW_SIZE - 1))
) & self.mask
if new_start <= new_end:
if not (new_start <= counter < new_end):
return True
else:
if not (counter < new_end or new_start <= counter):
return True
# This is a new message
shift = counter - self.message_counter
if counter < self.message_counter:
shift += 0x100000000
if shift > MSG_COUNTER_WINDOW_SIZE:
self.window_bitmap = 0
else:
new_bitmap = (self.window_bitmap << shift) & self.mask
self.window_bitmap = new_bitmap
if 1 < shift < MSG_COUNTER_WINDOW_SIZE:
self.window_bitmap |= 1 << (shift - 1)
self.message_counter = counter
return False
class UnsecuredSessionContext:
def __init__(self, initiator, ephemeral_initiator_node_id):
self.initiator = initiator
self.ephemeral_initiator_node_id = ephemeral_initiator_node_id
self.message_reception_state = None
class SecureSessionContext:
def __init__(self, local_session_id):
self.session_type = None
"""Records whether the session was established using CASE or PASE."""
self.session_role = None
"""Records whether the node is the session initiator or responder."""
self.local_session_id = local_session_id
"""Individually selected by each participant in secure unicast communication during session establishment and used as a unique identifier to recover encryption keys, authenticate incoming messages and associate them to existing sessions."""
self.peer_session_id = None
"""Assigned by the peer during session establishment"""
self.i2r_key = None
"""Encrypts data in messages sent from the initiator of session establishment to the responder."""
self.r2i_key = None
"""Encrypts data in messages sent from the session establishment responder to the initiator."""
self.shared_secret = None
"""Computed during the CASE protocol execution and re-used when CASE session resumption is implemented."""
self.local_message_counter = None
"""Secure Session Message Counter for outbound messages."""
self.message_reception_state = None
"""Provides tracking for the Secure Session Message Counter of the remote"""
self.local_fabric_index = None
"""Records the local Index for the sessions Fabric, which MAY be used to look up Fabric metadata related to the Fabric for which this session context applies."""
self.peer_node_id = None
"""Records the authenticated node ID of the remote peer, when available."""
self.resumption_id = None
"""The ID used when resuming a session between the local and remote peer."""
self.session_timestamp = None
"""A timestamp indicating the time at which the last message was sent or received. This timestamp SHALL be initialized with the time the session was created."""
self.active_timestamp = None
"""A timestamp indicating the time at which the last message was received. This timestamp SHALL be initialized with the time the session was created."""
self.session_idle_interval = None
self.session_active_interval = None
self.session_active_threshold = None
@property
def peer_active(self):
return (time.monotonic() - self.active_timestamp) < self.session_active_interval
class Message:
def __init__(self, buffer):
self.buffer = buffer
self.flags, self.session_id, self.security_flags, self.message_counter = (
struct.unpack_from("<BHBI", buffer)
)
offset = 8
self.source_node_id = None
if self.flags & (1 << 2):
self.source_node_id = struct.unpack_from("<Q", buffer, 8)[0]
offset += 8
if (self.flags >> 4) != 0:
raise RuntimeError("Incorrect version")
self.secure_session = self.security_flags & 0x3 != 0 or self.session_id != 0
if not self.secure_session:
self.payload = memoryview(buffer)[offset:]
context = UnsecuredSessionContext(False, self.source_node_id)
self.unsecured_session_context[self.source_node_id] = context
else:
self.payload = None
def _parse_protocol_header(self):
self.exchange_flags, self.protocol_opcode, self.exchange_id = (
struct.unpack_from("<BBH", self.payload)
)
self.exchange_flags = ExchangeFlags(self.exchange_flags)
decrypted_offset = 4
self.protocol_vendor_id = 0
if self.exchange_flags & ExchangeFlags.V:
self.protocol_vendor_id = struct.unpack_from(
"<H", self.payload, decrypted_offset
)[0]
decrypted_offset += 2
protocol_id = struct.unpack_from("<H", self.payload, decrypted_offset)[0]
decrypted_offset += 2
self.protocol_id = ProtocolId(protocol_id)
self.protocol_opcode = PROTOCOL_OPCODES[protocol_id](self.protocol_opcode)
self.acknowledged_message_counter = None
if self.exchange_flags & ExchangeFlags.A:
self.acknowledged_message_counter = struct.unpack_from(
"<I", self.payload, decrypted_offset
)[0]
decrypted_offset += 4
def reply(self, payload, protocol_id=None, protocol_opcode=None) -> memoryview:
reply = bytearray(1280)
offset = 0
# struct.pack_into(
# "<BHBI", reply, offset, flags, session_id, security_flags, message_counter
# )
# offset += 8
return memoryview(reply)[:offset]
class SessionManager:
def __init__(self):
persist_path = pathlib.Path("counters.json")
if persist_path.exists():
self.nonvolatile = json.loads(persist_path.read_text())
else:
self.nonvolatile = {}
self.nonvolatile["unencrypted_message_counter"] = 0
self.nonvolatile["group_encrypted_data_message_counter"] = 0
self.nonvolatile["group_encrypted_control_message_counter"] = 0
self.unencrypted_message_counter = self.nonvolatile[
"unencrypted_message_counter"
]
self.group_encrypted_data_message_counter = self.nonvolatile[
"group_encrypted_data_message_counter"
]
self.group_encrypted_control_message_counter = self.nonvolatile[
"group_encrypted_control_message_counter"
]
self.check_in_counter = 0
self.unsecured_session_context = {}
self.secure_session_contexts = ["reserved"]
def _increment(self, value):
return (value + 1) % 0xFFFFFFFF
def counter_ok(self, message):
"""Implements 4.6.7"""
if message.secure_session:
if message.security_flags & SecurityFlags.GROUP:
if message.source_node_id is None:
return False
# TODO: Get MRS for source node id and message type
else:
session_context = self.secure_session_contexts[message.session_id]
else:
if message.source_node_id not in self.unsecured_session_context:
self.unsecured_session_context[message.source_node_id] = (
UnsecuredSessionContext(
initiator=False,
ephemeral_initiator_node_id=message.source_node_id,
)
)
session_context = self.unsecured_session_context[message.source_node_id]
if session_context.message_reception_state is None:
session_context.message_reception_state = MessageReceptionState(
message.message_counter,
rollover=False,
encrypted=message.secure_session,
)
return True
return session_context.message_reception_state.process_counter(
message.message_counter
)
def next_message_counter(self, message):
"""Implements 4.6.6"""
if not message.secure_session:
value = self.unencrypted_message_counter
self.unencrypted_message_counter = self._increment(
self.unencrypted_message_counter
)
return value
elif message.security_flags & SecurityFlags.GROUP:
if message.security_flags & SecurityFlags.C:
value = self.group_encrypted_control_message_counter
self.group_encrypted_control_message_counter = self._increment(
self.group_encrypted_control_message_counter
)
return value
else:
value = self.group_encrypted_data_message_counter
self.group_encrypted_data_message_counter = self._increment(
self.group_encrypted_data_message_counter
)
return value
session = self.secure_session_contexts[message.session_id]
value = session.local_message_counter
next_value = self._increment(value)
session.local_message_counter = next_value
if next_value == 0:
# TODO expire the encryption key
raise NotImplementedError("Expire the encryption key 4.6.6")
return next_value
def new_context(self):
if None not in self.secure_session_contexts:
self.secure_session_contexts.append(None)
session_id = self.secure_session_contexts.index(None)
self.secure_session_contexts[session_id] = SecureSessionContext(session_id)
return self.secure_session_contexts[session_id]

View file

@ -1,6 +1,6 @@
"""Pure Python implementation of the Matter IOT protocol."""
import struct
import os
import circuitmatter as cm
@ -21,6 +21,7 @@ import circuitmatter as cm
# print(f"Listening on UDP port {UDP_PORT}")
unsecured_session_context = {}
secure_session_contexts = ["reserved"]
# while True:
# # Receive data from the socket (1280 is the minimum ipv6 MTU and the max UDP matter packet size.)
@ -50,66 +51,23 @@ def add_bookmark(start, length, name, color=0x0000FF):
def run():
manager = cm.SessionManager()
# Print the received data and the address of the sender
# This is section 4.7.2
print(f"Received packet from {addr}: {data}")
print(f"Data length: {len(data)} bytes")
flags, session_id, security_flags, message_counter = struct.unpack_from(
"<BHBI", data
)
add_bookmark(0, 8, "Header")
print(
f"Flags: {flags:x} Session ID: {session_id:x} Security Flags: {cm.SecurityFlags(security_flags)} Message Counter: {message_counter}"
)
offset = 8
if flags & (1 << 2):
source_node_id = struct.unpack_from("<Q", data, 8)[0]
add_bookmark(8, 8, "Source Node ID")
print(source_node_id)
offset += 8
print(f"DSIZ {flags & (0x3)}")
if (flags >> 4) != 0:
print("Incorrect version")
# continue
secure_session = security_flags & 0x3 != 0 or session_id != 0
message = cm.Message(data)
if message.secure_session:
# Decrypt the payload
pass
if not manager.counter_ok(message):
print("Dropping message due to counter error")
return
# if not manager.rmp_ok(message):
# print("Dropping message due to RMP")
# continue
if not secure_session:
print("Unsecured session")
print(data[offset : offset + 8])
decrypted_message = memoryview(data)[offset:]
context = {"role": "responder", "node_id": source_node_id}
unsecured_session_context[source_node_id] = context
exchange_flags, protocol_opcode, exchange_id = struct.unpack_from(
"<BBH", decrypted_message
)
add_bookmark(offset, 4, "Protocol header")
exchange_flags = cm.ExchangeFlags(exchange_flags)
print(f"Exchange Flags: {exchange_flags} Exchange ID: {exchange_id}")
decrypted_offset = 4
protocol_vendor_id = 0
if exchange_flags & cm.ExchangeFlags.V:
protocol_vendor_id = struct.unpack_from(
"<H", decrypted_message, decrypted_offset
)[0]
add_bookmark(offset + decrypted_offset, 2, "Protocol Vendor ID")
decrypted_offset += 2
protocol_id = struct.unpack_from("<H", decrypted_message, decrypted_offset)[0]
add_bookmark(offset + decrypted_offset, 2, "Protocol ID")
decrypted_offset += 2
protocol_id = cm.ProtocolId(protocol_id)
protocol_opcode = cm.PROTOCOL_OPCODES[protocol_id](protocol_opcode)
print(
f"Protocol Vendor ID: {protocol_vendor_id} Protocol ID: {protocol_id} Protocol Opcode: {protocol_opcode}"
)
acknowledged_message_counter = None
if exchange_flags & cm.ExchangeFlags.A:
acknowledged_message_counter = struct.unpack_from(
"<I", decrypted_message, decrypted_offset
)[0]
decrypted_offset += 4
print(f"Acknowledged Message Counter: {acknowledged_message_counter}")
protocol_id = message.protocol_id
protocol_opcode = message.protocol_opcode
if protocol_id == cm.ProtocolId.SECURE_CHANNEL:
if protocol_opcode == cm.SecureProtocolOpcode.MSG_COUNTER_SYNC_REQ:
@ -118,16 +76,27 @@ def run():
print("Received Message Counter Synchronization Response")
elif protocol_opcode == cm.SecureProtocolOpcode.PBKDF_PARAM_REQUEST:
print("Received PBKDF Parameter Request")
request = cm.PBKDFParamRequest(decrypted_message[decrypted_offset + 1 :])
# This is Section 4.14.1.2
request = cm.PBKDFParamRequest(message.payload)
if request.passcodeID == 0:
pass
# Send back failure
# response = StatusReport()
# response.GeneralCode
print(request)
response = cm.PBKDFParamResponse()
response.initiatorRandom = request.initiatorRandom
response.responderRandom = b"\x00" * 32
response.responderSessionId = 0
params = cm.Crypto_PBKDFParameterSet()
params.iterations = 1000
params.salt = b"\x00" * 32
response.pbkdf_parameters = params
# Generate a random number
response.responderRandom = os.urandom(32)
session_context = manager.new_context(response.responderSessionId)
session_context.peer_session_id = request.initiatorSessionId
if not request.hasPBKDFParameters:
params = cm.Crypto_PBKDFParameterSet()
params.iterations = 1000
params.salt = b"\x00" * 32
response.pbkdf_parameters = params
print(response)
elif protocol_opcode == cm.SecureProtocolOpcode.PBKDF_PARAM_RESPONSE:

View file

@ -0,0 +1,117 @@
from circuitmatter import MessageReceptionState
def test_basics():
"""These test the common window behavior"""
state = MessageReceptionState(123)
assert state.message_counter == 123
# Older messages are not ok
assert state.process_counter(122)
# The current max is not ok
assert state.process_counter(123)
# A new value is ok
assert not state.process_counter(126)
#
assert state.process_counter(123)
assert not state.process_counter(124)
assert not state.process_counter(125)
assert state.process_counter(124)
def test_window_wrapping():
"""Test wrapping the window data across a rollover"""
state = MessageReceptionState(123, rollover=True)
assert state.message_counter == 123
# Move to the end of the range
assert not state.process_counter(0xFFFFFFFF)
# Older is ok when in the window.
assert not state.process_counter(0xFFFFFFF0)
# A new value is ok. Window is now 0xFFFFFFF0 to 15
assert not state.process_counter(16)
assert state.process_counter(0xFFFFFFF0)
assert state.process_counter(0xFFFFFFFF)
assert not state.process_counter(1)
assert not state.process_counter(0xFFFFFFF8)
def test_unencrypted():
"""These test the common window behavior"""
state = MessageReceptionState(123, rollover=True, encrypted=False)
assert state.message_counter == 123
# Older messages are not ok
assert state.process_counter(123 - 32)
# Older messages outside the window are ok
assert not state.process_counter(123 - 32 - 1)
def test_encrypted_no_rollover():
"""These test the common window behavior"""
state = MessageReceptionState(123, rollover=False, encrypted=True)
assert state.message_counter == 123
# Older messages are not ok
assert state.process_counter(123 - 32)
# Older messages outside the window are not ok
assert state.process_counter(123 - 32 - 1)
# Older messages outside the window are not ok
assert state.process_counter(0)
# All newer numbers are ok
assert not state.process_counter(0xFFFFFFFE)
# Ok because it is in the window
assert not state.process_counter(0xFFFFFFFD)
# All older messages outside the window are not ok
assert state.process_counter(0)
def test_encrypted_with_rollover():
"""These test the common window behavior"""
state = MessageReceptionState(123, rollover=True, encrypted=True)
assert state.message_counter == 123
# Older messages are not ok
assert state.process_counter(123 - 32)
# Older messages outside the window are not ok
assert state.process_counter(123 - 32 - 1)
# Older messages outside the window are not ok
assert state.process_counter(0)
# Numbers wrapped back within the 2**31 window are not ok
assert state.process_counter(0xFFFFFFFE)
assert not state.process_counter(0x80000000)
assert not state.process_counter(0xFFFFFFFE)
# Ok because it is in the window
assert not state.process_counter(0xFFFFFFFD)
# All older messages outside the window are not ok
assert state.process_counter(0xFFFFFFFE - 32 - 32)
assert state.process_counter(0xFFFFFFFE - 0x80000000)
# It is ok to wrap back around outside the 2**31 window.
assert not state.process_counter(0xFFFFFFFE - 0x80000000 - 1)