diff --git a/protocol_prototype/DryBox/UI/main.py b/protocol_prototype/DryBox/UI/main.py index f9a45fc..2ed1d01 100644 --- a/protocol_prototype/DryBox/UI/main.py +++ b/protocol_prototype/DryBox/UI/main.py @@ -1,5 +1,5 @@ import sys -import random +import secrets from PyQt5.QtWidgets import ( QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QFrame, QSizePolicy, QStyle @@ -9,6 +9,7 @@ from PyQt5.QtGui import QFont from phone_client import PhoneClient from waveform_widget import WaveformWidget from phone_state import PhoneState +from session import NoiseXKSession class PhoneUI(QMainWindow): def __init__(self): @@ -68,13 +69,17 @@ class PhoneUI(QMainWindow): # Initialize phones self.phones = [] + self.handshake_done_count = 0 for i in range(2): client = PhoneClient("localhost", 12345, i) client.data_received.connect(lambda data, cid=i: self.update_waveform(cid, data)) client.state_changed.connect(lambda state, num, cid=i: self.set_phone_state(cid, self.map_state(state), num)) client.start() - phone_widget_container, phone_display, phone_button, phone_waveform, phone_status_label = self._create_phone_ui( + # Generate keypair for each phone + keypair = NoiseXKSession.generate_keypair() + + phone_container_widget, phone_display_frame, phone_button, waveform_widget, phone_status_label = self._create_phone_ui( f"Phone {i+1}", lambda checked, phone_id=i: self.phone_action(phone_id) ) self.phones.append({ @@ -82,12 +87,19 @@ class PhoneUI(QMainWindow): 'client': client, 'state': PhoneState.IDLE, 'button': phone_button, - 'waveform': phone_waveform, + 'waveform': waveform_widget, 'number': "123-4567" if i == 0 else "987-6543", 'audio_timer': None, - 'status_label': phone_status_label + 'status_label': phone_status_label, + 'keypair': keypair, + 'public_key': keypair.public, + 'is_initiator': False }) - phone_controls_layout.addWidget(phone_widget_container) + phone_controls_layout.addWidget(phone_container_widget) + + # Share public key between phones + self.phones[0]['peer_public_key'] = self.phones[1]['public_key'] + self.phones[1]['peer_public_key'] = self.phones[0]['public_key'] # Spacer main_layout.addStretch(1) @@ -175,12 +187,15 @@ class PhoneUI(QMainWindow): def phone_action(self, phone_id): phone = self.phones[phone_id] other_phone = self.phones[1 - phone_id] - print(f"Phone {phone_id + 1} Action, current state: {phone['state']}") + print(f"Phone {phone_id + 1} Action, current state: {phone['state']}, is_initiator: {phone['is_initiator']}") if phone['state'] == PhoneState.IDLE: # Initiate a call phone['state'] = PhoneState.CALLING other_phone['state'] = PhoneState.RINGING + # Set init/resp + phone['is_initiator'] = True + other_phone['is_initiator'] = False self._update_phone_button_ui(phone['button'], phone['status_label'], phone['state'], other_phone['number']) self._update_phone_button_ui(other_phone['button'], other_phone['status_label'], other_phone['state'], phone['number']) phone['client'].send("RINGING") @@ -192,33 +207,49 @@ class PhoneUI(QMainWindow): self._update_phone_button_ui(phone['button'], phone['status_label'], phone['state'], other_phone['number']) self._update_phone_button_ui(other_phone['button'], other_phone['status_label'], other_phone['state'], phone['number']) phone['client'].send("IN_CALL") - # Start audio timers for both phones - for p in [phone, other_phone]: - if not p['audio_timer'] or not p['audio_timer'].isActive(): - p['audio_timer'] = QTimer(self) - p['audio_timer'].timeout.connect(lambda pid=p['id']: self.send_audio(pid)) - p['audio_timer'].start(1000) elif phone['state'] == PhoneState.IN_CALL or phone['state'] == PhoneState.CALLING: # Hang up or cancel - phone['state'] = PhoneState.IDLE - other_phone['state'] = PhoneState.IDLE - self._update_phone_button_ui(phone['button'], phone['status_label'], phone['state'], "") - self._update_phone_button_ui(other_phone['button'], other_phone['status_label'], other_phone['state'], "") - phone['client'].send("CALL_END") - # Stop audio timers for both phones - for p in [phone, other_phone]: - if p['audio_timer']: - p['audio_timer'].stop() + if not phone['client'].handshake_in_progress and phone['state'] != PhoneState.CALLING: + phone['state'] = PhoneState.IDLE + other_phone['state'] = PhoneState.IDLE + self._update_phone_button_ui(phone['button'], phone['status_label'], phone['state'], "") + self._update_phone_button_ui(other_phone['button'], other_phone['status_label'], other_phone['state'], "") + phone['client'].send("CALL_END") + # Stop audio timers for both phones + for p in [phone, other_phone]: + if p['audio_timer']: + p['audio_timer'].stop() + else: + print(f"Phone {phone_id + 1} cannot hang up during handshake or call setup") + + def start_audio(self, client_id): + """Start audio timer after both clients send HANDSHAKE_DONE.""" + self.handshake_done_count += 1 + print(f"HANDSHAKE_DONE received for client {client_id}, count: {self.handshake_done_count}") + if self.handshake_done_count == 2: + for phone in self.phones: + if phone['state'] == PhoneState.IN_CALL: + if not phone['audio_timer'] or not phone['audio_timer'].isActive(): + phone['audio_timer'] = QTimer(self) + phone['audio_timer'].timeout.connect(lambda pid=phone['id']: self.send_audio(pid)) + phone['audio_timer'].start(100) # 100ms for smoother updates + self.handshake_done_count = 0 def send_audio(self, phone_id): phone = self.phones[phone_id] - if phone['state'] == PhoneState.IN_CALL: - message = f"Audio packet {random.randint(1, 1000)}" - phone['client'].send(message) + if phone['state'] == PhoneState.IN_CALL and phone['client'].session and phone['client'].sock: + # Generate mock 16-byte audio data + mock_audio = secrets.token_bytes(16) + try: + # Encrypt with Noise session, send over socket + phone['client'].session.send(phone['client'].sock, mock_audio) + print(f"Client {phone_id} sent encrypted audio packet, length=32") + except Exception as e: + print(f"Client {phone_id} failed to send audio: {e}") def update_waveform(self, client_id, data): - print(f"Updating waveform for client_id {client_id}") + print(f"Updating waveform for client_id {client_id}, data_length={len(data)}") waveform = self.phones[client_id]['waveform'] waveform.set_data(data) @@ -229,24 +260,41 @@ class PhoneUI(QMainWindow): return PhoneState.IDLE elif state_str == "IN_CALL": return PhoneState.IN_CALL + elif state_str == "HANDSHAKE": + return PhoneState.IN_CALL # Stay in IN_CALL, trigger handshake + elif state_str == "HANDSHAKE_DONE": + return PhoneState.IN_CALL # Stay in IN_CALL, start audio return PhoneState.IDLE def set_phone_state(self, client_id, state, number=""): phone = self.phones[client_id] other_phone = self.phones[1 - client_id] + print(f"Setting state for Phone {client_id + 1}: {state}, number: {number}, is_initiator: {phone['is_initiator']}") phone['state'] = state if state == PhoneState.RINGING: self._update_phone_button_ui(phone['button'], phone['status_label'], state, other_phone['number']) elif state == PhoneState.IN_CALL: + print(f"Phone {client_id + 1} confirmed in IN_CALL state") self._update_phone_button_ui(phone['button'], phone['status_label'], state, other_phone['number']) + if number == "IN_CALL" and phone['is_initiator']: + # Initiator starts handshake after receiving IN_CALL + print(f"Phone {client_id + 1} (initiator) starting handshake") + phone['client'].send("HANDSHAKE") + phone['client'].start_handshake(initiator=True, keypair=phone['keypair'], peer_pubkey=other_phone['public_key']) + elif number == "HANDSHAKE" and not phone['is_initiator']: + # Responder starts handshake after receiving HANDSHAKE + print(f"Phone {client_id + 1} (responder) starting handshake") + phone['client'].start_handshake(initiator=False, keypair=phone['keypair'], peer_pubkey=other_phone['public_key']) + elif number == "HANDSHAKE_DONE": + # Start audio after HANDSHAKE_DONE + self.start_audio(client_id) else: + # Handle disconnect gracefully self._update_phone_button_ui(phone['button'], phone['status_label'], state, "") + if state == PhoneState.IDLE and number == "CALL_END": + print(f"Phone {client_id + 1} resetting due to disconnect") if state == PhoneState.IDLE and phone['audio_timer']: phone['audio_timer'].stop() - elif state == PhoneState.IN_CALL and (not phone['audio_timer'] or not phone['audio_timer'].isActive()): - phone['audio_timer'] = QTimer(self) - phone['audio_timer'].timeout.connect(lambda: self.send_audio(client_id)) - phone['audio_timer'].start(1000) def settings_action(self): print("Settings clicked") diff --git a/protocol_prototype/DryBox/UI/phone_client.py b/protocol_prototype/DryBox/UI/phone_client.py index 18e398c..fac7d27 100644 --- a/protocol_prototype/DryBox/UI/phone_client.py +++ b/protocol_prototype/DryBox/UI/phone_client.py @@ -1,5 +1,9 @@ import socket +import time +import select from PyQt5.QtCore import QThread, pyqtSignal +from queue import Queue +from session import NoiseXKSession class PhoneClient(QThread): data_received = pyqtSignal(bytes, int) # Include client_id @@ -12,50 +16,162 @@ class PhoneClient(QThread): self.client_id = client_id self.sock = None self.running = True + self.command_queue = Queue() + self.initiator = None + self.keypair = None + self.peer_pubkey = None + self.session = None + self.handshake_in_progress = False + self.handshake_start_time = None + self.call_active = False # Track active call after HANDSHAKE_DONE + + def connect_socket(self): + """Attempt to connect to the server with retries.""" + retries = 3 + for attempt in range(retries): + try: + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + self.sock.settimeout(120) # 120s for socket operations + self.sock.connect((self.host, self.port)) + print(f"Client {self.client_id} connected to {self.host}:{self.port}") + return True + except Exception as e: + print(f"Client {self.client_id} connection attempt {attempt + 1} failed: {e}") + if attempt < retries - 1: + time.sleep(1) # Wait before retrying + self.sock = None + return False def run(self): - try: - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - self.sock.settimeout(15) - self.sock.connect((self.host, self.port)) - print(f"Client {self.client_id} connected to {self.host}:{self.port}") - while self.running: - try: - data = self.sock.recv(1024) - if not data: - print(f"Client {self.client_id} disconnected") - self.state_changed.emit("CALL_END", "", self.client_id) - break - decoded_data = data.decode('utf-8', errors='ignore').strip() - print(f"Client {self.client_id} received raw: {decoded_data}") - if decoded_data in ["RINGING", "CALL_END", "CALL_DROPPED", "IN_CALL"]: - self.state_changed.emit(decoded_data, "", self.client_id) - else: - self.data_received.emit(data, self.client_id) - print(f"Client {self.client_id} received audio: {decoded_data}") - except socket.timeout: - print(f"Client {self.client_id} timed out waiting for data") - continue - except Exception as e: - print(f"Client {self.client_id} error: {e}") + while self.running: + if not self.sock: + if not self.connect_socket(): + print(f"Client {self.client_id} failed to connect after retries") self.state_changed.emit("CALL_END", "", self.client_id) break - except Exception as e: - print(f"Client {self.client_id} connection failed: {e}") - finally: - if self.sock: - self.sock.close() + try: + while self.running: + # print(f"Client {self.client_id} run loop iteration") + # Check command queue first + if not self.command_queue.empty(): + print(f"Client {self.client_id} processing command queue, size: {self.command_queue.qsize()}") + command = self.command_queue.get() + if command == "handshake": + try: + print(f"Client {self.client_id} starting handshake, initiator: {self.initiator}") + self.session = NoiseXKSession(self.keypair, self.peer_pubkey) + self.session.handshake(self.sock, self.initiator) + print(f"Client {self.client_id} handshake complete") + self.send("HANDSHAKE_DONE") + except socket.timeout: + print(f"Client {self.client_id} handshake timed out") + self.state_changed.emit("CALL_END", "", self.client_id) + break + except Exception as e: + print(f"Client {self.client_id} handshake failed: {e}") + self.state_changed.emit("CALL_END", "", self.client_id) + break + finally: + self.handshake_in_progress = False + self.handshake_start_time = None + else: + # Check for handshake timeout + if self.handshake_in_progress and self.handshake_start_time: + if time.time() - self.handshake_start_time > 30: # 30s handshake timeout + print(f"Client {self.client_id} handshake timeout after 30s") + self.state_changed.emit("CALL_END", "", self.client_id) + self.handshake_in_progress = False + self.handshake_start_time = None + break + # Only read socket if not in handshake + if not self.handshake_in_progress: + # Use select to check if data is available + readable, _, _ = select.select([self.sock], [], [], 0.01) # 10ms timeout + if readable: + try: + print(f"Client {self.client_id} attempting sock.recv") + data = self.sock.recv(1024) + if not data: + print(f"Client {self.client_id} disconnected") + self.state_changed.emit("CALL_END", "", self.client_id) + break + # Handle control messages (UTF-8) + try: + decoded_data = data.decode('utf-8').strip() + print(f"Client {self.client_id} received raw: {decoded_data}") + if decoded_data in ["RINGING", "CALL_END", "CALL_DROPPED", "IN_CALL", "HANDSHAKE", "HANDSHAKE_DONE"]: + self.state_changed.emit(decoded_data, decoded_data, self.client_id) + if decoded_data == "HANDSHAKE": + self.handshake_in_progress = True # Block further reads + elif decoded_data == "HANDSHAKE_DONE": + self.call_active = True # Enable audio processing + else: + print(f"Client {self.client_id} ignored unexpected text message: {decoded_data}") + except UnicodeDecodeError: + # Handle binary data (audio packets) + if self.call_active and self.session: + try: + print(f"Client {self.client_id} received audio packet, length={len(data)}") + decrypted_data = self.session.decrypt(data) + print(f"Client {self.client_id} decrypted audio packet, length={len(decrypted_data)}") + self.data_received.emit(decrypted_data, self.client_id) + except Exception as e: + print(f"Client {self.client_id} failed to process audio packet: {e}") + else: + print(f"Client {self.client_id} ignored non-text message: {data.hex()}") + except socket.timeout: + print(f"Client {self.client_id} timed out waiting for data") + continue + except socket.error as e: + print(f"Client {self.client_id} socket error: {e}") + self.state_changed.emit("CALL_END", "", self.client_id) + break + except Exception as e: + print(f"Client {self.client_id} error: {e}") + self.state_changed.emit("CALL_END", "", self.client_id) + break + else: + # print(f"Client {self.client_id} no data available, skipping recv") + pass + else: + # Yield during handshake + self.msleep(20) # 20ms sleep to yield CPU + print(f"Client {self.client_id} yielding during handshake") + # Short sleep to yield Qt event loop + self.msleep(1) # 1ms sleep + finally: + if self.sock: + self.sock.close() + self.sock = None def send(self, message): if self.sock and self.running: try: - self.sock.send(message.encode()) - print(f"Client {self.client_id} sent: {message}") - except Exception as e: + if isinstance(message, str): + data = message.encode('utf-8') + self.sock.send(data) + print(f"Client {self.client_id} sent: {message}, length={len(data)}") + else: + # Send binary data (audio) + self.sock.send(message) + print(f"Client {self.client_id} sent binary data, length={len(message)}") + except socket.error as e: print(f"Client {self.client_id} send error: {e}") + self.state_changed.emit("CALL_END", "", self.client_id) def stop(self): self.running = False if self.sock: - self.sock.close() \ No newline at end of file + self.sock.close() + self.sock = None + + def start_handshake(self, initiator, keypair, peer_pubkey): + """Queue the handshake command with necessary parameters.""" + self.initiator = initiator + self.keypair = keypair + self.peer_pubkey = peer_pubkey + print(f"Client {self.client_id} queuing handshake, initiator: {initiator}") + self.handshake_in_progress = True # Block recv before handshake starts + self.handshake_start_time = time.time() + self.command_queue.put("handshake") \ No newline at end of file diff --git a/protocol_prototype/DryBox/UI/session.py b/protocol_prototype/DryBox/UI/session.py new file mode 100644 index 0000000..a4833a4 --- /dev/null +++ b/protocol_prototype/DryBox/UI/session.py @@ -0,0 +1,196 @@ +import socket +import logging +from dissononce.processing.impl.handshakestate import HandshakeState +from dissononce.processing.impl.symmetricstate import SymmetricState +from dissononce.processing.impl.cipherstate import CipherState +from dissononce.processing.handshakepatterns.interactive.XK import XKHandshakePattern +from dissononce.cipher.chachapoly import ChaChaPolyCipher +from dissononce.dh.x25519.x25519 import X25519DH +from dissononce.dh.keypair import KeyPair +from dissononce.dh.x25519.public import PublicKey +from dissononce.hash.sha256 import SHA256Hash + +# Configure root logger for debug output +logging.basicConfig(level=logging.DEBUG, format="%(message)s") + +class NoiseXKSession: + @staticmethod + def generate_keypair() -> KeyPair: + """ + Generate a static X25519 KeyPair. + Returns: + KeyPair object with .private and .public attributes. + """ + return X25519DH().generate_keypair() + + def __init__(self, local_kp: KeyPair, peer_pubkey: PublicKey): + """ + Initialize with our KeyPair and the peer's PublicKey. + """ + self.local_kp: KeyPair = local_kp + self.peer_pubkey: PublicKey = peer_pubkey + + # Build the Noise handshake state (X25519 DH, ChaChaPoly cipher, SHA256 hash) + cipher = ChaChaPolyCipher() + dh = X25519DH() + hshash = SHA256Hash() + symmetric = SymmetricState(CipherState(cipher), hshash) + self._hs = HandshakeState(symmetric, dh) + + self._send_cs = None # type: CipherState + self._recv_cs = None + + def handshake(self, sock: socket.socket, initiator: bool) -> None: + """ + Perform the XK handshake over the socket. Branches on initiator/responder + so that each side reads or writes in the correct message order. + On completion, self._send_cs and self._recv_cs hold the two CipherStates. + """ + logging.debug(f"[handshake] start (initiator={initiator})") + # initialize with our KeyPair and their PublicKey + if initiator: + # initiator knows peer’s static out-of-band + self._hs.initialize( + XKHandshakePattern(), + True, + b'', + s=self.local_kp, + rs=self.peer_pubkey + ) + else: + logging.debug("[handshake] responder initializing without rs") + # responder must NOT supply rs here + self._hs.initialize( + XKHandshakePattern(), + False, + b'', + s=self.local_kp + ) + + cs_pair = None + if initiator: + # 1) -> e + buf1 = bytearray() + cs_pair = self._hs.write_message(b'', buf1) + logging.debug(f"[-> e] {buf1.hex()}") + self._send_all(sock, buf1) + + # 2) <- e, es, s, ss + msg2 = self._recv_all(sock) + logging.debug(f"[<- msg2] {msg2.hex()}") + self._hs.read_message(msg2, bytearray()) + + # 3) -> se (final) + buf3 = bytearray() + cs_pair = self._hs.write_message(b'', buf3) + logging.debug(f"[-> se] {buf3.hex()}") + self._send_all(sock, buf3) + else: + # 1) <- e + msg1 = self._recv_all(sock) + logging.debug(f"[<- e] {msg1.hex()}") + self._hs.read_message(msg1, bytearray()) + + # 2) -> e, es, s, ss + buf2 = bytearray() + cs_pair = self._hs.write_message(b'', buf2) + logging.debug(f"[-> msg2] {buf2.hex()}") + self._send_all(sock, buf2) + + # 3) <- se (final) + msg3 = self._recv_all(sock) + logging.debug(f"[<- se] {msg3.hex()}") + cs_pair = self._hs.read_message(msg3, bytearray()) + + # on the final step, we must get exactly two CipherStates + if not cs_pair or len(cs_pair) != 2: + raise RuntimeError("Handshake did not complete properly") + cs0, cs1 = cs_pair + # the library returns (cs_encrypt_for_initiator, cs_decrypt_for_initiator) + if initiator: + # initiator: cs0 encrypts, cs1 decrypts + self._send_cs, self._recv_cs = cs0, cs1 + else: + # responder must swap + self._send_cs, self._recv_cs = cs1, cs0 + + # dump the raw symmetric keys & nonces (if available) + self._dump_cipherstate("HANDSHAKE→ SEND", self._send_cs) + self._dump_cipherstate("HANDSHAKE→ RECV", self._recv_cs) + + def send(self, sock: socket.socket, plaintext: bytes) -> None: + """ + Encrypt and send a message. + """ + if self._send_cs is None: + raise RuntimeError("Handshake not complete") + ct = self._send_cs.encrypt_with_ad(b'', plaintext) + logging.debug(f"[ENCRYPT] {ct.hex()}") + self._dump_cipherstate("SEND→ after encrypt", self._send_cs) + self._send_all(sock, ct) + + def receive(self, sock: socket.socket) -> bytes: + """ + Receive and decrypt a message. + """ + if self._recv_cs is None: + raise RuntimeError("Handshake not complete") + ct = self._recv_all(sock) + logging.debug(f"[CIPHERTEXT] {ct.hex()}") + self._dump_cipherstate("RECV→ before decrypt", self._recv_cs) + pt = self._recv_cs.decrypt_with_ad(b'', ct) + logging.debug(f"[DECRYPT] {pt!r}") + return pt + + def decrypt(self, ciphertext: bytes) -> bytes: + """ + Decrypt a ciphertext received as bytes. + """ + if self._recv_cs is None: + raise RuntimeError("Handshake not complete") + # Remove 2-byte length prefix if present + if len(ciphertext) >= 2 and int.from_bytes(ciphertext[:2], 'big') == len(ciphertext) - 2: + logging.debug(f"[DECRYPT] Stripping 2-byte length prefix from {len(ciphertext)}-byte input") + ciphertext = ciphertext[2:] + logging.debug(f"[CIPHERTEXT] {ciphertext.hex()}") + self._dump_cipherstate("DECRYPT→ before decrypt", self._recv_cs) + pt = self._recv_cs.decrypt_with_ad(b'', ciphertext) + logging.debug(f"[DECRYPT] {pt!r}") + return pt + + def _send_all(self, sock: socket.socket, data: bytes) -> None: + # Length-prefix (2 bytes big-endian) + data + length = len(data).to_bytes(2, 'big') + logging.debug(f"[SEND] length={length.hex()}, data={data.hex()}") + sock.sendall(length + data) + + def _recv_all(self, sock: socket.socket) -> bytes: + # Read 2-byte length prefix, then the payload + hdr = self._read_exact(sock, 2) + length = int.from_bytes(hdr, 'big') + logging.debug(f"[RECV] length={length} ({hdr.hex()})") + data = self._read_exact(sock, length) + logging.debug(f"[RECV] data={data.hex()}") + return data + + @staticmethod + def _read_exact(sock: socket.socket, n: int) -> bytes: + buf = bytearray() + while len(buf) < n: + chunk = sock.recv(n - len(buf)) + if not chunk: + raise ConnectionError("Socket closed during read") + buf.extend(chunk) + return bytes(buf) + + def _dump_cipherstate(self, label: str, cs: CipherState) -> None: + """ + Print the symmetric key (cs._k) and nonce counter (cs._n) for inspection. + """ + key = cs._key + nonce = getattr(cs, "_n", None) + if isinstance(key, (bytes, bytearray)): + key_hex = key.hex() + else: + key_hex = repr(key) + logging.debug(f"[{label}] key={key_hex}") \ No newline at end of file diff --git a/protocol_prototype/requirements.txt b/protocol_prototype/requirements.txt index 731ba4c..14c2100 100644 --- a/protocol_prototype/requirements.txt +++ b/protocol_prototype/requirements.txt @@ -3,4 +3,5 @@ Docker Python3 # Venv install -PyQt5 \ No newline at end of file +PyQt5 +dissononce \ No newline at end of file