diff --git a/zssp/src/zssp (copy).rs b/zssp/src/zssp (copy).rs new file mode 100644 index 000000000..4af45e97a --- /dev/null +++ b/zssp/src/zssp (copy).rs @@ -0,0 +1,1701 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + * + * (c) ZeroTier, Inc. + * https://www.zerotier.com/ + */ + +// ZSSP: ZeroTier Secure Session Protocol +// FIPS compliant Noise_XK with Jedi powers (Kyber1024) and built-in attack-resistant large payload (fragmentation) support. + +use std::collections::{HashMap, HashSet}; +use std::num::NonZeroU64; +use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, RwLock, Weak, MutexGuard}; + +use zerotier_crypto::aes::{Aes, AesGcm}; +use zerotier_crypto::hash::{hmac_sha512, SHA384, SHA384_HASH_SIZE, hmac_sha512_secret}; +use zerotier_crypto::p384::{P384KeyPair, P384PublicKey, P384_ECDH_SHARED_SECRET_SIZE}; +use zerotier_crypto::secret::Secret; +use zerotier_crypto::{random, secure_eq}; + +use pqc_kyber::{KYBER_SECRETKEYBYTES, KYBER_SSBYTES}; + +use crate::applicationlayer::ApplicationLayer; +use crate::error::Error; +use crate::fragged::Fragged; +use crate::proto::*; +use crate::sessionid::SessionId; + +/// Session context for local application. +/// +/// Each application using ZSSP must create an instance of this to own sessions and +/// defragment incoming packets that are not yet associated with a session. +pub struct Context { + max_incomplete_session_queue_size: usize, + defrag: Mutex< + HashMap< + (Application::PhysicalPath, u64), + Arc<( + Mutex>, + i64, // creation timestamp + )>, + >, + >, + sessions: RwLock>, +} + +/// Lookup maps for sessions within a session context. +struct SessionsById { + // Active sessions, automatically closed if the application no longer holds their Arc<>. + active: HashMap>>, + + // Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout. + incoming: HashMap>, +} + +/// Result generated by the context packet receive function, with possible payloads. +pub enum ReceiveResult<'b, Application: ApplicationLayer> { + /// Packet was valid, but no action needs to be taken and no payload was delivered. + Ok(Option>>), + + /// Packet was valid and a data payload was decoded and authenticated. + OkData(Arc>, &'b mut [u8]), + + /// Packet was valid and a new session was created, with optional attached meta-data. + OkNewSession(Arc>, Option<&'b mut [u8]>), + + /// Packet appears valid but was rejected by the application layer, e.g. a rejected new session attempt. + Rejected, +} + +/// ZeroTier Secure Session Protocol (ZSSP) Session +/// +/// A FIPS/NIST compliant variant of Noise_XK with hybrid Kyber1024 PQ data forward secrecy. +pub struct Session { + /// This side's locally unique session ID + pub id: SessionId, + + /// An arbitrary application defined object associated with each session + pub application_data: Application::Data, + + psk: Secret, + send_counter: AtomicU64, + receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO], + header_protection_cipher: Mutex, + state: RwLock, + defrag: [Mutex>; COUNTER_WINDOW_MAX_OOO], +} + +/// Most of the mutable parts of a session state. +struct State { + remote_session_id: Option, + keys: [Option; 2], + current_key: usize, + current_offer: Offer, +} + +struct BobIncomingIncompleteSessionState { + timestamp: i64, + alice_session_id: SessionId, + bob_session_id: SessionId, + noise_h: [u8; SHA384_HASH_SIZE], + noise_es_ee: Secret, + hk: Secret, + header_protection_key: Secret, + bob_noise_e_secret: P384KeyPair, +} + +struct AliceOutgoingIncompleteSessionState { + last_retry_time: AtomicI64, + noise_h: [u8; SHA384_HASH_SIZE], + noise_es: Secret, + alice_noise_e_secret: P384KeyPair, + alice_hk_secret: Secret, + metadata: Option>, + init_packet: [u8; AliceNoiseXKInit::SIZE], +} + +struct OutgoingSessionAck { + last_retry_time: AtomicI64, + ack: [u8; MAX_NOISE_HANDSHAKE_SIZE], + ack_size: usize, +} + +enum Offer { + None, + NoiseXKInit(Box), + NoiseXKAck(Box), + RekeyInit(P384KeyPair, i64), +} + +struct SessionKey { + ratchet_key: Secret, // Key used in derivation of the next session key + //receive_key: Secret, // Receive side AES-GCM key + //send_key: Secret, // Send side AES-GCM key + receive_cipher_pool: [Mutex>; 4], // Pool of reusable sending ciphers + send_cipher_pool: [Mutex>; 4], // Pool of reusable receiving ciphers + rekey_at_time: i64, // Rekey at or after this time (ticks) + created_at_counter: u64, // Counter at which session was created + rekey_at_counter: u64, // Rekey at or after this counter + expire_at_counter: u64, // Hard error when this counter value is reached or exceeded + ratchet_count: u64, // Number of rekey events + bob: bool, // Was this side "Bob" in this exchange? + confirmed: bool, // Is this key confirmed by the other side? +} + +impl Context { + /// Create a new session context. + /// + /// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase + pub fn new(max_incomplete_session_queue_size: usize) -> Self { + zerotier_crypto::init(); + Self { + max_incomplete_session_queue_size, + defrag: Mutex::new(HashMap::new()), + sessions: RwLock::new(SessionsById { + active: HashMap::with_capacity(64), + incoming: HashMap::with_capacity(64), + }), + } + } + + /// Perform periodic background service and cleanup tasks. + /// + /// This returns the number of milliseconds until it should be called again. + /// + /// * `send` - Function to send packets to remote sessions + /// * `mtu` - Physical MTU + /// * `current_time` - Current monotonic time in milliseconds + pub fn service>, &mut [u8])>(&self, mut send: SendFunction, mtu: usize, current_time: i64) -> i64 { + let mut dead_active = Vec::new(); + let mut dead_pending = Vec::new(); + let retry_cutoff = current_time - Application::RETRY_INTERVAL; + let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; + + // Scan sessions in read lock mode, then lock more briefly in write mode to delete any dead entries that we found. + { + let sessions = self.sessions.read().unwrap(); + for (id, s) in sessions.active.iter() { + if let Some(session) = s.upgrade() { + let state = session.state.read().unwrap(); + if match &state.current_offer { + Offer::None => true, + Offer::NoiseXKInit(offer) => { + // If there's an outstanding attempt to open a session, retransmit this periodically + // in case the initial packet doesn't make it. Note that we currently don't have + // retransmission for the intermediate steps, so a new session may still fail if the + // packet loss rate is huge. The application layer has its own logic to keep trying + // under those conditions. + if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { + offer.last_retry_time.store(current_time, Ordering::Relaxed); + let _ = send_with_fragmentation( + |b| send(&session, b), + &mut (offer.init_packet.clone()), + mtu, + PACKET_TYPE_ALICE_NOISE_XK_INIT, + None, + 0, + 1, + None, + ); + } + false + } + Offer::NoiseXKAck(ack) => { + // We also keep retransmitting the final ACK until we get a valid DATA or NOP packet + // from Bob, otherwise we could get a half open session. + if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { + ack.last_retry_time.store(current_time, Ordering::Relaxed); + let _ = send_with_fragmentation( + |b| send(&session, b), + &mut (ack.ack.clone())[..ack.ack_size], + mtu, + PACKET_TYPE_ALICE_NOISE_XK_ACK, + state.remote_session_id, + 0, + 2, + Some(&mut *session.get_header_cipher()), + ); + } + false + } + Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff, + } { + // Check whether we need to rekey if there is no pending offer or if the last rekey + // offer was before retry_cutoff (checked in the 'match' above). + if let Some(key) = state.keys[state.current_key].as_ref() { + if key.bob && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) + { + drop(state); + session.initiate_rekey(|b| send(&session, b), current_time); + } + } + } + } else { + dead_active.push(*id); + } + } + + for (id, incoming) in sessions.incoming.iter() { + if incoming.timestamp <= negotiation_timeout_cutoff { + dead_pending.push(*id); + } + } + } + + if !dead_active.is_empty() || !dead_pending.is_empty() { + let mut sessions = self.sessions.write().unwrap(); + for id in dead_active.iter() { + sessions.active.remove(id); + } + for id in dead_pending.iter() { + sessions.incoming.remove(id); + } + } + + // Delete any expired defragmentation queue items not associated with a session. + self.defrag.lock().unwrap().retain(|_, fragged| fragged.1 > negotiation_timeout_cutoff); + + Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS.min(Application::RETRY_INTERVAL) + } + + /// Create a new session and send initial packet(s) to other side. + /// + /// This will return Error::DataTooLarge if the combined size of the metadata and the local static public + /// blob (as retrieved from the application layer) exceed MAX_INIT_PAYLOAD_SIZE. + /// + /// * `app` - Application layer instance + /// * `send` - User-supplied packet sending function + /// * `mtu` - Physical MTU for calls to send() + /// * `remote_s_public_blob` - Remote side's opaque static public blob (which must contain remote_s_public_p384) + /// * `remote_s_public_p384` - Remote side's static public NIST P-384 key + /// * `psk` - Pre-shared key (use all zero if none) + /// * `metadata` - Optional metadata to be included in initial handshake + /// * `application_data` - Arbitrary opaque data to include with session object + /// * `current_time` - Current monotonic time in milliseconds + pub fn open( + &self, + app: &Application, + mut send: SendFunction, + mtu: usize, + remote_s_public_blob: &[u8], + remote_s_public_p384: &P384PublicKey, + psk: Secret, + metadata: Option>, + application_data: Application::Data, + current_time: i64, + ) -> Result>, Error> { + if (metadata.as_ref().map(|md| md.len()).unwrap_or(0) + app.get_local_s_public_blob().len()) > MAX_INIT_PAYLOAD_SIZE { + return Err(Error::DataTooLarge); + } + + let alice_noise_e_secret = P384KeyPair::generate(); + let alice_noise_e = alice_noise_e_secret.public_key_bytes().clone(); + let noise_es = alice_noise_e_secret.agree(&remote_s_public_p384).ok_or(Error::InvalidParameter)?; + let alice_hk_secret = pqc_kyber::keypair(&mut random::SecureRandom::default()); + let header_protection_key: Secret = Secret(random::get_bytes_secure()); + + let (local_session_id, session) = { + let mut sessions = self.sessions.write().unwrap(); + + let mut local_session_id; + loop { + local_session_id = SessionId::random(); + if !sessions.active.contains_key(&local_session_id) && !sessions.incoming.contains_key(&local_session_id) { + break; + } + } + + let session = Arc::new(Session { + id: local_session_id, + application_data, + psk, + send_counter: AtomicU64::new(3), // 1 and 2 are reserved for init and final ack + receive_window: std::array::from_fn(|_| AtomicU64::new(0)), + header_protection_cipher: Mutex::new(Aes::new(&header_protection_key)), + state: RwLock::new(State { + remote_session_id: None, + keys: [None, None], + current_key: 0, + current_offer: Offer::NoiseXKInit(Box::new(AliceOutgoingIncompleteSessionState { + last_retry_time: AtomicI64::new(current_time), + noise_h: mix_hash(&mix_hash(&INITIAL_H, remote_s_public_blob), &alice_noise_e), + noise_es: noise_es.clone(), + alice_noise_e_secret, + alice_hk_secret: Secret(alice_hk_secret.secret), + metadata, + init_packet: [0u8; AliceNoiseXKInit::SIZE], + })), + }), + defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), + }); + + sessions.active.insert(local_session_id, Arc::downgrade(&session)); + + (local_session_id, session) + }; + + { + let mut state = session.state.write().unwrap(); + let offer = if let Offer::NoiseXKInit(offer) = &mut state.current_offer { + offer + } else { + panic!(); // should be impossible as this is what we initialized with + }; + + // Create Alice's initial outgoing state message. + let init_packet = &mut offer.init_packet; + { + let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(init_packet).unwrap(); + init.session_protocol_version = SESSION_PROTOCOL_VERSION; + init.alice_noise_e = alice_noise_e; + init.alice_session_id = *local_session_id.as_bytes(); + init.alice_hk_public = alice_hk_secret.public; + init.header_protection_key = header_protection_key.0; + } + + // Encrypt and add authentication tag. + let mut gcm = AesGcm::new( + &kbkdf::(noise_es.as_bytes()) + ); + gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_INIT, 1)); + gcm.aad(&offer.noise_h); + gcm.crypt_in_place(&mut init_packet[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]); + init_packet[AliceNoiseXKInit::AUTH_START..AliceNoiseXKInit::AUTH_START + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt()); + + // Update ongoing state hash with Alice's outgoing init ciphertext. + offer.noise_h = mix_hash(&offer.noise_h, &init_packet[HEADER_SIZE..]); + + send_with_fragmentation( + &mut send, + &mut (init_packet.clone()), + mtu, + PACKET_TYPE_ALICE_NOISE_XK_INIT, + None, + 0, + 1, + None, + )?; + } + + return Ok(session); + } + + /// Receive, authenticate, decrypt, and process a physical wire packet. + /// + /// The send function may be called one or more times to send packets. If the packet is associated + /// wtth an active session this session is supplied, otherwise this parameter is None and the packet + /// should be a reply to the current incoming packet. The size of packets to be sent will not exceed + /// the supplied mtu. + /// + /// The check_allow_incoming_session function is called when an initial Noise_XK init message is + /// received. This is before anything is known about the caller. A return value of true proceeds + /// with negotiation. False drops the packet. + /// + /// The check_accept_session function is called at the end of negotiation for an incoming session + /// with the caller's static public blob. It must return the P-384 static public key extracted from + /// the supplied blob, a PSK (or all zeroes if none), and application data to associate with the new + /// session. A return of None rejects and abandons the session. + /// + /// Note that if check_accept_session accepts and returns Some() the session could still fail with + /// receive() returning an error. A Some() return from check_accept_sesion doesn't guarantee + /// successful new session init, only that the application has authorized it. + /// + /// Finally, note that the check_X() functions can end up getting called more than once for a given + /// incoming attempt from a given node if the network quality is poor. That's because the caller may + /// have to retransmit init packets causing repetition of parts of the exchange. + /// + /// * `app` - Interface to application using ZSSP + /// * `check_allow_incoming_session` - Function to call to check whether an unidentified new session should be accepted + /// * `check_accept_session` - Function to accept sessions after final negotiation, or returns None if rejected + /// * `send` - Function to call to send packets + /// * `data_buf` - Buffer to receive decrypted and authenticated object data (an error is returned if too small) + /// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership) + /// * `mtu` - Physical wire MTU for sending packets + /// * `current_time` - Current monotonic time in milliseconds + pub fn receive< + 'b, + SendFunction: FnMut(Option<&Arc>>, &mut [u8]), + CheckAllowIncomingSession: FnMut() -> bool, + CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>, + >( + &self, + app: &Application, + mut check_allow_incoming_session: CheckAllowIncomingSession, + mut check_accept_session: CheckAcceptSession, + mut send: SendFunction, + source: &Application::PhysicalPath, + data_buf: &'b mut [u8], + mut incoming_packet_buf: Application::IncomingPacketBuffer, + mtu: usize, + current_time: i64, + ) -> Result, Error> { + let incoming_packet: &mut [u8] = incoming_packet_buf.as_mut(); + if incoming_packet.len() < MIN_PACKET_SIZE { + return Err(Error::InvalidPacket); + } + + let mut incoming = None; + if let Some(local_session_id) = SessionId::new_from_u64_le(u64::from_le_bytes(incoming_packet[0..8].try_into().unwrap())) { + if let Some(session) = self.sessions.read().unwrap().active.get(&local_session_id).and_then(|s| s.upgrade()) { + debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id)); + + session + .get_header_cipher() + .decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); + + if session.check_receive_window(incoming_counter) { + if fragment_count > 1 { + let mut fragged = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO].lock().unwrap(); + if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) { + drop(fragged); + return self.process_complete_incoming_packet( + app, + &mut send, + &mut check_allow_incoming_session, + &mut check_accept_session, + data_buf, + incoming_counter, + assembled_packet.as_ref(), + packet_type, + Some(session), + None, + key_index, + mtu, + current_time, + ); + } else { + drop(fragged); + return Ok(ReceiveResult::Ok(Some(session))); + } + } else { + return self.process_complete_incoming_packet( + app, + &mut send, + &mut check_allow_incoming_session, + &mut check_accept_session, + data_buf, + incoming_counter, + &[incoming_packet_buf], + packet_type, + Some(session), + None, + key_index, + mtu, + current_time, + ); + } + } else { + return Err(Error::OutOfSequence); + } + } else { + if let Some(i) = self.sessions.read().unwrap().incoming.get(&local_session_id).cloned() { + Aes::new(&i.header_protection_key) + .decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + incoming = Some(i); + } else { + return Err(Error::UnknownLocalSessionId); + } + } + } + + // If we make it here the packet is not associated with a session or is associated with an + // incoming session (Noise_XK mid-negotiation). + + let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); + if fragment_count > 1 { + let f = { + let mut defrag = self.defrag.lock().unwrap(); + let f = defrag + .entry((source.clone(), incoming_counter)) + .or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time))) + .clone(); + + // Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions. + if defrag.len() >= self.max_incomplete_session_queue_size { + // First, drop all entries that are timed out or whose physical source duplicates another entry. + let mut sources = HashSet::with_capacity(defrag.len()); + let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; + defrag.retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f)); + + // Then, if we are still at or over the limit, drop 10% of remaining entries at random. + if defrag.len() >= self.max_incomplete_session_queue_size { + let mut rn = random::next_u32_secure(); + defrag.retain(|_, fragged| { + rn = prng32(rn); + rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f) + }); + } + } + + f + }; + let mut fragged = f.0.lock().unwrap(); + + if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) { + self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter)); + return self.process_complete_incoming_packet( + app, + &mut send, + &mut check_allow_incoming_session, + &mut check_accept_session, + data_buf, + incoming_counter, + assembled_packet.as_ref(), + packet_type, + None, + incoming, + key_index, + mtu, + current_time, + ); + } + } else { + return self.process_complete_incoming_packet( + app, + &mut send, + &mut check_allow_incoming_session, + &mut check_accept_session, + data_buf, + incoming_counter, + &[incoming_packet_buf], + packet_type, + None, + incoming, + key_index, + mtu, + current_time, + ); + } + + return Ok(ReceiveResult::Ok(None)); + } + + fn process_complete_incoming_packet< + 'b, + SendFunction: FnMut(Option<&Arc>>, &mut [u8]), + CheckAllowIncomingSession: FnMut() -> bool, + CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>, + >( + &self, + app: &Application, + send: &mut SendFunction, + check_allow_incoming_session: &mut CheckAllowIncomingSession, + check_accept_session: &mut CheckAcceptSession, + data_buf: &'b mut [u8], + incoming_counter: u64, + fragments: &[Application::IncomingPacketBuffer], + packet_type: u8, + session: Option>>, + incoming: Option>, + key_index: usize, + mtu: usize, + current_time: i64, + ) -> Result, Error> { + debug_assert!(fragments.len() >= 1); + + // Generate incoming message nonce for decryption and authentication. + let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter); + + if packet_type <= PACKET_TYPE_DATA { + if let Some(session) = session { + let state = session.state.read().unwrap(); + if let Some(key) = state.keys[key_index].as_ref() { + let mut c = key.get_receive_cipher(); + c.reset_init_gcm(&incoming_message_nonce); + + let mut data_len = 0; + + // Decrypt fragments 0..N-1 where N is the number of fragments. + for f in fragments[..(fragments.len() - 1)].iter() { + let f: &[u8] = f.as_ref(); + debug_assert!(f.len() >= HEADER_SIZE); + let current_frag_data_start = data_len; + data_len += f.len() - HEADER_SIZE; + if data_len > data_buf.len() { + drop(c); + return Err(Error::DataBufferTooSmall); + } + c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); + } + + // Decrypt final fragment (or only fragment if not fragmented) + let current_frag_data_start = data_len; + let last_fragment = fragments.last().unwrap().as_ref(); + if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { + return Err(Error::InvalidPacket); + } + data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); + if data_len > data_buf.len() { + drop(c); + return Err(Error::DataBufferTooSmall); + } + let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE; + c.crypt(&last_fragment[HEADER_SIZE..payload_end], &mut data_buf[current_frag_data_start..data_len]); + + let aead_authentication_ok = c.finish_decrypt(&last_fragment[payload_end..]); + drop(c); + + if aead_authentication_ok { + if session.update_receive_window(incoming_counter) { + // Update the current key to point to this key if it's newer, since having received + // a packet encrypted with it proves that the other side has successfully derived it + // as well. + if state.current_key == key_index && key.confirmed { + drop(state); + } else { + let current_key_created_at_counter = key.created_at_counter; + + drop(state); + let mut state = session.state.write().unwrap(); + + if state.current_key != key_index { + if let Some(other_session_key) = state.keys[state.current_key].as_ref() { + if other_session_key.created_at_counter < current_key_created_at_counter { + state.current_key = key_index; + } + } else { + state.current_key = key_index; + } + } + state.keys[key_index].as_mut().unwrap().confirmed = true; + + // If we got a valid data packet from Bob, this means we can cancel any offers + // that are still oustanding for initialization. + match &state.current_offer { + Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => { + state.current_offer = Offer::None; + } + _ => {} + } + } + + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len])); + } else { + return Ok(ReceiveResult::Ok(Some(session))); + } + } else { + return Err(Error::OutOfSequence); + } + } + } + + return Err(Error::FailedAuthentication); + } else { + return Err(Error::UnknownLocalSessionId); + } + } else { + // For Noise setup/KEX packets go ahead and pre-assemble all fragments to simplify the code below. + let mut pkt_assembly_buffer = [0u8; MAX_NOISE_HANDSHAKE_SIZE]; + let pkt_assembled_size = assemble_fragments_into::(fragments, &mut pkt_assembly_buffer)?; + if pkt_assembled_size < MIN_PACKET_SIZE { + return Err(Error::InvalidPacket); + } + let pkt_assembled = &mut pkt_assembly_buffer[..pkt_assembled_size]; + if pkt_assembled[HEADER_SIZE] != SESSION_PROTOCOL_VERSION { + return Err(Error::UnknownProtocolVersion); + } + + match packet_type { + PACKET_TYPE_ALICE_NOISE_XK_INIT => { + // Alice (remote) --> Bob (local) + + /* + * This is the first message Bob receives from Alice, the initiator. It contains + * Alice's ephemeral keys but not her identity. Alice will not reveal her identity + * until forward secrecy is established and she's authenticated Bob. + * + * Bob authenticates the message and confirms that Alice indeed knows Bob's + * identity, then responds with his ephemeral keys. + */ + + if incoming_counter != 1 || session.is_some() || incoming.is_some() { + return Err(Error::OutOfSequence); + } + if pkt_assembled.len() != AliceNoiseXKInit::SIZE { + return Err(Error::InvalidPacket); + } + + // Otherwise parse the packet, authenticate, generate keys, etc. and record state in an + // incoming state object until this phase of the negotiation is done. + let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?; + let alice_noise_e = P384PublicKey::from_bytes(&pkt.alice_noise_e).ok_or(Error::FailedAuthentication)?; + let noise_es = app.get_local_s_keypair().agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?; + + let noise_h = mix_hash(&mix_hash(&INITIAL_H, app.get_local_s_public_blob()), alice_noise_e.as_bytes()); + let noise_h_next = mix_hash(&noise_h, &pkt_assembled[HEADER_SIZE..]); + + // Decrypt and authenticate init packet, also proving that caller knows our static identity. + let mut gcm = AesGcm::new( + &kbkdf::(noise_es.as_bytes()) + ); + gcm.reset_init_gcm(&incoming_message_nonce); + gcm.aad(&noise_h); + gcm.crypt_in_place(&mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]); + if !gcm.finish_decrypt(&pkt_assembled[AliceNoiseXKInit::AUTH_START..AliceNoiseXKInit::AUTH_START + AES_GCM_TAG_SIZE]) { + return Err(Error::FailedAuthentication); + } + + // Let application filter incoming connection attempt by whatever criteria it wants. + if !check_allow_incoming_session() { + return Ok(ReceiveResult::Rejected); + } + + let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?; + let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?; + let header_protection_key = Secret(pkt.header_protection_key); + + // Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create + // a Kyber ciphertext to send back to Alice. + let bob_noise_e_secret = P384KeyPair::generate(); + let bob_noise_e = bob_noise_e_secret.public_key_bytes().clone(); + let noise_es_ee = Secret(hmac_sha512( + noise_es.as_bytes(), + bob_noise_e_secret.agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?.as_bytes(), + )); + let (bob_hk_ciphertext, hk) = pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default()) + .map_err(|_| Error::FailedAuthentication) + .map(|(ct, hk)| (ct, Secret(hk)))?; + + let mut sessions = self.sessions.write().unwrap(); + + let mut bob_session_id; + loop { + bob_session_id = SessionId::random(); + if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) { + break; + } + } + + // Create Bob's ephemeral counter-offer reply. + let mut ack_packet = [0u8; BobNoiseXKAck::SIZE]; + let ack: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut ack_packet)?; + ack.session_protocol_version = SESSION_PROTOCOL_VERSION; + ack.bob_noise_e = bob_noise_e; + ack.bob_session_id = *bob_session_id.as_bytes(); + ack.bob_hk_ciphertext = bob_hk_ciphertext; + + // Encrypt main section of reply and attach tag. + let mut gcm = AesGcm::new( + &kbkdf::(noise_es_ee.as_bytes()) + ); + gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_BOB_NOISE_XK_ACK, 1)); + gcm.aad(&noise_h_next); + gcm.crypt_in_place(&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]); + ack_packet[BobNoiseXKAck::AUTH_START..BobNoiseXKAck::AUTH_START + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt()); + + // If this queue is too big, we remove the latest entry and replace it. The latest + // is used because under flood conditions this is most likely to be another bogus + // entry. If we find one that is actually timed out, that one is replaced instead. + if sessions.incoming.len() >= self.max_incomplete_session_queue_size { + let mut newest = i64::MIN; + let mut replace_id = None; + let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; + for (id, s) in sessions.incoming.iter() { + if s.timestamp <= cutoff_time { + replace_id = Some(*id); + break; + } else if s.timestamp >= newest { + newest = s.timestamp; + replace_id = Some(*id); + } + } + let _ = sessions.incoming.remove(replace_id.as_ref().unwrap()); + } + + // Reserve session ID on this side and record incomplete session state. + sessions.incoming.insert( + bob_session_id, + Arc::new(BobIncomingIncompleteSessionState { + timestamp: current_time, + alice_session_id, + bob_session_id, + noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]), + noise_es_ee: noise_es_ee.clone(), + hk, + bob_noise_e_secret, + header_protection_key: Secret(pkt.header_protection_key), + }), + ); + debug_assert!(!sessions.active.contains_key(&bob_session_id)); + + // Release lock + drop(sessions); + + send_with_fragmentation( + |b| send(None, b), + &mut ack_packet, + mtu, + PACKET_TYPE_BOB_NOISE_XK_ACK, + Some(alice_session_id), + 0, + 1, + Some(&mut Aes::new(&header_protection_key)), + )?; + + return Ok(ReceiveResult::Ok(session)); + } + + PACKET_TYPE_BOB_NOISE_XK_ACK => { + // Bob (remote) --> Alice (local) + + /* + * This is Bob's reply to Alice's first message, allowing Alice to verify Bob's + * identity. Once this is done Alice can send her identity (encrypted) to complete + * the negotiation. + */ + + if incoming_counter != 1 || incoming.is_some() { + return Err(Error::OutOfSequence); + } + if pkt_assembled.len() != BobNoiseXKAck::SIZE { + return Err(Error::InvalidPacket); + } + + if let Some(session) = session { + let state = session.state.read().unwrap(); + + // This doesn't make sense if the session is up. + if state.keys[state.current_key].is_some() { + return Err(Error::OutOfSequence); + } + + if let Offer::NoiseXKInit(outgoing_offer) = &state.current_offer { + let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?; + + // Derive noise_es_ee from Bob's ephemeral public key. + let bob_noise_e = P384PublicKey::from_bytes(&pkt.bob_noise_e).ok_or(Error::FailedAuthentication)?; + let noise_es_ee = Secret(hmac_sha512( + outgoing_offer.noise_es.as_bytes(), + outgoing_offer + .alice_noise_e_secret + .agree(&bob_noise_e) + .ok_or(Error::FailedAuthentication)? + .as_bytes(), + )); + + // Go ahead and compute the next 'h' state before we lose the ciphertext in decrypt. + let noise_h_next = mix_hash(&mix_hash(&outgoing_offer.noise_h, bob_noise_e.as_bytes()), &pkt_assembled[HEADER_SIZE..]); + + // Decrypt and authenticate Bob's reply. + let mut gcm = AesGcm::new( + &kbkdf::(noise_es_ee.as_bytes()) + ); + gcm.reset_init_gcm(&incoming_message_nonce); + gcm.aad(&outgoing_offer.noise_h); + gcm.crypt_in_place(&mut pkt_assembled[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]); + if !gcm.finish_decrypt(&pkt_assembled[BobNoiseXKAck::AUTH_START..BobNoiseXKAck::AUTH_START + AES_GCM_TAG_SIZE]) { + return Err(Error::FailedAuthentication); + } + + let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?; + + if let Some(bob_session_id) = SessionId::new_from_bytes(&pkt.bob_session_id) { + // Complete Noise_XKpsk3 by mixing in noise_se followed by the PSK. The PSK as far as + // the Noise pattern is concerned is the result of mixing the externally supplied PSK + // with the Kyber1024 shared secret (hk). Kyber is treated as part of the PSK because + // it's an external add-on beyond the Noise spec. + let hk = pqc_kyber::decapsulate(&pkt.bob_hk_ciphertext, outgoing_offer.alice_hk_secret.as_bytes()) + .map_err(|_| Error::FailedAuthentication) + .map(|k| Secret(k))?; + let noise_es_ee_se_hk_psk = Secret(hmac_sha512( + &hmac_sha512( + noise_es_ee.as_bytes(), + app.get_local_s_keypair() + .agree(&bob_noise_e) + .ok_or(Error::FailedAuthentication)? + .as_bytes(), + ), + &hmac_sha512(session.psk.as_bytes(), hk.as_bytes()), + )); + + let reply_message_nonce = create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_ACK, 2); + + // Create reply informing Bob of our static identity now that we've verified Bob and set + // up forward secrecy. Also return Bob's opaque note. + let mut reply_buffer = [0u8; MAX_NOISE_HANDSHAKE_SIZE]; + reply_buffer[HEADER_SIZE] = SESSION_PROTOCOL_VERSION; + let mut reply_len = HEADER_SIZE + 1; + + let alice_s_public_blob = app.get_local_s_public_blob(); + assert!(alice_s_public_blob.len() <= (u16::MAX as usize)); + reply_len = append_to_slice(&mut reply_buffer, reply_len, &(alice_s_public_blob.len() as u16).to_le_bytes())?; + let mut enc_start = reply_len; + reply_len = append_to_slice(&mut reply_buffer, reply_len, alice_s_public_blob)?; + + let mut gcm = AesGcm::new( + &kbkdf::(&hmac_sha512( + noise_es_ee.as_bytes(), + hk.as_bytes(), + )) + ); + gcm.reset_init_gcm(&reply_message_nonce); + gcm.aad(&noise_h_next); + gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]); + reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?; + + let metadata = outgoing_offer.metadata.as_ref().map_or(&[][..0], |md| md.as_slice()); + + assert!(metadata.len() <= (u16::MAX as usize)); + reply_len = append_to_slice(&mut reply_buffer, reply_len, &(metadata.len() as u16).to_le_bytes())?; + + let noise_h_next = mix_hash(&mix_hash(&noise_h_next, &reply_buffer[HEADER_SIZE..reply_len]), session.psk.as_bytes()); + + enc_start = reply_len; + reply_len = append_to_slice(&mut reply_buffer, reply_len, metadata)?; + + let mut gcm = AesGcm::new( + &kbkdf::(noise_es_ee_se_hk_psk.as_bytes()) + ); + gcm.reset_init_gcm(&reply_message_nonce); + gcm.aad(&noise_h_next); + gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]); + reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?; + + drop(state); + { + let mut state = session.state.write().unwrap(); + let _ = state.remote_session_id.insert(bob_session_id); + let _ = + state.keys[0].insert(SessionKey::new::(noise_es_ee_se_hk_psk, 1, current_time, 2, false, false)); + debug_assert!(state.keys[1].is_none()); + state.current_key = 0; + state.current_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck { + last_retry_time: AtomicI64::new(current_time), + ack: reply_buffer, + ack_size: reply_len, + })); + } + + send_with_fragmentation( + |b| send(Some(&session), b), + &mut reply_buffer[..reply_len], + mtu, + PACKET_TYPE_ALICE_NOISE_XK_ACK, + Some(bob_session_id), + 0, + 2, + Some(&mut *session.get_header_cipher()), + )?; + + return Ok(ReceiveResult::Ok(Some(session))); + } else { + return Err(Error::InvalidPacket); + } + } else { + return Err(Error::OutOfSequence); + } + } else { + return Err(Error::UnknownLocalSessionId); + } + } + + PACKET_TYPE_ALICE_NOISE_XK_ACK => { + // Alice (remote) --> Bob (local) + + /* + * After negotiating a keyed session and Alice has had the opportunity to + * verify Bob, this is when Bob gets to learn who Alice is. At this point + * Bob can make a final decision about whether to keep talking to Alice + * and can create an actual session using the state memo-ized in the memo + * that Alice must return. + */ + + if incoming_counter != 2 || session.is_some() { + return Err(Error::OutOfSequence); + } + if pkt_assembled.len() < ALICE_NOISE_XK_ACK_MIN_SIZE { + return Err(Error::InvalidPacket); + } + + if let Some(incoming) = incoming { + let mut r = PktReader(pkt_assembled, HEADER_SIZE + 1); + + let alice_static_public_blob_size = r.read_u16()? as usize; + + let ciphertext_up_to_metadata_size = r.1 + alice_static_public_blob_size + AES_GCM_TAG_SIZE + 2; + if r.0.len() < ciphertext_up_to_metadata_size { + return Err(Error::InvalidPacket); + } + let noise_h_next = mix_hash(&incoming.noise_h, &r.0[HEADER_SIZE..ciphertext_up_to_metadata_size]); + + let alice_static_public_blob = r.read_decrypt_auth( + alice_static_public_blob_size, + kbkdf::(&hmac_sha512( + incoming.noise_es_ee.as_bytes(), + incoming.hk.as_bytes(), + )), + &incoming.noise_h, + &incoming_message_nonce, + )?; + + // Check session acceptance and fish Alice's NIST P-384 static public key out of her static public blob. + let check_result = check_accept_session(alice_static_public_blob); + if check_result.is_none() { + self.sessions.write().unwrap().incoming.remove(&incoming.bob_session_id); + return Ok(ReceiveResult::Rejected); + } + let (alice_noise_s, psk, application_data) = check_result.unwrap(); + + let noise_h_next = mix_hash(&noise_h_next, psk.as_bytes()); + + // Complete Noise_XKpsk3 on Bob's side. + let noise_es_ee_se_hk_psk = Secret(hmac_sha512( + &hmac_sha512( + incoming.noise_es_ee.as_bytes(), + incoming + .bob_noise_e_secret + .agree(&alice_noise_s) + .ok_or(Error::FailedAuthentication)? + .as_bytes(), + ), + &hmac_sha512(psk.as_bytes(), incoming.hk.as_bytes()), + )); + + // Decrypt meta-data and verify the final key in the process. Copy meta-data + // into the temporary data buffer to return. + let alice_meta_data_size = r.read_u16()? as usize; + let alice_meta_data = r.read_decrypt_auth( + alice_meta_data_size, + kbkdf::(noise_es_ee_se_hk_psk.as_bytes()), + &noise_h_next, + &incoming_message_nonce, + )?; + if alice_meta_data.len() > data_buf.len() { + return Err(Error::DataTooLarge); + } + data_buf[..alice_meta_data.len()].copy_from_slice(alice_meta_data); + + let session = Arc::new(Session { + id: incoming.bob_session_id, + application_data, + psk, + send_counter: AtomicU64::new(2), // 1 was already used during negotiation + receive_window: std::array::from_fn(|_| AtomicU64::new(0)), + header_protection_cipher: Mutex::new(Aes::new(&incoming.header_protection_key)), + state: RwLock::new(State { + remote_session_id: Some(incoming.alice_session_id), + keys: [ + Some(SessionKey::new::(noise_es_ee_se_hk_psk, 1, current_time, 2, true, true)), + None, + ], + current_key: 0, + current_offer: Offer::None, + }), + defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), + }); + + // Promote incoming session to active. + { + let mut sessions = self.sessions.write().unwrap(); + sessions.incoming.remove(&incoming.bob_session_id); + sessions.active.insert(incoming.bob_session_id, Arc::downgrade(&session)); + } + + let _ = session.send_nop(|b| send(Some(&session), b)); + + return Ok(ReceiveResult::OkNewSession( + session, + if alice_meta_data.is_empty() { + None + } else { + Some(&mut data_buf[..alice_meta_data.len()]) + }, + )); + } else { + return Err(Error::UnknownLocalSessionId); + } + } + + PACKET_TYPE_REKEY_INIT => { + if pkt_assembled.len() != RekeyInit::SIZE { + return Err(Error::InvalidPacket); + } + if incoming.is_some() { + return Err(Error::OutOfSequence); + } + + if let Some(session) = session { + let state = session.state.read().unwrap(); + if let Some(remote_session_id) = state.remote_session_id { + if let Some(key) = state.keys[key_index].as_ref() { + // Only the current "Alice" accepts rekeys initiated by the current "Bob." These roles + // flip with each rekey event. + if !key.bob { + let mut c = key.get_receive_cipher(); + c.reset_init_gcm(&incoming_message_nonce); + c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]); + let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]); + drop(c); + + if aead_authentication_ok { + let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); + if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) { + let bob_e_secret = P384KeyPair::generate(); + let next_session_key = Secret(hmac_sha512( + key.ratchet_key.as_bytes(), + bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?.as_bytes(), + )); + + let mut reply_buf = [0u8; RekeyAck::SIZE]; + let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap(); + reply.session_protocol_version = SESSION_PROTOCOL_VERSION; + reply.bob_e = *bob_e_secret.public_key_bytes(); + reply.next_key_fingerprint = SHA384::hash(next_session_key.as_bytes()); + + let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + set_packet_header( + &mut reply_buf, + 1, + 0, + PACKET_TYPE_REKEY_ACK, + u64::from(remote_session_id), + state.current_key, + counter, + ); + + let mut c = key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter)); + c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]); + reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt()); + drop(c); + + session + .get_header_cipher() + .encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(Some(&session), &mut reply_buf); + + // The new "Bob" doesn't know yet if Alice has received the new key, so the + // new key is recorded as the "alt" (key_index ^ 1) but the current key is + // not advanced yet. This happens automatically the first time we receive a + // valid packet with the new key. + let next_ratchet_count = key.ratchet_count + 1; + drop(state); + let mut state = session.state.write().unwrap(); + let _ = state.keys[key_index ^ 1].replace(SessionKey::new::( + next_session_key, + next_ratchet_count, + current_time, + counter, + false, + false, + )); + + drop(state); + return Ok(ReceiveResult::Ok(Some(session))); + } + } + return Err(Error::FailedAuthentication); + } + } + } + return Err(Error::OutOfSequence); + } else { + return Err(Error::UnknownLocalSessionId); + } + } + + PACKET_TYPE_REKEY_ACK => { + if pkt_assembled.len() != RekeyAck::SIZE { + return Err(Error::InvalidPacket); + } + if incoming.is_some() { + return Err(Error::OutOfSequence); + } + + if let Some(session) = session { + let state = session.state.read().unwrap(); + if let Offer::RekeyInit(alice_e_secret, _) = &state.current_offer { + if let Some(key) = state.keys[key_index].as_ref() { + // Only the current "Bob" initiates rekeys and expects this ACK. + if key.bob { + let mut c = key.get_receive_cipher(); + c.reset_init_gcm(&incoming_message_nonce); + c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]); + let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]); + drop(c); + + if aead_authentication_ok { + let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); + if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) { + let next_session_key = Secret(hmac_sha512( + key.ratchet_key.as_bytes(), + alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?.as_bytes(), + )); + + if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(next_session_key.as_bytes())) { + // The new "Alice" knows Bob has the key since this is an ACK, so she can go + // ahead and set current_key to the new key. Then when she sends something + // to Bob the other side will automatically advance to the new key as well. + let next_ratchet_count = key.ratchet_count + 1; + drop(state); + let next_key_index = key_index ^ 1; + let mut state = session.state.write().unwrap(); + let _ = state.keys[next_key_index].replace(SessionKey::new::( + next_session_key, + next_ratchet_count, + current_time, + session.send_counter.load(Ordering::Acquire), + true, + true, + )); + state.current_key = next_key_index; // this is an ACK so it's confirmed + state.current_offer = Offer::None; + + drop(state); + return Ok(ReceiveResult::Ok(Some(session))); + } + } + } + return Err(Error::FailedAuthentication); + } + } + } + return Err(Error::OutOfSequence); + } else { + return Err(Error::UnknownLocalSessionId); + } + } + + _ => { + return Err(Error::InvalidPacket); + } + } + } + } +} + +impl Session { + /// Send data over the session. + /// + /// * `send` - Function to call to send physical packet(s) + /// * `mtu_sized_buffer` - A writable work buffer whose size also specifies the physical MTU + /// * `data` - Data to send + #[inline] + pub fn send(&self, mut send: SendFunction, mtu_sized_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> { + debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); + let state = self.state.read().unwrap(); + if let Some(remote_session_id) = state.remote_session_id { + if let Some(session_key) = state.keys[state.current_key].as_ref() { + let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + + let mut c = session_key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); + + let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize; + let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE; + let last_fragment_no = fragment_count - 1; + + for fragment_no in 0..fragment_count { + let chunk_size = fragment_max_chunk_size.min(data.len()); + let mut fragment_size = chunk_size + HEADER_SIZE; + + set_packet_header( + mtu_sized_buffer, + fragment_count, + fragment_no, + PACKET_TYPE_DATA, + u64::from(remote_session_id), + state.current_key, + counter, + ); + + c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); + data = &data[chunk_size..]; + + if fragment_no == last_fragment_no { + debug_assert!(data.is_empty()); + let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE; + mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt()); + fragment_size = tagged_fragment_size; + } + + self.get_header_cipher() + .encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(&mut mtu_sized_buffer[..fragment_size]); + } + debug_assert!(data.is_empty()); + + drop(c); + + return Ok(()); + } + } + return Err(Error::SessionNotEstablished); + } + + /// Send a NOP to the other side (e.g. for keep alive). + pub fn send_nop(&self, mut send: SendFunction) -> Result<(), Error> { + let state = self.state.read().unwrap(); + if let Some(remote_session_id) = state.remote_session_id { + if let Some(session_key) = state.keys[state.current_key].as_ref() { + let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; + let mut c = session_key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); + nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); + drop(c); + set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter); + self.get_header_cipher() + .encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(&mut nop); + } + } + return Err(Error::SessionNotEstablished); + } + + /// Check whether this session is established. + pub fn established(&self) -> bool { + let state = self.state.read().unwrap(); + state.keys[state.current_key].as_ref().map_or(false, |k| k.confirmed) + } + + /// Get the ratchet count and a hash fingerprint of the current active key. + pub fn key_info(&self) -> Option<(u64, [u8; 48])> { + let state = self.state.read().unwrap(); + if let Some(key) = state.keys[state.current_key].as_ref() { + Some((key.ratchet_count, SHA384::hash(key.ratchet_key.as_bytes()))) + } else { + None + } + } + + /// Send a rekey init message. + /// + /// This is called from the session context's service() method when it's time to rekey. + /// It should only be called when the current key was established in the 'bob' role. This + /// is checked when rekey time is checked. + fn initiate_rekey(&self, mut send: SendFunction, current_time: i64) { + let rekey_e = P384KeyPair::generate(); + + let mut rekey_buf = [0u8; RekeyInit::SIZE]; + let pkt: &mut RekeyInit = byte_array_as_proto_buffer_mut(&mut rekey_buf).unwrap(); + pkt.session_protocol_version = SESSION_PROTOCOL_VERSION; + pkt.alice_e = *rekey_e.public_key_bytes(); + + let state = self.state.read().unwrap(); + if let Some(remote_session_id) = state.remote_session_id { + if let Some(key) = state.keys[state.current_key].as_ref() { + if let Some(counter) = self.get_next_outgoing_counter() { + if let Ok(mut gcm) = key.get_send_cipher(counter.get()) { + gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_INIT, counter.get())); + gcm.crypt_in_place(&mut rekey_buf[RekeyInit::ENC_START..RekeyInit::AUTH_START]); + rekey_buf[RekeyInit::AUTH_START..].copy_from_slice(&gcm.finish_encrypt()); + drop(gcm); + + debug_assert!(rekey_buf.len() <= MIN_TRANSPORT_MTU); + set_packet_header( + &mut rekey_buf, + 1, + 0, + PACKET_TYPE_REKEY_INIT, + u64::from(remote_session_id), + state.current_key, + counter.get(), + ); + + //drop(key); + //drop(gcm); + //drop(state); + + self.get_header_cipher() + .encrypt_block_in_place(&mut rekey_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(&mut rekey_buf); + + self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, current_time); + } + } + } + } + } + + /// Get the next outgoing counter value. + #[inline(always)] + fn get_next_outgoing_counter(&self) -> Option { + NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::SeqCst)) + } + + /// Check the receive window without mutating state. + #[inline(always)] + fn check_receive_window(&self, counter: u64) -> bool { + let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Acquire); + prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD + } + + /// Update the receive window, returning true if the packet is still valid. + /// This should only be called after the packet is authenticated. + #[inline(always)] + fn update_receive_window(&self, counter: u64) -> bool { + let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel); + prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD + } + + #[inline(always)] + fn get_header_cipher<'a>(&'a self) -> MutexGuard<'a, Aes>{ + self.header_protection_cipher.lock().unwrap() + } +} + +#[inline(always)] +fn set_packet_header( + packet: &mut [u8], + fragment_count: usize, + fragment_no: usize, + packet_type: u8, + remote_session_id: u64, + key_index: usize, + counter: u64, +) { + debug_assert!(packet.len() >= MIN_PACKET_SIZE); + debug_assert!(fragment_count > 0); + debug_assert!(fragment_count <= MAX_FRAGMENTS); + debug_assert!(fragment_no < MAX_FRAGMENTS); + debug_assert!(packet_type <= 0x0f); // packet type is 4 bits + + // [0-47] recipient session ID + // -- start of header check cipher single block encrypt -- + // [48-48] key index (least significant bit) + // [49-51] packet type (0-15) + // [52-57] fragment count (1..64 - 1, so 0 means 1 fragment) + // [58-63] fragment number (0..63) + // [64-127] 64-bit counter + assert!(packet.len() >= 16); + packet[0..8].copy_from_slice( + &(remote_session_id + | ((key_index & 1) as u64).wrapping_shl(48) + | (packet_type as u64).wrapping_shl(49) + | ((fragment_count - 1) as u64).wrapping_shl(52) + | (fragment_no as u64).wrapping_shl(58)) + .to_le_bytes(), + ); + packet[8..16].copy_from_slice(&counter.to_le_bytes()); +} + +#[inline(always)] +fn parse_packet_header(incoming_packet: &[u8]) -> (usize, u8, u8, u8, u64) { + let raw_header_a = u16::from_le_bytes(incoming_packet[6..8].try_into().unwrap()); + ( + (raw_header_a & 1) as usize, + (raw_header_a.wrapping_shr(1) & 7) as u8, + ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8, + raw_header_a.wrapping_shr(10) as u8, + u64::from_le_bytes(incoming_packet[8..16].try_into().unwrap()), + ) +} + +/// Break a packet into fragments and send them all. +/// +/// The contents of packet[] are mangled during this operation, so it should be discarded after. +/// This is only used for key exchange and control packets. For data packets this is done inline +/// for better performance with encryption and fragmentation happening at the same time. +fn send_with_fragmentation( + mut send: SendFunction, + packet: &mut [u8], + mtu: usize, + packet_type: u8, + remote_session_id: Option, + key_index: usize, + counter: u64, + mut header_protect_cipher: Option<&mut Aes>, +) -> Result<(), Error> { + let packet_len = packet.len(); + let recipient_session_id = remote_session_id.map_or(SessionId::NONE, |s| u64::from(s)); + let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize; + let mut fragment_start = 0; + let mut fragment_end = packet_len.min(mtu); + for fragment_no in 0..fragment_count { + let fragment = &mut packet[fragment_start..fragment_end]; + set_packet_header( + fragment, + fragment_count, + fragment_no, + packet_type, + recipient_session_id, + key_index, + counter, + ); + if let Some(hcc) = header_protect_cipher.take() { + hcc.encrypt_block_in_place(&mut fragment[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + } + send(fragment); + fragment_start = fragment_end - HEADER_SIZE; + fragment_end = (fragment_start + mtu).min(packet_len); + } + Ok(()) +} + +/// Assemble a series of fragments into a buffer and return the length of the assembled packet in bytes. +/// +/// This is also only used for key exchange and control packets. For data packets decryption and assembly +/// happen in one pass for better performance. +fn assemble_fragments_into(fragments: &[A::IncomingPacketBuffer], d: &mut [u8]) -> Result { + let mut l = 0; + for i in 0..fragments.len() { + let mut ff = fragments[i].as_ref(); + if ff.len() <= MIN_PACKET_SIZE { + return Err(Error::InvalidPacket); + } + if i > 0 { + ff = &ff[HEADER_SIZE..]; + } + let j = l + ff.len(); + if j > d.len() { + return Err(Error::InvalidPacket); + } + d[l..j].copy_from_slice(ff); + l = j; + } + return Ok(l); +} + +impl SessionKey { + fn new( + key: Secret, + ratchet_count: u64, + current_time: i64, + current_counter: u64, + bob: bool, + confirmed: bool, + ) -> Self { + let a2b = kbkdf::(key.as_bytes()); + let b2a = kbkdf::(key.as_bytes()); + let (receive_key, send_key) = if bob { + (a2b, b2a) + } else { + (b2a, a2b) + }; + let receive_cipher_pool = std::array::from_fn(|_| Mutex::new(AesGcm::new(&receive_key))); + let send_cipher_pool = std::array::from_fn(|_| Mutex::new(AesGcm::new(&send_key))); + Self { + ratchet_key: kbkdf::(key.as_bytes()), + //receive_key, + //send_key, + receive_cipher_pool, + send_cipher_pool, + rekey_at_time: current_time + .checked_add( + Application::REKEY_AFTER_TIME_MS + ((random::xorshift64_random() as u32) % Application::REKEY_AFTER_TIME_MS_MAX_JITTER) as i64, + ) + .unwrap(), + created_at_counter: current_counter, + rekey_at_counter: current_counter.checked_add(Application::REKEY_AFTER_USES).unwrap(), + expire_at_counter: current_counter.checked_add(Application::EXPIRE_AFTER_USES).unwrap(), + ratchet_count, + bob, + confirmed, + } + } + + fn get_send_cipher<'a>(&'a self, counter: u64) -> Result>, Error> { + if counter < self.expire_at_counter { + for mutex in &self.send_cipher_pool { + if let Ok(guard) = mutex.try_lock() { + return Ok(guard) + } + } + Ok(self.send_cipher_pool[0].lock().unwrap()) + } else { + Err(Error::MaxKeyLifetimeExceeded) + } + } + + fn get_receive_cipher<'a>(&'a self) -> MutexGuard<'a, AesGcm> { + for mutex in &self.receive_cipher_pool { + if let Ok(guard) = mutex.try_lock() { + return guard + } + } + self.receive_cipher_pool[0].lock().unwrap() + } +} + +/// Helper code for parsing variable length ALICE_NOISE_XK_ACK during negotiation. +struct PktReader<'a>(&'a mut [u8], usize); + +impl<'a> PktReader<'a> { + fn read_u16(&mut self) -> Result { + let tmp = self.1 + 2; + if tmp <= self.0.len() { + let n = u16::from_le_bytes(self.0[self.1..tmp].try_into().unwrap()); + self.1 = tmp; + Ok(n) + } else { + Err(Error::InvalidPacket) + } + } + + fn read_decrypt_auth<'b>(&'b mut self, l: usize, k: Secret, gcm_aad: &[u8], nonce: &[u8]) -> Result<&'b [u8], Error> { + let mut tmp = self.1 + l; + if (tmp + AES_GCM_TAG_SIZE) <= self.0.len() { + let mut gcm = AesGcm::new(&k); + gcm.reset_init_gcm(nonce); + gcm.aad(gcm_aad); + gcm.crypt_in_place(&mut self.0[self.1..tmp]); + let s = &self.0[self.1..tmp]; + self.1 = tmp; + tmp += AES_GCM_TAG_SIZE; + if !gcm.finish_decrypt(&self.0[self.1..tmp]) { + Err(Error::FailedAuthentication) + } else { + self.1 = tmp; + Ok(s) + } + } else { + Err(Error::InvalidPacket) + } + } +} + +/// Helper function to append to a slice when we still want to be able to look back at it. +fn append_to_slice(s: &mut [u8], p: usize, d: &[u8]) -> Result { + let tmp = p + d.len(); + if tmp <= s.len() { + s[p..tmp].copy_from_slice(d); + Ok(tmp) + } else { + Err(Error::UnexpectedBufferOverrun) + } +} + +/// MixHash to update 'h' during negotiation. +fn mix_hash(h: &[u8; SHA384_HASH_SIZE], m: &[u8]) -> [u8; SHA384_HASH_SIZE] { + let mut hasher = SHA384::new(); + hasher.update(h); + hasher.update(m); + hasher.finish() +} + +/// HMAC-SHA512 key derivation based on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 7) +/// Cryptographically this isn't meaningfully different from HMAC(key, [label]) but this is how NIST rolls. +fn kbkdf(key: &[u8]) -> Secret { + //These are the values we have assigned to the 5 variables involved in https://csrc.nist.gov/publications/detail/sp/800-108/final: + // K_in = key, i = 0x01, Label = 'Z'||'T'||LABEL, Context = 0x00, L = (OUTPUT_BYTES * 8) + hmac_sha512_secretZ( + key, + &[ + 1, + b'Z', + b'T', + LABEL, + 0x00, + 0, + (((OUTPUT_BYTES * 8) >> 8) & 0xff) as u8, + ((OUTPUT_BYTES * 8) & 0xff) as u8, + ], + ) +} + +fn prng32(mut x: u32) -> u32 { + // based on lowbias32 from https://nullprogram.com/blog/2018/07/31/ + x = x.wrapping_add(1); // don't get stuck on 0 + x ^= x.wrapping_shr(16); + x = x.wrapping_mul(0x7feb352d); + x ^= x.wrapping_shr(15); + x = x.wrapping_mul(0x846ca68b); + x ^= x.wrapping_shr(16); + x +} diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 0b0905a0c..e4f6815d8 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -12,7 +12,7 @@ use std::collections::{HashMap, HashSet}; use std::num::NonZeroU64; use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, RwLock, Weak}; +use std::sync::{Arc, Mutex, RwLock, Weak, MutexGuard}; use zerotier_crypto::aes::{Aes, AesGcm}; use zerotier_crypto::hash::{hmac_sha512, SHA384, SHA384_HASH_SIZE}; @@ -83,7 +83,7 @@ pub struct Session { psk: Secret, send_counter: AtomicU64, receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO], - header_protection_cipher: Aes, + header_protection_cipher: Mutex, state: RwLock, defrag: [Mutex>; COUNTER_WINDOW_MAX_OOO], } @@ -216,7 +216,7 @@ impl Context { state.remote_session_id, 0, 2, - Some(&session.header_protection_cipher), + Some(&mut *session.get_header_cipher()), ); } false @@ -314,7 +314,7 @@ impl Context { psk, send_counter: AtomicU64::new(3), // 1 and 2 are reserved for init and final ack receive_window: std::array::from_fn(|_| AtomicU64::new(0)), - header_protection_cipher: Aes::new(&header_protection_key), + header_protection_cipher: Mutex::new(Aes::new(&header_protection_key)), state: RwLock::new(State { remote_session_id: None, keys: [None, None], @@ -357,7 +357,7 @@ impl Context { } // Encrypt and add authentication tag. - let gcm = AesGcm::new( + let mut gcm = AesGcm::new( &kbkdf::(noise_es.as_bytes()) ); gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_INIT, 1)); @@ -443,7 +443,7 @@ impl Context { debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id)); session - .header_protection_cipher + .get_header_cipher() .decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); @@ -604,7 +604,7 @@ impl Context { if let Some(session) = session { let state = session.state.read().unwrap(); if let Some(key) = state.keys[key_index].as_ref() { - let c = key.get_receive_cipher(); + let mut c = key.get_receive_cipher(); c.reset_init_gcm(&incoming_message_nonce); let mut data_len = 0; @@ -730,7 +730,7 @@ impl Context { let noise_h_next = mix_hash(&noise_h, &pkt_assembled[HEADER_SIZE..]); // Decrypt and authenticate init packet, also proving that caller knows our static identity. - let gcm = AesGcm::new( + let mut gcm = AesGcm::new( &kbkdf::(noise_es.as_bytes()) ); gcm.reset_init_gcm(&incoming_message_nonce); @@ -780,7 +780,7 @@ impl Context { ack.bob_hk_ciphertext = bob_hk_ciphertext; // Encrypt main section of reply and attach tag. - let gcm = AesGcm::new( + let mut gcm = AesGcm::new( &kbkdf::(noise_es_ee.as_bytes()) ); gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_BOB_NOISE_XK_ACK, 1)); @@ -834,7 +834,7 @@ impl Context { Some(alice_session_id), 0, 1, - Some(&Aes::new(&header_protection_key)), + Some(&mut Aes::new(&header_protection_key)), )?; return Ok(ReceiveResult::Ok(session)); @@ -882,7 +882,7 @@ impl Context { let noise_h_next = mix_hash(&mix_hash(&outgoing_offer.noise_h, bob_noise_e.as_bytes()), &pkt_assembled[HEADER_SIZE..]); // Decrypt and authenticate Bob's reply. - let gcm = AesGcm::new( + let mut gcm = AesGcm::new( &kbkdf::(noise_es_ee.as_bytes()) ); gcm.reset_init_gcm(&incoming_message_nonce); @@ -927,7 +927,7 @@ impl Context { let mut enc_start = reply_len; reply_len = append_to_slice(&mut reply_buffer, reply_len, alice_s_public_blob)?; - let gcm = AesGcm::new( + let mut gcm = AesGcm::new( &kbkdf::(&hmac_sha512( noise_es_ee.as_bytes(), hk.as_bytes(), @@ -948,7 +948,7 @@ impl Context { enc_start = reply_len; reply_len = append_to_slice(&mut reply_buffer, reply_len, metadata)?; - let gcm = AesGcm::new( + let mut gcm = AesGcm::new( &kbkdf::(noise_es_ee_se_hk_psk.as_bytes()) ); gcm.reset_init_gcm(&reply_message_nonce); @@ -979,7 +979,7 @@ impl Context { Some(bob_session_id), 0, 2, - Some(&session.header_protection_cipher), + Some(&mut *session.get_header_cipher()), )?; return Ok(ReceiveResult::Ok(Some(session))); @@ -1076,7 +1076,7 @@ impl Context { psk, send_counter: AtomicU64::new(2), // 1 was already used during negotiation receive_window: std::array::from_fn(|_| AtomicU64::new(0)), - header_protection_cipher: Aes::new(&incoming.header_protection_key), + header_protection_cipher: Mutex::new(Aes::new(&incoming.header_protection_key)), state: RwLock::new(State { remote_session_id: Some(incoming.alice_session_id), keys: [ @@ -1126,7 +1126,7 @@ impl Context { // Only the current "Alice" accepts rekeys initiated by the current "Bob." These roles // flip with each rekey event. if !key.bob { - let c = key.get_receive_cipher(); + let mut c = key.get_receive_cipher(); c.reset_init_gcm(&incoming_message_nonce); c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]); let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]); @@ -1158,14 +1158,14 @@ impl Context { counter, ); - let c = key.get_send_cipher(counter)?; + let mut c = key.get_send_cipher(counter)?; c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter)); c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]); reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt()); key.return_send_cipher(c); session - .header_protection_cipher + .get_header_cipher() .encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(Some(&session), &mut reply_buf); @@ -1213,7 +1213,7 @@ impl Context { if let Some(key) = state.keys[key_index].as_ref() { // Only the current "Bob" initiates rekeys and expects this ACK. if key.bob { - let c = key.get_receive_cipher(); + let mut c = key.get_receive_cipher(); c.reset_init_gcm(&incoming_message_nonce); c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]); let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]); @@ -1283,7 +1283,7 @@ impl Session { if let Some(session_key) = state.keys[state.current_key].as_ref() { let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); - let c = session_key.get_send_cipher(counter)?; + let mut c = session_key.get_send_cipher(counter)?; c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize; @@ -1314,7 +1314,7 @@ impl Session { fragment_size = tagged_fragment_size; } - self.header_protection_cipher + self.get_header_cipher() .encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(&mut mtu_sized_buffer[..fragment_size]); } @@ -1335,12 +1335,12 @@ impl Session { if let Some(session_key) = state.keys[state.current_key].as_ref() { let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; - let c = session_key.get_send_cipher(counter)?; + let mut c = session_key.get_send_cipher(counter)?; c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); session_key.return_send_cipher(c); set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter); - self.header_protection_cipher + self.get_header_cipher() .encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(&mut nop); } @@ -1381,7 +1381,7 @@ impl Session { if let Some(remote_session_id) = state.remote_session_id { if let Some(key) = state.keys[state.current_key].as_ref() { if let Some(counter) = self.get_next_outgoing_counter() { - if let Ok(gcm) = key.get_send_cipher(counter.get()) { + if let Ok(mut gcm) = key.get_send_cipher(counter.get()) { gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_INIT, counter.get())); gcm.crypt_in_place(&mut rekey_buf[RekeyInit::ENC_START..RekeyInit::AUTH_START]); rekey_buf[RekeyInit::AUTH_START..].copy_from_slice(&gcm.finish_encrypt()); @@ -1400,7 +1400,7 @@ impl Session { drop(state); - self.header_protection_cipher + self.get_header_cipher() .encrypt_block_in_place(&mut rekey_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(&mut rekey_buf); @@ -1431,6 +1431,12 @@ impl Session { let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel); prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD } + + + #[inline(always)] + fn get_header_cipher<'a>(&'a self) -> MutexGuard<'a, Aes>{ + self.header_protection_cipher.lock().unwrap() + } } #[inline(always)] @@ -1493,7 +1499,7 @@ fn send_with_fragmentation( remote_session_id: Option, key_index: usize, counter: u64, - header_protect_cipher: Option<&Aes>, + mut header_protect_cipher: Option<&mut Aes>, ) -> Result<(), Error> { let packet_len = packet.len(); let recipient_session_id = remote_session_id.map_or(SessionId::NONE, |s| u64::from(s)); @@ -1511,7 +1517,7 @@ fn send_with_fragmentation( key_index, counter, ); - if let Some(hcc) = header_protect_cipher { + if let Some(hcc) = header_protect_cipher.take() { hcc.encrypt_block_in_place(&mut fragment[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); } send(fragment); @@ -1634,7 +1640,7 @@ impl<'a> PktReader<'a> { fn read_decrypt_auth<'b>(&'b mut self, l: usize, k: Secret, gcm_aad: &[u8], nonce: &[u8]) -> Result<&'b [u8], Error> { let mut tmp = self.1 + l; if (tmp + AES_GCM_TAG_SIZE) <= self.0.len() { - let gcm = AesGcm::new(&k); + let mut gcm = AesGcm::new(&k); gcm.reset_init_gcm(nonce); gcm.aad(gcm_aad); gcm.crypt_in_place(&mut self.0[self.1..tmp]);