diff --git a/protocol_prototype/IcingProtocol.drawio b/protocol_prototype/IcingProtocol.drawio index 683237e..8f46988 100644 --- a/protocol_prototype/IcingProtocol.drawio +++ b/protocol_prototype/IcingProtocol.drawio @@ -260,7 +260,7 @@ - + @@ -403,7 +403,7 @@ - + @@ -456,53 +456,55 @@ - - + + - + - - + + - - + + - + - - + + - - + + + + + + + + + + - - - - - - - - + + - + - + - - + + - + - - + + @@ -522,6 +524,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/protocol_prototype/cli.py b/protocol_prototype/cli.py index 48894ee..d9aeeae 100644 --- a/protocol_prototype/cli.py +++ b/protocol_prototype/cli.py @@ -14,7 +14,6 @@ def main(): print(f"{YELLOW}\n======================================") print(" Icing Protocol - Manual CLI Demo ") print("======================================\n" + RESET) - print(f"Listening on port: {protocol.local_port}") print(f"Your identity public key (hex): {protocol.identity_pubkey.hex()}") print("\nAvailable commands:") @@ -22,91 +21,90 @@ def main(): print(" connect ") print(" generate_ephemeral_keys") print(" send_ping") - print(" respond_ping <0|1>") print(" send_handshake") + print(" respond_ping <0|1>") print(" generate_ecdhe ") print(" derive_hkdf") + print(" send_encrypted ") + print(" decrypt_message <hex_message>") print(" auto_responder <on|off>") print(" show_state") print(" exit\n") while True: - while True: - try: - line = input("Cmd> ").strip() - except EOFError: - break - if not line: + try: + line = input("Cmd> ").strip() + except EOFError: + break + if not line: + continue + parts = line.split() + cmd = parts[0].lower() + if cmd == "exit": + protocol.stop() + break + elif cmd == "show_state": + protocol.show_state() + elif cmd == "set_peer_identity": + if len(parts) != 2: + print("Usage: set_peer_identity <hex_pubkey>") continue - parts = line.split() - cmd = parts[0].lower() - - if cmd == "exit": - protocol.stop() - sys.exit(0) - - elif cmd == "show_state": - protocol.show_state() - - elif cmd == "set_peer_identity": - if len(parts) != 2: - print("Usage: set_peer_identity <hex_pubkey>") - continue - protocol.set_peer_identity(parts[1]) - - elif cmd == "connect": - if len(parts) != 2: - print("Usage: connect <port>") - continue - try: - port = int(parts[1]) - protocol.connect_to_peer(port) - except ValueError: - print("Invalid port.") - - elif cmd == "generate_ephemeral_keys": - protocol.generate_ephemeral_keys() - - elif cmd == "send_ping": - protocol.send_ping_request() - - elif cmd == "send_handshake": - protocol.send_handshake() - - elif cmd == "respond_ping": - if len(parts) != 3: - print("Usage: respond_ping <index> <0|1>") - continue - try: - idx = int(parts[1]) - ac = int(parts[2]) - protocol.respond_to_ping(idx, ac) - except ValueError: - print("Index and answer must be integers.") - - elif cmd == "generate_ecdhe": - if len(parts) != 2: - print("Usage: generate_ecdhe <index>") - continue - try: - idx = int(parts[1]) - protocol.generate_ecdhe(idx) - except ValueError: - print("Index must be an integer.") - - elif cmd == "derive_hkdf": - protocol.derive_hkdf() - - elif cmd == "auto_responder": - if len(parts) != 2: - print("Usage: auto_responder <on|off>") - continue - arg = parts[1].lower() - protocol.enable_auto_responder(arg == "on") - - else: - print(f"{RED}[ERROR]{RESET} Unknown command: {cmd}") - + protocol.set_peer_identity(parts[1]) + elif cmd == "connect": + if len(parts) != 2: + print("Usage: connect <port>") + continue + try: + port = int(parts[1]) + protocol.connect_to_peer(port) + except ValueError: + print("Invalid port.") + elif cmd == "generate_ephemeral_keys": + protocol.generate_ephemeral_keys() + elif cmd == "send_ping": + protocol.send_ping_request() + elif cmd == "send_handshake": + protocol.send_handshake() + elif cmd == "respond_ping": + if len(parts) != 3: + print("Usage: respond_ping <index> <0|1>") + continue + try: + idx = int(parts[1]) + ac = int(parts[2]) + protocol.respond_to_ping(idx, ac) + except ValueError: + print("Index and answer must be integers.") + elif cmd == "generate_ecdhe": + if len(parts) != 2: + print("Usage: generate_ecdhe <index>") + continue + try: + idx = int(parts[1]) + protocol.generate_ecdhe(idx) + except ValueError: + print("Index must be an integer.") + elif cmd == "derive_hkdf": + protocol.derive_hkdf() + elif cmd == "send_encrypted": + if len(parts) < 2: + print("Usage: send_encrypted <plaintext>") + continue + # Join the rest of the line as plaintext + plaintext = " ".join(parts[1:]) + protocol.send_encrypted_message(plaintext) + elif cmd == "decrypt_message": + if len(parts) != 2: + print("Usage: decrypt_message <hex_message>") + continue + protocol.decrypt_encrypted_message(parts[1]) + elif cmd == "auto_responder": + if len(parts) != 2: + print("Usage: auto_responder <on|off>") + continue + protocol.enable_auto_responder(parts[1].lower() == "on") + else: + print(f"{RED}[ERROR]{RESET} Unknown command: {cmd}") if __name__ == "__main__": main() diff --git a/protocol_prototype/encryption.py b/protocol_prototype/encryption.py new file mode 100644 index 0000000..ebd922f --- /dev/null +++ b/protocol_prototype/encryption.py @@ -0,0 +1,94 @@ +import os +import struct +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +class MessageHeader: + """ + Represents the header of an encrypted message. + - flag (16 bits) + - data_len (16 bits): length in bytes of the encrypted payload (excluding tag) + - Associated Data (AD): + * retry (8 bits) + * connexion_status (4 bits) + 4 bits padding (packed in one byte) + * iv/messageID (96 bits / 12 bytes) + Total header size: 2 + 2 + 1 + 1 + 12 = 18 bytes. + """ + def __init__(self, flag: int, data_len: int, retry: int, connexion_status: int, iv: bytes): + self.flag = flag # 16 bits + self.data_len = data_len # 16 bits + self.retry = retry # 8 bits + self.connexion_status = connexion_status # 4 bits + self.iv = iv # 96 bits (12 bytes) + + def pack(self) -> bytes: + # Pack flag and data_len as unsigned shorts (2 bytes each) + header = struct.pack('>H H', self.flag, self.data_len) + # Pack retry (1 byte) and connexion_status (4 bits in high nibble, 4 bits padding as zero) + ad_byte = (self.connexion_status & 0x0F) << 4 + ad_packed = struct.pack('>B B', self.retry, ad_byte) + # Append IV (12 bytes) + return header + ad_packed + self.iv + + @classmethod + def unpack(cls, data: bytes) -> 'MessageHeader': + # Expect exactly 18 bytes + flag, data_len = struct.unpack('>H H', data[:4]) + retry, ad_byte = struct.unpack('>B B', data[4:6]) + connexion_status = (ad_byte >> 4) & 0x0F + iv = data[6:18] + return cls(flag, data_len, retry, connexion_status, iv) + +def generate_iv(initial: bool, previous_iv: bytes = None) -> bytes: + """ + Generate a 96-bit IV (12 bytes). + - If 'initial' is True, return a random IV. + - Otherwise, increment the previous IV by 1 modulo 2^96. + """ + if initial or previous_iv is None: + return os.urandom(12) + else: + iv_int = int.from_bytes(previous_iv, 'big') + iv_int = (iv_int + 1) % (1 << 96) + return iv_int.to_bytes(12, 'big') + +def encrypt_message(plaintext: bytes, key: bytes, flag: int = 0xBEEF, retry: int = 0, connexion_status: int = 0) -> bytes: + """ + Encrypts a plaintext using AES-256-GCM. + - Generates a random 96-bit IV. + - Encrypts the plaintext with AESGCM. + - Builds a MessageHeader with the provided flag, the data_len (length of ciphertext excluding tag), + retry, connexion_status, and the IV. + - Returns the full encrypted message: header (18 bytes) || ciphertext || tag (16 bytes). + """ + aesgcm = AESGCM(key) + iv = generate_iv(initial=True) + # Encrypt with no associated data (you may later use the header as AD if needed) + ciphertext_with_tag = aesgcm.encrypt(iv, plaintext, None) + tag_length = 16 # default tag size + ciphertext = ciphertext_with_tag[:-tag_length] + tag = ciphertext_with_tag[-tag_length:] + data_len = len(ciphertext) + header = MessageHeader(flag=flag, data_len=data_len, retry=retry, connexion_status=connexion_status, iv=iv) + packed_header = header.pack() + return packed_header + ciphertext + tag + +def decrypt_message(message: bytes, key: bytes) -> bytes: + """ + Decrypts a message that was encrypted with encrypt_message. + Expects message format: header (18 bytes) || ciphertext || tag (16 bytes). + Returns the decrypted plaintext. + """ + if len(message) < 18 + 16: + raise ValueError("Message too short.") + header_bytes = message[:18] + header = MessageHeader.unpack(header_bytes) + data_len = header.data_len + expected_len = 18 + data_len + 16 + if len(message) != expected_len: + raise ValueError("Message length does not match header's data_len.") + ciphertext = message[18:18+data_len] + tag = message[18+data_len:] + ciphertext_with_tag = ciphertext + tag + aesgcm = AESGCM(key) + plaintext = aesgcm.decrypt(header.iv, ciphertext_with_tag, None) + return plaintext diff --git a/protocol_prototype/protocol.py b/protocol_prototype/protocol.py index 74c71db..b1c2521 100644 --- a/protocol_prototype/protocol.py +++ b/protocol_prototype/protocol.py @@ -20,8 +20,7 @@ from messages import ( ) import transmission -from cryptography.hazmat.primitives.kdf.hkdf import HKDF -from cryptography.hazmat.primitives import hashes +from encryption import encrypt_message, decrypt_message # ANSI colors RED = "\033[91m" @@ -468,3 +467,46 @@ class IcingProtocol: self.connections.clear() self.inbound_messages.clear() print(f"{RED}[STOP]{RESET} Protocol stopped.") + + # New method: Send an encrypted message over the first active connection. + def send_encrypted_message(self, plaintext: str): + """ + Encrypts the provided plaintext (a UTF-8 string) using the derived HKDF key (AES-256), + and sends the encrypted message over the first active connection. + The message format is: header (18 bytes) || ciphertext || tag (16 bytes). + """ + if not self.connections: + print(f"{RED}[ERROR]{RESET} No active connections.") + return + if not self.hkdf_key: + print(f"{RED}[ERROR]{RESET} No HKDF key derived. Cannot encrypt message.") + return + key = bytes.fromhex(self.hkdf_key) + plaintext_bytes = plaintext.encode('utf-8') + encrypted = encrypt_message(plaintext_bytes, key) + # Send the encrypted message over the first connection. + self._send_packet(self.connections[0], encrypted, "ENCRYPTED_MESSAGE") + print(f"{GREEN}[SEND_ENCRYPTED]{RESET} Encrypted message sent.") + + # New method: Decrypt an encrypted message provided as a hex string. + def decrypt_encrypted_message(self, hex_message: str): + """ + Decrypts an encrypted message (given as a hex string) using the HKDF key. + Returns the plaintext (UTF-8 string) and prints it. + """ + if not self.hkdf_key: + print(f"{RED}[ERROR]{RESET} No HKDF key derived. Cannot decrypt message.") + return + try: + message_bytes = bytes.fromhex(hex_message) + except Exception as e: + print(f"{RED}[ERROR]{RESET} Invalid hex input.") + return + key = bytes.fromhex(self.hkdf_key) + try: + plaintext_bytes = decrypt_message(message_bytes, key) + plaintext = plaintext_bytes.decode('utf-8') + print(f"{GREEN}[DECRYPTED]{RESET} Decrypted message: {plaintext}") + return plaintext + except Exception as e: + print(f"{RED}[ERROR]{RESET} Decryption failed: {e}")