docs, comments, readable code

This commit is contained in:
Adam Ierymenko 2022-11-18 13:41:44 -05:00
parent 1d23f40226
commit 7522282c2e
5 changed files with 208 additions and 131 deletions

View file

@ -426,7 +426,7 @@ impl Controller {
}
if member_changed {
self.database.save_member(member).await?;
self.database.save_member(member, false).await?;
}
Ok((authorization_result, network_config, revocations))

View file

@ -23,11 +23,11 @@ pub enum Change {
pub trait Database: Sync + Send + NodeStorage + 'static {
async fn list_networks(&self) -> Result<Vec<NetworkId>, Box<dyn Error + Send + Sync>>;
async fn get_network(&self, id: NetworkId) -> Result<Option<Network>, Box<dyn Error + Send + Sync>>;
async fn save_network(&self, obj: Network) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn list_members(&self, network_id: NetworkId) -> Result<Vec<Address>, Box<dyn Error + Send + Sync>>;
async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Box<dyn Error + Send + Sync>>;
async fn save_member(&self, obj: Member) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Box<dyn Error + Send + Sync>>;
/// Get a receiver that can be used to receive changes made to networks and members, if supported.
///

View file

@ -30,6 +30,9 @@ const EVENT_HANDLER_TASK_TIMEOUT: Duration = Duration::from_secs(5);
/// A cache is maintained that contains the actual objects. When an object is live edited,
/// once it successfully reads and loads it is merged with the cached object and saved to
/// the cache. The cache will also contain any ephemeral data, generated data, etc.
///
/// The file format is YAML instead of JSON for better human friendliness and the layout
/// is different from V1 so it'll need a converter to use with V1 FileDb controller data.
pub struct FileDatabase {
base_path: PathBuf,
controller_address: AtomicU64,
@ -332,14 +335,17 @@ impl Database for FileDatabase {
let network_id_should_be = network.id.change_network_controller(controller_address);
if network.id != network_id_should_be {
network.id = network_id_should_be;
let _ = self.save_network(network.clone()).await?;
let _ = self.save_network(network.clone(), false).await?;
}
}
}
Ok(network)
}
async fn save_network(&self, obj: Network) -> Result<(), Box<dyn Error + Send + Sync>> {
async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Box<dyn Error + Send + Sync>> {
if !generate_change_notification {
let _ = self.cache.on_network_updated(obj.clone());
}
let base_network_path = self.network_path(obj.id);
let _ = fs::create_dir_all(base_network_path.parent().unwrap()).await;
let _ = fs::write(base_network_path, serde_yaml::to_string(&obj)?.as_bytes()).await?;
@ -374,13 +380,16 @@ impl Database for FileDatabase {
if member.network_id != network_id {
// Also auto-update member network IDs, see get_network().
member.network_id = network_id;
self.save_member(member.clone()).await?;
self.save_member(member.clone(), false).await?;
}
}
Ok(member)
}
async fn save_member(&self, obj: Member) -> Result<(), Box<dyn Error + Send + Sync>> {
async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Box<dyn Error + Send + Sync>> {
if !generate_change_notification {
let _ = self.cache.on_member_updated(obj.clone());
}
let base_member_path = self.member_path(obj.network_id, obj.node_id);
let _ = fs::create_dir_all(base_member_path.parent().unwrap()).await;
let _ = fs::write(base_member_path, serde_yaml::to_string(&obj)?.as_bytes()).await?;
@ -401,6 +410,7 @@ impl Database for FileDatabase {
mod tests {
#[allow(unused_imports)]
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[allow(unused)]
#[test]
@ -420,11 +430,15 @@ mod tests {
db.save_node_identity(&controller_id);
assert!(db.load_node_identity().is_some());
let change_count = Arc::new(AtomicUsize::new(0));
let db2 = db.clone();
let change_count2 = change_count.clone();
tokio_runtime.spawn(async move {
let mut change_receiver = db2.changes().await.unwrap();
loop {
if let Ok(change) = change_receiver.recv().await {
change_count2.fetch_add(1, Ordering::SeqCst);
//println!("[FileDatabase] {:#?}", change);
} else {
break;
@ -433,14 +447,16 @@ mod tests {
});
let mut test_network = Network::new(network_id);
db.save_network(test_network.clone()).await.expect("network save error");
db.save_network(test_network.clone(), true).await.expect("network save error");
let mut test_member = Member::new_without_identity(node_id, network_id);
for x in 0..3 {
test_member.name = x.to_string();
db.save_member(test_member.clone()).await.expect("member save error");
db.save_member(test_member.clone(), true).await.expect("member save error");
zerotier_utils::tokio::task::yield_now().await;
sleep(Duration::from_millis(100)).await;
zerotier_utils::tokio::task::yield_now().await;
let test_member2 = db.get_member(network_id, node_id).await.unwrap().unwrap();
assert!(test_member == test_member2);

View file

@ -1,6 +1,8 @@
ZeroTier Secure Socket Protocol
======
**NOTE: this protocol and code have not yet been formally audited and should not be used in anything production.**
ZSSP (ZeroTier Secure Socket Protocol) is an implementation of the Noise_IK pattern using FIPS/NIST compliant primitives. After Noise_IK negotiation is complete ZSSP also adds key ratcheting and optional (enabled by default) support for quantum data forward secrecy with Kyber1024.
It's general purpose and could be used with any system but contains a few specific design choices to make it optimal for ZeroTier and easy to distinguish from legacy ZeroTier V1 traffic for backward compatibility.

View file

@ -20,10 +20,10 @@ use zerotier_utils::ringbuffermap::RingBufferMap;
use zerotier_utils::unlikely_branch;
use zerotier_utils::varint;
/// Minimum size of a valid packet.
/// Minimum size of a valid physical packet.
pub const MIN_PACKET_SIZE: usize = HEADER_SIZE + AES_GCM_TAG_SIZE;
/// Minimum wire MTU for ZSSP to function normally.
/// Minimum physical MTU for ZSSP to function.
pub const MIN_TRANSPORT_MTU: usize = 1280;
/// Minimum recommended interval between calls to service() on each session, in milliseconds.
@ -222,6 +222,10 @@ pub enum ReceiveResult<'a, H: Host> {
pub struct SessionId(u64);
impl SessionId {
/// The nil session ID used in messages initiating a new session.
///
/// This is all 1's so that ZeroTier can easily tell the difference between ZSSP init packets
/// and ZeroTier V1 packets.
pub const NIL: SessionId = SessionId(0xffffffffffff);
#[inline]
@ -242,7 +246,7 @@ impl SessionId {
#[inline]
pub fn new_random() -> Self {
Self(random::next_u64_secure() % (Self::NIL.0 - 1))
Self(random::next_u64_secure() % Self::NIL.0)
}
}
@ -331,13 +335,16 @@ pub trait Host: Sized {
/// ZSSP bi-directional packet transport channel.
pub struct Session<H: Host> {
/// This side's session ID (unique on this side)
pub id: SessionId,
/// An arbitrary object associated with session (type defined in Host trait)
pub associated_object: H::AssociatedObject,
send_counter: Counter,
send_counter: Counter, // Outgoing packet counter and nonce state
psk: Secret<64>, // Arbitrary PSK provided by external code
ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer
header_check_cipher: Aes, // Cipher used for fast 32-bit header MAC
ss: Secret<48>, // Static raw shared ECDH NIST P-384 key
header_check_cipher: Aes, // Cipher used for header MAC (fragmentation)
state: RwLock<SessionMutableState>, // Mutable parts of state (other than defrag buffers)
remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
@ -345,11 +352,11 @@ pub struct Session<H: Host> {
}
struct SessionMutableState {
remote_session_id: Option<SessionId>,
keys: [Option<SessionKey>; KEY_HISTORY_SIZE],
key_ptr: usize,
offer: Option<Box<EphemeralOffer>>,
last_remote_offer: i64,
remote_session_id: Option<SessionId>, // The other side's 48-bit session ID
keys: [Option<SessionKey>; KEY_HISTORY_SIZE], // Buffers to store current, next, and last active key
key_ptr: usize, // Pointer used for keys[] circular buffer
offer: Option<Box<EphemeralOffer>>, // Most recent ephemeral offer sent to remote
last_remote_offer: i64, // Time of most recent ephemeral offer (ms)
}
impl<H: Host> Session<H> {
@ -380,7 +387,8 @@ impl<H: Host> Session<H> {
let remote_s_public_hash = SHA384::hash(remote_s_public);
let outgoing_init_header_check_cipher =
Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>());
if let Ok(offer) = create_initial_offer(
if let Ok(offer) = send_ephemeral_offer(
&mut send,
send_counter.next(),
local_session_id,
@ -421,8 +429,9 @@ impl<H: Host> Session<H> {
/// Send data over the session.
///
/// * `mtu_buffer` - A writable work buffer whose size must be equal to the wire MTU
/// * `mtu_buffer` - A writable work buffer whose size must be equal to the physical MTU
/// * `data` - Data to send
#[inline]
pub fn send<SendFunction: FnMut(&mut [u8])>(
&self,
mut send: SendFunction,
@ -433,9 +442,13 @@ impl<H: Host> Session<H> {
let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys[state.key_ptr].as_ref() {
// Total size of the armored packet we are going to send (may end up being fragmented)
let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
// This outgoing packet's nonce counter value.
let counter = self.send_counter.next();
// Create initial header for first fragment of packet and place in first HEADER_SIZE bytes of buffer.
create_packet_header(
mtu_buffer,
packet_len,
@ -445,13 +458,12 @@ impl<H: Host> Session<H> {
counter,
)?;
// Get an initialized AES-GCM cipher and re-initialize with a 96-bit IV built from remote session ID,
// packet type, and counter.
let mut c = key.get_send_cipher(counter)?;
c.init(memory::as_byte_array::<Pseudoheader, 12>(&Pseudoheader::make(
remote_session_id.into(),
PACKET_TYPE_DATA,
counter.to_u32(),
)));
c.init(CanonicalHeader::make(remote_session_id, PACKET_TYPE_DATA, counter.to_u32()).as_bytes());
// Send first N-1 fragments of N total fragments.
if packet_len > mtu_buffer.len() {
let mut header: [u8; 16] = mtu_buffer[..HEADER_SIZE].try_into().unwrap();
let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE;
@ -461,7 +473,7 @@ impl<H: Host> Session<H> {
let fragment_size = fragment_data_size + HEADER_SIZE;
c.crypt(&data[..fragment_data_size], &mut mtu_buffer[HEADER_SIZE..fragment_size]);
data = &data[fragment_data_size..];
set_header_mac(mtu_buffer, &self.header_check_cipher);
set_header_check_code(mtu_buffer, &self.header_check_cipher);
send(&mut mtu_buffer[..fragment_size]);
debug_assert!(header[15].wrapping_shr(2) < 63);
@ -475,17 +487,22 @@ impl<H: Host> Session<H> {
packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
}
// Send final fragment (or only fragment if no fragmentation was needed)
let gcm_tag_idx = data.len() + HEADER_SIZE;
c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]);
mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish_encrypt());
set_header_mac(mtu_buffer, &self.header_check_cipher);
set_header_check_code(mtu_buffer, &self.header_check_cipher);
send(&mut mtu_buffer[..packet_len]);
// Check reusable AES-GCM instance back into pool.
key.return_send_cipher(c);
return Ok(());
} else {
unlikely_branch();
}
} else {
unlikely_branch();
}
return Err(Error::SessionNotEstablished);
}
@ -503,7 +520,7 @@ impl<H: Host> Session<H> {
pub fn security_info(&self) -> Option<([u8; 16], i64, u64, bool)> {
let state = self.state.read().unwrap();
if let Some(key) = state.keys[state.key_ptr].as_ref() {
Some((key.fingerprint, key.establish_time, key.ratchet_count, key.jedi))
Some((key.secret_fingerprint, key.establish_time, key.ratchet_count, key.jedi))
} else {
None
}
@ -528,7 +545,7 @@ impl<H: Host> Session<H> {
if (force_rekey
|| state.keys[state.key_ptr]
.as_ref()
.map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time)))
.map_or(true, |key| key.lifetime.should_rekey(self.send_counter.previous(), current_time)))
&& state
.offer
.as_ref()
@ -536,7 +553,7 @@ impl<H: Host> Session<H> {
{
if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) {
let mut tmp_header_check_cipher = None;
if let Ok(offer) = create_initial_offer(
if let Ok(offer) = send_ephemeral_offer(
&mut send,
self.send_counter.next(),
self.id,
@ -602,13 +619,13 @@ impl<H: Host> ReceiveContext<H> {
let packet_type_fragment_info = u16::from_le(memory::load_raw(&incoming_packet[14..16]));
let packet_type = (packet_type_fragment_info & 0x0f) as u8;
let fragment_count = ((packet_type_fragment_info.wrapping_shr(4) + 1) as u8) & 63;
let fragment_no = packet_type_fragment_info.wrapping_shr(10) as u8;
let fragment_no = packet_type_fragment_info.wrapping_shr(10) as u8; // & 63 not needed
if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64)
{
if let Some(session) = host.session_lookup(local_session_id) {
if check_header_mac(incoming_packet, &session.header_check_cipher) {
let pseudoheader = Pseudoheader::make(u64::from(local_session_id), packet_type, counter);
if verify_header_check_code(incoming_packet, &session.header_check_cipher) {
let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter);
if fragment_count > 1 {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
let mut defrag = session.defrag.lock().unwrap();
@ -620,7 +637,7 @@ impl<H: Host> ReceiveContext<H> {
remote_address,
&mut send,
data_buf,
memory::as_byte_array(&pseudoheader),
canonical_header.as_bytes(),
assembled_packet.as_ref(),
packet_type,
Some(session),
@ -638,7 +655,7 @@ impl<H: Host> ReceiveContext<H> {
remote_address,
&mut send,
data_buf,
memory::as_byte_array(&pseudoheader),
canonical_header.as_bytes(),
&[incoming_packet_buf],
packet_type,
Some(session),
@ -655,9 +672,10 @@ impl<H: Host> ReceiveContext<H> {
return Err(Error::UnknownLocalSessionId(local_session_id));
}
} else {
unlikely_branch();
if check_header_mac(incoming_packet, &self.incoming_init_header_check_cipher) {
let pseudoheader = Pseudoheader::make(SessionId::NIL.0, packet_type, counter);
unlikely_branch(); // we want data receive to be the priority branch, this is only occasionally used
if verify_header_check_code(incoming_packet, &self.incoming_init_header_check_cipher) {
let canonical_header = CanonicalHeader::make(SessionId::NIL, packet_type, counter);
if fragment_count > 1 {
let mut defrag = self.initial_offer_defrag.lock().unwrap();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
@ -668,7 +686,7 @@ impl<H: Host> ReceiveContext<H> {
remote_address,
&mut send,
data_buf,
memory::as_byte_array(&pseudoheader),
canonical_header.as_bytes(),
assembled_packet.as_ref(),
packet_type,
None,
@ -682,7 +700,7 @@ impl<H: Host> ReceiveContext<H> {
remote_address,
&mut send,
data_buf,
memory::as_byte_array(&pseudoheader),
canonical_header.as_bytes(),
&[incoming_packet_buf],
packet_type,
None,
@ -699,13 +717,15 @@ impl<H: Host> ReceiveContext<H> {
return Ok(ReceiveResult::Ok);
}
/// Called internally when all fragments of a packet are received.
/// Header check codes will already have been validated for each fragment.
fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>(
&self,
host: &H,
remote_address: &H::RemoteAddress,
send: &mut SendFunction,
data_buf: &'a mut [u8],
pseudoheader: &[u8; 12],
canonical_header_bytes: &[u8; 12],
fragments: &[H::IncomingPacketBuffer],
packet_type: u8,
session: Option<H::SessionRef>,
@ -714,8 +734,10 @@ impl<H: Host> ReceiveContext<H> {
) -> Result<ReceiveResult<'a, H>, Error> {
debug_assert!(fragments.len() >= 1);
// These just confirm that the first 'if' below does what it should do.
debug_assert_eq!(PACKET_TYPE_DATA, 0);
debug_assert_eq!(PACKET_TYPE_NOP, 1);
if packet_type <= PACKET_TYPE_NOP {
if let Some(session) = session {
let state = session.state.read().unwrap();
@ -729,7 +751,7 @@ impl<H: Host> ReceiveContext<H> {
}
let mut c = key.get_receive_cipher();
c.init(pseudoheader);
c.init(canonical_header_bytes);
let mut data_len = 0;
@ -797,7 +819,9 @@ impl<H: Host> ReceiveContext<H> {
} else {
unlikely_branch();
let mut incoming_packet_buf = [0_u8; 4096]; // big enough for key exchange packets
// To greatly simplify logic handling key exchange packets, assemble these first.
// This adds some extra memory copying but this is not the fast path.
let mut incoming_packet_buf = [0_u8; 4096];
let mut incoming_packet_len = 0;
for i in 0..fragments.len() {
let mut ff = fragments[i].as_ref();
@ -833,7 +857,7 @@ impl<H: Host> ReceiveContext<H> {
// Check that the sender knows this host's identity before doing anything else.
if !hmac_sha384_2(
host.get_local_s_public_hash(),
pseudoheader,
canonical_header_bytes,
&incoming_packet[HEADER_SIZE..hmac1_end],
)
.eq(&incoming_packet[hmac1_end..])
@ -866,7 +890,7 @@ impl<H: Host> ReceiveContext<H> {
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(),
false,
);
c.init(pseudoheader);
c.init(canonical_header_bytes);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
if !c.finish_decrypt(&incoming_packet[payload_end..aes_gcm_tag_end]) {
return Err(Error::FailedAuthentication);
@ -898,7 +922,7 @@ impl<H: Host> ReceiveContext<H> {
// just mixed into the key.
if !hmac_sha384_2(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
pseudoheader,
canonical_header_bytes,
&original_ciphertext[HEADER_SIZE..aes_gcm_tag_end],
)
.eq(&incoming_packet[aes_gcm_tag_end..hmac1_end])
@ -923,7 +947,7 @@ impl<H: Host> ReceiveContext<H> {
let state = session.state.read().unwrap();
for k in state.keys.iter() {
if let Some(k) = k.as_ref() {
if key_fingerprint(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) {
if secret_fingerprint(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) {
ratchet_key = Some(k.ratchet_key.clone());
ratchet_count = k.ratchet_count;
break;
@ -1044,8 +1068,8 @@ impl<H: Host> ReceiveContext<H> {
alice_session_id.into(),
reply_counter,
)?;
let reply_pseudoheader =
Pseudoheader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32());
let reply_canonical_header =
CanonicalHeader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32());
// Encrypt reply packet using final Noise_IK key BEFORE mixing hybrid or ratcheting, since the other side
// must decrypt before doing these things.
@ -1053,7 +1077,7 @@ impl<H: Host> ReceiveContext<H> {
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(),
true,
);
c.init(memory::as_byte_array::<Pseudoheader, 12>(&reply_pseudoheader));
c.init(reply_canonical_header.as_bytes());
c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]);
let c = c.finish_encrypt();
reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c);
@ -1073,7 +1097,7 @@ impl<H: Host> ReceiveContext<H> {
// Kyber exchange, but you'd need a not-yet-existing quantum computer for that.
let hmac = hmac_sha384_2(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
memory::as_byte_array::<Pseudoheader, 12>(&reply_pseudoheader),
reply_canonical_header.as_bytes(),
&reply_buf[HEADER_SIZE..reply_len],
);
reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac);
@ -1131,7 +1155,7 @@ impl<H: Host> ReceiveContext<H> {
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(),
false,
);
c.init(pseudoheader);
c.init(canonical_header_bytes);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
if !c.finish_decrypt(&incoming_packet[payload_end..aes_gcm_tag_end]) {
return Err(Error::FailedAuthentication);
@ -1167,7 +1191,7 @@ impl<H: Host> ReceiveContext<H> {
if !hmac_sha384_2(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
pseudoheader,
canonical_header_bytes,
&original_ciphertext[HEADER_SIZE..aes_gcm_tag_end],
)
.eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()])
@ -1191,15 +1215,11 @@ impl<H: Host> ReceiveContext<H> {
)?;
let mut c = key.get_send_cipher(counter)?;
c.init(memory::as_byte_array::<Pseudoheader, 12>(&Pseudoheader::make(
bob_session_id.into(),
PACKET_TYPE_NOP,
counter.to_u32(),
)));
c.init(CanonicalHeader::make(bob_session_id.into(), PACKET_TYPE_NOP, counter.to_u32()).as_bytes());
reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt());
key.return_send_cipher(c);
set_header_mac(&mut reply_buf, &session.header_check_cipher);
set_header_check_code(&mut reply_buf, &session.header_check_cipher);
send(&mut reply_buf);
drop(state);
@ -1224,19 +1244,24 @@ impl<H: Host> ReceiveContext<H> {
}
}
/// Outgoing packet counter with strictly ordered atomic semantics.
struct Counter(AtomicU64);
impl Counter {
#[inline(always)]
fn new() -> Self {
// Using a random value has no security implication. Zero would be fine. This just
// helps randomize packet contents a bit.
Self(AtomicU64::new(random::next_u32_secure() as u64))
}
/// Get the value most recently used to send a packet.
#[inline(always)]
fn current(&self) -> CounterValue {
fn previous(&self) -> CounterValue {
CounterValue(self.0.load(Ordering::SeqCst))
}
/// Get a counter value for the next packet being sent.
#[inline(always)]
fn next(&self) -> CounterValue {
CounterValue(self.0.fetch_add(1, Ordering::SeqCst))
@ -1245,9 +1270,11 @@ impl Counter {
/// A value of the outgoing packet counter.
///
/// The counter is internally 64-bit so we can more easily track usage limits without
/// confusing logic to handle 32-bit wrapping. The least significant 32 bits are the
/// actual counter put in the packet.
/// The used portion of the packet counter is the least significant 32 bits, but the internal
/// counter state is kept as a 64-bit integer. This makes it easier to correctly handle
/// key expiration after usage limits are reached without complicated logic to handle 32-bit
/// wrapping. Usage limits are below 2^32 so the actual 32-bit counter will not wrap for a
/// given shared secret key.
#[repr(transparent)]
#[derive(Copy, Clone)]
struct CounterValue(u64);
@ -1259,30 +1286,44 @@ impl CounterValue {
}
}
/// Temporary object to construct a "pseudo-header" for AES-GCM nonce and HMAC calculation.
/// "Canonical header" for generating 96-bit AES-GCM nonce and for inclusion in HMACs.
///
/// This is basically the actual header but with fragment count and fragment total set to zero.
/// Fragmentation is not considered when authenticating the entire packet. A separate header
/// check code is used to make fragmentation itself more robust, but that's outside the scope
/// of AEAD authentication.
#[derive(Clone, Copy)]
#[repr(C, packed)]
struct Pseudoheader(u64, u32);
struct CanonicalHeader(u64, u32);
impl Pseudoheader {
impl CanonicalHeader {
#[inline(always)]
pub fn make(session_id: u64, packet_type: u8, counter: u32) -> Self {
Pseudoheader((session_id | (packet_type as u64)).to_le(), counter.to_le())
pub fn make(session_id: SessionId, packet_type: u8, counter: u32) -> Self {
CanonicalHeader(
(u64::from(session_id) | (packet_type as u64).wrapping_shl(48)).to_le(),
counter.to_le(),
)
}
#[inline(always)]
pub fn as_bytes(&self) -> &[u8; 12] {
memory::as_byte_array(self)
}
}
/// Ephemeral offer sent with KEY_OFFER and rememebered so state can be reconstructed on COUNTER_OFFER.
/// Alice's KEY_OFFER, remembered so Noise agreement process can resume on KEY_COUNTER_OFFER.
struct EphemeralOffer {
id: [u8; 16],
creation_time: i64,
ratchet_count: u64,
ratchet_key: Option<Secret<64>>,
key: Secret<64>,
alice_e0_keypair: P384KeyPair,
alice_e1_keypair: Option<pqc_kyber::Keypair>,
id: [u8; 16], // Arbitrary random offer ID
creation_time: i64, // Local time when offer was created
ratchet_count: u64, // Ratchet count starting at zero for initial offer
ratchet_key: Option<Secret<64>>, // Ratchet key from previous offer
key: Secret<64>, // Shared secret in-progress, at state after offer sent
alice_e0_keypair: P384KeyPair, // NIST P-384 key pair (Noise ephemeral key for Alice)
alice_e1_keypair: Option<pqc_kyber::Keypair>, // Kyber1024 key pair (agreement result mixed post-Noise)
}
fn create_initial_offer<SendFunction: FnMut(&mut [u8])>(
/// Create and send an ephemeral offer, returning the EphemeralOffer part that must be saved.
fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
send: &mut SendFunction,
counter: CounterValue,
alice_session_id: SessionId,
@ -1297,26 +1338,30 @@ fn create_initial_offer<SendFunction: FnMut(&mut [u8])>(
mtu: usize,
current_time: i64,
) -> Result<Box<EphemeralOffer>, Error> {
// Generate a NIST P-384 pair.
let alice_e0_keypair = P384KeyPair::generate();
let e0s = alice_e0_keypair.agree(bob_s_public_p384);
if e0s.is_none() {
return Err(Error::InvalidPacket);
}
// Perform key agreement with the other side's static P-384 public key.
let e0s = alice_e0_keypair.agree(bob_s_public_p384).ok_or(Error::InvalidPacket)?;
// Generate a Kyber1024 pair if enabled.
let alice_e1_keypair = if JEDI {
Some(pqc_kyber::keypair(&mut random::SecureRandom::get()))
} else {
None
};
// Get ratchet key for current key if one exists.
let (ratchet_key, ratchet_count) = if let Some(current_key) = current_key {
(Some(current_key.ratchet_key.clone()), current_key.ratchet_count)
} else {
(None, 0)
};
// Random ephemeral offer ID
let id: [u8; 16] = random::get_bytes_secure();
// Create ephemeral offer packet (not fragmented yet).
const PACKET_BUF_SIZE: usize = MIN_TRANSPORT_MTU * KEY_EXCHANGE_MAX_FRAGMENTS;
let mut packet_buf = [0_u8; PACKET_BUF_SIZE];
let mut packet_len = {
@ -1339,7 +1384,7 @@ fn create_initial_offer<SendFunction: FnMut(&mut [u8])>(
}
if let Some(ratchet_key) = ratchet_key.as_ref() {
p.write_all(&[0x01])?;
p.write_all(&key_fingerprint(ratchet_key.as_bytes())[..16])?;
p.write_all(&secret_fingerprint(ratchet_key.as_bytes())[..16])?;
} else {
p.write_all(&[0x00])?;
}
@ -1347,42 +1392,44 @@ fn create_initial_offer<SendFunction: FnMut(&mut [u8])>(
PACKET_BUF_SIZE - p.len()
};
let bob_session_id: u64 = bob_session_id.map_or(SessionId::NIL.0, |i| i.into());
create_packet_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id, counter)?;
let pseudoheader = Pseudoheader::make(bob_session_id, PACKET_TYPE_KEY_OFFER, counter.to_u32());
// Create ephemeral agreement secret.
let key = Secret(hmac_sha512(
&hmac_sha512(&INITIAL_KEY, alice_e0_keypair.public_key_bytes()),
e0s.unwrap().as_bytes(),
e0s.as_bytes(),
));
let bob_session_id = bob_session_id.unwrap_or(SessionId::NIL);
create_packet_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id, counter)?;
let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_KEY_OFFER, counter.to_u32());
// Encrypt packet and attach AES-GCM tag.
let gcm_tag = {
let mut c = AesGcm::new(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(),
true,
);
c.init(memory::as_byte_array::<Pseudoheader, 12>(&pseudoheader));
c.init(canonical_header.as_bytes());
c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]);
c.finish_encrypt()
};
packet_buf[packet_len..(packet_len + AES_GCM_TAG_SIZE)].copy_from_slice(&gcm_tag);
packet_len += AES_GCM_TAG_SIZE;
// Mix in static secret.
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
// HMAC packet using static + ephemeral key.
let hmac = hmac_sha384_2(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
memory::as_byte_array::<Pseudoheader, 12>(&pseudoheader),
canonical_header.as_bytes(),
&packet_buf[HEADER_SIZE..packet_len],
);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE;
let hmac = hmac_sha384_2(
bob_s_public_hash,
memory::as_byte_array::<Pseudoheader, 12>(&pseudoheader),
&packet_buf[HEADER_SIZE..packet_len],
);
// Add secondary HMAC to verify that the caller knows the recipient's full static public identity.
let hmac = hmac_sha384_2(bob_s_public_hash, canonical_header.as_bytes(), &packet_buf[HEADER_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE;
@ -1399,13 +1446,14 @@ fn create_initial_offer<SendFunction: FnMut(&mut [u8])>(
}))
}
/// Populate all but the header check code in the first 16 bytes of a packet or fragment.
#[inline(always)]
fn create_packet_header(
header: &mut [u8],
packet_len: usize,
mtu: usize,
packet_type: u8,
recipient_session_id: u64,
recipient_session_id: SessionId,
counter: CounterValue,
) -> Result<(), Error> {
let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize;
@ -1414,14 +1462,21 @@ fn create_packet_header(
debug_assert!(mtu >= MIN_TRANSPORT_MTU);
debug_assert!(packet_len >= MIN_PACKET_SIZE);
debug_assert!(fragment_count > 0);
debug_assert!(fragment_count <= MAX_FRAGMENTS);
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
if fragment_count <= MAX_FRAGMENTS {
// CCCC____IIIIIITF
// Header indexed by bit:
// [0-31] counter
// [32-63] header check code (computed later)
// [64-111] recipient's session ID (unique on their side)
// [112-115] packet type (0-15)
// [116-121] number of fragments (0..63 for 1..64 fragments total)
// [122-127] fragment number (0, 1, 2, ...)
memory::store_raw((counter.to_u32() as u64).to_le(), header);
memory::store_raw(
(recipient_session_id | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)).to_le(),
(u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52))
.to_le(),
&mut header[8..],
);
Ok(())
@ -1431,6 +1486,7 @@ fn create_packet_header(
}
}
/// Break a packet into fragments and send them all.
fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(
send: &mut SendFunction,
packet: &mut [u8],
@ -1443,7 +1499,7 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(
let mut header: [u8; 16] = packet[..HEADER_SIZE].try_into().unwrap();
loop {
let fragment = &mut packet[fragment_start..fragment_end];
set_header_mac(fragment, header_check_cipher);
set_header_check_code(fragment, header_check_cipher);
send(fragment);
if fragment_end < packet_len {
debug_assert!(header[15].wrapping_shr(2) < 63);
@ -1458,17 +1514,18 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(
}
}
/// Set 32-bit header MAC.
#[inline(always)]
fn set_header_mac(packet: &mut [u8], header_check_cipher: &Aes) {
/// Set 32-bit header check code, used to make fragmentation mechanism robust.
#[inline]
fn set_header_check_code(packet: &mut [u8], header_check_cipher: &Aes) {
debug_assert!(packet.len() >= MIN_PACKET_SIZE);
let mut header_mac = 0u128.to_ne_bytes();
header_check_cipher.encrypt_block(&packet[8..24], &mut header_mac);
packet[4..8].copy_from_slice(&header_mac[..4]);
let mut check_code = 0u128.to_ne_bytes();
header_check_cipher.encrypt_block(&packet[8..24], &mut check_code);
packet[4..8].copy_from_slice(&check_code[..4]);
}
/// Check 32-bit header MAC on an incoming packet.
fn check_header_mac(packet: &[u8], header_check_cipher: &Aes) -> bool {
/// Verify 32-bit header check code.
#[inline]
fn verify_header_check_code(packet: &[u8], header_check_cipher: &Aes) -> bool {
debug_assert!(packet.len() >= MIN_PACKET_SIZE);
let mut header_mac = 0u128.to_ne_bytes();
header_check_cipher.encrypt_block(&packet[8..24], &mut header_mac);
@ -1544,12 +1601,15 @@ fn parse_key_offer_after_header(
))
}
/// Was this side the one who sent the first offer (Alice) or countered (Bob).
/// Note that role is not fixed. Either side can take either role. It's just who
/// initiated first.
enum Role {
Alice,
Bob,
}
/// Specialized class for the careful management of key lifetimes.
/// Key lifetime manager state and logic (separate to spotlight and keep clean)
struct KeyLifetime {
rekey_at_or_after_counter: u64,
hard_expire_at_counter: u64,
@ -1580,19 +1640,19 @@ impl KeyLifetime {
}
}
#[allow(unused)]
/// A shared symmetric session key.
struct SessionKey {
fingerprint: [u8; 16],
establish_time: i64,
establish_counter: u64,
lifetime: KeyLifetime,
ratchet_key: Secret<64>,
receive_key: Secret<32>,
send_key: Secret<32>,
receive_cipher_pool: Mutex<Vec<Box<AesGcm>>>,
send_cipher_pool: Mutex<Vec<Box<AesGcm>>>,
ratchet_count: u64,
jedi: bool, // true if kyber was enabled on both sides
secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret
establish_time: i64, // Time session key was established
establish_counter: u64, // Counter value at which session was established
lifetime: KeyLifetime, // Key expiration time and counter
ratchet_key: Secret<64>, // Ratchet key for deriving the next session key
receive_key: Secret<32>, // Receive side AES-GCM key
send_key: Secret<32>, // Send side AES-GCM key
receive_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of initialized sending ciphers
send_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of initialized receiving ciphers
ratchet_count: u64, // Number of new keys negotiated in this session
jedi: bool, // True if Kyber1024 was used (both sides enabled)
}
impl SessionKey {
@ -1605,7 +1665,7 @@ impl SessionKey {
Role::Bob => (a2b, b2a),
};
Self {
fingerprint: key_fingerprint(key.as_bytes())[..16].try_into().unwrap(),
secret_fingerprint: secret_fingerprint(key.as_bytes())[..16].try_into().unwrap(),
establish_time: current_time,
establish_counter: current_counter.0,
lifetime: KeyLifetime::new(current_counter, current_time),
@ -1619,7 +1679,7 @@ impl SessionKey {
}
}
#[inline(always)]
#[inline]
fn get_send_cipher(&self, counter: CounterValue) -> Result<Box<AesGcm>, Error> {
if !self.lifetime.expired(counter) {
Ok(self
@ -1638,12 +1698,12 @@ impl SessionKey {
}
}
#[inline(always)]
#[inline]
fn return_send_cipher(&self, c: Box<AesGcm>) {
self.send_cipher_pool.lock().unwrap().push(c);
}
#[inline(always)]
#[inline]
fn get_receive_cipher(&self) -> Box<AesGcm> {
self.receive_cipher_pool
.lock()
@ -1652,7 +1712,7 @@ impl SessionKey {
.unwrap_or_else(|| Box::new(AesGcm::new(self.receive_key.as_bytes(), false)))
}
#[inline(always)]
#[inline]
fn return_receive_cipher(&self, c: Box<AesGcm>) {
self.receive_cipher_pool.lock().unwrap().push(c);
}
@ -1667,14 +1727,13 @@ fn hmac_sha384_2(key: &[u8], a: &[u8], b: &[u8]) -> [u8; 48] {
}
/// HMAC-SHA512 key derivation function modeled on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12)
/// Cryptographically this isn't really different from HMAC(key, [label]) with just one byte.
fn kbkdf512(key: &[u8], label: u8) -> Secret<64> {
Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00]))
}
/// Get a hash of a secret key that can be used as a public fingerprint.
///
/// This just needs to be a hash that will never be the same as any hash or HMAC used for actual key derivation.
fn key_fingerprint(key: &[u8]) -> [u8; 48] {
fn secret_fingerprint(key: &[u8]) -> [u8; 48] {
let mut tmp = SHA384::new();
tmp.update("fp".as_bytes());
tmp.update(key);