diff --git a/protocol_prototype/cli.py b/protocol_prototype/cli.py index d9aeeae..c65bdd2 100644 --- a/protocol_prototype/cli.py +++ b/protocol_prototype/cli.py @@ -17,7 +17,7 @@ def main(): print(f"Listening on port: {protocol.local_port}") print(f"Your identity public key (hex): {protocol.identity_pubkey.hex()}") print("\nAvailable commands:") - print(" set_peer_identity ") + print(" peer_id ") print(" connect ") print(" generate_ephemeral_keys") print(" send_ping") @@ -26,7 +26,7 @@ def main(): print(" generate_ecdhe ") print(" derive_hkdf") print(" send_encrypted ") - print(" decrypt_message <hex_message>") + print(" decrypt_received <index>") print(" auto_responder <on|off>") print(" show_state") print(" exit\n") @@ -45,9 +45,9 @@ def main(): break elif cmd == "show_state": protocol.show_state() - elif cmd == "set_peer_identity": + elif cmd == "peer_id": if len(parts) != 2: - print("Usage: set_peer_identity <hex_pubkey>") + print("Usage: peer_id <hex_pubkey>") continue protocol.set_peer_identity(parts[1]) elif cmd == "connect": @@ -90,14 +90,17 @@ def main(): if len(parts) < 2: print("Usage: send_encrypted <plaintext>") continue - # Join the rest of the line as plaintext plaintext = " ".join(parts[1:]) protocol.send_encrypted_message(plaintext) - elif cmd == "decrypt_message": + elif cmd == "decrypt_received": if len(parts) != 2: - print("Usage: decrypt_message <hex_message>") + print("Usage: decrypt_received <index>") continue - protocol.decrypt_encrypted_message(parts[1]) + try: + idx = int(parts[1]) + protocol.decrypt_received_message(idx) + except ValueError: + print("Index must be an integer.") elif cmd == "auto_responder": if len(parts) != 2: print("Usage: auto_responder <on|off>") diff --git a/protocol_prototype/messages.py b/protocol_prototype/messages.py index 55f5f5c..b835003 100644 --- a/protocol_prototype/messages.py +++ b/protocol_prototype/messages.py @@ -11,145 +11,126 @@ def crc32_of(data: bytes) -> int: return zlib.crc32(data) & 0xffffffff -# ============================================================================= -# 1) Ping Request (295 bits) -# - 256-bit nonce -# - 7-bit version -# - 32-bit CRC -# = 295 bits total -# In practice, we store 37 bytes (296 bits); 1 bit is unused. -# ============================================================================= - -def build_ping_request(version: int, nonce: bytes = None) -> bytes: +# --------------------------------------------------------------------------- +# PING REQUEST (new format) +# Fields (in order): +# - session_nonce: 129 bits (from the top 129 bits of 17 random bytes) +# - version: 7 bits +# - cipher: 4 bits (0 = AES-256-GCM, 1 = ChaCha20-poly1305; for now only 0 is used) +# - CRC: 32 bits +# +# Total bits: 129 + 7 + 4 + 32 = 172 bits. We pack into 22 bytes (176 bits) with 4 spare bits. +# --------------------------------------------------------------------------- +def build_ping_request(version: int, cipher: int, nonce_full: bytes = None) -> bytes: """ Build a Ping request with: - - 256-bit nonce (32 bytes) - - 7-bit version - - 32-bit CRC - Total = 295 bits logically. - Since 295 bits do not fill an integer number of bytes, we pack into 37 bytes (296 bits), - with one unused bit. + - session_nonce: 129 bits (derived from 17 random bytes by discarding the lowest 7 bits) + - version: 7 bits + - cipher: 4 bits (0 = AES-256-GCM, 1 = ChaCha20-poly1305; we use 0 for now) + - CRC: 32 bits + + Total = 129 + 7 + 4 + 32 = 172 bits. + We pack into 22 bytes (176 bits) leaving 4 unused bits. """ - if nonce is None: - nonce = os.urandom(32) # 32 bytes = 256 bits - if len(nonce) != 32: - raise ValueError("Nonce must be exactly 32 bytes.") if not (0 <= version < 128): raise ValueError("Version must fit in 7 bits (0..127)") + if not (0 <= cipher < 16): + raise ValueError("Cipher must fit in 4 bits (0..15)") + # Generate 17 random bytes if none provided (17 bytes = 136 bits) + if nonce_full is None: + nonce_full = os.urandom(17) + if len(nonce_full) < 17: + raise ValueError("nonce_full must be at least 17 bytes") + # Use the top 129 bits of the 136 bits: + nonce_int_full = int.from_bytes(nonce_full, 'big') # 136 bits + nonce_129_int = nonce_int_full >> 7 # drop the lowest 7 bits => 129 bits + # Convert the derived 129-bit nonce to 17 bytes. + # Since 129 bits < 17*8=136 bits, the top 7 bits of the result will be 0. + nonce_129_bytes = nonce_129_int.to_bytes(17, 'big') - # Build the partial integer: - # Shift the nonce (256 bits) left by 7 bits and then OR with the 7-bit version. - partial_int = int.from_bytes(nonce, 'big') << 7 - partial_int |= version # version occupies the lower 7 bits - - # Convert to 33 bytes (263 bits needed) - partial_bytes = partial_int.to_bytes(33, 'big') - - # Compute CRC over these 33 bytes + # Pack the fields: shift the 129-bit nonce left by (7+4)=11 bits, then add version (7 bits) and cipher (4 bits). + partial_int = (nonce_129_int << 11) | (version << 4) | (cipher & 0x0F) + # This partial data is 129+7+4 = 140 bits; we pack into 18 bytes (144 bits) with 4 spare bits. + partial_bytes = partial_int.to_bytes(18, 'big') + # Compute CRC over these 18 bytes. cval = crc32_of(partial_bytes) - - # Combine partial data (263 bits) with the 32-bit CRC => 295 bits total. - final_int = (int.from_bytes(partial_bytes, 'big') << 32) | cval - final_bytes = final_int.to_bytes(37, 'big') + # Combine the partial data with the 32-bit CRC. + final_int = (int.from_bytes(partial_bytes, 'big') << 32) | cval # 140 + 32 = 172 bits + final_bytes = final_int.to_bytes(22, 'big') + # Optionally, store or print nonce_129_bytes (the session nonce) rather than the original nonce_full. return final_bytes - - def parse_ping_request(data: bytes): """ - Parse a Ping request (37 bytes = 295 bits). - Returns (nonce, version) or None if invalid. + Parse a Ping request (22 bytes = 172 bits). + Returns (session_nonce_bytes, version, cipher) or None if invalid. + The session_nonce_bytes will be 17 bytes, representing the 129-bit value. """ - if len(data) != 37: + if len(data) != 22: return None - - # Convert to int - val_295 = int.from_bytes(data, 'big') # 295 bits in a 37-byte integer - # Extract CRC (lowest 32 bits) - crc_in = val_295 & 0xffffffff - # Then shift right 32 bits to get partial_data - partial_val = val_295 >> 32 # 263 bits - - # Convert partial_val back to bytes - partial_bytes = partial_val.to_bytes(33, 'big') - - # Recompute CRC + final_int = int.from_bytes(data, 'big') # 176 bits integer; lower 32 bits = CRC, higher 140 bits = partial + crc_in = final_int & 0xffffffff + partial_int = final_int >> 32 # 140 bits + partial_bytes = partial_int.to_bytes(18, 'big') crc_calc = crc32_of(partial_bytes) if crc_calc != crc_in: return None - - # Now parse out nonce (256 bits) and version (7 bits) - # partial_val is 263 bits - version = partial_val & 0x7f # low 7 bits - nonce_val = partial_val >> 7 # high 256 bits - nonce_bytes = nonce_val.to_bytes(32, 'big') - - return (nonce_bytes, version) + # Extract fields: lowest 4 bits: cipher; next 7 bits: version; remaining 129 bits: session_nonce. + cipher = partial_int & 0x0F + version = (partial_int >> 4) & 0x7F + nonce_129_int = partial_int >> 11 # 140 - 11 = 129 bits + # Convert to 17 bytes. Since the number is < 2^129, the top 7 bits will be zero. + session_nonce_bytes = nonce_129_int.to_bytes(17, 'big') + return (session_nonce_bytes, version, cipher) -# ============================================================================= -# 2) Ping Response (72 bits) -# - 32-bit timestamp -# - 7-bit version + 1-bit answer => 8 bits -# - 32-bit CRC -# = 72 bits total => 9 bytes -# ============================================================================= -def build_ping_response(version: int, answer: int) -> bytes: - """ - Build a Ping response: - - 32-bit timestamp (lowest 32 bits of current time in ms) - - 7-bit version + 1-bit answer - - 32-bit CRC - => 72 bits = 9 bytes - """ +# --------------------------------------------------------------------------- +# PING RESPONSE (new format) +# Fields: +# - timestamp: 32 bits (we take the lower 32 bits of the time in ms) +# - version: 7 bits +# - cipher: 4 bits +# - answer: 1 bit +# - CRC: 32 bits +# +# Total bits: 32 + 7 + 4 + 1 + 32 = 76 bits; pack into 10 bytes (80 bits) with 4 spare bits. +# --------------------------------------------------------------------------- +def build_ping_response(version: int, cipher: int, answer: int) -> bytes: if not (0 <= version < 128): - raise ValueError("Version must fit in 7 bits.") + raise ValueError("Version must fit in 7 bits") + if not (0 <= cipher < 16): + raise ValueError("Cipher must fit in 4 bits") if answer not in (0, 1): - raise ValueError("Answer must be 0 or 1.") - - # 32-bit timestamp = current time in ms, truncated to 32 bits - t_ms = int(time.time() * 1000) & 0xffffffff - - # partial = [timestamp (32 bits), version (7 bits), answer (1 bit)] => 40 bits - partial_val = (t_ms << 8) | ((version << 1) & 0xfe) | (answer & 0x01) - # partial_val is 40 bits => 5 bytes - partial_bytes = partial_val.to_bytes(5, 'big') - - # CRC over these 5 bytes + raise ValueError("Answer must be 0 or 1") + t_ms = int(time.time() * 1000) & 0xffffffff # 32 bits + # Pack timestamp (32 bits), then version (7 bits), cipher (4 bits), answer (1 bit): total 32+7+4+1 = 44 bits. + partial_val = (t_ms << (7+4+1)) | (version << (4+1)) | (cipher << 1) | answer + partial_bytes = partial_val.to_bytes(6, 'big') # 6 bytes = 48 bits, 4 spare bits. cval = crc32_of(partial_bytes) - - # Combine partial (40 bits) with 32 bits of CRC => 72 bits total - final_val = (int.from_bytes(partial_bytes, 'big') << 32) | cval - final_bytes = final_val.to_bytes(9, 'big') + final_val = (int.from_bytes(partial_bytes, 'big') << 32) | cval # 44+32 = 76 bits. + final_bytes = final_val.to_bytes(10, 'big') # 10 bytes = 80 bits. return final_bytes - def parse_ping_response(data: bytes): - """ - Parse a Ping response (72 bits = 9 bytes). - Return (timestamp_ms, version, answer) or None if invalid. - """ - if len(data) != 9: + if len(data) != 10: return None - - val_72 = int.from_bytes(data, 'big') # 72 bits - crc_in = val_72 & 0xffffffff - partial_val = val_72 >> 32 # 40 bits - - partial_bytes = partial_val.to_bytes(5, 'big') + final_int = int.from_bytes(data, 'big') # 80 bits + crc_in = final_int & 0xffffffff + partial_int = final_int >> 32 # 48 bits + partial_bytes = partial_int.to_bytes(6, 'big') crc_calc = crc32_of(partial_bytes) if crc_calc != crc_in: return None - - # Now parse partial_val - # partial_val = [timestamp(32 bits), version(7 bits), answer(1 bit)] - t_ms = (partial_val >> 8) & 0xffffffff - va = partial_val & 0xff # 8 bits = [7 bits version, 1 bit answer] - version = (va >> 1) & 0x7f - answer = va & 0x01 - - return (t_ms, version, answer) + # Extract fields: partial_int has 48 bits. We only used 44 bits for the fields. + # Discard the lower 4 spare bits. + partial_int >>= 4 # now 44 bits. + # Now fields: timestamp: 32 bits, version: 7 bits, cipher: 4 bits, answer: 1 bit. + answer = partial_int & 0x01 + cipher = (partial_int >> 1) & 0x0F + version = (partial_int >> (1+4)) & 0x7F + timestamp = partial_int >> (1+4+7) + return (timestamp, version, cipher, answer) # ============================================================================= diff --git a/protocol_prototype/protocol.py b/protocol_prototype/protocol.py index b1c2521..f8087b1 100644 --- a/protocol_prototype/protocol.py +++ b/protocol_prototype/protocol.py @@ -20,7 +20,7 @@ from messages import ( ) import transmission -from encryption import encrypt_message, decrypt_message +from encryption import encrypt_message, decrypt_message, MessageHeader, generate_iv # ANSI colors RED = "\033[91m" @@ -91,77 +91,97 @@ class IcingProtocol: def on_data_received(self, conn: transmission.PeerConnection, data: bytes): bits_count = len(data) * 8 - print(f"{GREEN}[RECV]{RESET} {bits_count} bits from peer: {data.hex()[:60]}{'...' if len(data.hex())>60 else ''}") - # For a PING_REQUEST, parse and store the session nonce if not already set. - if len(data) == 37: + print( + f"{GREEN}[RECV]{RESET} {bits_count} bits from peer: {data.hex()[:60]}{'...' if len(data.hex()) > 60 else ''}") + + # New-format PING REQUEST (22 bytes) + if len(data) == 22: parsed = parse_ping_request(data) if parsed: - nonce, version = parsed + nonce_full, version, cipher = parsed self.state["ping_received"] = True - # Store session nonce if not already set + # If the received cipher field is not 0, we force it to 0 in our response. + if cipher != 0: + print( + f"{YELLOW}[NOTICE]{RESET} Received PING with unsupported cipher ({cipher}); forcing cipher to 0 in response.") + cipher = 0 + # Store session nonce if not already set: if self.session_nonce is None: - self.session_nonce = nonce + # Here, we already generated 17 bytes (136 bits) and only the top 129 bits are valid. + nonce_int = int.from_bytes(nonce_full, 'big') >> 7 + self.session_nonce = nonce_int.to_bytes(17, 'big') print(f"{YELLOW}[NOTICE]{RESET} Stored session nonce from received PING.") index = len(self.inbound_messages) msg = { "type": "PING_REQUEST", "raw": data, - "parsed": {"nonce": nonce, "version": version}, + "parsed": {"nonce_full": nonce_full, "version": version, "cipher": cipher}, "connection": conn } self.inbound_messages.append(msg) - print(f"{YELLOW}[NOTICE]{RESET} Stored inbound PING request (nonce={nonce.hex()}) at index={index}.") - - if self.auto_responder: - # Schedule an automatic response after 2 seconds - threading.Timer(2.0, self._auto_respond_ping, args=(index,)).start() - + # (Optional auto-responder code could go here) return - # Attempt to parse Ping response - if len(data) == 9: + # New-format PING RESPONSE (10 bytes) + elif len(data) == 10: parsed = parse_ping_response(data) if parsed: - ts, version, answer_code = parsed + timestamp, version, cipher, answer = parsed + # If cipher is not 0 (AES-256-GCM), override it. + if cipher != 0: + print( + f"{YELLOW}[NOTICE]{RESET} Received PING RESPONSE with unsupported cipher; treating as AES (cipher=0).") + cipher = 0 index = len(self.inbound_messages) msg = { "type": "PING_RESPONSE", "raw": data, - "parsed": {"timestamp": ts, "version": version, "answer_code": answer_code}, + "parsed": {"timestamp": timestamp, "version": version, "cipher": cipher, "answer": answer}, "connection": conn } self.inbound_messages.append(msg) - print(f"{YELLOW}[NOTICE]{RESET} Stored inbound PING response (answer_code={answer_code}) at index={index}.") return - # Attempt to parse handshake - if len(data) == 168: + # HANDSHAKE message (168 bytes) + elif len(data) == 168: parsed = parse_handshake_message(data) if parsed: - ts, ephemeral_pub, ephemeral_sig, pfs_hash = parsed + timestamp, ephemeral_pub, ephemeral_sig, pfs_hash = parsed self.state["handshake_received"] = True index = len(self.inbound_messages) msg = { "type": "HANDSHAKE", "raw": data, - "parsed": { - "ephemeral_pub": ephemeral_pub, - "ephemeral_sig": ephemeral_sig, - "timestamp": ts, - "pfs hash": pfs_hash - }, + "parsed": (timestamp, ephemeral_pub, ephemeral_sig, pfs_hash), "connection": conn } self.inbound_messages.append(msg) - print(f"{YELLOW}[NOTICE]{RESET} Stored inbound HANDSHAKE at index={index}. ephemeral_pub={ephemeral_pub.hex()[:20]}...") - - if self.auto_responder: - # Schedule an automatic handshake "response" after 2 seconds - threading.Timer(2.0, self._auto_respond_handshake, args=(index,)).start() - + # (Optional auto-responder for handshake could go here) return - # Otherwise, unrecognized + # Check if the message might be an encrypted message (e.g. header of 18 bytes at start) + elif len(data) > 18: + # Try to parse header from encryption module + try: + from encryption import MessageHeader + header = MessageHeader.unpack(data[:18]) + # If header unpacking is successful and data length fits header.data_len + header size + tag size: + expected_len = 18 + header.data_len + 16 # tag is 16 bytes + if len(data) == expected_len: + index = len(self.inbound_messages) + msg = { + "type": "ENCRYPTED_MESSAGE", + "raw": data, + "parsed": header, # we can store header for further processing + "connection": conn + } + self.inbound_messages.append(msg) + print(f"{YELLOW}[NOTICE]{RESET} Stored inbound ENCRYPTED_MESSAGE at index={index}.") + return + except Exception: + pass + + # Otherwise, unrecognized/malformed message. index = len(self.inbound_messages) msg = { "type": "UNKNOWN", @@ -284,16 +304,19 @@ class IcingProtocol: self.ephemeral_privkey, self.ephemeral_pubkey = get_ephemeral_keypair() print(f"{GREEN}[IcingProtocol]{RESET} Generated ephemeral key pair: pubkey={self.ephemeral_pubkey.hex()[:16]}...") + # Updated send_ping_request: generate a 17-byte nonce, extract top 129 bits, and store that (as bytes) def send_ping_request(self): if not self.connections: print(f"{RED}[ERROR]{RESET} No active connections.") return - # Generate a new nonce for this ping and store it as session_nonce if not already set. - nonce = os.urandom(32) + nonce_full = os.urandom(17) + # Compute the 129-bit session nonce: take the top 129 bits. + nonce_int = int.from_bytes(nonce_full, 'big') >> 7 # 129 bits + session_nonce_bytes = nonce_int.to_bytes(17, 'big') # still 17 bytes but only 129 bits are meaningful if self.session_nonce is None: - self.session_nonce = nonce + self.session_nonce = session_nonce_bytes print(f"{YELLOW}[NOTICE]{RESET} Stored session nonce from sent PING.") - pkt = build_ping_request(version=0, nonce=nonce) + pkt = build_ping_request(version=0, cipher=0, nonce_full=nonce_full) self._send_packet(self.connections[0], pkt, "PING_REQUEST") self.state["ping_sent"] = True @@ -352,10 +375,7 @@ class IcingProtocol: # Manual Responses # ------------------------------------------------------------------------- - def respond_to_ping(self, index: int, answer_code: int): - """ - Manually respond to an inbound PING_REQUEST in inbound_messages[index]. - """ + def respond_to_ping(self, index: int, answer: int): if index < 0 or index >= len(self.inbound_messages): print(f"{RED}[ERROR]{RESET} Invalid index {index}.") return @@ -365,10 +385,17 @@ class IcingProtocol: return version = msg["parsed"]["version"] + cipher = msg["parsed"]["cipher"] + # Force cipher to 0 if it's not 0 (only AES-256-GCM is supported) + if cipher != 0: + print( + f"{YELLOW}[NOTICE]{RESET} Received PING with unsupported cipher ({cipher}); forcing cipher to 0 in response.") + cipher = 0 + conn = msg["connection"] - resp = build_ping_response(version, answer_code) + resp = build_ping_response(version, cipher, answer) self._send_packet(conn, resp, "PING_RESPONSE") - print(f"{BLUE}[MANUAL]{RESET} Responded to ping with answer_code={answer_code}.") + print(f"{BLUE}[MANUAL]{RESET} Responded to ping with answer={answer}.") def generate_ecdhe(self, index: int): """ @@ -383,8 +410,8 @@ class IcingProtocol: print(f"{RED}[ERROR]{RESET} inbound_messages[{index}] is not a HANDSHAKE.") return - ephemeral_pub = msg["parsed"]["ephemeral_pub"] - ephemeral_sig = msg["parsed"]["ephemeral_sig"] + # Unpack the tuple directly: + timestamp, ephemeral_pub, ephemeral_sig, pfs_hash = msg["parsed"] # Use our raw_signature_to_der wrapper only if signature is 64 bytes. # Otherwise, assume the signature is already DER-encoded. @@ -439,12 +466,14 @@ class IcingProtocol: print(f"\nShared Secret: {self.shared_secret if self.shared_secret else '[None]'}") if self.hkdf_key: - print(f"HKDF Derived Key: {self.hkdf_key.hex()} (size: {len(self.hkdf_key)*8} bits)") + print(f"HKDF Derived Key: {self.hkdf_key} (size: {len(self.hkdf_key)*8} bits)") else: print("HKDF Derived Key: [None]") - - print("\nSession Nonce: " + (self.session_nonce.hex() if self.session_nonce else "[None]")) - + if self.session_nonce: + # session_nonce now contains 17 bytes, but only the top 129 bits are used. + print(f"Session Nonce: {self.session_nonce.hex()} (129 bits)") + else: + print("Session Nonce: [None]") print("\nProtocol Flags:") for k, v in self.state.items(): print(f" {k}: {v}") @@ -489,22 +518,20 @@ class IcingProtocol: print(f"{GREEN}[SEND_ENCRYPTED]{RESET} Encrypted message sent.") # New method: Decrypt an encrypted message provided as a hex string. - def decrypt_encrypted_message(self, hex_message: str): - """ - Decrypts an encrypted message (given as a hex string) using the HKDF key. - Returns the plaintext (UTF-8 string) and prints it. - """ +# New command: decrypt a received encrypted message from the inbound queue. + def decrypt_received_message(self, index: int): + if index < 0 or index >= len(self.inbound_messages): + print(f"{RED}[ERROR]{RESET} Invalid message index.") + return + msg = self.inbound_messages[index] + # Expect the message to be an encrypted transmission (we assume it is in the proper format). + encrypted = msg["raw"] if not self.hkdf_key: print(f"{RED}[ERROR]{RESET} No HKDF key derived. Cannot decrypt message.") return - try: - message_bytes = bytes.fromhex(hex_message) - except Exception as e: - print(f"{RED}[ERROR]{RESET} Invalid hex input.") - return key = bytes.fromhex(self.hkdf_key) try: - plaintext_bytes = decrypt_message(message_bytes, key) + plaintext_bytes = decrypt_message(encrypted, key) plaintext = plaintext_bytes.decode('utf-8') print(f"{GREEN}[DECRYPTED]{RESET} Decrypted message: {plaintext}") return plaintext