ZeroTierOne/zssp/src/zssp.rs
2023-03-10 07:54:51 -05:00

1697 lines
84 KiB
Rust

/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
*
* (c) ZeroTier, Inc.
* https://www.zerotier.com/
*/
// ZSSP: ZeroTier Secure Session Protocol
// FIPS compliant Noise_XK with Jedi powers (Kyber1024) and built-in attack-resistant large payload (fragmentation) support.
use std::collections::{HashMap, HashSet};
use std::num::NonZeroU64;
use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
use zerotier_crypto::aes::{Aes, AesGcm};
use zerotier_crypto::hash::{hmac_sha512_secret, SHA384, SHA384_HASH_SIZE};
use zerotier_crypto::p384::{P384KeyPair, P384PublicKey, P384_ECDH_SHARED_SECRET_SIZE};
use zerotier_crypto::secret::Secret;
use zerotier_crypto::{random, secure_eq};
use pqc_kyber::{KYBER_SECRETKEYBYTES, KYBER_SSBYTES};
use crate::applicationlayer::ApplicationLayer;
use crate::error::Error;
use crate::fragged::Fragged;
use crate::proto::*;
use crate::sessionid::SessionId;
/// Session context for local application.
///
/// Each application using ZSSP must create an instance of this to own sessions and
/// defragment incoming packets that are not yet associated with a session.
pub struct Context<Application: ApplicationLayer> {
max_incomplete_session_queue_size: usize,
defrag: Mutex<
HashMap<
(Application::PhysicalPath, u64),
Arc<(
Mutex<Fragged<Application::IncomingPacketBuffer, MAX_NOISE_HANDSHAKE_FRAGMENTS>>,
i64, // creation timestamp
)>,
>,
>,
sessions: RwLock<SessionsById<Application>>,
}
/// Lookup maps for sessions within a session context.
struct SessionsById<Application: ApplicationLayer> {
// Active sessions, automatically closed if the application no longer holds their Arc<>.
active: HashMap<SessionId, Weak<Session<Application>>>,
// Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout.
incoming: HashMap<SessionId, Arc<BobIncomingIncompleteSessionState>>,
}
/// Result generated by the context packet receive function, with possible payloads.
pub enum ReceiveResult<'b, Application: ApplicationLayer> {
/// Packet was valid, but no action needs to be taken and no payload was delivered.
Ok(Option<Arc<Session<Application>>>),
/// Packet was valid and a data payload was decoded and authenticated.
OkData(Arc<Session<Application>>, &'b mut [u8]),
/// Packet was valid and a new session was created, with optional attached meta-data.
OkNewSession(Arc<Session<Application>>, Option<&'b mut [u8]>),
/// Packet appears valid but was rejected by the application layer, e.g. a rejected new session attempt.
Rejected,
}
/// ZeroTier Secure Session Protocol (ZSSP) Session
///
/// A FIPS/NIST compliant variant of Noise_XK with hybrid Kyber1024 PQ data forward secrecy.
pub struct Session<Application: ApplicationLayer> {
/// This side's locally unique session ID
pub id: SessionId,
/// An arbitrary application defined object associated with each session
pub application_data: Application::Data,
psk: Secret<BASE_KEY_SIZE>,
send_counter: AtomicU64,
receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO],
header_protection_cipher: Aes,
state: RwLock<State>,
defrag: [Mutex<Fragged<Application::IncomingPacketBuffer, MAX_FRAGMENTS>>; COUNTER_WINDOW_MAX_OOO],
}
/// Most of the mutable parts of a session state.
struct State {
remote_session_id: Option<SessionId>,
keys: [Option<SessionKey>; 2],
current_key: usize,
current_offer: Offer,
}
struct BobIncomingIncompleteSessionState {
timestamp: i64,
alice_session_id: SessionId,
bob_session_id: SessionId,
noise_h: [u8; SHA384_HASH_SIZE],
noise_es_ee: Secret<BASE_KEY_SIZE>,
hk: Secret<KYBER_SSBYTES>,
header_protection_key: Secret<AES_HEADER_PROTECTION_KEY_SIZE>,
bob_noise_e_secret: P384KeyPair,
}
struct AliceOutgoingIncompleteSessionState {
last_retry_time: AtomicI64,
noise_h: [u8; SHA384_HASH_SIZE],
noise_es: Secret<P384_ECDH_SHARED_SECRET_SIZE>,
alice_noise_e_secret: P384KeyPair,
alice_hk_secret: Secret<KYBER_SECRETKEYBYTES>,
metadata: Option<Vec<u8>>,
init_packet: [u8; AliceNoiseXKInit::SIZE],
}
struct OutgoingSessionAck {
last_retry_time: AtomicI64,
ack: [u8; MAX_NOISE_HANDSHAKE_SIZE],
ack_size: usize,
}
enum Offer {
None,
NoiseXKInit(Box<AliceOutgoingIncompleteSessionState>),
NoiseXKAck(Box<OutgoingSessionAck>),
RekeyInit(P384KeyPair, i64),
}
const AES_POOL_SIZE: usize = 4;
struct SessionKey {
ratchet_key: Secret<BASE_KEY_SIZE>, // Key used in derivation of the next session key
//receive_key: Secret<AES_256_KEY_SIZE>, // Receive side AES-GCM key
//send_key: Secret<AES_256_KEY_SIZE>, // Send side AES-GCM key
receive_cipher_pool: [Mutex<AesGcm<false>>; AES_POOL_SIZE], // Pool of reusable sending ciphers
send_cipher_pool: [Mutex<AesGcm<true>>; AES_POOL_SIZE], // Pool of reusable receiving ciphers
rekey_at_time: i64, // Rekey at or after this time (ticks)
created_at_counter: u64, // Counter at which session was created
rekey_at_counter: u64, // Rekey at or after this counter
expire_at_counter: u64, // Hard error when this counter value is reached or exceeded
ratchet_count: u64, // Number of rekey events
bob: bool, // Was this side "Bob" in this exchange?
confirmed: bool, // Is this key confirmed by the other side?
}
impl<Application: ApplicationLayer> Context<Application> {
/// Create a new session context.
///
/// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase
pub fn new(max_incomplete_session_queue_size: usize) -> Self {
zerotier_crypto::init();
Self {
max_incomplete_session_queue_size,
defrag: Mutex::new(HashMap::new()),
sessions: RwLock::new(SessionsById {
active: HashMap::with_capacity(64),
incoming: HashMap::with_capacity(64),
}),
}
}
/// Perform periodic background service and cleanup tasks.
///
/// This returns the number of milliseconds until it should be called again.
///
/// * `send` - Function to send packets to remote sessions
/// * `mtu` - Physical MTU
/// * `current_time` - Current monotonic time in milliseconds
pub fn service<SendFunction: FnMut(&Arc<Session<Application>>, &mut [u8])>(&self, mut send: SendFunction, mtu: usize, current_time: i64) -> i64 {
let mut dead_active = Vec::new();
let mut dead_pending = Vec::new();
let retry_cutoff = current_time - Application::RETRY_INTERVAL;
let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
// Scan sessions in read lock mode, then lock more briefly in write mode to delete any dead entries that we found.
{
let sessions = self.sessions.read().unwrap();
for (id, s) in sessions.active.iter() {
if let Some(session) = s.upgrade() {
let state = session.state.read().unwrap();
if match &state.current_offer {
Offer::None => true,
Offer::NoiseXKInit(offer) => {
// If there's an outstanding attempt to open a session, retransmit this periodically
// in case the initial packet doesn't make it. Note that we currently don't have
// retransmission for the intermediate steps, so a new session may still fail if the
// packet loss rate is huge. The application layer has its own logic to keep trying
// under those conditions.
if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
offer.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
|b| send(&session, b),
&mut (offer.init_packet.clone()),
mtu,
PACKET_TYPE_ALICE_NOISE_XK_INIT,
None,
0,
1,
None,
);
}
false
}
Offer::NoiseXKAck(ack) => {
// We also keep retransmitting the final ACK until we get a valid DATA or NOP packet
// from Bob, otherwise we could get a half open session.
if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
ack.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
|b| send(&session, b),
&mut (ack.ack.clone())[..ack.ack_size],
mtu,
PACKET_TYPE_ALICE_NOISE_XK_ACK,
state.remote_session_id,
0,
2,
Some(&session.header_protection_cipher),
);
}
false
}
Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff,
} {
// Check whether we need to rekey if there is no pending offer or if the last rekey
// offer was before retry_cutoff (checked in the 'match' above).
if let Some(key) = state.keys[state.current_key].as_ref() {
if key.bob && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter)
{
drop(state);
session.initiate_rekey(|b| send(&session, b), current_time);
}
}
}
} else {
dead_active.push(*id);
}
}
for (id, incoming) in sessions.incoming.iter() {
if incoming.timestamp <= negotiation_timeout_cutoff {
dead_pending.push(*id);
}
}
}
if !dead_active.is_empty() || !dead_pending.is_empty() {
let mut sessions = self.sessions.write().unwrap();
for id in dead_active.iter() {
sessions.active.remove(id);
}
for id in dead_pending.iter() {
sessions.incoming.remove(id);
}
}
// Delete any expired defragmentation queue items not associated with a session.
self.defrag.lock().unwrap().retain(|_, fragged| fragged.1 > negotiation_timeout_cutoff);
Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS.min(Application::RETRY_INTERVAL)
}
/// Create a new session and send initial packet(s) to other side.
///
/// This will return Error::DataTooLarge if the combined size of the metadata and the local static public
/// blob (as retrieved from the application layer) exceed MAX_INIT_PAYLOAD_SIZE.
///
/// * `app` - Application layer instance
/// * `send` - User-supplied packet sending function
/// * `mtu` - Physical MTU for calls to send()
/// * `remote_s_public_blob` - Remote side's opaque static public blob (which must contain remote_s_public_p384)
/// * `remote_s_public_p384` - Remote side's static public NIST P-384 key
/// * `psk` - Pre-shared key (use all zero if none)
/// * `metadata` - Optional metadata to be included in initial handshake
/// * `application_data` - Arbitrary opaque data to include with session object
/// * `current_time` - Current monotonic time in milliseconds
pub fn open<SendFunction: FnMut(&mut [u8])>(
&self,
app: &Application,
mut send: SendFunction,
mtu: usize,
remote_s_public_blob: &[u8],
remote_s_public_p384: &P384PublicKey,
psk: Secret<BASE_KEY_SIZE>,
metadata: Option<Vec<u8>>,
application_data: Application::Data,
current_time: i64,
) -> Result<Arc<Session<Application>>, Error> {
if (metadata.as_ref().map(|md| md.len()).unwrap_or(0) + app.get_local_s_public_blob().len()) > MAX_INIT_PAYLOAD_SIZE {
return Err(Error::DataTooLarge);
}
let alice_noise_e_secret = P384KeyPair::generate();
let alice_noise_e = alice_noise_e_secret.public_key_bytes().clone();
let noise_es = alice_noise_e_secret.agree(&remote_s_public_p384).ok_or(Error::InvalidParameter)?;
let alice_hk_secret = pqc_kyber::keypair(&mut random::SecureRandom::default());
let header_protection_key: Secret<AES_HEADER_PROTECTION_KEY_SIZE> = Secret(random::get_bytes_secure());
let (local_session_id, session) = {
let mut sessions = self.sessions.write().unwrap();
let mut local_session_id;
loop {
local_session_id = SessionId::random();
if !sessions.active.contains_key(&local_session_id) && !sessions.incoming.contains_key(&local_session_id) {
break;
}
}
let session = Arc::new(Session {
id: local_session_id,
application_data,
psk,
send_counter: AtomicU64::new(3), // 1 and 2 are reserved for init and final ack
receive_window: std::array::from_fn(|_| AtomicU64::new(0)),
header_protection_cipher: Aes::new(&header_protection_key),
state: RwLock::new(State {
remote_session_id: None,
keys: [None, None],
current_key: 0,
current_offer: Offer::NoiseXKInit(Box::new(AliceOutgoingIncompleteSessionState {
last_retry_time: AtomicI64::new(current_time),
noise_h: mix_hash(&mix_hash(&INITIAL_H, remote_s_public_blob), &alice_noise_e),
noise_es: noise_es.clone(),
alice_noise_e_secret,
alice_hk_secret: Secret(alice_hk_secret.secret),
metadata,
init_packet: [0u8; AliceNoiseXKInit::SIZE],
})),
}),
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
});
sessions.active.insert(local_session_id, Arc::downgrade(&session));
(local_session_id, session)
};
{
let mut state = session.state.write().unwrap();
let offer = if let Offer::NoiseXKInit(offer) = &mut state.current_offer {
offer
} else {
panic!(); // should be impossible as this is what we initialized with
};
// Create Alice's initial outgoing state message.
let init_packet = &mut offer.init_packet;
{
let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(init_packet).unwrap();
init.session_protocol_version = SESSION_PROTOCOL_VERSION;
init.alice_noise_e = alice_noise_e;
init.alice_session_id = *local_session_id.as_bytes();
init.alice_hk_public = alice_hk_secret.public;
init.header_protection_key = header_protection_key.0;
}
// Encrypt and add authentication tag.
let mut gcm = AesGcm::new(&kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ES>(noise_es.as_bytes()));
gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_INIT, 1));
gcm.aad(&offer.noise_h);
gcm.crypt_in_place(&mut init_packet[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]);
init_packet[AliceNoiseXKInit::AUTH_START..AliceNoiseXKInit::AUTH_START + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt());
// Update ongoing state hash with Alice's outgoing init ciphertext.
offer.noise_h = mix_hash(&offer.noise_h, &init_packet[HEADER_SIZE..]);
send_with_fragmentation(
&mut send,
&mut (init_packet.clone()),
mtu,
PACKET_TYPE_ALICE_NOISE_XK_INIT,
None,
0,
1,
None,
)?;
}
return Ok(session);
}
/// Receive, authenticate, decrypt, and process a physical wire packet.
///
/// The send function may be called one or more times to send packets. If the packet is associated
/// wtth an active session this session is supplied, otherwise this parameter is None and the packet
/// should be a reply to the current incoming packet. The size of packets to be sent will not exceed
/// the supplied mtu.
///
/// The check_allow_incoming_session function is called when an initial Noise_XK init message is
/// received. This is before anything is known about the caller. A return value of true proceeds
/// with negotiation. False drops the packet.
///
/// The check_accept_session function is called at the end of negotiation for an incoming session
/// with the caller's static public blob. It must return the P-384 static public key extracted from
/// the supplied blob, a PSK (or all zeroes if none), and application data to associate with the new
/// session. A return of None rejects and abandons the session.
///
/// Note that if check_accept_session accepts and returns Some() the session could still fail with
/// receive() returning an error. A Some() return from check_accept_sesion doesn't guarantee
/// successful new session init, only that the application has authorized it.
///
/// Finally, note that the check_X() functions can end up getting called more than once for a given
/// incoming attempt from a given node if the network quality is poor. That's because the caller may
/// have to retransmit init packets causing repetition of parts of the exchange.
///
/// * `app` - Interface to application using ZSSP
/// * `check_allow_incoming_session` - Function to call to check whether an unidentified new session should be accepted
/// * `check_accept_session` - Function to accept sessions after final negotiation, or returns None if rejected
/// * `send` - Function to call to send packets
/// * `data_buf` - Buffer to receive decrypted and authenticated object data (an error is returned if too small)
/// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership)
/// * `mtu` - Physical wire MTU for sending packets
/// * `current_time` - Current monotonic time in milliseconds
pub fn receive<
'b,
SendFunction: FnMut(Option<&Arc<Session<Application>>>, &mut [u8]),
CheckAllowIncomingSession: FnMut() -> bool,
CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>,
>(
&self,
app: &Application,
mut check_allow_incoming_session: CheckAllowIncomingSession,
mut check_accept_session: CheckAcceptSession,
mut send: SendFunction,
source: &Application::PhysicalPath,
data_buf: &'b mut [u8],
mut incoming_packet_buf: Application::IncomingPacketBuffer,
mtu: usize,
current_time: i64,
) -> Result<ReceiveResult<'b, Application>, Error> {
let incoming_packet: &mut [u8] = incoming_packet_buf.as_mut();
if incoming_packet.len() < MIN_PACKET_SIZE {
return Err(Error::InvalidPacket);
}
let mut incoming = None;
if let Some(local_session_id) = SessionId::new_from_u64_le(u64::from_le_bytes(incoming_packet[0..8].try_into().unwrap())) {
if let Some(session) = self.sessions.read().unwrap().active.get(&local_session_id).and_then(|s| s.upgrade()) {
debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id));
session
.header_protection_cipher
.decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet);
if session.check_receive_window(incoming_counter) {
if fragment_count > 1 {
let mut fragged = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO].lock().unwrap();
if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) {
drop(fragged);
return self.process_complete_incoming_packet(
app,
&mut send,
&mut check_allow_incoming_session,
&mut check_accept_session,
data_buf,
incoming_counter,
assembled_packet.as_ref(),
packet_type,
Some(session),
None,
key_index,
mtu,
current_time,
);
} else {
drop(fragged);
return Ok(ReceiveResult::Ok(Some(session)));
}
} else {
return self.process_complete_incoming_packet(
app,
&mut send,
&mut check_allow_incoming_session,
&mut check_accept_session,
data_buf,
incoming_counter,
&[incoming_packet_buf],
packet_type,
Some(session),
None,
key_index,
mtu,
current_time,
);
}
} else {
return Err(Error::OutOfSequence);
}
} else {
if let Some(i) = self.sessions.read().unwrap().incoming.get(&local_session_id).cloned() {
Aes::new(&i.header_protection_key)
.decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
incoming = Some(i);
} else {
return Err(Error::UnknownLocalSessionId);
}
}
}
// If we make it here the packet is not associated with a session or is associated with an
// incoming session (Noise_XK mid-negotiation).
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet);
if fragment_count > 1 {
let f = {
let mut defrag = self.defrag.lock().unwrap();
let f = defrag
.entry((source.clone(), incoming_counter))
.or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time)))
.clone();
// Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions.
if defrag.len() >= self.max_incomplete_session_queue_size {
// First, drop all entries that are timed out or whose physical source duplicates another entry.
let mut sources = HashSet::with_capacity(defrag.len());
let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
defrag.retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f));
// Then, if we are still at or over the limit, drop 10% of remaining entries at random.
if defrag.len() >= self.max_incomplete_session_queue_size {
let mut rn = random::next_u32_secure();
defrag.retain(|_, fragged| {
rn = prng32(rn);
rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f)
});
}
}
f
};
let mut fragged = f.0.lock().unwrap();
if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) {
self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter));
return self.process_complete_incoming_packet(
app,
&mut send,
&mut check_allow_incoming_session,
&mut check_accept_session,
data_buf,
incoming_counter,
assembled_packet.as_ref(),
packet_type,
None,
incoming,
key_index,
mtu,
current_time,
);
}
} else {
return self.process_complete_incoming_packet(
app,
&mut send,
&mut check_allow_incoming_session,
&mut check_accept_session,
data_buf,
incoming_counter,
&[incoming_packet_buf],
packet_type,
None,
incoming,
key_index,
mtu,
current_time,
);
}
return Ok(ReceiveResult::Ok(None));
}
fn process_complete_incoming_packet<
'b,
SendFunction: FnMut(Option<&Arc<Session<Application>>>, &mut [u8]),
CheckAllowIncomingSession: FnMut() -> bool,
CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>,
>(
&self,
app: &Application,
send: &mut SendFunction,
check_allow_incoming_session: &mut CheckAllowIncomingSession,
check_accept_session: &mut CheckAcceptSession,
data_buf: &'b mut [u8],
incoming_counter: u64,
fragments: &[Application::IncomingPacketBuffer],
packet_type: u8,
session: Option<Arc<Session<Application>>>,
incoming: Option<Arc<BobIncomingIncompleteSessionState>>,
key_index: usize,
mtu: usize,
current_time: i64,
) -> Result<ReceiveResult<'b, Application>, Error> {
debug_assert!(fragments.len() >= 1);
// Generate incoming message nonce for decryption and authentication.
let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter);
if packet_type <= PACKET_TYPE_DATA {
if let Some(session) = session {
let state = session.state.read().unwrap();
if let Some(key) = state.keys[key_index].as_ref() {
let mut c = key.get_receive_cipher(incoming_counter);
c.reset_init_gcm(&incoming_message_nonce);
let mut data_len = 0;
// Decrypt fragments 0..N-1 where N is the number of fragments.
for f in fragments[..(fragments.len() - 1)].iter() {
let f: &[u8] = f.as_ref();
debug_assert!(f.len() >= HEADER_SIZE);
let current_frag_data_start = data_len;
data_len += f.len() - HEADER_SIZE;
if data_len > data_buf.len() {
drop(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
}
// Decrypt final fragment (or only fragment if not fragmented)
let current_frag_data_start = data_len;
let last_fragment = fragments.last().unwrap().as_ref();
if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
return Err(Error::InvalidPacket);
}
data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
if data_len > data_buf.len() {
drop(c);
return Err(Error::DataBufferTooSmall);
}
let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE;
c.crypt(&last_fragment[HEADER_SIZE..payload_end], &mut data_buf[current_frag_data_start..data_len]);
let aead_authentication_ok = c.finish_decrypt(&last_fragment[payload_end..]);
drop(c);
if aead_authentication_ok {
// Packet fully authenticated
if session.update_receive_window(incoming_counter) {
// Update the current key to point to this key if it's newer, since having received
// a packet encrypted with it proves that the other side has successfully derived it
// as well.
if state.current_key == key_index && key.confirmed {
drop(state);
} else {
let current_key_created_at_counter = key.created_at_counter;
drop(state);
let mut state = session.state.write().unwrap();
if state.current_key != key_index {
if let Some(other_session_key) = state.keys[state.current_key].as_ref() {
if other_session_key.created_at_counter < current_key_created_at_counter {
state.current_key = key_index;
}
} else {
state.current_key = key_index;
}
}
state.keys[key_index].as_mut().unwrap().confirmed = true;
// If we got a valid data packet from Bob, this means we can cancel any offers
// that are still oustanding for initialization.
match &state.current_offer {
Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => {
state.current_offer = Offer::None;
}
_ => {}
}
}
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len]));
} else {
return Ok(ReceiveResult::Ok(Some(session)));
}
} else {
return Err(Error::OutOfSequence);
}
}
}
return Err(Error::FailedAuthentication);
} else {
return Err(Error::UnknownLocalSessionId);
}
} else {
// For Noise setup/KEX packets go ahead and pre-assemble all fragments to simplify the code below.
let mut pkt_assembly_buffer = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
let pkt_assembled_size = assemble_fragments_into::<Application>(fragments, &mut pkt_assembly_buffer)?;
if pkt_assembled_size < MIN_PACKET_SIZE {
return Err(Error::InvalidPacket);
}
let pkt_assembled = &mut pkt_assembly_buffer[..pkt_assembled_size];
if pkt_assembled[HEADER_SIZE] != SESSION_PROTOCOL_VERSION {
return Err(Error::UnknownProtocolVersion);
}
match packet_type {
PACKET_TYPE_ALICE_NOISE_XK_INIT => {
// Alice (remote) --> Bob (local)
/*
* This is the first message Bob receives from Alice, the initiator. It contains
* Alice's ephemeral keys but not her identity. Alice will not reveal her identity
* until forward secrecy is established and she's authenticated Bob.
*
* Bob authenticates the message and confirms that Alice indeed knows Bob's
* identity, then responds with his ephemeral keys.
*/
if incoming_counter != 1 || session.is_some() || incoming.is_some() {
return Err(Error::OutOfSequence);
}
if pkt_assembled.len() != AliceNoiseXKInit::SIZE {
return Err(Error::InvalidPacket);
}
// Otherwise parse the packet, authenticate, generate keys, etc. and record state in an
// incoming state object until this phase of the negotiation is done.
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_noise_e = P384PublicKey::from_bytes(&pkt.alice_noise_e).ok_or(Error::FailedAuthentication)?;
let noise_es = app.get_local_s_keypair().agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?;
let noise_h = mix_hash(&mix_hash(&INITIAL_H, app.get_local_s_public_blob()), alice_noise_e.as_bytes());
let noise_h_next = mix_hash(&noise_h, &pkt_assembled[HEADER_SIZE..]);
// Decrypt and authenticate init packet, also proving that caller knows our static identity.
let mut gcm = AesGcm::new(&kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ES>(noise_es.as_bytes()));
gcm.reset_init_gcm(&incoming_message_nonce);
gcm.aad(&noise_h);
gcm.crypt_in_place(&mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]);
if !gcm.finish_decrypt(&pkt_assembled[AliceNoiseXKInit::AUTH_START..AliceNoiseXKInit::AUTH_START + AES_GCM_TAG_SIZE]) {
return Err(Error::FailedAuthentication);
}
// Let application filter incoming connection attempt by whatever criteria it wants.
if !check_allow_incoming_session() {
return Ok(ReceiveResult::Rejected);
}
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?;
let header_protection_key = Secret(pkt.header_protection_key);
// Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create
// a Kyber ciphertext to send back to Alice.
let bob_noise_e_secret = P384KeyPair::generate();
let bob_noise_e = bob_noise_e_secret.public_key_bytes().clone();
let noise_es_ee = hmac_sha512_secret(
noise_es.as_bytes(),
bob_noise_e_secret.agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?.as_bytes(),
);
let (bob_hk_ciphertext, hk) = pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default())
.map_err(|_| Error::FailedAuthentication)
.map(|(ct, hk)| (ct, Secret(hk)))?;
let mut sessions = self.sessions.write().unwrap();
let mut bob_session_id;
loop {
bob_session_id = SessionId::random();
if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) {
break;
}
}
// Create Bob's ephemeral counter-offer reply.
let mut ack_packet = [0u8; BobNoiseXKAck::SIZE];
let ack: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut ack_packet)?;
ack.session_protocol_version = SESSION_PROTOCOL_VERSION;
ack.bob_noise_e = bob_noise_e;
ack.bob_session_id = *bob_session_id.as_bytes();
ack.bob_hk_ciphertext = bob_hk_ciphertext;
// Encrypt main section of reply and attach tag.
let mut gcm = AesGcm::new(&kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ES_EE>(noise_es_ee.as_bytes()));
gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_BOB_NOISE_XK_ACK, 1));
gcm.aad(&noise_h_next);
gcm.crypt_in_place(&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]);
ack_packet[BobNoiseXKAck::AUTH_START..BobNoiseXKAck::AUTH_START + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt());
// If this queue is too big, we remove the latest entry and replace it. The latest
// is used because under flood conditions this is most likely to be another bogus
// entry. If we find one that is actually timed out, that one is replaced instead.
if sessions.incoming.len() >= self.max_incomplete_session_queue_size {
let mut newest = i64::MIN;
let mut replace_id = None;
let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
for (id, s) in sessions.incoming.iter() {
if s.timestamp <= cutoff_time {
replace_id = Some(*id);
break;
} else if s.timestamp >= newest {
newest = s.timestamp;
replace_id = Some(*id);
}
}
let _ = sessions.incoming.remove(replace_id.as_ref().unwrap());
}
// Reserve session ID on this side and record incomplete session state.
sessions.incoming.insert(
bob_session_id,
Arc::new(BobIncomingIncompleteSessionState {
timestamp: current_time,
alice_session_id,
bob_session_id,
noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]),
noise_es_ee: noise_es_ee.clone(),
hk,
bob_noise_e_secret,
header_protection_key: Secret(pkt.header_protection_key),
}),
);
debug_assert!(!sessions.active.contains_key(&bob_session_id));
// Release lock
drop(sessions);
send_with_fragmentation(
|b| send(None, b),
&mut ack_packet,
mtu,
PACKET_TYPE_BOB_NOISE_XK_ACK,
Some(alice_session_id),
0,
1,
Some(&Aes::new(&header_protection_key)),
)?;
return Ok(ReceiveResult::Ok(session));
}
PACKET_TYPE_BOB_NOISE_XK_ACK => {
// Bob (remote) --> Alice (local)
/*
* This is Bob's reply to Alice's first message, allowing Alice to verify Bob's
* identity. Once this is done Alice can send her identity (encrypted) to complete
* the negotiation.
*/
if incoming_counter != 1 || incoming.is_some() {
return Err(Error::OutOfSequence);
}
if pkt_assembled.len() != BobNoiseXKAck::SIZE {
return Err(Error::InvalidPacket);
}
if let Some(session) = session {
let state = session.state.read().unwrap();
// This doesn't make sense if the session is up.
if state.keys[state.current_key].is_some() {
return Err(Error::OutOfSequence);
}
if let Offer::NoiseXKInit(outgoing_offer) = &state.current_offer {
let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?;
// Derive noise_es_ee from Bob's ephemeral public key.
let bob_noise_e = P384PublicKey::from_bytes(&pkt.bob_noise_e).ok_or(Error::FailedAuthentication)?;
let noise_es_ee = hmac_sha512_secret::<BASE_KEY_SIZE>(
outgoing_offer.noise_es.as_bytes(),
outgoing_offer
.alice_noise_e_secret
.agree(&bob_noise_e)
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
);
// Go ahead and compute the next 'h' state before we lose the ciphertext in decrypt.
let noise_h_next = mix_hash(&mix_hash(&outgoing_offer.noise_h, bob_noise_e.as_bytes()), &pkt_assembled[HEADER_SIZE..]);
// Decrypt and authenticate Bob's reply.
let mut gcm = AesGcm::new(&kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ES_EE>(noise_es_ee.as_bytes()));
gcm.reset_init_gcm(&incoming_message_nonce);
gcm.aad(&outgoing_offer.noise_h);
gcm.crypt_in_place(&mut pkt_assembled[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]);
if !gcm.finish_decrypt(&pkt_assembled[BobNoiseXKAck::AUTH_START..BobNoiseXKAck::AUTH_START + AES_GCM_TAG_SIZE]) {
return Err(Error::FailedAuthentication);
}
let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?;
if let Some(bob_session_id) = SessionId::new_from_bytes(&pkt.bob_session_id) {
// Complete Noise_XKpsk3 by mixing in noise_se followed by the PSK. The PSK as far as
// the Noise pattern is concerned is the result of mixing the externally supplied PSK
// with the Kyber1024 shared secret (hk). Kyber is treated as part of the PSK because
// it's an external add-on beyond the Noise spec.
let hk = pqc_kyber::decapsulate(&pkt.bob_hk_ciphertext, outgoing_offer.alice_hk_secret.as_bytes())
.map_err(|_| Error::FailedAuthentication)
.map(|k| Secret(k))?;
let noise_se = app.get_local_s_keypair().agree(&bob_noise_e).ok_or(Error::FailedAuthentication)?;
// Packet fully authenticated
if session.update_receive_window(incoming_counter) {
let noise_es_ee_se_hk_psk = hmac_sha512_secret::<BASE_KEY_SIZE>(
hmac_sha512_secret::<BASE_KEY_SIZE>(noise_es_ee.as_bytes(), noise_se.as_bytes()).as_bytes(),
hmac_sha512_secret::<BASE_KEY_SIZE>(session.psk.as_bytes(), hk.as_bytes()).as_bytes(),
);
let reply_message_nonce = create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_ACK, 2);
// Create reply informing Bob of our static identity now that we've verified Bob and set
// up forward secrecy. Also return Bob's opaque note.
let mut reply_buffer = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
reply_buffer[HEADER_SIZE] = SESSION_PROTOCOL_VERSION;
let mut reply_len = HEADER_SIZE + 1;
let alice_s_public_blob = app.get_local_s_public_blob();
assert!(alice_s_public_blob.len() <= (u16::MAX as usize));
reply_len = append_to_slice(&mut reply_buffer, reply_len, &(alice_s_public_blob.len() as u16).to_le_bytes())?;
let mut enc_start = reply_len;
reply_len = append_to_slice(&mut reply_buffer, reply_len, alice_s_public_blob)?;
let mut gcm = AesGcm::new(&kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ES_EE_HK>(
hmac_sha512_secret::<BASE_KEY_SIZE>(noise_es_ee.as_bytes(), hk.as_bytes()).as_bytes(),
));
gcm.reset_init_gcm(&reply_message_nonce);
gcm.aad(&noise_h_next);
gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]);
reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?;
let metadata = outgoing_offer.metadata.as_ref().map_or(&[][..0], |md| md.as_slice());
assert!(metadata.len() <= (u16::MAX as usize));
reply_len = append_to_slice(&mut reply_buffer, reply_len, &(metadata.len() as u16).to_le_bytes())?;
let noise_h_next = mix_hash(&mix_hash(&noise_h_next, &reply_buffer[HEADER_SIZE..reply_len]), session.psk.as_bytes());
enc_start = reply_len;
reply_len = append_to_slice(&mut reply_buffer, reply_len, metadata)?;
let mut gcm = AesGcm::new(&kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ES_EE_SE_HK_PSK>(
noise_es_ee_se_hk_psk.as_bytes(),
));
gcm.reset_init_gcm(&reply_message_nonce);
gcm.aad(&noise_h_next);
gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]);
reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?;
drop(state);
{
let mut state = session.state.write().unwrap();
let _ = state.remote_session_id.insert(bob_session_id);
let _ =
state.keys[0].insert(SessionKey::new::<Application>(noise_es_ee_se_hk_psk, 1, current_time, 2, false, false));
debug_assert!(state.keys[1].is_none());
state.current_key = 0;
state.current_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck {
last_retry_time: AtomicI64::new(current_time),
ack: reply_buffer,
ack_size: reply_len,
}));
}
send_with_fragmentation(
|b| send(Some(&session), b),
&mut reply_buffer[..reply_len],
mtu,
PACKET_TYPE_ALICE_NOISE_XK_ACK,
Some(bob_session_id),
0,
2,
Some(&session.header_protection_cipher),
)?;
return Ok(ReceiveResult::Ok(Some(session)));
} else {
return Err(Error::OutOfSequence);
}
} else {
return Err(Error::InvalidPacket);
}
} else {
return Err(Error::OutOfSequence);
}
} else {
return Err(Error::UnknownLocalSessionId);
}
}
PACKET_TYPE_ALICE_NOISE_XK_ACK => {
// Alice (remote) --> Bob (local)
/*
* After negotiating a keyed session and Alice has had the opportunity to
* verify Bob, this is when Bob gets to learn who Alice is. At this point
* Bob can make a final decision about whether to keep talking to Alice
* and can create an actual session using the state memo-ized in the memo
* that Alice must return.
*/
if incoming_counter != 2 || session.is_some() {
return Err(Error::OutOfSequence);
}
if pkt_assembled.len() < ALICE_NOISE_XK_ACK_MIN_SIZE {
return Err(Error::InvalidPacket);
}
if let Some(incoming) = incoming {
let mut r = PktReader(pkt_assembled, HEADER_SIZE + 1);
let alice_static_public_blob_size = r.read_u16()? as usize;
let ciphertext_up_to_metadata_size = r.1 + alice_static_public_blob_size + AES_GCM_TAG_SIZE + 2;
if r.0.len() < ciphertext_up_to_metadata_size {
return Err(Error::InvalidPacket);
}
let noise_h_next = mix_hash(&incoming.noise_h, &r.0[HEADER_SIZE..ciphertext_up_to_metadata_size]);
let alice_static_public_blob = r.read_decrypt_auth(
alice_static_public_blob_size,
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ES_EE_HK>(
hmac_sha512_secret::<BASE_KEY_SIZE>(incoming.noise_es_ee.as_bytes(), incoming.hk.as_bytes()).as_bytes(),
),
&incoming.noise_h,
&incoming_message_nonce,
)?;
// Check session acceptance and fish Alice's NIST P-384 static public key out of her static public blob.
let check_result = check_accept_session(alice_static_public_blob);
if check_result.is_none() {
self.sessions.write().unwrap().incoming.remove(&incoming.bob_session_id);
return Ok(ReceiveResult::Rejected);
}
let (alice_noise_s, psk, application_data) = check_result.unwrap();
let noise_h_next = mix_hash(&noise_h_next, psk.as_bytes());
// Complete Noise_XKpsk3 on Bob's side.
let noise_es_ee_se_hk_psk = hmac_sha512_secret::<BASE_KEY_SIZE>(
hmac_sha512_secret::<BASE_KEY_SIZE>(
incoming.noise_es_ee.as_bytes(),
incoming
.bob_noise_e_secret
.agree(&alice_noise_s)
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
)
.as_bytes(),
hmac_sha512_secret::<BASE_KEY_SIZE>(psk.as_bytes(), incoming.hk.as_bytes()).as_bytes(),
);
// Decrypt meta-data and verify the final key in the process. Copy meta-data
// into the temporary data buffer to return.
let alice_meta_data_size = r.read_u16()? as usize;
let alice_meta_data = r.read_decrypt_auth(
alice_meta_data_size,
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ES_EE_SE_HK_PSK>(noise_es_ee_se_hk_psk.as_bytes()),
&noise_h_next,
&incoming_message_nonce,
)?;
if alice_meta_data.len() > data_buf.len() {
return Err(Error::DataTooLarge);
}
data_buf[..alice_meta_data.len()].copy_from_slice(alice_meta_data);
let session = Arc::new(Session {
id: incoming.bob_session_id,
application_data,
psk,
send_counter: AtomicU64::new(2), // 1 was already used during negotiation
receive_window: std::array::from_fn(|_| AtomicU64::new(incoming_counter)),
header_protection_cipher: Aes::new(&incoming.header_protection_key),
state: RwLock::new(State {
remote_session_id: Some(incoming.alice_session_id),
keys: [
Some(SessionKey::new::<Application>(noise_es_ee_se_hk_psk, 1, current_time, 2, true, true)),
None,
],
current_key: 0,
current_offer: Offer::None,
}),
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
});
// Promote incoming session to active.
{
let mut sessions = self.sessions.write().unwrap();
sessions.incoming.remove(&incoming.bob_session_id);
sessions.active.insert(incoming.bob_session_id, Arc::downgrade(&session));
}
let _ = session.send_nop(|b| send(Some(&session), b));
return Ok(ReceiveResult::OkNewSession(
session,
if alice_meta_data.is_empty() {
None
} else {
Some(&mut data_buf[..alice_meta_data.len()])
},
));
} else {
return Err(Error::UnknownLocalSessionId);
}
}
PACKET_TYPE_REKEY_INIT => {
if pkt_assembled.len() != RekeyInit::SIZE {
return Err(Error::InvalidPacket);
}
if incoming.is_some() {
return Err(Error::OutOfSequence);
}
if let Some(session) = session {
let state = session.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys[key_index].as_ref() {
// Only the current "Alice" accepts rekeys initiated by the current "Bob." These roles
// flip with each rekey event.
if !key.bob {
let mut c = key.get_receive_cipher(incoming_counter);
c.reset_init_gcm(&incoming_message_nonce);
c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]);
let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]);
drop(c);
if aead_authentication_ok {
let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) {
let bob_e_secret = P384KeyPair::generate();
let next_session_key = hmac_sha512_secret(
key.ratchet_key.as_bytes(),
bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?.as_bytes(),
);
// Packet fully authenticated
if session.update_receive_window(incoming_counter) {
let mut reply_buf = [0u8; RekeyAck::SIZE];
let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap();
reply.session_protocol_version = SESSION_PROTOCOL_VERSION;
reply.bob_e = *bob_e_secret.public_key_bytes();
reply.next_key_fingerprint = SHA384::hash(next_session_key.as_bytes());
let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
set_packet_header(
&mut reply_buf,
1,
0,
PACKET_TYPE_REKEY_ACK,
u64::from(remote_session_id),
state.current_key,
counter,
);
let mut c = key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter));
c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]);
reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt());
drop(c);
session
.header_protection_cipher
.encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
send(Some(&session), &mut reply_buf);
// The new "Bob" doesn't know yet if Alice has received the new key, so the
// new key is recorded as the "alt" (key_index ^ 1) but the current key is
// not advanced yet. This happens automatically the first time we receive a
// valid packet with the new key.
let next_ratchet_count = key.ratchet_count + 1;
drop(state);
let mut state = session.state.write().unwrap();
let _ = state.keys[key_index ^ 1].replace(SessionKey::new::<Application>(
next_session_key,
next_ratchet_count,
current_time,
counter,
false,
false,
));
drop(state);
return Ok(ReceiveResult::Ok(Some(session)));
} else {
return Err(Error::OutOfSequence);
}
}
}
return Err(Error::FailedAuthentication);
}
}
}
return Err(Error::OutOfSequence);
} else {
return Err(Error::UnknownLocalSessionId);
}
}
PACKET_TYPE_REKEY_ACK => {
if pkt_assembled.len() != RekeyAck::SIZE {
return Err(Error::InvalidPacket);
}
if incoming.is_some() {
return Err(Error::OutOfSequence);
}
if let Some(session) = session {
let state = session.state.read().unwrap();
if let Offer::RekeyInit(alice_e_secret, _) = &state.current_offer {
if let Some(key) = state.keys[key_index].as_ref() {
// Only the current "Bob" initiates rekeys and expects this ACK.
if key.bob {
let mut c = key.get_receive_cipher(incoming_counter);
c.reset_init_gcm(&incoming_message_nonce);
c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]);
let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]);
drop(c);
if aead_authentication_ok {
// Packet fully authenticated
let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) {
let next_session_key = hmac_sha512_secret(
key.ratchet_key.as_bytes(),
alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?.as_bytes(),
);
if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(next_session_key.as_bytes())) {
if session.update_receive_window(incoming_counter) {
// The new "Alice" knows Bob has the key since this is an ACK, so she can go
// ahead and set current_key to the new key. Then when she sends something
// to Bob the other side will automatically advance to the new key as well.
let next_ratchet_count = key.ratchet_count + 1;
drop(state);
let next_key_index = key_index ^ 1;
let mut state = session.state.write().unwrap();
let _ = state.keys[next_key_index].replace(SessionKey::new::<Application>(
next_session_key,
next_ratchet_count,
current_time,
session.send_counter.load(Ordering::Relaxed),
true,
true,
));
state.current_key = next_key_index; // this is an ACK so it's confirmed
state.current_offer = Offer::None;
drop(state);
return Ok(ReceiveResult::Ok(Some(session)));
} else {
return Err(Error::OutOfSequence);
}
}
}
}
return Err(Error::FailedAuthentication);
}
}
}
return Err(Error::OutOfSequence);
} else {
return Err(Error::UnknownLocalSessionId);
}
}
_ => {
return Err(Error::InvalidPacket);
}
}
}
}
}
impl<Application: ApplicationLayer> Session<Application> {
/// Send data over the session.
///
/// * `send` - Function to call to send physical packet(s)
/// * `mtu_sized_buffer` - A writable work buffer whose size also specifies the physical MTU
/// * `data` - Data to send
#[inline]
pub fn send<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction, mtu_sized_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> {
debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU);
let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(session_key) = state.keys[state.current_key].as_ref() {
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter));
let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize;
let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE;
let last_fragment_no = fragment_count - 1;
for fragment_no in 0..fragment_count {
let chunk_size = fragment_max_chunk_size.min(data.len());
let mut fragment_size = chunk_size + HEADER_SIZE;
set_packet_header(
mtu_sized_buffer,
fragment_count,
fragment_no,
PACKET_TYPE_DATA,
u64::from(remote_session_id),
state.current_key,
counter,
);
c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]);
data = &data[chunk_size..];
if fragment_no == last_fragment_no {
debug_assert!(data.is_empty());
let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE;
mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt());
fragment_size = tagged_fragment_size;
}
self.header_protection_cipher
.encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
send(&mut mtu_sized_buffer[..fragment_size]);
}
debug_assert!(data.is_empty());
drop(c);
return Ok(());
}
}
return Err(Error::SessionNotEstablished);
}
/// Send a NOP to the other side (e.g. for keep alive).
pub fn send_nop<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction) -> Result<(), Error> {
let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(session_key) = state.keys[state.current_key].as_ref() {
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter));
nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt());
drop(c);
set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter);
self.header_protection_cipher
.encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
send(&mut nop);
}
}
return Err(Error::SessionNotEstablished);
}
/// Check whether this session is established.
pub fn established(&self) -> bool {
let state = self.state.read().unwrap();
state.keys[state.current_key].as_ref().map_or(false, |k| k.confirmed)
}
/// Get the ratchet count and a hash fingerprint of the current active key.
pub fn key_info(&self) -> Option<(u64, [u8; 48])> {
let state = self.state.read().unwrap();
if let Some(key) = state.keys[state.current_key].as_ref() {
Some((key.ratchet_count, SHA384::hash(key.ratchet_key.as_bytes())))
} else {
None
}
}
/// Send a rekey init message.
///
/// This is called from the session context's service() method when it's time to rekey.
/// It should only be called when the current key was established in the 'bob' role. This
/// is checked when rekey time is checked.
fn initiate_rekey<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction, current_time: i64) {
let rekey_e = P384KeyPair::generate();
let mut rekey_buf = [0u8; RekeyInit::SIZE];
let pkt: &mut RekeyInit = byte_array_as_proto_buffer_mut(&mut rekey_buf).unwrap();
pkt.session_protocol_version = SESSION_PROTOCOL_VERSION;
pkt.alice_e = *rekey_e.public_key_bytes();
let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys[state.current_key].as_ref() {
if let Some(counter) = self.get_next_outgoing_counter() {
if let Ok(mut gcm) = key.get_send_cipher(counter.get()) {
gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_INIT, counter.get()));
gcm.crypt_in_place(&mut rekey_buf[RekeyInit::ENC_START..RekeyInit::AUTH_START]);
rekey_buf[RekeyInit::AUTH_START..].copy_from_slice(&gcm.finish_encrypt());
} else {
return;
};
debug_assert!(rekey_buf.len() <= MIN_TRANSPORT_MTU);
set_packet_header(
&mut rekey_buf,
1,
0,
PACKET_TYPE_REKEY_INIT,
u64::from(remote_session_id),
state.current_key,
counter.get(),
);
self.header_protection_cipher
.encrypt_block_in_place(&mut rekey_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
send(&mut rekey_buf);
drop(state);
self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, current_time);
}
}
}
}
/// Get the next outgoing counter value.
#[inline(always)]
fn get_next_outgoing_counter(&self) -> Option<NonZeroU64> {
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed))
}
/// Check the receive window without mutating state.
#[inline(always)]
fn check_receive_window(&self, counter: u64) -> bool {
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Relaxed);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
/// Update the receive window, returning true if the packet is still valid.
/// This should only be called after the packet is authenticated.
#[inline(always)]
fn update_receive_window(&self, counter: u64) -> bool {
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::Relaxed);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
}
#[inline(always)]
fn set_packet_header(
packet: &mut [u8],
fragment_count: usize,
fragment_no: usize,
packet_type: u8,
remote_session_id: u64,
key_index: usize,
counter: u64,
) {
debug_assert!(packet.len() >= MIN_PACKET_SIZE);
debug_assert!(fragment_count > 0);
debug_assert!(fragment_count <= MAX_FRAGMENTS);
debug_assert!(fragment_no < MAX_FRAGMENTS);
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
// [0-47] recipient session ID
// -- start of header check cipher single block encrypt --
// [48-48] key index (least significant bit)
// [49-51] packet type (0-15)
// [52-57] fragment count (1..64 - 1, so 0 means 1 fragment)
// [58-63] fragment number (0..63)
// [64-127] 64-bit counter
assert!(packet.len() >= 16);
packet[0..8].copy_from_slice(
&(remote_session_id
| ((key_index & 1) as u64).wrapping_shl(48)
| (packet_type as u64).wrapping_shl(49)
| ((fragment_count - 1) as u64).wrapping_shl(52)
| (fragment_no as u64).wrapping_shl(58))
.to_le_bytes(),
);
packet[8..16].copy_from_slice(&counter.to_le_bytes());
}
#[inline(always)]
fn parse_packet_header(incoming_packet: &[u8]) -> (usize, u8, u8, u8, u64) {
let raw_header_a = u16::from_le_bytes(incoming_packet[6..8].try_into().unwrap());
(
(raw_header_a & 1) as usize,
(raw_header_a.wrapping_shr(1) & 7) as u8,
((raw_header_a.wrapping_shr(4) & 63) + 1) as u8,
raw_header_a.wrapping_shr(10) as u8,
u64::from_le_bytes(incoming_packet[8..16].try_into().unwrap()),
)
}
/// Break a packet into fragments and send them all.
///
/// The contents of packet[] are mangled during this operation, so it should be discarded after.
/// This is only used for key exchange and control packets. For data packets this is done inline
/// for better performance with encryption and fragmentation happening at the same time.
fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(
mut send: SendFunction,
packet: &mut [u8],
mtu: usize,
packet_type: u8,
remote_session_id: Option<SessionId>,
key_index: usize,
counter: u64,
header_protect_cipher: Option<&Aes>,
) -> Result<(), Error> {
let packet_len = packet.len();
let recipient_session_id = remote_session_id.map_or(SessionId::NONE, |s| u64::from(s));
let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
let mut fragment_start = 0;
let mut fragment_end = packet_len.min(mtu);
for fragment_no in 0..fragment_count {
let fragment = &mut packet[fragment_start..fragment_end];
set_packet_header(
fragment,
fragment_count,
fragment_no,
packet_type,
recipient_session_id,
key_index,
counter,
);
if let Some(hcc) = header_protect_cipher {
hcc.encrypt_block_in_place(&mut fragment[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
}
send(fragment);
fragment_start = fragment_end - HEADER_SIZE;
fragment_end = (fragment_start + mtu).min(packet_len);
}
Ok(())
}
/// Assemble a series of fragments into a buffer and return the length of the assembled packet in bytes.
///
/// This is also only used for key exchange and control packets. For data packets decryption and assembly
/// happen in one pass for better performance.
fn assemble_fragments_into<A: ApplicationLayer>(fragments: &[A::IncomingPacketBuffer], d: &mut [u8]) -> Result<usize, Error> {
let mut l = 0;
for i in 0..fragments.len() {
let mut ff = fragments[i].as_ref();
if ff.len() <= MIN_PACKET_SIZE {
return Err(Error::InvalidPacket);
}
if i > 0 {
ff = &ff[HEADER_SIZE..];
}
let j = l + ff.len();
if j > d.len() {
return Err(Error::InvalidPacket);
}
d[l..j].copy_from_slice(ff);
l = j;
}
return Ok(l);
}
impl SessionKey {
fn new<Application: ApplicationLayer>(
key: Secret<BASE_KEY_SIZE>,
ratchet_count: u64,
current_time: i64,
current_counter: u64,
bob: bool,
confirmed: bool,
) -> Self {
let a2b = kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB>(key.as_bytes());
let b2a = kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE>(key.as_bytes());
let (receive_key, send_key) = if bob {
(a2b, b2a)
} else {
(b2a, a2b)
};
let receive_cipher_pool = std::array::from_fn(|_| Mutex::new(AesGcm::new(&receive_key)));
let send_cipher_pool = std::array::from_fn(|_| Mutex::new(AesGcm::new(&send_key)));
Self {
ratchet_key: kbkdf::<BASE_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_RATCHET>(key.as_bytes()),
//receive_key,
//send_key,
receive_cipher_pool,
send_cipher_pool,
rekey_at_time: current_time
.checked_add(
Application::REKEY_AFTER_TIME_MS + ((random::xorshift64_random() as u32) % Application::REKEY_AFTER_TIME_MS_MAX_JITTER) as i64,
)
.unwrap(),
created_at_counter: current_counter,
rekey_at_counter: current_counter.checked_add(Application::REKEY_AFTER_USES).unwrap(),
expire_at_counter: current_counter.checked_add(Application::EXPIRE_AFTER_USES).unwrap(),
ratchet_count,
bob,
confirmed,
}
}
fn get_send_cipher<'a>(&'a self, counter: u64) -> Result<MutexGuard<'a, AesGcm<true>>, Error> {
if counter < self.expire_at_counter {
for i in 0..(AES_POOL_SIZE - 1) {
if let Ok(p) = self.send_cipher_pool[(counter as usize).wrapping_add(i) % AES_POOL_SIZE].try_lock() {
return Ok(p);
}
}
Ok(self.send_cipher_pool[(counter as usize) % AES_POOL_SIZE].lock().unwrap())
} else {
Err(Error::MaxKeyLifetimeExceeded)
}
}
fn get_receive_cipher<'a>(&'a self, counter: u64) -> MutexGuard<'a, AesGcm<false>> {
for i in 0..(AES_POOL_SIZE - 1) {
if let Ok(p) = self.receive_cipher_pool[(counter as usize).wrapping_add(i) % AES_POOL_SIZE].try_lock() {
return p;
}
}
self.receive_cipher_pool[(counter as usize) % AES_POOL_SIZE].lock().unwrap()
}
}
/// Helper code for parsing variable length ALICE_NOISE_XK_ACK during negotiation.
struct PktReader<'a>(&'a mut [u8], usize);
impl<'a> PktReader<'a> {
fn read_u16(&mut self) -> Result<u16, Error> {
let tmp = self.1 + 2;
if tmp <= self.0.len() {
let n = u16::from_le_bytes(self.0[self.1..tmp].try_into().unwrap());
self.1 = tmp;
Ok(n)
} else {
Err(Error::InvalidPacket)
}
}
fn read_decrypt_auth<'b>(&'b mut self, l: usize, k: Secret<AES_256_KEY_SIZE>, gcm_aad: &[u8], nonce: &[u8]) -> Result<&'b [u8], Error> {
let mut tmp = self.1 + l;
if (tmp + AES_GCM_TAG_SIZE) <= self.0.len() {
let mut gcm = AesGcm::new(&k);
gcm.reset_init_gcm(nonce);
gcm.aad(gcm_aad);
gcm.crypt_in_place(&mut self.0[self.1..tmp]);
let s = &self.0[self.1..tmp];
self.1 = tmp;
tmp += AES_GCM_TAG_SIZE;
if !gcm.finish_decrypt(&self.0[self.1..tmp]) {
Err(Error::FailedAuthentication)
} else {
self.1 = tmp;
Ok(s)
}
} else {
Err(Error::InvalidPacket)
}
}
}
/// Helper function to append to a slice when we still want to be able to look back at it.
fn append_to_slice(s: &mut [u8], p: usize, d: &[u8]) -> Result<usize, Error> {
let tmp = p + d.len();
if tmp <= s.len() {
s[p..tmp].copy_from_slice(d);
Ok(tmp)
} else {
Err(Error::UnexpectedBufferOverrun)
}
}
/// MixHash to update 'h' during negotiation.
fn mix_hash(h: &[u8; SHA384_HASH_SIZE], m: &[u8]) -> [u8; SHA384_HASH_SIZE] {
let mut hasher = SHA384::new();
hasher.update(h);
hasher.update(m);
hasher.finish()
}
/// HMAC-SHA512 key derivation based on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 7)
/// Cryptographically this isn't meaningfully different from HMAC(key, [label]) but this is how NIST rolls.
fn kbkdf<const OUTPUT_BYTES: usize, const LABEL: u8>(key: &[u8]) -> Secret<OUTPUT_BYTES> {
//These are the values we have assigned to the 5 variables involved in https://csrc.nist.gov/publications/detail/sp/800-108/final:
// K_in = key, i = 0x01, Label = 'Z'||'T'||LABEL, Context = 0x00, L = (OUTPUT_BYTES * 8)
hmac_sha512_secret(
key,
&[
1,
b'Z',
b'T',
LABEL,
0x00,
0,
(((OUTPUT_BYTES * 8) >> 8) & 0xff) as u8,
((OUTPUT_BYTES * 8) & 0xff) as u8,
],
)
}
fn prng32(mut x: u32) -> u32 {
// based on lowbias32 from https://nullprogram.com/blog/2018/07/31/
x = x.wrapping_add(1); // don't get stuck on 0
x ^= x.wrapping_shr(16);
x = x.wrapping_mul(0x7feb352d);
x ^= x.wrapping_shr(15);
x = x.wrapping_mul(0x846ca68b);
x ^= x.wrapping_shr(16);
x
}