Add message reception state for counter verification
This commit is contained in:
parent
6f28636869
commit
9d2687b784
3 changed files with 422 additions and 66 deletions
|
|
@ -1,6 +1,10 @@
|
|||
"""Pure Python implementation of the Matter IOT protocol."""
|
||||
|
||||
import enum
|
||||
import pathlib
|
||||
import json
|
||||
import struct
|
||||
import time
|
||||
|
||||
from . import tlv
|
||||
|
||||
|
|
@ -21,7 +25,10 @@ __version__ = "0.0.0"
|
|||
|
||||
# print(f"Listening on UDP port {UDP_PORT}")
|
||||
|
||||
unsecured_session_context = {}
|
||||
# Section 4.11.2
|
||||
MSG_COUNTER_WINDOW_SIZE = 32
|
||||
MSG_COUNTER_SYNC_REQ_JITTER_MS = 500
|
||||
MSG_COUNTER_SYNC_TIMEOUT_MS = 400
|
||||
|
||||
|
||||
class ProtocolId(enum.Enum):
|
||||
|
|
@ -36,6 +43,8 @@ class SecurityFlags(enum.Flag):
|
|||
P = 1 << 7
|
||||
C = 1 << 6
|
||||
MX = 1 << 5
|
||||
# This is actually 2 bits but the top bit is reserved and always zero.
|
||||
GROUP = 1 << 0
|
||||
|
||||
|
||||
class ExchangeFlags(enum.Flag):
|
||||
|
|
@ -173,3 +182,264 @@ class PBKDFParamResponse(tlv.TLVStructure):
|
|||
responderSessionId = tlv.NumberMember(3, "<H")
|
||||
pbkdf_parameters = tlv.StructMember(4, Crypto_PBKDFParameterSet)
|
||||
responderSessionParams = tlv.StructMember(5, SessionParameterStruct, optional=True)
|
||||
|
||||
|
||||
class MessageReceptionState:
|
||||
def __init__(self, starting_value, rollover=True, encrypted=False):
|
||||
"""Implements 4.6.5.1"""
|
||||
self.message_counter = starting_value
|
||||
self.window_bitmap = (1 << MSG_COUNTER_WINDOW_SIZE) - 1
|
||||
self.mask = self.window_bitmap
|
||||
self.encrypted = encrypted
|
||||
self.rollover = rollover
|
||||
|
||||
def process_counter(self, counter) -> bool:
|
||||
"""Returns True if the counter number is a duplicate"""
|
||||
# Process the current window first. Behavior outside the window varies.
|
||||
if counter == self.message_counter:
|
||||
return True
|
||||
if self.message_counter <= MSG_COUNTER_WINDOW_SIZE < counter:
|
||||
# Window wraps
|
||||
bit_position = 0xFFFFFFFF - counter + self.message_counter
|
||||
else:
|
||||
bit_position = self.message_counter - counter - 1
|
||||
if 0 <= bit_position < MSG_COUNTER_WINDOW_SIZE:
|
||||
if self.window_bitmap & (1 << bit_position) != 0:
|
||||
# This is a duplicate message
|
||||
return True
|
||||
self.window_bitmap |= 1 << bit_position
|
||||
return False
|
||||
|
||||
new_start = (self.message_counter + 1) & self.mask # Inclusive
|
||||
new_end = (
|
||||
self.message_counter - MSG_COUNTER_WINDOW_SIZE
|
||||
) & self.mask # Exclusive
|
||||
if not self.rollover:
|
||||
new_end = (1 << MSG_COUNTER_WINDOW_SIZE) - 1
|
||||
elif self.encrypted:
|
||||
new_end = (
|
||||
self.message_counter + (1 << (MSG_COUNTER_WINDOW_SIZE - 1))
|
||||
) & self.mask
|
||||
|
||||
if new_start <= new_end:
|
||||
if not (new_start <= counter < new_end):
|
||||
return True
|
||||
else:
|
||||
if not (counter < new_end or new_start <= counter):
|
||||
return True
|
||||
|
||||
# This is a new message
|
||||
shift = counter - self.message_counter
|
||||
if counter < self.message_counter:
|
||||
shift += 0x100000000
|
||||
if shift > MSG_COUNTER_WINDOW_SIZE:
|
||||
self.window_bitmap = 0
|
||||
else:
|
||||
new_bitmap = (self.window_bitmap << shift) & self.mask
|
||||
self.window_bitmap = new_bitmap
|
||||
if 1 < shift < MSG_COUNTER_WINDOW_SIZE:
|
||||
self.window_bitmap |= 1 << (shift - 1)
|
||||
self.message_counter = counter
|
||||
return False
|
||||
|
||||
|
||||
class UnsecuredSessionContext:
|
||||
def __init__(self, initiator, ephemeral_initiator_node_id):
|
||||
self.initiator = initiator
|
||||
self.ephemeral_initiator_node_id = ephemeral_initiator_node_id
|
||||
self.message_reception_state = None
|
||||
|
||||
|
||||
class SecureSessionContext:
|
||||
def __init__(self, local_session_id):
|
||||
self.session_type = None
|
||||
"""Records whether the session was established using CASE or PASE."""
|
||||
self.session_role = None
|
||||
"""Records whether the node is the session initiator or responder."""
|
||||
self.local_session_id = local_session_id
|
||||
"""Individually selected by each participant in secure unicast communication during session establishment and used as a unique identifier to recover encryption keys, authenticate incoming messages and associate them to existing sessions."""
|
||||
self.peer_session_id = None
|
||||
"""Assigned by the peer during session establishment"""
|
||||
self.i2r_key = None
|
||||
"""Encrypts data in messages sent from the initiator of session establishment to the responder."""
|
||||
self.r2i_key = None
|
||||
"""Encrypts data in messages sent from the session establishment responder to the initiator."""
|
||||
self.shared_secret = None
|
||||
"""Computed during the CASE protocol execution and re-used when CASE session resumption is implemented."""
|
||||
self.local_message_counter = None
|
||||
"""Secure Session Message Counter for outbound messages."""
|
||||
self.message_reception_state = None
|
||||
"""Provides tracking for the Secure Session Message Counter of the remote"""
|
||||
self.local_fabric_index = None
|
||||
"""Records the local Index for the session’s Fabric, which MAY be used to look up Fabric metadata related to the Fabric for which this session context applies."""
|
||||
self.peer_node_id = None
|
||||
"""Records the authenticated node ID of the remote peer, when available."""
|
||||
self.resumption_id = None
|
||||
"""The ID used when resuming a session between the local and remote peer."""
|
||||
self.session_timestamp = None
|
||||
"""A timestamp indicating the time at which the last message was sent or received. This timestamp SHALL be initialized with the time the session was created."""
|
||||
self.active_timestamp = None
|
||||
"""A timestamp indicating the time at which the last message was received. This timestamp SHALL be initialized with the time the session was created."""
|
||||
self.session_idle_interval = None
|
||||
self.session_active_interval = None
|
||||
self.session_active_threshold = None
|
||||
|
||||
@property
|
||||
def peer_active(self):
|
||||
return (time.monotonic() - self.active_timestamp) < self.session_active_interval
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, buffer):
|
||||
self.buffer = buffer
|
||||
self.flags, self.session_id, self.security_flags, self.message_counter = (
|
||||
struct.unpack_from("<BHBI", buffer)
|
||||
)
|
||||
offset = 8
|
||||
self.source_node_id = None
|
||||
if self.flags & (1 << 2):
|
||||
self.source_node_id = struct.unpack_from("<Q", buffer, 8)[0]
|
||||
offset += 8
|
||||
|
||||
if (self.flags >> 4) != 0:
|
||||
raise RuntimeError("Incorrect version")
|
||||
self.secure_session = self.security_flags & 0x3 != 0 or self.session_id != 0
|
||||
|
||||
if not self.secure_session:
|
||||
self.payload = memoryview(buffer)[offset:]
|
||||
|
||||
context = UnsecuredSessionContext(False, self.source_node_id)
|
||||
self.unsecured_session_context[self.source_node_id] = context
|
||||
else:
|
||||
self.payload = None
|
||||
|
||||
def _parse_protocol_header(self):
|
||||
self.exchange_flags, self.protocol_opcode, self.exchange_id = (
|
||||
struct.unpack_from("<BBH", self.payload)
|
||||
)
|
||||
|
||||
self.exchange_flags = ExchangeFlags(self.exchange_flags)
|
||||
decrypted_offset = 4
|
||||
self.protocol_vendor_id = 0
|
||||
if self.exchange_flags & ExchangeFlags.V:
|
||||
self.protocol_vendor_id = struct.unpack_from(
|
||||
"<H", self.payload, decrypted_offset
|
||||
)[0]
|
||||
decrypted_offset += 2
|
||||
protocol_id = struct.unpack_from("<H", self.payload, decrypted_offset)[0]
|
||||
decrypted_offset += 2
|
||||
self.protocol_id = ProtocolId(protocol_id)
|
||||
self.protocol_opcode = PROTOCOL_OPCODES[protocol_id](self.protocol_opcode)
|
||||
|
||||
self.acknowledged_message_counter = None
|
||||
if self.exchange_flags & ExchangeFlags.A:
|
||||
self.acknowledged_message_counter = struct.unpack_from(
|
||||
"<I", self.payload, decrypted_offset
|
||||
)[0]
|
||||
decrypted_offset += 4
|
||||
|
||||
def reply(self, payload, protocol_id=None, protocol_opcode=None) -> memoryview:
|
||||
reply = bytearray(1280)
|
||||
offset = 0
|
||||
|
||||
# struct.pack_into(
|
||||
# "<BHBI", reply, offset, flags, session_id, security_flags, message_counter
|
||||
# )
|
||||
# offset += 8
|
||||
return memoryview(reply)[:offset]
|
||||
|
||||
|
||||
class SessionManager:
|
||||
def __init__(self):
|
||||
persist_path = pathlib.Path("counters.json")
|
||||
if persist_path.exists():
|
||||
self.nonvolatile = json.loads(persist_path.read_text())
|
||||
else:
|
||||
self.nonvolatile = {}
|
||||
self.nonvolatile["unencrypted_message_counter"] = 0
|
||||
self.nonvolatile["group_encrypted_data_message_counter"] = 0
|
||||
self.nonvolatile["group_encrypted_control_message_counter"] = 0
|
||||
self.unencrypted_message_counter = self.nonvolatile[
|
||||
"unencrypted_message_counter"
|
||||
]
|
||||
self.group_encrypted_data_message_counter = self.nonvolatile[
|
||||
"group_encrypted_data_message_counter"
|
||||
]
|
||||
self.group_encrypted_control_message_counter = self.nonvolatile[
|
||||
"group_encrypted_control_message_counter"
|
||||
]
|
||||
self.check_in_counter = 0
|
||||
self.unsecured_session_context = {}
|
||||
self.secure_session_contexts = ["reserved"]
|
||||
|
||||
def _increment(self, value):
|
||||
return (value + 1) % 0xFFFFFFFF
|
||||
|
||||
def counter_ok(self, message):
|
||||
"""Implements 4.6.7"""
|
||||
if message.secure_session:
|
||||
if message.security_flags & SecurityFlags.GROUP:
|
||||
if message.source_node_id is None:
|
||||
return False
|
||||
# TODO: Get MRS for source node id and message type
|
||||
else:
|
||||
session_context = self.secure_session_contexts[message.session_id]
|
||||
else:
|
||||
if message.source_node_id not in self.unsecured_session_context:
|
||||
self.unsecured_session_context[message.source_node_id] = (
|
||||
UnsecuredSessionContext(
|
||||
initiator=False,
|
||||
ephemeral_initiator_node_id=message.source_node_id,
|
||||
)
|
||||
)
|
||||
session_context = self.unsecured_session_context[message.source_node_id]
|
||||
|
||||
if session_context.message_reception_state is None:
|
||||
session_context.message_reception_state = MessageReceptionState(
|
||||
message.message_counter,
|
||||
rollover=False,
|
||||
encrypted=message.secure_session,
|
||||
)
|
||||
return True
|
||||
|
||||
return session_context.message_reception_state.process_counter(
|
||||
message.message_counter
|
||||
)
|
||||
|
||||
def next_message_counter(self, message):
|
||||
"""Implements 4.6.6"""
|
||||
if not message.secure_session:
|
||||
value = self.unencrypted_message_counter
|
||||
self.unencrypted_message_counter = self._increment(
|
||||
self.unencrypted_message_counter
|
||||
)
|
||||
return value
|
||||
elif message.security_flags & SecurityFlags.GROUP:
|
||||
if message.security_flags & SecurityFlags.C:
|
||||
value = self.group_encrypted_control_message_counter
|
||||
self.group_encrypted_control_message_counter = self._increment(
|
||||
self.group_encrypted_control_message_counter
|
||||
)
|
||||
return value
|
||||
else:
|
||||
value = self.group_encrypted_data_message_counter
|
||||
self.group_encrypted_data_message_counter = self._increment(
|
||||
self.group_encrypted_data_message_counter
|
||||
)
|
||||
return value
|
||||
session = self.secure_session_contexts[message.session_id]
|
||||
value = session.local_message_counter
|
||||
next_value = self._increment(value)
|
||||
session.local_message_counter = next_value
|
||||
if next_value == 0:
|
||||
# TODO expire the encryption key
|
||||
raise NotImplementedError("Expire the encryption key 4.6.6")
|
||||
return next_value
|
||||
|
||||
def new_context(self):
|
||||
if None not in self.secure_session_contexts:
|
||||
self.secure_session_contexts.append(None)
|
||||
session_id = self.secure_session_contexts.index(None)
|
||||
|
||||
self.secure_session_contexts[session_id] = SecureSessionContext(session_id)
|
||||
return self.secure_session_contexts[session_id]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Pure Python implementation of the Matter IOT protocol."""
|
||||
|
||||
import struct
|
||||
import os
|
||||
|
||||
import circuitmatter as cm
|
||||
|
||||
|
|
@ -21,6 +21,7 @@ import circuitmatter as cm
|
|||
# print(f"Listening on UDP port {UDP_PORT}")
|
||||
|
||||
unsecured_session_context = {}
|
||||
secure_session_contexts = ["reserved"]
|
||||
|
||||
# while True:
|
||||
# # Receive data from the socket (1280 is the minimum ipv6 MTU and the max UDP matter packet size.)
|
||||
|
|
@ -50,66 +51,23 @@ def add_bookmark(start, length, name, color=0x0000FF):
|
|||
|
||||
|
||||
def run():
|
||||
manager = cm.SessionManager()
|
||||
# Print the received data and the address of the sender
|
||||
# This is section 4.7.2
|
||||
print(f"Received packet from {addr}: {data}")
|
||||
print(f"Data length: {len(data)} bytes")
|
||||
flags, session_id, security_flags, message_counter = struct.unpack_from(
|
||||
"<BHBI", data
|
||||
)
|
||||
add_bookmark(0, 8, "Header")
|
||||
print(
|
||||
f"Flags: {flags:x} Session ID: {session_id:x} Security Flags: {cm.SecurityFlags(security_flags)} Message Counter: {message_counter}"
|
||||
)
|
||||
offset = 8
|
||||
if flags & (1 << 2):
|
||||
source_node_id = struct.unpack_from("<Q", data, 8)[0]
|
||||
add_bookmark(8, 8, "Source Node ID")
|
||||
print(source_node_id)
|
||||
offset += 8
|
||||
print(f"DSIZ {flags & (0x3)}")
|
||||
if (flags >> 4) != 0:
|
||||
print("Incorrect version")
|
||||
# continue
|
||||
secure_session = security_flags & 0x3 != 0 or session_id != 0
|
||||
message = cm.Message(data)
|
||||
if message.secure_session:
|
||||
# Decrypt the payload
|
||||
pass
|
||||
if not manager.counter_ok(message):
|
||||
print("Dropping message due to counter error")
|
||||
return
|
||||
# if not manager.rmp_ok(message):
|
||||
# print("Dropping message due to RMP")
|
||||
# continue
|
||||
|
||||
if not secure_session:
|
||||
print("Unsecured session")
|
||||
print(data[offset : offset + 8])
|
||||
decrypted_message = memoryview(data)[offset:]
|
||||
|
||||
context = {"role": "responder", "node_id": source_node_id}
|
||||
unsecured_session_context[source_node_id] = context
|
||||
|
||||
exchange_flags, protocol_opcode, exchange_id = struct.unpack_from(
|
||||
"<BBH", decrypted_message
|
||||
)
|
||||
add_bookmark(offset, 4, "Protocol header")
|
||||
exchange_flags = cm.ExchangeFlags(exchange_flags)
|
||||
print(f"Exchange Flags: {exchange_flags} Exchange ID: {exchange_id}")
|
||||
decrypted_offset = 4
|
||||
protocol_vendor_id = 0
|
||||
if exchange_flags & cm.ExchangeFlags.V:
|
||||
protocol_vendor_id = struct.unpack_from(
|
||||
"<H", decrypted_message, decrypted_offset
|
||||
)[0]
|
||||
add_bookmark(offset + decrypted_offset, 2, "Protocol Vendor ID")
|
||||
decrypted_offset += 2
|
||||
protocol_id = struct.unpack_from("<H", decrypted_message, decrypted_offset)[0]
|
||||
add_bookmark(offset + decrypted_offset, 2, "Protocol ID")
|
||||
decrypted_offset += 2
|
||||
protocol_id = cm.ProtocolId(protocol_id)
|
||||
protocol_opcode = cm.PROTOCOL_OPCODES[protocol_id](protocol_opcode)
|
||||
print(
|
||||
f"Protocol Vendor ID: {protocol_vendor_id} Protocol ID: {protocol_id} Protocol Opcode: {protocol_opcode}"
|
||||
)
|
||||
|
||||
acknowledged_message_counter = None
|
||||
if exchange_flags & cm.ExchangeFlags.A:
|
||||
acknowledged_message_counter = struct.unpack_from(
|
||||
"<I", decrypted_message, decrypted_offset
|
||||
)[0]
|
||||
decrypted_offset += 4
|
||||
print(f"Acknowledged Message Counter: {acknowledged_message_counter}")
|
||||
protocol_id = message.protocol_id
|
||||
protocol_opcode = message.protocol_opcode
|
||||
|
||||
if protocol_id == cm.ProtocolId.SECURE_CHANNEL:
|
||||
if protocol_opcode == cm.SecureProtocolOpcode.MSG_COUNTER_SYNC_REQ:
|
||||
|
|
@ -118,16 +76,27 @@ def run():
|
|||
print("Received Message Counter Synchronization Response")
|
||||
elif protocol_opcode == cm.SecureProtocolOpcode.PBKDF_PARAM_REQUEST:
|
||||
print("Received PBKDF Parameter Request")
|
||||
request = cm.PBKDFParamRequest(decrypted_message[decrypted_offset + 1 :])
|
||||
# This is Section 4.14.1.2
|
||||
request = cm.PBKDFParamRequest(message.payload)
|
||||
if request.passcodeID == 0:
|
||||
pass
|
||||
# Send back failure
|
||||
# response = StatusReport()
|
||||
# response.GeneralCode
|
||||
print(request)
|
||||
response = cm.PBKDFParamResponse()
|
||||
response.initiatorRandom = request.initiatorRandom
|
||||
response.responderRandom = b"\x00" * 32
|
||||
response.responderSessionId = 0
|
||||
params = cm.Crypto_PBKDFParameterSet()
|
||||
params.iterations = 1000
|
||||
params.salt = b"\x00" * 32
|
||||
response.pbkdf_parameters = params
|
||||
|
||||
# Generate a random number
|
||||
response.responderRandom = os.urandom(32)
|
||||
session_context = manager.new_context(response.responderSessionId)
|
||||
|
||||
session_context.peer_session_id = request.initiatorSessionId
|
||||
if not request.hasPBKDFParameters:
|
||||
params = cm.Crypto_PBKDFParameterSet()
|
||||
params.iterations = 1000
|
||||
params.salt = b"\x00" * 32
|
||||
response.pbkdf_parameters = params
|
||||
print(response)
|
||||
|
||||
elif protocol_opcode == cm.SecureProtocolOpcode.PBKDF_PARAM_RESPONSE:
|
||||
|
|
|
|||
117
tests/test_message_reception_state.py
Normal file
117
tests/test_message_reception_state.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
from circuitmatter import MessageReceptionState
|
||||
|
||||
|
||||
def test_basics():
|
||||
"""These test the common window behavior"""
|
||||
state = MessageReceptionState(123)
|
||||
assert state.message_counter == 123
|
||||
|
||||
# Older messages are not ok
|
||||
assert state.process_counter(122)
|
||||
|
||||
# The current max is not ok
|
||||
assert state.process_counter(123)
|
||||
|
||||
# A new value is ok
|
||||
assert not state.process_counter(126)
|
||||
|
||||
#
|
||||
assert state.process_counter(123)
|
||||
|
||||
assert not state.process_counter(124)
|
||||
|
||||
assert not state.process_counter(125)
|
||||
|
||||
assert state.process_counter(124)
|
||||
|
||||
|
||||
def test_window_wrapping():
|
||||
"""Test wrapping the window data across a rollover"""
|
||||
state = MessageReceptionState(123, rollover=True)
|
||||
assert state.message_counter == 123
|
||||
|
||||
# Move to the end of the range
|
||||
assert not state.process_counter(0xFFFFFFFF)
|
||||
|
||||
# Older is ok when in the window.
|
||||
assert not state.process_counter(0xFFFFFFF0)
|
||||
|
||||
# A new value is ok. Window is now 0xFFFFFFF0 to 15
|
||||
assert not state.process_counter(16)
|
||||
|
||||
assert state.process_counter(0xFFFFFFF0)
|
||||
|
||||
assert state.process_counter(0xFFFFFFFF)
|
||||
|
||||
assert not state.process_counter(1)
|
||||
|
||||
assert not state.process_counter(0xFFFFFFF8)
|
||||
|
||||
|
||||
def test_unencrypted():
|
||||
"""These test the common window behavior"""
|
||||
state = MessageReceptionState(123, rollover=True, encrypted=False)
|
||||
assert state.message_counter == 123
|
||||
|
||||
# Older messages are not ok
|
||||
assert state.process_counter(123 - 32)
|
||||
|
||||
# Older messages outside the window are ok
|
||||
assert not state.process_counter(123 - 32 - 1)
|
||||
|
||||
|
||||
def test_encrypted_no_rollover():
|
||||
"""These test the common window behavior"""
|
||||
state = MessageReceptionState(123, rollover=False, encrypted=True)
|
||||
assert state.message_counter == 123
|
||||
|
||||
# Older messages are not ok
|
||||
assert state.process_counter(123 - 32)
|
||||
|
||||
# Older messages outside the window are not ok
|
||||
assert state.process_counter(123 - 32 - 1)
|
||||
|
||||
# Older messages outside the window are not ok
|
||||
assert state.process_counter(0)
|
||||
|
||||
# All newer numbers are ok
|
||||
assert not state.process_counter(0xFFFFFFFE)
|
||||
|
||||
# Ok because it is in the window
|
||||
assert not state.process_counter(0xFFFFFFFD)
|
||||
|
||||
# All older messages outside the window are not ok
|
||||
assert state.process_counter(0)
|
||||
|
||||
|
||||
def test_encrypted_with_rollover():
|
||||
"""These test the common window behavior"""
|
||||
state = MessageReceptionState(123, rollover=True, encrypted=True)
|
||||
assert state.message_counter == 123
|
||||
|
||||
# Older messages are not ok
|
||||
assert state.process_counter(123 - 32)
|
||||
|
||||
# Older messages outside the window are not ok
|
||||
assert state.process_counter(123 - 32 - 1)
|
||||
|
||||
# Older messages outside the window are not ok
|
||||
assert state.process_counter(0)
|
||||
|
||||
# Numbers wrapped back within the 2**31 window are not ok
|
||||
assert state.process_counter(0xFFFFFFFE)
|
||||
|
||||
assert not state.process_counter(0x80000000)
|
||||
|
||||
assert not state.process_counter(0xFFFFFFFE)
|
||||
|
||||
# Ok because it is in the window
|
||||
assert not state.process_counter(0xFFFFFFFD)
|
||||
|
||||
# All older messages outside the window are not ok
|
||||
assert state.process_counter(0xFFFFFFFE - 32 - 32)
|
||||
|
||||
assert state.process_counter(0xFFFFFFFE - 0x80000000)
|
||||
|
||||
# It is ok to wrap back around outside the 2**31 window.
|
||||
assert not state.process_counter(0xFFFFFFFE - 0x80000000 - 1)
|
||||
Loading…
Reference in a new issue