ZSSP ratcheting works, key life cycle works.

This commit is contained in:
Adam Ierymenko 2022-09-09 21:22:04 -04:00
parent 98c0575a00
commit 4a1f2db54e
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3

View file

@ -3,7 +3,6 @@
// ZSSP: ZeroTier Secure Session Protocol
// FIPS compliant Noise_IK with Jedi powers and built-in attack-resistant large payload (fragmentation) support.
use std::collections::LinkedList;
use std::io::{Read, Write};
use std::num::NonZeroU64;
use std::ops::Deref;
@ -45,19 +44,24 @@ pub const SERVICE_INTERVAL: u64 = 10000;
const JEDI: bool = true;
/// Start attempting to rekey after a key has been used to send packets this many times.
///
/// This is 1/4 the NIST recommended maximum and 1/8 the absolute limit where u32 wraps.
const REKEY_AFTER_USES: u64 = 536870912;
/// Maximum random jitter to add to rekey-after usage count.
const REKEY_AFTER_USES_MAX_JITTER: u32 = 1048576;
/// Hard expiration after this many uses.
///
/// Use of the key beyond this point is prohibited. This is the point where u32 wraps minus
/// a little bit of margin. We should never get here under ordinary circumstances.
const EXPIRE_AFTER_USES: u64 = (u32::MAX - 1024) as u64;
/// Start attempting to rekey after a key has been in use for this many milliseconds.
const REKEY_AFTER_TIME_MS: i64 = 1000 * 60 * 60; // 1 hour
/// Maximum random jitter to add to rekey-after time.
const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 5;
const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 10;
/// Rate limit for sending new offers to attempt to re-key.
const OFFER_RATE_LIMIT_MS: i64 = 2000;
@ -96,7 +100,7 @@ const HMAC_SIZE: usize = 48;
const SESSION_ID_SIZE: usize = 6;
/// Maximum number of present and future keys to hold at any given time.
const KEY_HISTORY_SIZE_MAX: usize = 3;
const KEY_HISTORY_SIZE: usize = 3;
// Key usage labels for sub-key derivation using kbkdf (HMAC).
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
@ -294,7 +298,8 @@ pub struct Session<H: Host> {
struct SessionMutableState {
remote_session_id: Option<SessionId>,
keys: LinkedList<SessionKey>,
keys: [Option<SessionKey>; KEY_HISTORY_SIZE],
key_ptr: usize,
offer: Option<Box<EphemeralOffer>>,
}
@ -358,7 +363,8 @@ impl<H: Host> Session<H> {
header_check_cipher,
state: RwLock::new(SessionMutableState {
remote_session_id: None,
keys: LinkedList::new(),
keys: [None, None, None],
key_ptr: 0,
offer: Some(offer),
}),
remote_s_public_hash,
@ -379,7 +385,7 @@ impl<H: Host> Session<H> {
debug_assert!(mtu_buffer.len() >= MIN_MTU);
let state = self.state.read();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys.front() {
if let Some(key) = state.keys[state.key_ptr].as_ref() {
let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
let counter = self.send_counter.next();
@ -433,17 +439,17 @@ impl<H: Host> Session<H> {
/// Check whether this session is established.
pub fn established(&self) -> bool {
let state = self.state.read();
state.remote_session_id.is_some() && !state.keys.is_empty()
state.remote_session_id.is_some() && state.keys[state.key_ptr].is_some()
}
/// Get information about this session's security state.
///
/// This returns a tuple of: the time at which the current key was established, the length of its ratchet chain,
/// This returns a tuple of: the key fingerprint, the time it was established, the length of its ratchet chain,
/// and whether Kyber1024 was used. None is returned if the session isn't established.
pub fn security_info(&self) -> Option<(i64, u64, bool)> {
pub fn security_info(&self) -> Option<([u8; 16], i64, u64, bool)> {
let state = self.state.read();
if let Some(key) = state.keys.front() {
Some((key.establish_time, key.ratchet_count, key.jedi))
if let Some(key) = state.keys[state.key_ptr].as_ref() {
Some((key.fingerprint, key.establish_time, key.ratchet_count, key.jedi))
} else {
None
}
@ -454,9 +460,13 @@ impl<H: Host> Session<H> {
/// * `offer_metadata' - Any meta-data to include with initial key offers sent.
/// * `mtu` - Physical MTU for sent packets
/// * `current_time` - Current monotonic time in milliseconds
pub fn service<SendFunction: FnMut(&mut [u8])>(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64) {
/// * `force_rekey` - Re-key the session now regardless of key aging (still subject to rate limiting)
pub fn service<SendFunction: FnMut(&mut [u8])>(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force_rekey: bool) {
let state = self.state.upgradable_read();
if state.keys.front().map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time))
if (force_rekey
|| state.keys[state.key_ptr]
.as_ref()
.map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time)))
&& state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)
{
if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) {
@ -471,7 +481,7 @@ impl<H: Host> Session<H> {
&remote_s_public_p384,
&self.remote_s_public_hash,
&self.ss,
state.keys.front(),
state.keys[state.key_ptr].as_ref(),
if state.remote_session_id.is_some() {
&self.header_check_cipher
} else {
@ -612,61 +622,56 @@ impl<H: Host> ReceiveContext<H> {
if packet_type <= PACKET_TYPE_NOP {
if let Some(session) = session {
let state = session.state.read();
let key_count = state.keys.len();
for (key_index, key) in state.keys.iter().enumerate() {
let tail = fragments.last().unwrap().as_ref();
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
for p in 0..KEY_HISTORY_SIZE {
let key_ptr = (state.key_ptr + p) % KEY_HISTORY_SIZE;
if let Some(key) = state.keys[key_ptr].as_ref() {
let tail = fragments.last().unwrap().as_ref();
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
let mut c = key.get_receive_cipher();
c.init(pseudoheader);
let mut c = key.get_receive_cipher();
c.init(pseudoheader);
let mut data_len = 0;
let mut data_len = 0;
for f in fragments[..(fragments.len() - 1)].iter() {
let f = f.as_ref();
debug_assert!(f.len() >= HEADER_SIZE);
let current_frag_data_start = data_len;
data_len += f.len() - HEADER_SIZE;
if data_len > data_buf.len() {
unlikely_branch();
key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
}
for f in fragments[..(fragments.len() - 1)].iter() {
let f = f.as_ref();
debug_assert!(f.len() >= HEADER_SIZE);
let current_frag_data_start = data_len;
data_len += f.len() - HEADER_SIZE;
data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
if data_len > data_buf.len() {
unlikely_branch();
key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
}
c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
let current_frag_data_start = data_len;
data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
if data_len > data_buf.len() {
unlikely_branch();
let ok = c.finish_decrypt(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]);
key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
let ok = c.finish_decrypt(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]);
key.return_receive_cipher(c);
if ok {
// Drop obsolete keys if we had to iterate past the first key to get here.
if key_index > 0 {
unlikely_branch();
drop(state);
let mut state = session.state.write();
if state.keys.len() == key_count {
for _ in 0..key_index {
let _ = state.keys.pop_front();
}
if ok {
// Select this key as the new default if it's newer than the current key.
if p > 0 && state.keys[state.key_ptr].as_ref().map_or(true, |old| old.establish_counter < key.establish_counter) {
drop(state);
session.state.write().key_ptr = key_ptr;
}
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else {
unlikely_branch();
return Ok(ReceiveResult::Ok);
}
}
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else {
unlikely_branch();
return Ok(ReceiveResult::Ok);
}
}
}
@ -756,9 +761,11 @@ impl<H: Host> ReceiveContext<H> {
let mut ratchet_count = 0;
let state = session.state.read();
for k in state.keys.iter() {
if SHA384::hash(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_id) {
ratchet_key = Some(k.ratchet_key.clone());
ratchet_count = k.ratchet_count;
if let Some(k) = k.as_ref() {
if SHA384::hash(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_id) {
ratchet_key = Some(k.ratchet_key.clone());
ratchet_count = k.ratchet_count;
}
}
}
(ratchet_key, ratchet_count)
@ -815,7 +822,8 @@ impl<H: Host> ReceiveContext<H> {
header_check_cipher,
state: RwLock::new(SessionMutableState {
remote_session_id: Some(alice_session_id),
keys: LinkedList::new(),
keys: [None, None, None],
key_ptr: 0,
offer: None,
}),
remote_s_public_hash: SHA384::hash(&alice_s_public),
@ -914,9 +922,12 @@ impl<H: Host> ReceiveContext<H> {
reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac);
reply_len += HMAC_SIZE;
let key = SessionKey::new(key, Role::Bob, current_time, reply_counter, ratchet_count + 1, e1e1.is_some());
let mut state = session.state.write();
let _ = state.remote_session_id.replace(alice_session_id);
add_session_key(&mut state.keys, SessionKey::new(key, Role::Bob, current_time, reply_counter, ratchet_count + 1, e1e1.is_some()));
let next_key_ptr = (state.key_ptr + 1) % KEY_HISTORY_SIZE;
let _ = state.keys[next_key_ptr].replace(key);
drop(state);
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
@ -1015,8 +1026,9 @@ impl<H: Host> ReceiveContext<H> {
let mut state = RwLockUpgradableReadGuard::upgrade(state);
let _ = state.remote_session_id.replace(bob_session_id);
let next_key_ptr = (state.key_ptr + 1) % KEY_HISTORY_SIZE;
let _ = state.keys[next_key_ptr].replace(key);
let _ = state.offer.take();
add_session_key(&mut state.keys, key);
return Ok(ReceiveResult::Ok);
}
@ -1281,23 +1293,6 @@ fn dearmor_header(packet: &[u8], header_check_cipher: &Aes) -> Option<(u8, u8, u
}
}
fn add_session_key(keys: &mut LinkedList<SessionKey>, key: SessionKey) {
// Sanity check to make sure duplicates can't get in here. Should be impossible.
for k in keys.iter() {
if k.receive_key.eq(&key.receive_key) {
return;
}
}
debug_assert!(KEY_HISTORY_SIZE_MAX >= 2);
while keys.len() >= KEY_HISTORY_SIZE_MAX {
let current = keys.pop_front().unwrap();
let _ = keys.pop_front();
keys.push_front(current);
}
keys.push_back(key);
}
fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<([u8; 16], SessionId, &[u8], &[u8], &[u8], Option<[u8; 16]>), Error> {
let mut p = &incoming_packet[..];
let mut offer_id = [0_u8; 16];
@ -1328,12 +1323,16 @@ fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Resu
if p.len() < (pqc_kyber::KYBER_PUBLICKEYBYTES + 1) {
return Err(Error::InvalidPacket);
}
&p[1..(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)]
let e1p = &p[1..(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)];
p = &p[(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)..];
e1p
} else {
if p.len() < (pqc_kyber::KYBER_CIPHERTEXTBYTES + 1) {
return Err(Error::InvalidPacket);
}
&p[1..(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)]
let e1p = &p[1..(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)];
p = &p[(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)..];
e1p
}
}
_ => &[],
@ -1387,7 +1386,9 @@ impl KeyLifetime {
#[allow(unused)]
struct SessionKey {
fingerprint: [u8; 16],
establish_time: i64,
establish_counter: u64,
lifetime: KeyLifetime,
ratchet_key: Secret<64>,
receive_key: Secret<32>,
@ -1409,7 +1410,9 @@ impl SessionKey {
Role::Bob => (a2b, b2a),
};
Self {
fingerprint: SHA384::hash(key.as_bytes())[..16].try_into().unwrap(),
establish_time: current_time,
establish_counter: current_counter.0,
lifetime: KeyLifetime::new(current_counter, current_time),
ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING),
receive_key,
@ -1465,6 +1468,7 @@ mod tests {
use parking_lot::Mutex;
use std::collections::LinkedList;
use std::sync::Arc;
use zerotier_utils::hex;
#[allow(unused_imports)]
use super::*;
@ -1475,9 +1479,10 @@ mod tests {
psk: Secret<64>,
session: Mutex<Option<Arc<Session<Box<TestHost>>>>>,
session_id_counter: Mutex<u64>,
pub queue: Mutex<LinkedList<Vec<u8>>>,
pub this_name: &'static str,
pub other_name: &'static str,
queue: Mutex<LinkedList<Vec<u8>>>,
key_id: Mutex<[u8; 16]>,
this_name: &'static str,
other_name: &'static str,
}
impl TestHost {
@ -1491,6 +1496,7 @@ mod tests {
session: Mutex::new(None),
session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1),
queue: Mutex::new(LinkedList::new()),
key_id: Mutex::new([0; 16]),
this_name,
other_name,
}
@ -1568,7 +1574,7 @@ mod tests {
));
let mut ts = 0;
for _ in 0..3 {
for test_loop in 0..128 {
for host in [&alice_host, &bob_host] {
let send_to_other = |data: &mut [u8]| {
if std::ptr::eq(host, &alice_host) {
@ -1621,11 +1627,28 @@ mod tests {
data_buf.fill(0x12);
if let Some(session) = host.session.lock().as_ref().cloned() {
if session.established() {
for _ in 0..16 {
{
let mut key_id = host.key_id.lock();
let security_info = session.security_info().unwrap();
if !security_info.0.eq(key_id.as_ref()) {
*key_id = security_info.0;
println!(
"zssp: new key at {}: fingerprint {} ratchet {} kyber {}",
host.this_name,
hex::to_string(key_id.as_ref()),
security_info.2,
security_info.3
);
}
}
for _ in 0..32 {
assert!(session
.send(send_to_other, &mut mtu_buffer, &data_buf[..((random::xorshift64_random() as usize) % data_buf.len())])
.is_ok());
}
if (test_loop % 8) == 0 && test_loop >= 8 && host.this_name.eq("alice") {
session.service(host, send_to_other, &[], mtu_buffer.len(), test_loop as i64, true);
}
}
}
}