From 0b52d602ef3cd4e1e83105c109997ded3be81bd2 Mon Sep 17 00:00:00 2001 From: stcb <21@stcb.cc> Date: Sat, 29 Mar 2025 22:59:15 +0200 Subject: [PATCH] Fix + enhancement --- protocol_prototype/cli.py | 244 +++++++++++++-------- protocol_prototype/crypto_utils.py | 108 ++++++++-- protocol_prototype/encryption.py | 291 +++++++++++++++++++------ protocol_prototype/messages.py | 332 +++++++++++++++++------------ protocol_prototype/protocol.py | 320 +++++++++++++++++---------- 5 files changed, 870 insertions(+), 425 deletions(-) diff --git a/protocol_prototype/cli.py b/protocol_prototype/cli.py index c65bdd2..63c84d6 100644 --- a/protocol_prototype/cli.py +++ b/protocol_prototype/cli.py @@ -12,24 +12,25 @@ def main(): protocol = IcingProtocol() print(f"{YELLOW}\n======================================") - print(" Icing Protocol - Manual CLI Demo ") + print(" Icing Protocol - Secure Communication ") print("======================================\n" + RESET) print(f"Listening on port: {protocol.local_port}") print(f"Your identity public key (hex): {protocol.identity_pubkey.hex()}") print("\nAvailable commands:") - print(" peer_id ") - print(" connect ") - print(" generate_ephemeral_keys") - print(" send_ping") - print(" send_handshake") - print(" respond_ping <0|1>") - print(" generate_ecdhe ") - print(" derive_hkdf") - print(" send_encrypted ") - print(" decrypt_received <index>") - print(" auto_responder <on|off>") - print(" show_state") - print(" exit\n") + print(" help - Show this help message") + print(" peer_id <hex_pubkey> - Set peer identity public key") + print(" connect <port> - Connect to a peer at the specified port") + print(" generate_ephemeral_keys - Generate ephemeral ECDH keys") + print(" send_ping [cipher] - Send PING request (cipher: 0=AES-GCM, 1=ChaCha20-Poly1305, default: 0)") + print(" respond_ping <index> <0|1> - Respond to a PING (0=reject, 1=accept)") + print(" send_handshake - Send handshake with ephemeral keys") + print(" generate_ecdhe <index> - Process handshake at specified index") + print(" derive_hkdf - Derive encryption key using HKDF") + print(" send_encrypted <plaintext> - Encrypt and send a message") + print(" decrypt <index> - Decrypt received message at index") + print(" auto_responder <on|off> - Enable/disable automatic responses") + print(" show_state - Display current protocol state") + print(" exit - Exit the program\n") while True: try: @@ -38,76 +39,153 @@ def main(): break if not line: continue + parts = line.split() cmd = parts[0].lower() - if cmd == "exit": - protocol.stop() - break - elif cmd == "show_state": - protocol.show_state() - elif cmd == "peer_id": - if len(parts) != 2: - print("Usage: peer_id <hex_pubkey>") - continue - protocol.set_peer_identity(parts[1]) - elif cmd == "connect": - if len(parts) != 2: - print("Usage: connect <port>") - continue - try: - port = int(parts[1]) - protocol.connect_to_peer(port) - except ValueError: - print("Invalid port.") - elif cmd == "generate_ephemeral_keys": - protocol.generate_ephemeral_keys() - elif cmd == "send_ping": - protocol.send_ping_request() - elif cmd == "send_handshake": - protocol.send_handshake() - elif cmd == "respond_ping": - if len(parts) != 3: - print("Usage: respond_ping <index> <0|1>") - continue - try: - idx = int(parts[1]) - ac = int(parts[2]) - protocol.respond_to_ping(idx, ac) - except ValueError: - print("Index and answer must be integers.") - elif cmd == "generate_ecdhe": - if len(parts) != 2: - print("Usage: generate_ecdhe <index>") - continue - try: - idx = int(parts[1]) - protocol.generate_ecdhe(idx) - except ValueError: - print("Index must be an integer.") - elif cmd == "derive_hkdf": - protocol.derive_hkdf() - elif cmd == "send_encrypted": - if len(parts) < 2: - print("Usage: send_encrypted <plaintext>") - continue - plaintext = " ".join(parts[1:]) - protocol.send_encrypted_message(plaintext) - elif cmd == "decrypt_received": - if len(parts) != 2: - print("Usage: decrypt_received <index>") - continue - try: - idx = int(parts[1]) - protocol.decrypt_received_message(idx) - except ValueError: - print("Index must be an integer.") - elif cmd == "auto_responder": - if len(parts) != 2: - print("Usage: auto_responder <on|off>") - continue - protocol.enable_auto_responder(parts[1].lower() == "on") - else: - print(f"{RED}[ERROR]{RESET} Unknown command: {cmd}") + + try: + if cmd == "exit": + protocol.stop() + break + + elif cmd == "help": + print("\nAvailable commands:") + print(" help - Show this help message") + print(" peer_id <hex_pubkey> - Set peer identity public key") + print(" connect <port> - Connect to a peer at the specified port") + print(" generate_ephemeral_keys - Generate ephemeral ECDH keys") + print(" send_ping [cipher] - Send PING request (cipher: 0=AES-GCM, 1=ChaCha20-Poly1305, default: 0)") + print(" respond_ping <index> <0|1> - Respond to a PING (0=reject, 1=accept)") + print(" send_handshake - Send handshake with ephemeral keys") + print(" generate_ecdhe <index> - Process handshake at specified index") + print(" derive_hkdf - Derive encryption key using HKDF") + print(" send_encrypted <plaintext> - Encrypt and send a message") + print(" decrypt <index> - Decrypt received message at index") + print(" auto_responder <on|off> - Enable/disable automatic responses") + print(" show_state - Display current protocol state") + print(" exit - Exit the program") + + elif cmd == "show_state": + protocol.show_state() + + elif cmd == "peer_id": + if len(parts) != 2: + print(f"{RED}[ERROR]{RESET} Usage: peer_id <hex_pubkey>") + continue + try: + protocol.set_peer_identity(parts[1]) + except ValueError as e: + print(f"{RED}[ERROR]{RESET} Invalid public key: {e}") + + elif cmd == "connect": + if len(parts) != 2: + print(f"{RED}[ERROR]{RESET} Usage: connect <port>") + continue + try: + port = int(parts[1]) + protocol.connect_to_peer(port) + except ValueError: + print(f"{RED}[ERROR]{RESET} Invalid port number.") + except Exception as e: + print(f"{RED}[ERROR]{RESET} Connection failed: {e}") + + elif cmd == "generate_ephemeral_keys": + protocol.generate_ephemeral_keys() + + elif cmd == "send_ping": + # Optional cipher parameter (0 = AES-GCM, 1 = ChaCha20-Poly1305) + cipher = 0 # Default to AES-GCM + if len(parts) >= 2: + try: + cipher = int(parts[1]) + if cipher not in (0, 1): + print(f"{YELLOW}[WARNING]{RESET} Unsupported cipher code {cipher}. Using AES-GCM (0).") + cipher = 0 + except ValueError: + print(f"{YELLOW}[WARNING]{RESET} Invalid cipher code. Using AES-GCM (0).") + protocol.send_ping_request() + + elif cmd == "send_handshake": + protocol.send_handshake() + + elif cmd == "respond_ping": + if len(parts) != 3: + print(f"{RED}[ERROR]{RESET} Usage: respond_ping <index> <0|1>") + continue + try: + idx = int(parts[1]) + answer = int(parts[2]) + if answer not in (0, 1): + print(f"{RED}[ERROR]{RESET} Answer must be 0 (reject) or 1 (accept).") + continue + protocol.respond_to_ping(idx, answer) + except ValueError: + print(f"{RED}[ERROR]{RESET} Index and answer must be integers.") + except Exception as e: + print(f"{RED}[ERROR]{RESET} Failed to respond to ping: {e}") + + elif cmd == "generate_ecdhe": + if len(parts) != 2: + print(f"{RED}[ERROR]{RESET} Usage: generate_ecdhe <index>") + continue + try: + idx = int(parts[1]) + protocol.generate_ecdhe(idx) + except ValueError: + print(f"{RED}[ERROR]{RESET} Index must be an integer.") + except Exception as e: + print(f"{RED}[ERROR]{RESET} Failed to process handshake: {e}") + + elif cmd == "derive_hkdf": + try: + protocol.derive_hkdf() + except Exception as e: + print(f"{RED}[ERROR]{RESET} Failed to derive HKDF key: {e}") + + elif cmd == "send_encrypted": + if len(parts) < 2: + print(f"{RED}[ERROR]{RESET} Usage: send_encrypted <plaintext>") + continue + plaintext = " ".join(parts[1:]) + try: + protocol.send_encrypted_message(plaintext) + except Exception as e: + print(f"{RED}[ERROR]{RESET} Failed to send encrypted message: {e}") + + elif cmd == "decrypt": + if len(parts) != 2: + print(f"{RED}[ERROR]{RESET} Usage: decrypt <index>") + continue + try: + idx = int(parts[1]) + protocol.decrypt_received_message(idx) + except ValueError: + print(f"{RED}[ERROR]{RESET} Index must be an integer.") + except Exception as e: + print(f"{RED}[ERROR]{RESET} Failed to decrypt message: {e}") + + elif cmd == "auto_responder": + if len(parts) != 2: + print(f"{RED}[ERROR]{RESET} Usage: auto_responder <on|off>") + continue + val = parts[1].lower() + if val not in ("on", "off"): + print(f"{RED}[ERROR]{RESET} Value must be 'on' or 'off'.") + continue + protocol.enable_auto_responder(val == "on") + + else: + print(f"{RED}[ERROR]{RESET} Unknown command: {cmd}") + print("Type 'help' for a list of available commands.") + + except Exception as e: + print(f"{RED}[ERROR]{RESET} Command failed: {e}") if __name__ == "__main__": - main() + try: + main() + except KeyboardInterrupt: + print("\nExiting...") + except Exception as e: + print(f"{RED}[FATAL ERROR]{RESET} {e}") + sys.exit(1) diff --git a/protocol_prototype/crypto_utils.py b/protocol_prototype/crypto_utils.py index 07c5330..8c2e110 100644 --- a/protocol_prototype/crypto_utils.py +++ b/protocol_prototype/crypto_utils.py @@ -1,50 +1,78 @@ import os +from typing import Tuple from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec, utils from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature -def generate_identity_keys(): +def generate_identity_keys() -> Tuple[ec.EllipticCurvePrivateKey, bytes]: """ Generate an ECDSA (P-256) identity key pair. - Return (private_key, public_key_bytes). - public_key_bytes is raw x||y each 32 bytes (uncompressed minus the 0x04 prefix). + + Returns: + Tuple containing: + - private_key: EllipticCurvePrivateKey object + - public_key_bytes: Raw x||y format (64 bytes, 512 bits) """ private_key = ec.generate_private_key(ec.SECP256R1()) public_numbers = private_key.public_key().public_numbers() x_bytes = public_numbers.x.to_bytes(32, byteorder='big') y_bytes = public_numbers.y.to_bytes(32, byteorder='big') - pubkey_bytes = x_bytes + y_bytes # 64 bytes + pubkey_bytes = x_bytes + y_bytes # 64 bytes total return private_key, pubkey_bytes -def load_peer_identity_key(pubkey_bytes: bytes): +def load_peer_identity_key(pubkey_bytes: bytes) -> ec.EllipticCurvePublicKey: """ - Given 64 bytes (x||y) for P-256, return a cryptography public key object. + Convert a raw public key (64 bytes, x||y format) to a cryptography public key object. + + Args: + pubkey_bytes: Raw 64-byte public key (x||y format) + + Returns: + EllipticCurvePublicKey object + + Raises: + ValueError: If the pubkey_bytes is not exactly 64 bytes """ if len(pubkey_bytes) != 64: raise ValueError("Peer identity pubkey must be exactly 64 bytes (x||y).") + x_int = int.from_bytes(pubkey_bytes[:32], byteorder='big') y_int = int.from_bytes(pubkey_bytes[32:], byteorder='big') + public_numbers = ec.EllipticCurvePublicNumbers(x_int, y_int, ec.SECP256R1()) return public_numbers.public_key() -def sign_data(private_key, data: bytes) -> bytes: +def sign_data(private_key: ec.EllipticCurvePrivateKey, data: bytes) -> bytes: """ - Sign 'data' with ECDSA using P-256 private key. - Returns DER-encoded signature (variable length, up to ~70-72 bytes). + Sign data with ECDSA using a P-256 private key. + + Args: + private_key: EllipticCurvePrivateKey for signing + data: Bytes to sign + + Returns: + DER-encoded signature (variable length, up to ~70-72 bytes) """ signature = private_key.sign(data, ec.ECDSA(hashes.SHA256())) return signature -def verify_signature(public_key, signature: bytes, data: bytes) -> bool: +def verify_signature(public_key: ec.EllipticCurvePublicKey, signature: bytes, data: bytes) -> bool: """ - Verify DER-encoded ECDSA signature with the given public key. - Return True if valid, False otherwise. + Verify a DER-encoded ECDSA signature. + + Args: + public_key: EllipticCurvePublicKey for verification + signature: DER-encoded signature + data: Original signed data + + Returns: + True if signature is valid, False otherwise """ try: public_key.verify(signature, data, ec.ECDSA(hashes.SHA256())) @@ -53,35 +81,62 @@ def verify_signature(public_key, signature: bytes, data: bytes) -> bool: return False -def get_ephemeral_keypair(): +def get_ephemeral_keypair() -> Tuple[ec.EllipticCurvePrivateKey, bytes]: """ - Generate ephemeral ECDH keypair (P-256). - Return (private_key, pubkey_bytes). + Generate an ephemeral ECDH key pair (P-256). + + Returns: + Tuple containing: + - private_key: EllipticCurvePrivateKey object + - pubkey_bytes: Raw x||y format (64 bytes, 512 bits) """ private_key = ec.generate_private_key(ec.SECP256R1()) numbers = private_key.public_key().public_numbers() + x_bytes = numbers.x.to_bytes(32, 'big') y_bytes = numbers.y.to_bytes(32, 'big') - return private_key, x_bytes + y_bytes # 64 bytes + + return private_key, x_bytes + y_bytes # 64 bytes total -def compute_ecdh_shared_key(private_key, peer_pubkey_bytes: bytes) -> bytes: +def compute_ecdh_shared_key(private_key: ec.EllipticCurvePrivateKey, peer_pubkey_bytes: bytes) -> bytes: """ - Given a local ECDH private_key and the peer's ephemeral pubkey (64 bytes), - compute the shared secret. + Compute a shared secret using ECDH. + + Args: + private_key: Local ECDH private key + peer_pubkey_bytes: Peer's ephemeral public key (64 bytes, raw x||y format) + + Returns: + Shared secret bytes + + Raises: + ValueError: If peer_pubkey_bytes is not 64 bytes """ + if len(peer_pubkey_bytes) != 64: + raise ValueError("Peer public key must be 64 bytes (x||y format)") + x_int = int.from_bytes(peer_pubkey_bytes[:32], 'big') y_int = int.from_bytes(peer_pubkey_bytes[32:], 'big') + + # Create public key object from raw components peer_public_numbers = ec.EllipticCurvePublicNumbers(x_int, y_int, ec.SECP256R1()) peer_public_key = peer_public_numbers.public_key() + + # Perform key exchange shared_key = private_key.exchange(ec.ECDH(), peer_public_key) return shared_key def der_to_raw(der_sig: bytes) -> bytes: """ - Convert a DER-encoded ECDSA signature to a raw 64-byte signature (r||s), - where each component is padded to 32 bytes. + Convert a DER-encoded ECDSA signature to a raw 64-byte signature (r||s). + + Args: + der_sig: DER-encoded signature + + Returns: + Raw 64-byte signature (r||s format), with each component padded to 32 bytes """ r, s = decode_dss_signature(der_sig) r_bytes = r.to_bytes(32, byteorder='big') @@ -92,10 +147,19 @@ def der_to_raw(der_sig: bytes) -> bytes: def raw_signature_to_der(raw_sig: bytes) -> bytes: """ Convert a raw signature (64 bytes, concatenated r||s) to DER-encoded signature. + + Args: + raw_sig: Raw 64-byte signature (r||s format) + + Returns: + DER-encoded signature + + Raises: + ValueError: If raw_sig is not 64 bytes """ if len(raw_sig) != 64: raise ValueError("Raw signature must be 64 bytes (r||s).") - from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature + r = int.from_bytes(raw_sig[:32], 'big') s = int.from_bytes(raw_sig[32:], 'big') return encode_dss_signature(r, s) diff --git a/protocol_prototype/encryption.py b/protocol_prototype/encryption.py index ebd922f..9aa3730 100644 --- a/protocol_prototype/encryption.py +++ b/protocol_prototype/encryption.py @@ -1,94 +1,263 @@ import os import struct -from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from typing import Optional, Tuple +from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305 class MessageHeader: """ - Represents the header of an encrypted message. - - flag (16 bits) - - data_len (16 bits): length in bytes of the encrypted payload (excluding tag) - - Associated Data (AD): - * retry (8 bits) - * connexion_status (4 bits) + 4 bits padding (packed in one byte) - * iv/messageID (96 bits / 12 bytes) - Total header size: 2 + 2 + 1 + 1 + 12 = 18 bytes. + Header of an encrypted message (18 bytes total): + + Clear Text Section (4 bytes): + - flag: 16 bits (0xBEEF by default) + - data_len: 16 bits (length of encrypted payload excluding tag) + + Associated Data (14 bytes): + - retry: 8 bits (retry counter) + - connection_status: 4 bits (e.g., CRC required) + 4 bits padding + - iv/messageID: 96 bits (12 bytes) """ - def __init__(self, flag: int, data_len: int, retry: int, connexion_status: int, iv: bytes): - self.flag = flag # 16 bits - self.data_len = data_len # 16 bits - self.retry = retry # 8 bits - self.connexion_status = connexion_status # 4 bits - self.iv = iv # 96 bits (12 bytes) + def __init__(self, flag: int, data_len: int, retry: int, connection_status: int, iv: bytes): + if not (0 <= flag < 65536): + raise ValueError("Flag must fit in 16 bits (0..65535)") + if not (0 <= data_len < 65536): + raise ValueError("Data length must fit in 16 bits (0..65535)") + if not (0 <= retry < 256): + raise ValueError("Retry must fit in 8 bits (0..255)") + if not (0 <= connection_status < 16): + raise ValueError("Connection status must fit in 4 bits (0..15)") + if len(iv) != 12: + raise ValueError("IV must be 12 bytes (96 bits)") + + self.flag = flag # 16 bits + self.data_len = data_len # 16 bits + self.retry = retry # 8 bits + self.connection_status = connection_status # 4 bits + self.iv = iv # 96 bits (12 bytes) def pack(self) -> bytes: - # Pack flag and data_len as unsigned shorts (2 bytes each) + """Pack header into 18 bytes.""" + # Pack flag and data_len (4 bytes) header = struct.pack('>H H', self.flag, self.data_len) - # Pack retry (1 byte) and connexion_status (4 bits in high nibble, 4 bits padding as zero) - ad_byte = (self.connexion_status & 0x0F) << 4 + + # Pack retry and connection_status (2 bytes) + # connection_status in high 4 bits of second byte, 4 bits padding as zero + ad_byte = (self.connection_status & 0x0F) << 4 ad_packed = struct.pack('>B B', self.retry, ad_byte) + # Append IV (12 bytes) return header + ad_packed + self.iv - + + def get_associated_data(self) -> bytes: + """Get the associated data for AEAD encryption (retry, conn_status, iv).""" + # Pack retry and connection_status + ad_byte = (self.connection_status & 0x0F) << 4 + ad_packed = struct.pack('>B B', self.retry, ad_byte) + + # Append IV + return ad_packed + self.iv + @classmethod def unpack(cls, data: bytes) -> 'MessageHeader': - # Expect exactly 18 bytes + """Unpack 18 bytes into a MessageHeader object.""" + if len(data) < 18: + raise ValueError(f"Header data too short: {len(data)} bytes, expected 18") + flag, data_len = struct.unpack('>H H', data[:4]) retry, ad_byte = struct.unpack('>B B', data[4:6]) - connexion_status = (ad_byte >> 4) & 0x0F + connection_status = (ad_byte >> 4) & 0x0F iv = data[6:18] - return cls(flag, data_len, retry, connexion_status, iv) + + return cls(flag, data_len, retry, connection_status, iv) -def generate_iv(initial: bool, previous_iv: bytes = None) -> bytes: +class EncryptedMessage: + """ + Encrypted message packet format: + + - Header (18 bytes): + * flag: 16 bits + * data_len: 16 bits + * retry: 8 bits + * connection_status: 4 bits (+ 4 bits padding) + * iv/messageID: 96 bits (12 bytes) + + - Payload: variable length encrypted data + + - Footer: + * Authentication tag: 128 bits (16 bytes) + * CRC32: 32 bits (4 bytes) - optional, based on connection_status + """ + def __init__(self, plaintext: bytes, key: bytes, flag: int = 0xBEEF, + retry: int = 0, connection_status: int = 0, iv: bytes = None, + cipher_type: int = 0): + self.plaintext = plaintext + self.key = key + self.flag = flag + self.retry = retry + self.connection_status = connection_status + self.iv = iv or generate_iv(initial=True) + self.cipher_type = cipher_type # 0 = AES-256-GCM, 1 = ChaCha20-Poly1305 + + # Will be set after encryption + self.ciphertext = None + self.tag = None + self.header = None + + def encrypt(self) -> bytes: + """Encrypt the plaintext and return the full encrypted message.""" + # Create header with correct data_len (which will be set after encryption) + self.header = MessageHeader( + flag=self.flag, + data_len=0, # Will be updated after encryption + retry=self.retry, + connection_status=self.connection_status, + iv=self.iv + ) + + # Get associated data for AEAD + aad = self.header.get_associated_data() + + # Encrypt using the appropriate cipher + if self.cipher_type == 0: # AES-256-GCM + cipher = AESGCM(self.key) + ciphertext_with_tag = cipher.encrypt(self.iv, self.plaintext, aad) + elif self.cipher_type == 1: # ChaCha20-Poly1305 + cipher = ChaCha20Poly1305(self.key) + ciphertext_with_tag = cipher.encrypt(self.iv, self.plaintext, aad) + else: + raise ValueError(f"Unsupported cipher type: {self.cipher_type}") + + # Extract ciphertext and tag + self.tag = ciphertext_with_tag[-16:] + self.ciphertext = ciphertext_with_tag[:-16] + + # Update header with actual data length + self.header.data_len = len(self.ciphertext) + + # Pack everything together + packed_header = self.header.pack() + + # Check if CRC is required (based on connection_status) + if self.connection_status & 0x01: # Lowest bit indicates CRC required + import zlib + # Compute CRC32 of header + ciphertext + tag + crc = zlib.crc32(packed_header + self.ciphertext + self.tag) & 0xffffffff + crc_bytes = struct.pack('>I', crc) + return packed_header + self.ciphertext + self.tag + crc_bytes + else: + return packed_header + self.ciphertext + self.tag + + @classmethod + def decrypt(cls, data: bytes, key: bytes, cipher_type: int = 0) -> Tuple[bytes, MessageHeader]: + """ + Decrypt an encrypted message and return the plaintext and header. + + Args: + data: The full encrypted message + key: The encryption key + cipher_type: 0 for AES-256-GCM, 1 for ChaCha20-Poly1305 + + Returns: + Tuple of (plaintext, header) + """ + if len(data) < 18 + 16: # Header + minimum tag size + raise ValueError("Message too short") + + # Extract header + header_bytes = data[:18] + header = MessageHeader.unpack(header_bytes) + + # Get ciphertext and tag + data_len = header.data_len + ciphertext_start = 18 + ciphertext_end = ciphertext_start + data_len + + if ciphertext_end + 16 > len(data): + raise ValueError("Message length does not match header's data_len") + + ciphertext = data[ciphertext_start:ciphertext_end] + tag = data[ciphertext_end:ciphertext_end + 16] + + # Get associated data for AEAD + aad = header.get_associated_data() + + # Combine ciphertext and tag for decryption + ciphertext_with_tag = ciphertext + tag + + # Decrypt using the appropriate cipher + try: + if cipher_type == 0: # AES-256-GCM + cipher = AESGCM(key) + plaintext = cipher.decrypt(header.iv, ciphertext_with_tag, aad) + elif cipher_type == 1: # ChaCha20-Poly1305 + cipher = ChaCha20Poly1305(key) + plaintext = cipher.decrypt(header.iv, ciphertext_with_tag, aad) + else: + raise ValueError(f"Unsupported cipher type: {cipher_type}") + + return plaintext, header + except Exception as e: + raise ValueError(f"Decryption failed: {e}") + +def generate_iv(initial: bool = False, previous_iv: bytes = None) -> bytes: """ Generate a 96-bit IV (12 bytes). - - If 'initial' is True, return a random IV. - - Otherwise, increment the previous IV by 1 modulo 2^96. + + Args: + initial: If True, return a random IV + previous_iv: The previous IV to increment + + Returns: + A new IV """ if initial or previous_iv is None: - return os.urandom(12) + return os.urandom(12) # 96 bits else: + # Increment the previous IV by 1 modulo 2^96 iv_int = int.from_bytes(previous_iv, 'big') iv_int = (iv_int + 1) % (1 << 96) return iv_int.to_bytes(12, 'big') -def encrypt_message(plaintext: bytes, key: bytes, flag: int = 0xBEEF, retry: int = 0, connexion_status: int = 0) -> bytes: +# Convenience functions to match original API +def encrypt_message(plaintext: bytes, key: bytes, flag: int = 0xBEEF, + retry: int = 0, connection_status: int = 0, + iv: bytes = None, cipher_type: int = 0) -> bytes: """ - Encrypts a plaintext using AES-256-GCM. - - Generates a random 96-bit IV. - - Encrypts the plaintext with AESGCM. - - Builds a MessageHeader with the provided flag, the data_len (length of ciphertext excluding tag), - retry, connexion_status, and the IV. - - Returns the full encrypted message: header (18 bytes) || ciphertext || tag (16 bytes). + Encrypt a message using the specified parameters. + + Args: + plaintext: The data to encrypt + key: The encryption key (32 bytes for AES-256-GCM, 32 bytes for ChaCha20-Poly1305) + flag: 16-bit flag value (default: 0xBEEF) + retry: 8-bit retry counter + connection_status: 4-bit connection status + iv: Optional 96-bit IV (if None, a random one will be generated) + cipher_type: 0 for AES-256-GCM, 1 for ChaCha20-Poly1305 + + Returns: + The full encrypted message """ - aesgcm = AESGCM(key) - iv = generate_iv(initial=True) - # Encrypt with no associated data (you may later use the header as AD if needed) - ciphertext_with_tag = aesgcm.encrypt(iv, plaintext, None) - tag_length = 16 # default tag size - ciphertext = ciphertext_with_tag[:-tag_length] - tag = ciphertext_with_tag[-tag_length:] - data_len = len(ciphertext) - header = MessageHeader(flag=flag, data_len=data_len, retry=retry, connexion_status=connexion_status, iv=iv) - packed_header = header.pack() - return packed_header + ciphertext + tag + message = EncryptedMessage( + plaintext=plaintext, + key=key, + flag=flag, + retry=retry, + connection_status=connection_status, + iv=iv, + cipher_type=cipher_type + ) + return message.encrypt() -def decrypt_message(message: bytes, key: bytes) -> bytes: +def decrypt_message(message: bytes, key: bytes, cipher_type: int = 0) -> bytes: """ - Decrypts a message that was encrypted with encrypt_message. - Expects message format: header (18 bytes) || ciphertext || tag (16 bytes). - Returns the decrypted plaintext. + Decrypt a message. + + Args: + message: The full encrypted message + key: The encryption key + cipher_type: 0 for AES-256-GCM, 1 for ChaCha20-Poly1305 + + Returns: + The decrypted plaintext """ - if len(message) < 18 + 16: - raise ValueError("Message too short.") - header_bytes = message[:18] - header = MessageHeader.unpack(header_bytes) - data_len = header.data_len - expected_len = 18 + data_len + 16 - if len(message) != expected_len: - raise ValueError("Message length does not match header's data_len.") - ciphertext = message[18:18+data_len] - tag = message[18+data_len:] - ciphertext_with_tag = ciphertext + tag - aesgcm = AESGCM(key) - plaintext = aesgcm.decrypt(header.iv, ciphertext_with_tag, None) + plaintext, _ = EncryptedMessage.decrypt(message, key, cipher_type) return plaintext diff --git a/protocol_prototype/messages.py b/protocol_prototype/messages.py index b835003..98151ab 100644 --- a/protocol_prototype/messages.py +++ b/protocol_prototype/messages.py @@ -3,6 +3,7 @@ import struct import time import zlib import hashlib +from typing import Tuple, Optional def crc32_of(data: bytes) -> int: """ @@ -21,67 +22,78 @@ def crc32_of(data: bytes) -> int: # # Total bits: 129 + 7 + 4 + 32 = 172 bits. We pack into 22 bytes (176 bits) with 4 spare bits. # --------------------------------------------------------------------------- -def build_ping_request(version: int, cipher: int, nonce_full: bytes = None) -> bytes: +class PingRequest: """ - Build a Ping request with: - - session_nonce: 129 bits (derived from 17 random bytes by discarding the lowest 7 bits) - - version: 7 bits - - cipher: 4 bits (0 = AES-256-GCM, 1 = ChaCha20-poly1305; we use 0 for now) - - CRC: 32 bits - - Total = 129 + 7 + 4 + 32 = 172 bits. - We pack into 22 bytes (176 bits) leaving 4 unused bits. + PING REQUEST format (172 bits / 22 bytes): + - session_nonce: 129 bits (from top 129 bits of 17 random bytes) + - version: 7 bits + - cipher: 4 bits (0 = AES-256-GCM, 1 = ChaCha20-poly1305) + - CRC: 32 bits """ - if not (0 <= version < 128): - raise ValueError("Version must fit in 7 bits (0..127)") - if not (0 <= cipher < 16): - raise ValueError("Cipher must fit in 4 bits (0..15)") - # Generate 17 random bytes if none provided (17 bytes = 136 bits) - if nonce_full is None: - nonce_full = os.urandom(17) - if len(nonce_full) < 17: - raise ValueError("nonce_full must be at least 17 bytes") - # Use the top 129 bits of the 136 bits: - nonce_int_full = int.from_bytes(nonce_full, 'big') # 136 bits - nonce_129_int = nonce_int_full >> 7 # drop the lowest 7 bits => 129 bits - # Convert the derived 129-bit nonce to 17 bytes. - # Since 129 bits < 17*8=136 bits, the top 7 bits of the result will be 0. - nonce_129_bytes = nonce_129_int.to_bytes(17, 'big') - - # Pack the fields: shift the 129-bit nonce left by (7+4)=11 bits, then add version (7 bits) and cipher (4 bits). - partial_int = (nonce_129_int << 11) | (version << 4) | (cipher & 0x0F) - # This partial data is 129+7+4 = 140 bits; we pack into 18 bytes (144 bits) with 4 spare bits. - partial_bytes = partial_int.to_bytes(18, 'big') - # Compute CRC over these 18 bytes. - cval = crc32_of(partial_bytes) - # Combine the partial data with the 32-bit CRC. - final_int = (int.from_bytes(partial_bytes, 'big') << 32) | cval # 140 + 32 = 172 bits - final_bytes = final_int.to_bytes(22, 'big') - # Optionally, store or print nonce_129_bytes (the session nonce) rather than the original nonce_full. - return final_bytes - -def parse_ping_request(data: bytes): - """ - Parse a Ping request (22 bytes = 172 bits). - Returns (session_nonce_bytes, version, cipher) or None if invalid. - The session_nonce_bytes will be 17 bytes, representing the 129-bit value. - """ - if len(data) != 22: - return None - final_int = int.from_bytes(data, 'big') # 176 bits integer; lower 32 bits = CRC, higher 140 bits = partial - crc_in = final_int & 0xffffffff - partial_int = final_int >> 32 # 140 bits - partial_bytes = partial_int.to_bytes(18, 'big') - crc_calc = crc32_of(partial_bytes) - if crc_calc != crc_in: - return None - # Extract fields: lowest 4 bits: cipher; next 7 bits: version; remaining 129 bits: session_nonce. - cipher = partial_int & 0x0F - version = (partial_int >> 4) & 0x7F - nonce_129_int = partial_int >> 11 # 140 - 11 = 129 bits - # Convert to 17 bytes. Since the number is < 2^129, the top 7 bits will be zero. - session_nonce_bytes = nonce_129_int.to_bytes(17, 'big') - return (session_nonce_bytes, version, cipher) + def __init__(self, version: int, cipher: int, session_nonce: bytes = None): + if not (0 <= version < 128): + raise ValueError("Version must fit in 7 bits (0..127)") + if not (0 <= cipher < 16): + raise ValueError("Cipher must fit in 4 bits (0..15)") + + self.version = version + self.cipher = cipher + + # Generate session nonce if not provided + if session_nonce is None: + # Generate 17 random bytes + nonce_full = os.urandom(17) + # Use top 129 bits + nonce_int_full = int.from_bytes(nonce_full, 'big') + nonce_129_int = nonce_int_full >> 7 # drop lowest 7 bits + self.session_nonce = nonce_129_int.to_bytes(17, 'big') + else: + if len(session_nonce) != 17: + raise ValueError("Session nonce must be 17 bytes (136 bits)") + self.session_nonce = session_nonce + + def serialize(self) -> bytes: + """Serialize the ping request into a 22-byte packet.""" + # Convert session_nonce to integer (129 bits) + nonce_int = int.from_bytes(self.session_nonce, 'big') + + # Pack fields: shift nonce left by 11 bits, add version and cipher + partial_int = (nonce_int << 11) | (self.version << 4) | (self.cipher & 0x0F) + # This creates 129+7+4 = 140 bits; pack into 18 bytes + partial_bytes = partial_int.to_bytes(18, 'big') + + # Compute CRC over these 18 bytes + cval = crc32_of(partial_bytes) + + # Combine partial data with 32-bit CRC + final_int = (int.from_bytes(partial_bytes, 'big') << 32) | cval + return final_int.to_bytes(22, 'big') + + @classmethod + def deserialize(cls, data: bytes) -> Optional['PingRequest']: + """Deserialize a 22-byte packet into a PingRequest object.""" + if len(data) != 22: + return None + + # Extract 176-bit integer + final_int = int.from_bytes(data, 'big') + + # Extract CRC and verify + crc_in = final_int & 0xffffffff + partial_int = final_int >> 32 # 140 bits + partial_bytes = partial_int.to_bytes(18, 'big') + crc_calc = crc32_of(partial_bytes) + + if crc_calc != crc_in: + return None + + # Extract fields + cipher = partial_int & 0x0F + version = (partial_int >> 4) & 0x7F + nonce_129_int = partial_int >> 11 # 129 bits + session_nonce = nonce_129_int.to_bytes(17, 'big') + + return cls(version, cipher, session_nonce) @@ -96,41 +108,67 @@ def parse_ping_request(data: bytes): # # Total bits: 32 + 7 + 4 + 1 + 32 = 76 bits; pack into 10 bytes (80 bits) with 4 spare bits. # --------------------------------------------------------------------------- -def build_ping_response(version: int, cipher: int, answer: int) -> bytes: - if not (0 <= version < 128): - raise ValueError("Version must fit in 7 bits") - if not (0 <= cipher < 16): - raise ValueError("Cipher must fit in 4 bits") - if answer not in (0, 1): - raise ValueError("Answer must be 0 or 1") - t_ms = int(time.time() * 1000) & 0xffffffff # 32 bits - # Pack timestamp (32 bits), then version (7 bits), cipher (4 bits), answer (1 bit): total 32+7+4+1 = 44 bits. - partial_val = (t_ms << (7+4+1)) | (version << (4+1)) | (cipher << 1) | answer - partial_bytes = partial_val.to_bytes(6, 'big') # 6 bytes = 48 bits, 4 spare bits. - cval = crc32_of(partial_bytes) - final_val = (int.from_bytes(partial_bytes, 'big') << 32) | cval # 44+32 = 76 bits. - final_bytes = final_val.to_bytes(10, 'big') # 10 bytes = 80 bits. - return final_bytes - -def parse_ping_response(data: bytes): - if len(data) != 10: - return None - final_int = int.from_bytes(data, 'big') # 80 bits - crc_in = final_int & 0xffffffff - partial_int = final_int >> 32 # 48 bits - partial_bytes = partial_int.to_bytes(6, 'big') - crc_calc = crc32_of(partial_bytes) - if crc_calc != crc_in: - return None - # Extract fields: partial_int has 48 bits. We only used 44 bits for the fields. - # Discard the lower 4 spare bits. - partial_int >>= 4 # now 44 bits. - # Now fields: timestamp: 32 bits, version: 7 bits, cipher: 4 bits, answer: 1 bit. - answer = partial_int & 0x01 - cipher = (partial_int >> 1) & 0x0F - version = (partial_int >> (1+4)) & 0x7F - timestamp = partial_int >> (1+4+7) - return (timestamp, version, cipher, answer) +class PingResponse: + """ + PING RESPONSE format (76 bits / 10 bytes): + - timestamp: 32 bits (milliseconds since epoch, lower 32 bits) + - version: 7 bits + - cipher: 4 bits + - answer: 1 bit (0 = no, 1 = yes) + - CRC: 32 bits + """ + def __init__(self, version: int, cipher: int, answer: int, timestamp: int = None): + if not (0 <= version < 128): + raise ValueError("Version must fit in 7 bits") + if not (0 <= cipher < 16): + raise ValueError("Cipher must fit in 4 bits") + if answer not in (0, 1): + raise ValueError("Answer must be 0 or 1") + + self.version = version + self.cipher = cipher + self.answer = answer + self.timestamp = timestamp or (int(time.time() * 1000) & 0xffffffff) + + def serialize(self) -> bytes: + """Serialize the ping response into a 10-byte packet.""" + # Pack timestamp, version, cipher, answer: 32+7+4+1 = 44 bits + partial_val = (self.timestamp << (7+4+1)) | (self.version << (4+1)) | (self.cipher << 1) | self.answer + partial_bytes = partial_val.to_bytes(6, 'big') # 6 bytes = 48 bits, 4 spare bits + + # Compute CRC + cval = crc32_of(partial_bytes) + + # Combine with CRC + final_val = (int.from_bytes(partial_bytes, 'big') << 32) | cval + return final_val.to_bytes(10, 'big') + + @classmethod + def deserialize(cls, data: bytes) -> Optional['PingResponse']: + """Deserialize a 10-byte packet into a PingResponse object.""" + if len(data) != 10: + return None + + # Extract 80-bit integer + final_int = int.from_bytes(data, 'big') + + # Extract CRC and verify + crc_in = final_int & 0xffffffff + partial_int = final_int >> 32 # 48 bits + partial_bytes = partial_int.to_bytes(6, 'big') + crc_calc = crc32_of(partial_bytes) + + if crc_calc != crc_in: + return None + + # Extract fields (discard 4 spare bits) + partial_int >>= 4 # now 44 bits + answer = partial_int & 0x01 + cipher = (partial_int >> 1) & 0x0F + version = (partial_int >> (1+4)) & 0x7F + timestamp = partial_int >> (1+4+7) + + return cls(version, cipher, answer, timestamp) # ============================================================================= @@ -143,53 +181,60 @@ def parse_ping_response(data: bytes): # => total 4 + 64 + 64 + 32 + 4 = 168 bytes = 1344 bits # ============================================================================= -def build_handshake_message(timestamp: int, - ephemeral_pubkey: bytes, - ephemeral_signature: bytes, - pfs_hash: bytes) -> bytes: +class Handshake: """ - Build handshake: - - 4 bytes: timestamp - - 64 bytes: ephemeral_pubkey (x||y, raw) - - 64 bytes: ephemeral_signature (r||s, raw) - - 32 bytes: pfs_hash - - 4 bytes: CRC-32 - => 168 bytes total + HANDSHAKE format (1344 bits / 168 bytes): + - timestamp: 32 bits + - ephemeral_pubkey: 512 bits (64 bytes, raw x||y format) + - ephemeral_signature: 512 bits (64 bytes, raw r||s format) + - pfs_hash: 256 bits (32 bytes) + - CRC: 32 bits """ - if len(ephemeral_pubkey) != 64: - raise ValueError("ephemeral_pubkey must be 64 bytes (raw x||y).") - if len(ephemeral_signature) != 64: - raise ValueError("ephemeral_signature must be 64 bytes (raw r||s).") - if len(pfs_hash) != 32: - raise ValueError("pfs_hash must be 32 bytes.") - - partial = struct.pack("!I", timestamp) \ - + ephemeral_pubkey \ - + ephemeral_signature \ - + pfs_hash - cval = crc32_of(partial) - return partial + struct.pack("!I", cval) - - -def parse_handshake_message(data: bytes): - """ - Parse handshake message (168 bytes). - Return (timestamp, ephemeral_pub, ephemeral_sig, pfs_hash) or None if invalid. - """ - if len(data) != 168: - return None - partial = data[:-4] # first 164 bytes - crc_in = struct.unpack("!I", data[-4:])[0] - crc_calc = crc32_of(partial) - if crc_calc != crc_in: - return None - - # Now parse fields - timestamp = struct.unpack("!I", partial[:4])[0] - ephemeral_pub = partial[4:4+64] - ephemeral_sig = partial[68:68+64] - pfs_hash = partial[132:132+32] - return (timestamp, ephemeral_pub, ephemeral_sig, pfs_hash) + def __init__(self, ephemeral_pubkey: bytes, ephemeral_signature: bytes, pfs_hash: bytes, timestamp: int = None): + if len(ephemeral_pubkey) != 64: + raise ValueError("ephemeral_pubkey must be 64 bytes (raw x||y)") + if len(ephemeral_signature) != 64: + raise ValueError("ephemeral_signature must be 64 bytes (raw r||s)") + if len(pfs_hash) != 32: + raise ValueError("pfs_hash must be 32 bytes") + + self.ephemeral_pubkey = ephemeral_pubkey + self.ephemeral_signature = ephemeral_signature + self.pfs_hash = pfs_hash + self.timestamp = timestamp or (int(time.time() * 1000) & 0xffffffff) + + def serialize(self) -> bytes: + """Serialize the handshake into a 168-byte packet.""" + # Pack timestamp and other fields + partial = struct.pack("!I", self.timestamp) + self.ephemeral_pubkey + self.ephemeral_signature + self.pfs_hash + + # Compute CRC + cval = crc32_of(partial) + + # Append CRC + return partial + struct.pack("!I", cval) + + @classmethod + def deserialize(cls, data: bytes) -> Optional['Handshake']: + """Deserialize a 168-byte packet into a Handshake object.""" + if len(data) != 168: + return None + + # Extract and verify CRC + partial = data[:-4] + crc_in = struct.unpack("!I", data[-4:])[0] + crc_calc = crc32_of(partial) + + if crc_calc != crc_in: + return None + + # Extract fields + timestamp = struct.unpack("!I", partial[:4])[0] + ephemeral_pubkey = partial[4:4+64] + ephemeral_signature = partial[68:68+64] + pfs_hash = partial[132:132+32] + + return cls(ephemeral_pubkey, ephemeral_signature, pfs_hash, timestamp) # ============================================================================= @@ -200,15 +245,18 @@ def parse_handshake_message(data: bytes): def compute_pfs_hash(session_number: int, shared_secret_hex: str) -> bytes: """ - Return 32 bytes (256 bits) for the PFS field. - If session_number < 0 => means no previous session => 32 zero bytes. - Otherwise => sha256( session_number (4 bytes) || shared_secret ). + Compute the PFS hash field for handshake messages: + - If no previous session (session_number < 0), return 32 zero bytes + - Otherwise, compute sha256(session_number || shared_secret) """ if session_number < 0: return b"\x00" * 32 # Convert shared_secret_hex to raw bytes secret_bytes = bytes.fromhex(shared_secret_hex) + # Pack session_number as 4 bytes sn_bytes = struct.pack("!I", session_number) + + # Compute hash return hashlib.sha256(sn_bytes + secret_bytes).digest() diff --git a/protocol_prototype/protocol.py b/protocol_prototype/protocol.py index f8087b1..6007ce5 100644 --- a/protocol_prototype/protocol.py +++ b/protocol_prototype/protocol.py @@ -2,7 +2,7 @@ import random import os import time import threading -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional, Tuple from crypto_utils import ( generate_identity_keys, @@ -10,17 +10,19 @@ from crypto_utils import ( sign_data, verify_signature, get_ephemeral_keypair, - compute_ecdh_shared_key + compute_ecdh_shared_key, + der_to_raw, + raw_signature_to_der ) from messages import ( - build_ping_request, parse_ping_request, - build_ping_response, parse_ping_response, - build_handshake_message, parse_handshake_message, + PingRequest, PingResponse, Handshake, compute_pfs_hash ) import transmission - -from encryption import encrypt_message, decrypt_message, MessageHeader, generate_iv +from encryption import ( + EncryptedMessage, MessageHeader, + generate_iv, encrypt_message, decrypt_message +) # ANSI colors RED = "\033[91m" @@ -48,9 +50,12 @@ class IcingProtocol: # Derived HKDF key (hex string, 256 bits) self.hkdf_key = None + + # Negotiated cipher (0 = AES-256-GCM, 1 = ChaCha20-Poly1305) + self.cipher_type = 0 # For PFS: track per-peer session info (session number and last shared secret) - self.pfs_history: Dict[bytes, (int, str)] = {} + self.pfs_history: Dict[bytes, Tuple[int, str]] = {} # Protocol flags self.state = { @@ -69,8 +74,11 @@ class IcingProtocol: # Inbound messages (each message is a dict with keys: type, raw, parsed, connection) self.inbound_messages: List[Dict[str, Any]] = [] - # Store the session nonce (32 bytes) from our first sent or received PING + # Store the session nonce (17 bytes but only 129 bits are valid) from first sent or received PING self.session_nonce: bytes = None + + # Last used IV for encrypted messages + self.last_iv: bytes = None self.local_port = random.randint(30000, 40000) self.server_listener = transmission.ServerListener( @@ -94,49 +102,53 @@ class IcingProtocol: print( f"{GREEN}[RECV]{RESET} {bits_count} bits from peer: {data.hex()[:60]}{'...' if len(data.hex()) > 60 else ''}") - # New-format PING REQUEST (22 bytes) + # PING REQUEST (22 bytes) if len(data) == 22: - parsed = parse_ping_request(data) - if parsed: - nonce_full, version, cipher = parsed + ping_request = PingRequest.deserialize(data) + if ping_request: self.state["ping_received"] = True - # If the received cipher field is not 0, we force it to 0 in our response. - if cipher != 0: - print( - f"{YELLOW}[NOTICE]{RESET} Received PING with unsupported cipher ({cipher}); forcing cipher to 0 in response.") - cipher = 0 - # Store session nonce if not already set: + + # If received cipher is not supported, force to 0 (AES-256-GCM) + if ping_request.cipher != 0 and ping_request.cipher != 1: + print(f"{YELLOW}[NOTICE]{RESET} Received PING with unsupported cipher ({ping_request.cipher}); forcing cipher to 0 in response.") + ping_request.cipher = 0 + + # Store cipher type for future encrypted messages + self.cipher_type = ping_request.cipher + + # Store session nonce if not already set if self.session_nonce is None: - # Here, we already generated 17 bytes (136 bits) and only the top 129 bits are valid. - nonce_int = int.from_bytes(nonce_full, 'big') >> 7 - self.session_nonce = nonce_int.to_bytes(17, 'big') + self.session_nonce = ping_request.session_nonce print(f"{YELLOW}[NOTICE]{RESET} Stored session nonce from received PING.") + index = len(self.inbound_messages) msg = { "type": "PING_REQUEST", "raw": data, - "parsed": {"nonce_full": nonce_full, "version": version, "cipher": cipher}, + "parsed": ping_request, "connection": conn } self.inbound_messages.append(msg) - # (Optional auto-responder code could go here) + + # Auto-respond if enabled + if self.auto_responder: + timer = threading.Timer(2.0, self._auto_respond_ping, args=[index]) + timer.daemon = True + timer.start() return - # New-format PING RESPONSE (10 bytes) + # PING RESPONSE (10 bytes) elif len(data) == 10: - parsed = parse_ping_response(data) - if parsed: - timestamp, version, cipher, answer = parsed - # If cipher is not 0 (AES-256-GCM), override it. - if cipher != 0: - print( - f"{YELLOW}[NOTICE]{RESET} Received PING RESPONSE with unsupported cipher; treating as AES (cipher=0).") - cipher = 0 + ping_response = PingResponse.deserialize(data) + if ping_response: + # Store negotiated cipher type + self.cipher_type = ping_response.cipher + index = len(self.inbound_messages) msg = { "type": "PING_RESPONSE", "raw": data, - "parsed": {"timestamp": timestamp, "version": version, "cipher": cipher, "answer": answer}, + "parsed": ping_response, "connection": conn } self.inbound_messages.append(msg) @@ -144,42 +156,51 @@ class IcingProtocol: # HANDSHAKE message (168 bytes) elif len(data) == 168: - parsed = parse_handshake_message(data) - if parsed: - timestamp, ephemeral_pub, ephemeral_sig, pfs_hash = parsed + handshake = Handshake.deserialize(data) + if handshake: self.state["handshake_received"] = True index = len(self.inbound_messages) msg = { "type": "HANDSHAKE", "raw": data, - "parsed": (timestamp, ephemeral_pub, ephemeral_sig, pfs_hash), + "parsed": handshake, "connection": conn } self.inbound_messages.append(msg) - # (Optional auto-responder for handshake could go here) + + # Auto-respond if enabled + if self.auto_responder: + timer = threading.Timer(2.0, self._auto_respond_handshake, args=[index]) + timer.daemon = True + timer.start() return # Check if the message might be an encrypted message (e.g. header of 18 bytes at start) - elif len(data) > 18: - # Try to parse header from encryption module + elif len(data) >= 18: + # Try to parse header try: - from encryption import MessageHeader header = MessageHeader.unpack(data[:18]) - # If header unpacking is successful and data length fits header.data_len + header size + tag size: - expected_len = 18 + header.data_len + 16 # tag is 16 bytes - if len(data) == expected_len: + # If header unpacking is successful and data length matches header expectations + expected_len = 18 + header.data_len + 16 # Header + payload + tag + + # Check if CRC is included + has_crc = (header.connection_status & 0x01) != 0 + if has_crc: + expected_len += 4 # Add CRC32 length + + if len(data) >= expected_len: index = len(self.inbound_messages) msg = { "type": "ENCRYPTED_MESSAGE", "raw": data, - "parsed": header, # we can store header for further processing + "parsed": header, "connection": conn } self.inbound_messages.append(msg) print(f"{YELLOW}[NOTICE]{RESET} Stored inbound ENCRYPTED_MESSAGE at index={index}.") return - except Exception: - pass + except Exception as e: + print(f"{RED}[ERROR]{RESET} Failed to parse message header: {e}") # Otherwise, unrecognized/malformed message. index = len(self.inbound_messages) @@ -193,16 +214,16 @@ class IcingProtocol: print(f"{RED}[WARNING]{RESET} Unrecognized or malformed message stored at index={index}.") - # ------------------------------------------------------------------------- - # HKDF Derivation - # ------------------------------------------------------------------------- + # ------------------------------------------------------------------------- + # HKDF Derivation + # ------------------------------------------------------------------------- def derive_hkdf(self): """ Derives a 256-bit key using HKDF. Uses as input keying material (IKM) the shared secret from ECDH. The salt is computed as SHA256(session_nonce || pfs_param), where: - - session_nonce is taken from self.session_nonce (32 bytes) or defaults to zeros. + - session_nonce is taken from self.session_nonce (17 bytes, 129 bits) or defaults to zeros. - pfs_param is taken from the first inbound HANDSHAKE's pfs_hash field (32 bytes) or zeros. """ if not self.shared_secret: @@ -212,15 +233,15 @@ class IcingProtocol: # IKM: shared secret converted from hex to bytes. ikm = bytes.fromhex(self.shared_secret) # Use stored session_nonce if available; otherwise default to zeros. - session_nonce = self.session_nonce if self.session_nonce is not None else (b"\x00" * 32) + session_nonce = self.session_nonce if self.session_nonce is not None else (b"\x00" * 17) # Determine pfs_param from first HANDSHAKE message (if any) pfs_param = None for msg in self.inbound_messages: if msg["type"] == "HANDSHAKE": - # Expect parsed handshake as tuple: (timestamp, ephemeral_pub, ephemeral_sig, pfs_hash) try: - _, _, _, pfs_param = msg["parsed"] + handshake = msg["parsed"] + pfs_param = handshake.pfs_hash except Exception: pfs_param = None break @@ -259,7 +280,7 @@ class IcingProtocol: Called by a Timer to respond automatically to a PING_REQUEST after 2s. """ print(f"{BLUE}[AUTO]{RESET} Delayed responding to PING at index={index}") - self.respond_to_ping(index, answer_code=0) + self.respond_to_ping(index, answer=1) # Accept by default self.show_state() def _auto_respond_handshake(self, index: int): @@ -304,30 +325,33 @@ class IcingProtocol: self.ephemeral_privkey, self.ephemeral_pubkey = get_ephemeral_keypair() print(f"{GREEN}[IcingProtocol]{RESET} Generated ephemeral key pair: pubkey={self.ephemeral_pubkey.hex()[:16]}...") - # Updated send_ping_request: generate a 17-byte nonce, extract top 129 bits, and store that (as bytes) + # Send PING (session discovery and cipher negotiation) def send_ping_request(self): if not self.connections: print(f"{RED}[ERROR]{RESET} No active connections.") return - nonce_full = os.urandom(17) - # Compute the 129-bit session nonce: take the top 129 bits. - nonce_int = int.from_bytes(nonce_full, 'big') >> 7 # 129 bits - session_nonce_bytes = nonce_int.to_bytes(17, 'big') # still 17 bytes but only 129 bits are meaningful + + # Create ping request with our default cipher preference (AES-256-GCM = 0) + ping_request = PingRequest(version=0, cipher=0) + + # Store session nonce if not already set if self.session_nonce is None: - self.session_nonce = session_nonce_bytes + self.session_nonce = ping_request.session_nonce print(f"{YELLOW}[NOTICE]{RESET} Stored session nonce from sent PING.") - pkt = build_ping_request(version=0, cipher=0, nonce_full=nonce_full) + + # Serialize and send + pkt = ping_request.serialize() self._send_packet(self.connections[0], pkt, "PING_REQUEST") self.state["ping_sent"] = True def send_handshake(self): """ Build and send handshake: - - 32-bit timestamp - ephemeral_pubkey (64 bytes, raw x||y) - ephemeral_signature (64 bytes, raw r||s) - pfs_hash (32 bytes) - - 32-bit CRC + - timestamp (32 bits) + - CRC (32 bits) """ if not self.connections: print(f"{RED}[ERROR]{RESET} No active connections.") @@ -339,31 +363,24 @@ class IcingProtocol: print(f"{RED}[ERROR]{RESET} Peer identity not set; needed for PFS tracking.") return - # 1) Sign ephemeral_pubkey as r||s - # Instead of DER, we do raw r||s each 32 bytes + # 1) Sign ephemeral_pubkey using identity key sig_der = sign_data(self.identity_privkey, self.ephemeral_pubkey) - # Convert DER -> (r, s) -> raw 64 bytes - # Quick approach to parse DER using cryptography, or do a custom parse - from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature - r_int, s_int = decode_dss_signature(sig_der) - r_bytes = r_int.to_bytes(32, 'big') - s_bytes = s_int.to_bytes(32, 'big') - raw_signature = r_bytes + s_bytes # 64 bytes + # Convert DER signature to raw r||s format (64 bytes) + raw_signature = der_to_raw(sig_der) - # 2) PFS hash + # 2) Compute PFS hash session_number, last_secret_hex = self.pfs_history.get(self.peer_identity_pubkey_bytes, (-1, "")) pfs = compute_pfs_hash(session_number, last_secret_hex) - # 3) Build handshake - timestamp_32 = int(time.time() * 1000) & 0xffffffff - pkt = build_handshake_message( - timestamp_32, - self.ephemeral_pubkey, # 64 bytes raw - raw_signature, # 64 bytes raw - pfs # 32 bytes + # 3) Create handshake object + handshake = Handshake( + ephemeral_pubkey=self.ephemeral_pubkey, + ephemeral_signature=raw_signature, + pfs_hash=pfs ) - # 4) Send + # 4) Serialize and send + pkt = handshake.serialize() self._send_packet(self.connections[0], pkt, "HANDSHAKE") self.state["handshake_sent"] = True @@ -376,6 +393,10 @@ class IcingProtocol: # ------------------------------------------------------------------------- def respond_to_ping(self, index: int, answer: int): + """ + Respond to a ping request with the specified answer (0 = no, 1 = yes). + If answer is 1, we accept the connection and use the cipher specified in the request. + """ if index < 0 or index >= len(self.inbound_messages): print(f"{RED}[ERROR]{RESET} Invalid index {index}.") return @@ -384,23 +405,32 @@ class IcingProtocol: print(f"{RED}[ERROR]{RESET} inbound_messages[{index}] is not a PING_REQUEST.") return - version = msg["parsed"]["version"] - cipher = msg["parsed"]["cipher"] - # Force cipher to 0 if it's not 0 (only AES-256-GCM is supported) - if cipher != 0: - print( - f"{YELLOW}[NOTICE]{RESET} Received PING with unsupported cipher ({cipher}); forcing cipher to 0 in response.") + ping_request = msg["parsed"] + version = ping_request.version + cipher = ping_request.cipher + + # Force cipher to 0 or 1 (only AES-256-GCM and ChaCha20-Poly1305 are supported) + if cipher != 0 and cipher != 1: + print(f"{YELLOW}[NOTICE]{RESET} Received PING with unsupported cipher ({cipher}); forcing cipher to 0 in response.") cipher = 0 + + # Store the negotiated cipher type if we're accepting + if answer == 1: + self.cipher_type = cipher conn = msg["connection"] - resp = build_ping_response(version, cipher, answer) + # Create ping response + ping_response = PingResponse(version, cipher, answer) + resp = ping_response.serialize() self._send_packet(conn, resp, "PING_RESPONSE") print(f"{BLUE}[MANUAL]{RESET} Responded to ping with answer={answer}.") def generate_ecdhe(self, index: int): """ - Formerly 'respond_to_handshake'. Verifies the inbound ephemeral signature - and computes the ECDH shared secret, updating PFS history. + Process a handshake message: + 1. Verify the ephemeral signature + 2. Compute the ECDH shared secret + 3. Update PFS history """ if index < 0 or index >= len(self.inbound_messages): print(f"{RED}[ERROR]{RESET} Invalid index {index}.") @@ -410,30 +440,30 @@ class IcingProtocol: print(f"{RED}[ERROR]{RESET} inbound_messages[{index}] is not a HANDSHAKE.") return - # Unpack the tuple directly: - timestamp, ephemeral_pub, ephemeral_sig, pfs_hash = msg["parsed"] + handshake = msg["parsed"] + + # Convert raw signature to DER for verification + raw_sig = handshake.ephemeral_signature + sig_der = raw_signature_to_der(raw_sig) - # Use our raw_signature_to_der wrapper only if signature is 64 bytes. - # Otherwise, assume the signature is already DER-encoded. - from crypto_utils import raw_signature_to_der - if len(ephemeral_sig) == 64: - sig_der = raw_signature_to_der(ephemeral_sig) - else: - sig_der = ephemeral_sig - - ok = verify_signature(self.peer_identity_pubkey_obj, sig_der, ephemeral_pub) + # Verify signature + ok = verify_signature(self.peer_identity_pubkey_obj, sig_der, handshake.ephemeral_pubkey) if not ok: print(f"{RED}[ERROR]{RESET} Ephemeral signature invalid.") return print(f"{GREEN}[OK]{RESET} Ephemeral signature verified.") + # Check if we have ephemeral keys if not self.ephemeral_privkey: print(f"{YELLOW}[WARN]{RESET} No ephemeral_privkey available, cannot compute shared secret.") return - shared = compute_ecdh_shared_key(self.ephemeral_privkey, ephemeral_pub) + + # Compute ECDH shared secret + shared = compute_ecdh_shared_key(self.ephemeral_privkey, handshake.ephemeral_pubkey) self.shared_secret = shared.hex() print(f"{GREEN}[OK]{RESET} Computed ECDH shared key = {self.shared_secret}") + # Update PFS history old_session, _ = self.pfs_history.get(self.peer_identity_pubkey_bytes, (-1, "")) new_session = 1 if old_session < 0 else old_session + 1 self.pfs_history[self.peer_identity_pubkey_bytes] = (new_session, self.shared_secret) @@ -469,11 +499,19 @@ class IcingProtocol: print(f"HKDF Derived Key: {self.hkdf_key} (size: {len(self.hkdf_key)*8} bits)") else: print("HKDF Derived Key: [None]") + + print(f"Negotiated Cipher: {'AES-256-GCM' if self.cipher_type == 0 else 'ChaCha20-Poly1305'} (code: {self.cipher_type})") + if self.session_nonce: - # session_nonce now contains 17 bytes, but only the top 129 bits are used. print(f"Session Nonce: {self.session_nonce.hex()} (129 bits)") else: print("Session Nonce: [None]") + + if self.last_iv: + print(f"Last IV: {self.last_iv.hex()} (96 bits)") + else: + print("Last IV: [None]") + print("\nProtocol Flags:") for k, v in self.state.items(): print(f" {k}: {v}") @@ -500,9 +538,11 @@ class IcingProtocol: # New method: Send an encrypted message over the first active connection. def send_encrypted_message(self, plaintext: str): """ - Encrypts the provided plaintext (a UTF-8 string) using the derived HKDF key (AES-256), - and sends the encrypted message over the first active connection. - The message format is: header (18 bytes) || ciphertext || tag (16 bytes). + Encrypts and sends a message using the derived HKDF key and negotiated cipher. + The message format is: + - Header (18 bytes): flag, data_len, retry, connection_status, IV + - Payload: variable length encrypted data + - Footer: Authentication tag (16 bytes) + optional CRC32 (4 bytes) """ if not self.connections: print(f"{RED}[ERROR]{RESET} No active connections.") @@ -510,30 +550,76 @@ class IcingProtocol: if not self.hkdf_key: print(f"{RED}[ERROR]{RESET} No HKDF key derived. Cannot encrypt message.") return + + # Get the encryption key key = bytes.fromhex(self.hkdf_key) + + # Convert plaintext to bytes plaintext_bytes = plaintext.encode('utf-8') - encrypted = encrypt_message(plaintext_bytes, key) - # Send the encrypted message over the first connection. + + # Generate or increment the IV + if self.last_iv is None: + # First message, generate random IV + iv = generate_iv(initial=True) + else: + # Subsequent message, increment previous IV + iv = generate_iv(initial=False, previous_iv=self.last_iv) + + # Store the new IV + self.last_iv = iv + + # Create encrypted message (connection_status 0 = no CRC) + encrypted = encrypt_message( + plaintext=plaintext_bytes, + key=key, + flag=0xBEEF, # Default flag + retry=0, + connection_status=0, # No CRC + iv=iv, + cipher_type=self.cipher_type + ) + + # Send the encrypted message self._send_packet(self.connections[0], encrypted, "ENCRYPTED_MESSAGE") print(f"{GREEN}[SEND_ENCRYPTED]{RESET} Encrypted message sent.") - # New method: Decrypt an encrypted message provided as a hex string. -# New command: decrypt a received encrypted message from the inbound queue. + # New method: Decrypt an encrypted message from the inbound queue. def decrypt_received_message(self, index: int): + """ + Decrypt a received encrypted message using the HKDF key and negotiated cipher. + """ if index < 0 or index >= len(self.inbound_messages): print(f"{RED}[ERROR]{RESET} Invalid message index.") return + msg = self.inbound_messages[index] - # Expect the message to be an encrypted transmission (we assume it is in the proper format). + if msg["type"] != "ENCRYPTED_MESSAGE": + print(f"{RED}[ERROR]{RESET} Message at index {index} is not an ENCRYPTED_MESSAGE.") + return + + # Get the encrypted message encrypted = msg["raw"] + if not self.hkdf_key: print(f"{RED}[ERROR]{RESET} No HKDF key derived. Cannot decrypt message.") return + + # Get the encryption key key = bytes.fromhex(self.hkdf_key) + try: - plaintext_bytes = decrypt_message(encrypted, key) - plaintext = plaintext_bytes.decode('utf-8') - print(f"{GREEN}[DECRYPTED]{RESET} Decrypted message: {plaintext}") - return plaintext + # Decrypt the message + plaintext = decrypt_message(encrypted, key, self.cipher_type) + + # Convert to string + plaintext_str = plaintext.decode('utf-8') + + # Update last IV from the header + header = MessageHeader.unpack(encrypted[:18]) + self.last_iv = header.iv + + print(f"{GREEN}[DECRYPTED]{RESET} Decrypted message: {plaintext_str}") + return plaintext_str except Exception as e: print(f"{RED}[ERROR]{RESET} Decryption failed: {e}") + return None