circuitmatter/circuitmatter/pase.py
2024-10-10 16:26:49 -07:00

299 lines
9.2 KiB
Python

from . import crypto
from . import protocol
from . import tlv
from . import session
import hashlib
import struct
from cryptography.hazmat.primitives.ciphers.aead import AESCCM
from ecdsa.ellipticcurve import AbstractPoint, Point, PointJacobi
from ecdsa.curves import NIST256p
class PASEMessage(tlv.Structure):
PROTOCOL_ID = protocol.ProtocolId.SECURE_CHANNEL
# pbkdfparamreq-struct => STRUCTURE [ tag-order ]
# {
# initiatorRandom
# [1] : OCTET STRING [ length 32 ],
# initiatorSessionId
# [2] : UNSIGNED INTEGER [ range 16-bits ],
# passcodeId
# [3] : UNSIGNED INTEGER [ length 16-bits ],
# hasPBKDFParameters
# [4] : BOOLEAN,
# initiatorSessionParams [5, optional] : session-parameter-struct
# }
class PBKDFParamRequest(PASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.PBKDF_PARAM_REQUEST
initiatorRandom = tlv.OctetStringMember(1, 32)
initiatorSessionId = tlv.IntMember(2, signed=False, octets=2)
passcodeId = tlv.IntMember(3, signed=False, octets=2)
hasPBKDFParameters = tlv.BoolMember(4)
initiatorSessionParams = tlv.StructMember(
5, session.SessionParameterStruct, optional=True
)
# Crypto_PBKDFParameterSet => STRUCTURE [ tag-order ]
# {
# iterations [1] : UNSIGNED INTEGER [ range 32-bits ],
# salt [2] : OCTET STRING [ length 16..32 ],
# }
class Crypto_PBKDFParameterSet(tlv.Structure):
iterations = tlv.IntMember(1, signed=False, octets=4)
salt = tlv.OctetStringMember(2, 32)
# pbkdfparamresp-struct => STRUCTURE [ tag-order ]
# {
# initiatorRandom
# [1] : OCTET STRING [ length 32 ],
# responderRandom
# [2] : OCTET STRING [ length 32 ],
# responderSessionId
# [3] : UNSIGNED INTEGER [ range 16-bits ],
# pbkdf_parameters
# [4] : Crypto_PBKDFParameterSet,
# responderSessionParams [5, optional] : session-parameter-struct
# }
class PBKDFParamResponse(PASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.PBKDF_PARAM_RESPONSE
initiatorRandom = tlv.OctetStringMember(1, 32)
responderRandom = tlv.OctetStringMember(2, 32)
responderSessionId = tlv.IntMember(3, signed=False, octets=2)
pbkdf_parameters = tlv.StructMember(4, Crypto_PBKDFParameterSet)
responderSessionParams = tlv.StructMember(
5, session.SessionParameterStruct, optional=True
)
class PAKE1(PASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.PASE_PAKE1
pA = tlv.OctetStringMember(1, crypto.PUBLIC_KEY_SIZE_BYTES)
class PAKE2(PASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.PASE_PAKE2
pB = tlv.OctetStringMember(1, crypto.PUBLIC_KEY_SIZE_BYTES)
cB = tlv.OctetStringMember(2, crypto.HASH_LEN_BYTES)
class PAKE3(PASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.PASE_PAKE3
cA = tlv.OctetStringMember(1, crypto.HASH_LEN_BYTES)
M = PointJacobi.from_bytes(
NIST256p.curve,
b"\x02\x88\x6e\x2f\x97\xac\xe4\x6e\x55\xba\x9d\xd7\x24\x25\x79\xf2\x99\x3b\x64\xe1\x6e\xf3\xdc\xab\x95\xaf\xd4\x97\x33\x3d\x8f\xa1\x2f",
)
N = PointJacobi.from_bytes(
NIST256p.curve,
b"\x03\xd8\xbb\xd6\xc6\x39\xc6\x29\x37\xb0\x4d\x99\x7f\x38\xc3\x77\x07\x19\xc6\x29\xd7\x01\x4d\x49\xa2\x4b\x4f\x98\xba\xa1\x29\x2b\x49",
)
crypto.W_SIZE_BYTES = crypto.GROUP_SIZE_BYTES + 8
# in the spake2p math P is NIST256p.generator
# in the spake2p math p is NIST256p.order
def _pbkdf2(passcode, salt, iterations):
ws = hashlib.pbkdf2_hmac(
"sha256", struct.pack("<I", passcode), salt, iterations, crypto.W_SIZE_BYTES * 2
)
w0 = int.from_bytes(ws[: crypto.W_SIZE_BYTES], byteorder="big") % NIST256p.order
w1 = int.from_bytes(ws[crypto.W_SIZE_BYTES :], byteorder="big") % NIST256p.order
return w0, w1
def initiator_values(passcode, salt, iterations) -> tuple[bytes, bytes]:
w0, w1 = _pbkdf2(passcode, salt, iterations)
return w0.to_bytes(NIST256p.baselen, byteorder="big"), w1.to_bytes(
NIST256p.baselen, byteorder="big"
)
def verifier_values(passcode: int, salt: bytes, iterations: int) -> tuple[bytes, bytes]:
w0, w1 = _pbkdf2(passcode, salt, iterations)
L = NIST256p.generator * w1
return w0.to_bytes(NIST256p.baselen, byteorder="big"), L.to_bytes("uncompressed")
# w0 and w1 are big-endian encoded
def Crypto_pA(w0, w1) -> bytes:
return b""
def Crypto_pB(random_source, w0: int, L: Point) -> tuple[int, AbstractPoint]:
y = random_source.randbelow(NIST256p.order)
Y = y * NIST256p.generator + w0 * N
return y, Y
def Crypto_Transcript(context, pA, pB, Z, V, w0) -> bytes:
elements = [
context,
b"",
b"",
M.to_bytes("uncompressed"),
N.to_bytes("uncompressed"),
pA,
pB,
Z,
V,
w0,
]
total_length = 0
for e in elements:
total_length += len(e) + 8
tt = bytearray(total_length)
offset = 0
for e in elements:
struct.pack_into("<Q", tt, offset, len(e))
offset += 8
tt[offset : offset + len(e)] = e
offset += len(e)
return tt
def KDF(salt, key, info):
# Section 3.10 defines the mapping from KDF to Crypto_KDF but it is wrong!
# The arg order is correct above.
return crypto.KDF(key, salt, info, crypto.HASH_LEN_BITS)
def Crypto_P2(tt, pA, pB) -> tuple[bytes, bytes, bytes]:
KaKe = crypto.Hash(tt)
Ka = KaKe[: crypto.HASH_LEN_BYTES // 2]
Ke = KaKe[crypto.HASH_LEN_BYTES // 2 :]
# https://github.com/project-chip/connectedhomeip/blob/c88d5cf83cd3e3323ac196630acc34f196a2f405/src/crypto/CHIPCryptoPAL.cpp#L458-L468
KcAKcB = KDF(None, Ka, b"ConfirmationKeys")
KcA = KcAKcB[: crypto.HASH_LEN_BYTES // 2]
KcB = KcAKcB[crypto.HASH_LEN_BYTES // 2 :]
cA = crypto.HMAC(KcA, pB)
cB = crypto.HMAC(KcB, pA)
return (cA, cB, Ke)
def compute_session_keys(Ke, secure_session_context):
keys = crypto.KDF(
Ke,
b"",
b"SessionKeys",
3 * crypto.SYMMETRIC_KEY_LENGTH_BITS,
)
secure_session_context.i2r_key = keys[: crypto.SYMMETRIC_KEY_LENGTH_BYTES]
secure_session_context.i2r = AESCCM(
secure_session_context.i2r_key,
tag_length=crypto.AEAD_MIC_LENGTH_BYTES,
)
secure_session_context.r2i_key = keys[
crypto.SYMMETRIC_KEY_LENGTH_BYTES : 2 * crypto.SYMMETRIC_KEY_LENGTH_BYTES
]
secure_session_context.r2i = AESCCM(
secure_session_context.r2i_key,
tag_length=crypto.AEAD_MIC_LENGTH_BYTES,
)
secure_session_context.attestation_challenge = keys[
2 * crypto.SYMMETRIC_KEY_LENGTH_BYTES : 3 * crypto.SYMMETRIC_KEY_LENGTH_BYTES
]
def compute_verification(random_source, pake1, pake2, context, verifier):
w0 = memoryview(verifier)[: crypto.GROUP_SIZE_BYTES]
L = memoryview(verifier)[crypto.GROUP_SIZE_BYTES :]
L = Point.from_bytes(NIST256p.curve, L)
w0 = int.from_bytes(w0, byteorder="big")
y, Y = Crypto_pB(random_source, w0, L)
# pB is Y encoded uncompressed
# pA is X encoded uncompressed
pake2.pB = Y.to_bytes("uncompressed")
h = NIST256p.curve.cofactor()
# Use negation because the class doesn't support subtraction. 🤦
X = Point.from_bytes(NIST256p.curve, pake1.pA)
Z = h * y * (X + (-(w0 * M)))
# Z is wrong. V is right
V = h * y * L
tt = Crypto_Transcript(
context,
pake1.pA,
pake2.pB,
Z.to_bytes("uncompressed"),
V.to_bytes("uncompressed"),
w0.to_bytes(NIST256p.baselen, byteorder="big"),
)
cA, cB, Ke = Crypto_P2(tt, pake1.pA, pake2.pB)
pake2.cB = cB
return cA, Ke
def _write_bits(buf, offset, bits, value) -> int:
while bits > 0:
bits_remaining = 8 - offset % 8
write_bits = min(bits, bits_remaining)
mask = (1 << write_bits) - 1
buf[offset // 8] |= (value & mask) << (offset % 8)
offset += write_bits
bits -= write_bits
value >>= write_bits
return offset
def _base38_encode(buf) -> str:
alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ-."
encoded = []
for i in range(0, len(buf), 3):
value = 0
remaining = min(3, len(buf) - i)
print("remaining", remaining)
for j in range(remaining):
value |= buf[i + j] << (j * 8)
outputs = 5
if remaining == 2:
outputs = 4
elif remaining == 1:
outputs = 2
for j in range(outputs):
encoded.append(alphabet[value % 38])
value //= 38
print(encoded)
return "".join(encoded)
def show_qr_code(vendor_id, product_id, discriminator, passcode):
total_bits = 3 + 16 * 2 + 2 + 8 + 12 + 27 + 4
total_bytes = total_bits // 8
buf = bytearray(total_bytes)
discovery = 1 << 2 # On network already
offset = 0
offset = _write_bits(buf, offset, 3, 0)
offset = _write_bits(buf, offset, 16, vendor_id)
offset = _write_bits(buf, offset, 16, product_id)
offset = _write_bits(buf, offset, 2, 0)
offset = _write_bits(buf, offset, 8, discovery)
offset = _write_bits(buf, offset, 12, discriminator)
offset = _write_bits(buf, offset, 27, passcode)
print(buf.hex(" "))
encoded = _base38_encode(buf)
import qrcode
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=10,
border=4,
)
qr.add_data("MT:")
qr.add_data(encoded)
qr.print_ascii()