More ZSSP cleanup, docs, renaming.

This commit is contained in:
Adam Ierymenko 2022-12-16 10:41:38 -05:00
parent e8bffbd44d
commit a22bf51b7c
3 changed files with 93 additions and 106 deletions

View file

@ -16,7 +16,7 @@ use crate::{
/// and use case independent.
pub trait ApplicationLayer: Sized {
/// Arbitrary opaque object associated with a session, such as a connection state object.
type SessionUserData;
type Data;
/// Arbitrary object that dereferences to the session, such as Arc<Session<Self>>.
type SessionRef: Deref<Target = Session<Self>>;
@ -73,5 +73,5 @@ pub trait ApplicationLayer: Sized {
remote_address: &Self::RemoteAddress,
remote_static_public: &[u8],
remote_metadata: &[u8],
) -> Option<(SessionId, Secret<64>, Self::SessionUserData)>;
) -> Option<(SessionId, Secret<64>, Self::Data)>;
}

View file

@ -43,7 +43,7 @@ mod tests {
}
impl ApplicationLayer for Box<TestHost> {
type SessionUserData = u32;
type Data = u32;
type SessionRef = Arc<Session<Box<TestHost>>>;
type IncomingPacketBuffer = Vec<u8>;
type RemoteAddress = u32;
@ -80,13 +80,7 @@ mod tests {
true
}
fn accept_new_session(
&self,
_: &ReceiveContext<Self>,
_: &u32,
_: &[u8],
_: &[u8],
) -> Option<(SessionId, Secret<64>, Self::SessionUserData)> {
fn accept_new_session(&self, _: &ReceiveContext<Self>, _: &u32, _: &[u8], _: &[u8]) -> Option<(SessionId, Secret<64>, Self::Data)> {
loop {
let mut new_id = self.session_id_counter.lock().unwrap();
*new_id += 1;

View file

@ -103,19 +103,19 @@ pub struct ReceiveContext<H: ApplicationLayer> {
/// A FIPS compliant variant of Noise_IK with hybrid Kyber1024 PQ data forward secrecy.
pub struct Session<Application: ApplicationLayer> {
/// This side's session ID (unique on this side)
/// This side's locally unique session ID
pub id: SessionId,
/// An arbitrary application defined object associated with each session
pub application_data: Application::SessionUserData,
pub application_data: Application::Data,
send_counter: Counter, // Outgoing packet counter and nonce state
psk: Secret<64>, // Arbitrary PSK provided by external code
noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key
header_check_cipher: Aes, // Cipher used for header MAC (fragmentation)
header_check_cipher: Aes, // Cipher used for header check codes (not Noise related)
state: RwLock<SessionMutableState>, // Mutable parts of state (other than defrag buffers)
remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_raw: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
remote_s_public_p384_bytes: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
defrag: Mutex<RingBufferMap<u32, GatherArray<Application::IncomingPacketBuffer, MAX_FRAGMENTS>, 8, 8>>,
}
@ -131,34 +131,34 @@ struct SessionMutableState {
/// A shared symmetric session key.
struct SessionKey {
secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret
establish_time: i64, // Time session key was established
establish_counter: CounterValue, // Counter value at which session was established
creation_time: i64, // Time session key was established
creation_counter: CounterValue, // Counter value at which session was established
lifetime: KeyLifetime, // Key expiration time and counter
ratchet_key: Secret<64>, // Ratchet key for deriving the next session key
receive_key: Secret<AES_KEY_SIZE>, // Receive side AES-GCM key
send_key: Secret<AES_KEY_SIZE>, // Send side AES-GCM key
receive_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of initialized sending ciphers
send_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of initialized receiving ciphers
ratchet_count: u64, // Number of new keys negotiated in this session
receive_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of reusable sending ciphers
send_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of reusable receiving ciphers
ratchet_count: u64, // Number of preceding session keys in ratchet
jedi: bool, // True if Kyber1024 was used (both sides enabled)
}
/// Key lifetime state
struct KeyLifetime {
rekey_at_or_after_counter: CounterValue,
hard_expire_at_counter: CounterValue,
rekey_at_or_after_timestamp: i64,
}
/// Alice's KEY_OFFER, remembered so Noise agreement process can resume on KEY_COUNTER_OFFER.
struct EphemeralOffer {
id: [u8; 16], // Arbitrary random offer ID
creation_time: i64, // Local time when offer was created
ratchet_count: u64, // Ratchet count starting at zero for initial offer
ratchet_key: Option<Secret<64>>, // Ratchet key from previous offer
ss_key: Secret<64>, // Shared secret in-progress, at state after offer sent
ratchet_count: u64, // Ratchet count (starting at zero) for initial offer
ratchet_key: Option<Secret<64>>, // Ratchet key from previous offer or None if first offer
ss_key: Secret<64>, // Noise session key "under construction" at state after offer sent
alice_e_keypair: P384KeyPair, // NIST P-384 key pair (Noise ephemeral key for Alice)
alice_hk_keypair: Option<pqc_kyber::Keypair>, // Kyber1024 key pair (agreement result mixed post-Noise)
}
/// Key lifetime manager state and logic (separate to spotlight and keep clean)
struct KeyLifetime {
rekey_at_or_after_counter: CounterValue,
hard_expire_at_counter: CounterValue,
rekey_at_or_after_timestamp: i64,
alice_hk_keypair: Option<pqc_kyber::Keypair>, // Kyber1024 key pair (PQ hybrid ephemeral key for Alice)
}
/// "Canonical header" for generating 96-bit AES-GCM nonce and for inclusion in key exchange HMACs.
@ -198,55 +198,17 @@ impl std::fmt::Debug for Error {
}
}
/// Write src into buffer starting at the index idx. If buffer cannot fit src at that location, nothing at all is written and Error::UnexpectedBufferOverrun is returned. No other errors can be returned by this function. An idx incremented by the amount written is returned.
fn safe_write_all(buffer: &mut [u8], idx: usize, src: &[u8]) -> Result<usize, Error> {
let dest = &mut buffer[idx..];
let amt = src.len();
if dest.len() >= amt {
dest[..amt].copy_from_slice(src);
Ok(idx + amt)
} else {
unlikely_branch();
Err(Error::UnexpectedBufferOverrun)
}
}
/// Write a variable length integer, which can consume up to 10 bytes. Uses safe_write_all to do so.
fn varint_safe_write(buffer: &mut [u8], idx: usize, v: u64) -> Result<usize, Error> {
let mut b = [0_u8; varint::VARINT_MAX_SIZE_BYTES];
let i = varint::encode(&mut b, v);
safe_write_all(buffer, idx, &b[0..i])
}
/// Read exactly amt bytes from src and return the slice those bytes reside in. If src is smaller than amt, Error::InvalidPacket is returned. if the read was successful src is incremented to start at the first unread byte.
fn safe_read_exact<'a>(src: &mut &'a [u8], amt: usize) -> Result<&'a [u8], Error> {
if src.len() >= amt {
let (a, b) = src.split_at(amt);
*src = b;
Ok(a)
} else {
unlikely_branch();
Err(Error::InvalidPacket)
}
}
/// Read a variable length integer, which can consume up to 10 bytes. Uses varint_safe_read to do so.
fn varint_safe_read(src: &mut &[u8]) -> Result<u64, Error> {
let (v, amt) = varint::decode(*src).ok_or(Error::InvalidPacket)?;
let (_, b) = src.split_at(amt);
*src = b;
Ok(v)
}
impl<Application: ApplicationLayer> Session<Application> {
/// Create a new session and send an initial key offer message to the other end.
///
/// * `host` - Interface to application using ZSSP
/// * `app` - Interface to application using ZSSP
/// * `local_session_id` - ID for this side (Alice) of the session, must be locally unique
/// * `remote_s_public_raw` - Remote side's (Bob's) public key/identity
/// * `remote_s_public_blob` - Remote side's (Bob's) public key/identity
/// * `offer_metadata` - Arbitrary meta-data to send with key offer (empty if none)
/// * `psk` - Arbitrary pre-shared key to include as initial key material (use all zero secret if none)
/// * `user_data` - Arbitrary object to put into session
/// * `mtu` - Physical wire maximum transmition unit
/// * `current_time` - Current monotonic time in milliseconds
/// * `psk` - Arbitrary pre-shared key to include as initial key material (use all zeroes if none)
/// * `application_data` - Arbitrary object to put into session
/// * `mtu` - Physical wire maximum transmission unit (current value, can change through the course of a session)
/// * `current_time` - Current monotonic time in milliseconds since an arbitrary time in the past
pub fn start_new<SendFunction: FnMut(&mut [u8])>(
app: &Application,
mut send: SendFunction,
@ -254,7 +216,7 @@ impl<Application: ApplicationLayer> Session<Application> {
remote_s_public_blob: &[u8],
offer_metadata: &[u8],
psk: &Secret<64>,
application_data: Application::SessionUserData,
application_data: Application::Data,
mtu: usize,
current_time: i64,
) -> Result<Self, Error> {
@ -299,7 +261,7 @@ impl<Application: ApplicationLayer> Session<Application> {
last_remote_offer: i64::MIN,
}),
remote_s_public_blob_hash: bob_s_public_blob_hash,
remote_s_public_raw: bob_s_public.as_bytes().clone(),
remote_s_public_p384_bytes: bob_s_public.as_bytes().clone(),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
});
}
@ -410,7 +372,7 @@ impl<Application: ApplicationLayer> Session<Application> {
pub fn status(&self) -> Option<([u8; 16], i64, u64, bool)> {
let state = self.state.read().unwrap();
if let Some(key) = state.session_keys[state.cur_session_key_idx].as_ref() {
Some((key.secret_fingerprint, key.establish_time, key.ratchet_count, key.jedi))
Some((key.secret_fingerprint, key.creation_time, key.ratchet_count, key.jedi))
} else {
None
}
@ -418,15 +380,15 @@ impl<Application: ApplicationLayer> Session<Application> {
/// This function needs to be called on each session at least every SERVICE_INTERVAL milliseconds.
///
/// * `host` - Interface to application using ZSSP
/// * `app` - Interface to application using ZSSP
/// * `send` - Function to call to send physical packet(s)
/// * `offer_metadata' - Any meta-data to include with initial key offers sent.
/// * `mtu` - Physical MTU for sent packets
/// * `mtu` - Current physical transport MTU
/// * `current_time` - Current monotonic time in milliseconds
/// * `force_rekey` - Re-key the session now regardless of key aging (still subject to rate limiting)
pub fn service<SendFunction: FnMut(&mut [u8])>(
&self,
host: &Application,
app: &Application,
mut send: SendFunction,
offer_metadata: &[u8],
mtu: usize,
@ -443,14 +405,14 @@ impl<Application: ApplicationLayer> Session<Application> {
.as_ref()
.map_or(true, |o| (current_time - o.creation_time) > Application::REKEY_RATE_LIMIT_MS)
{
if let Some(remote_s_public) = P384PublicKey::from_bytes(&self.remote_s_public_raw) {
if let Some(remote_s_public) = P384PublicKey::from_bytes(&self.remote_s_public_p384_bytes) {
let mut offer = None;
if send_ephemeral_offer(
&mut send,
self.send_counter.next(),
self.id,
state.remote_session_id,
host.get_local_s_public_blob(),
app.get_local_s_public_blob(),
offer_metadata,
&remote_s_public,
&self.remote_s_public_blob_hash,
@ -476,18 +438,18 @@ impl<Application: ApplicationLayer> Session<Application> {
}
impl<Application: ApplicationLayer> ReceiveContext<Application> {
pub fn new(host: &Application) -> Self {
pub fn new(app: &Application) -> Self {
Self {
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::next_u32_secure())),
incoming_init_header_check_cipher: Aes::new(
kbkdf512(host.get_local_s_public_blob_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(),
kbkdf512(app.get_local_s_public_blob_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(),
),
}
}
/// Receive, authenticate, decrypt, and process a physical wire packet.
///
/// * `host` - Interface to application using ZSSP
/// * `app` - Interface to application using ZSSP
/// * `remote_address` - Remote physical address of source endpoint
/// * `data_buf` - Buffer to receive decrypted and authenticated object data (an error is returned if too small)
/// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership)
@ -616,7 +578,6 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
///
/// NOTE: header check codes will already have been validated on receipt of each fragment. AEAD authentication
/// and decryption has NOT yet been performed, and is done here.
#[inline]
fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>(
&self,
app: &Application,
@ -692,7 +653,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
if p > 0
&& state.session_keys[state.cur_session_key_idx]
.as_ref()
.map_or(true, |old| old.establish_counter < session_key.establish_counter)
.map_or(true, |old| old.creation_counter < session_key.creation_counter)
{
drop(state);
let mut state = session.state.write().unwrap();
@ -910,7 +871,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
last_remote_offer: current_time,
}),
remote_s_public_blob_hash: SHA384::hash(&alice_s_public_blob),
remote_s_public_raw: alice_s_public.as_bytes().clone(),
remote_s_public_p384_bytes: alice_s_public.as_bytes().clone(),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
}),
None,
@ -1207,7 +1168,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
}
}
/// Create and send an ephemeral offer, returning the EphemeralOffer part that must be saved.
/// Create an send an ephemeral offer, populating ret_ephemeral_offer on success.
fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
send: &mut SendFunction,
counter: CounterValue,
@ -1361,7 +1322,6 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
}
/// Populate all but the header check code in the first 16 bytes of a packet or fragment.
#[inline(always)]
fn create_packet_header(
header_destination_buffer: &mut [u8],
packet_len: usize,
@ -1429,7 +1389,6 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(
}
/// Set 32-bit header check code, used to make fragmentation mechanism robust.
#[inline]
fn set_header_check_code(packet: &mut [u8], header_check_cipher: &Aes) {
debug_assert!(packet.len() >= MIN_PACKET_SIZE);
let mut check_code = 0u128.to_ne_bytes();
@ -1438,7 +1397,6 @@ fn set_header_check_code(packet: &mut [u8], header_check_cipher: &Aes) {
}
/// Verify 32-bit header check code.
#[inline]
fn verify_header_check_code(packet: &[u8], header_check_cipher: &Aes) -> bool {
debug_assert!(packet.len() >= MIN_PACKET_SIZE);
let mut header_mac = 0u128.to_ne_bytes();
@ -1505,8 +1463,8 @@ impl SessionKey {
};
Self {
secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(),
establish_time: current_time,
establish_counter: current_counter,
creation_time: current_time,
creation_counter: current_counter,
lifetime: KeyLifetime::new(current_counter, current_time),
ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING),
receive_key,
@ -1518,7 +1476,6 @@ impl SessionKey {
}
}
#[inline]
fn get_send_cipher(&self, counter: CounterValue) -> Result<Box<AesGcm>, Error> {
if !self.lifetime.expired(counter) {
Ok(self
@ -1537,12 +1494,10 @@ impl SessionKey {
}
}
#[inline]
fn return_send_cipher(&self, c: Box<AesGcm>) {
self.send_cipher_pool.lock().unwrap().push(c);
}
#[inline]
fn get_receive_cipher(&self) -> Box<AesGcm> {
self.receive_cipher_pool
.lock()
@ -1551,7 +1506,6 @@ impl SessionKey {
.unwrap_or_else(|| Box::new(AesGcm::new(self.receive_key.as_bytes(), false)))
}
#[inline]
fn return_receive_cipher(&self, c: Box<AesGcm>) {
self.receive_cipher_pool.lock().unwrap().push(c);
}
@ -1580,7 +1534,6 @@ impl KeyLifetime {
}
impl CanonicalHeader {
#[inline(always)]
pub fn make(session_id: SessionId, packet_type: u8, counter: u32) -> Self {
CanonicalHeader(
(u64::from(session_id) | (packet_type as u64).wrapping_shl(48)).to_le(),
@ -1594,6 +1547,46 @@ impl CanonicalHeader {
}
}
/// Write src into buffer starting at the index idx. If buffer cannot fit src at that location, nothing at all is written and Error::UnexpectedBufferOverrun is returned. No other errors can be returned by this function. An idx incremented by the amount written is returned.
fn safe_write_all(buffer: &mut [u8], idx: usize, src: &[u8]) -> Result<usize, Error> {
let dest = &mut buffer[idx..];
let amt = src.len();
if dest.len() >= amt {
dest[..amt].copy_from_slice(src);
Ok(idx + amt)
} else {
unlikely_branch();
Err(Error::UnexpectedBufferOverrun)
}
}
/// Write a variable length integer, which can consume up to 10 bytes. Uses safe_write_all to do so.
fn varint_safe_write(buffer: &mut [u8], idx: usize, v: u64) -> Result<usize, Error> {
let mut b = [0_u8; varint::VARINT_MAX_SIZE_BYTES];
let i = varint::encode(&mut b, v);
safe_write_all(buffer, idx, &b[0..i])
}
/// Read exactly amt bytes from src and return the slice those bytes reside in. If src is smaller than amt, Error::InvalidPacket is returned. if the read was successful src is incremented to start at the first unread byte.
fn safe_read_exact<'a>(src: &mut &'a [u8], amt: usize) -> Result<&'a [u8], Error> {
if src.len() >= amt {
let (a, b) = src.split_at(amt);
*src = b;
Ok(a)
} else {
unlikely_branch();
Err(Error::InvalidPacket)
}
}
/// Read a variable length integer, which can consume up to 10 bytes. Uses varint_safe_read to do so.
fn varint_safe_read(src: &mut &[u8]) -> Result<u64, Error> {
let (v, amt) = varint::decode(*src).ok_or(Error::InvalidPacket)?;
let (_, b) = src.split_at(amt);
*src = b;
Ok(v)
}
/// Shortcut to HMAC data split into two slices.
fn hmac_sha384_2(key: &[u8], a: &[u8], b: &[u8]) -> [u8; 48] {
let mut hmac = HMACSHA384::new(key);