mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-06-08 21:43:44 +02:00
1347 lines
59 KiB
Rust
1347 lines
59 KiB
Rust
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
|
|
|
|
use std::collections::LinkedList;
|
|
use std::io::{Read, Write};
|
|
use std::num::NonZeroU64;
|
|
use std::ops::Deref;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
|
|
use crate::aes::{Aes, AesGcm};
|
|
use crate::hash::{hmac_sha384, hmac_sha512, SHA384};
|
|
use crate::p384::{P384KeyPair, P384PublicKey, P384_PUBLIC_KEY_SIZE};
|
|
use crate::random;
|
|
use crate::secret::Secret;
|
|
|
|
use zerotier_utils::gatherarray::GatherArray;
|
|
use zerotier_utils::memory;
|
|
use zerotier_utils::ringbuffermap::RingBufferMap;
|
|
use zerotier_utils::varint;
|
|
|
|
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard};
|
|
|
|
pub const MIN_PACKET_SIZE: usize = HEADER_SIZE;
|
|
pub const MIN_MTU: usize = 1280;
|
|
|
|
/// Start attempting to rekey after a key has been used to send packets this many times.
|
|
const REKEY_AFTER_USES: u64 = 536870912;
|
|
|
|
/// Maximum random jitter to add to rekey-after usage count.
|
|
const REKEY_AFTER_USES_MAX_JITTER: u32 = 1048576;
|
|
|
|
/// Hard expiration after this many uses.
|
|
const EXPIRE_AFTER_USES: u64 = (u32::MAX - 1024) as u64;
|
|
|
|
/// Start attempting to rekey after a key has been in use for this many milliseconds.
|
|
const REKEY_AFTER_TIME_MS: i64 = 1000 * 60 * 60; // 1 hour
|
|
|
|
/// Maximum random jitter to add to rekey-after time.
|
|
const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 5;
|
|
|
|
/// Rate limit for sending new offers to attempt to re-key.
|
|
const OFFER_RATE_LIMIT_MS: i64 = 2000;
|
|
|
|
/// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication)
|
|
const SESSION_PROTOCOL_VERSION: u8 = 0x00;
|
|
|
|
// Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use
|
|
const PACKET_TYPE_DATA: u8 = 0;
|
|
const PACKET_TYPE_NOP: u8 = 1;
|
|
const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice"
|
|
const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 3; // "bob"
|
|
|
|
const E1_TYPE_NONE: u8 = 0;
|
|
const E1_TYPE_KYBER1024: u8 = 1;
|
|
|
|
const MAX_FRAGMENTS: usize = 48; // protocol max: 63
|
|
const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc.
|
|
const HEADER_SIZE: usize = 16;
|
|
const HEADER_CHECK_SIZE: usize = 4;
|
|
const AES_GCM_TAG_SIZE: usize = 16;
|
|
const AES_GCM_NONCE_SIZE: usize = 12;
|
|
const AES_GCM_NONCE_START: usize = 4;
|
|
const AES_GCM_NONCE_END: usize = 16;
|
|
const HMAC_SIZE: usize = 48;
|
|
const SESSION_ID_SIZE: usize = 6;
|
|
const KEY_HISTORY_SIZE_MAX: usize = 3;
|
|
|
|
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
|
|
const KBKDF_KEY_USAGE_LABEL_HEADER_MAC: u8 = b'H';
|
|
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
|
|
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
|
|
|
|
/// Aribitrary starting value for key derivation chain.
|
|
///
|
|
/// It doesn't matter very much what this is, but it's good for it to be unique.
|
|
const KEY_DERIVATION_CHAIN_STARTING_SALT: [u8; 64] = [
|
|
// macOS command line to generate:
|
|
// echo -n 'ZSSP_Noise_IKpsk2_NISTP384_?KYBER1024_AESGCM_SHA512' | shasum -a 512 | cut -d ' ' -f 1 | xxd -r -p | xxd -i
|
|
0x35, 0x6a, 0x75, 0xc0, 0xbf, 0xbe, 0xc3, 0x59, 0x70, 0x94, 0x50, 0x69, 0x4c, 0xa2, 0x08, 0x40, 0xc7, 0xdf, 0x67, 0xa8, 0x68, 0x52, 0x6e, 0xd5, 0xdd, 0x77, 0xec, 0x59, 0x6f, 0x8e, 0xa1, 0x99,
|
|
0xb4, 0x32, 0x85, 0xaf, 0x7f, 0x0d, 0xa9, 0x6c, 0x01, 0xfb, 0x72, 0x46, 0xc0, 0x09, 0x58, 0xb8, 0xe0, 0xa8, 0xcf, 0xb1, 0x58, 0x04, 0x6e, 0x32, 0xba, 0xa8, 0xb8, 0xf9, 0x0a, 0xa4, 0xbf, 0x36,
|
|
];
|
|
|
|
pub enum Error {
|
|
/// The packet was addressed to an unrecognized local session
|
|
UnknownLocalSessionId(SessionId),
|
|
|
|
/// Packet was not well formed
|
|
InvalidPacket,
|
|
|
|
/// An invalid paramter was supplied to the function
|
|
InvalidParameter,
|
|
|
|
/// Packet failed one or more authentication (MAC) checks
|
|
FailedAuthentication,
|
|
|
|
/// New session was rejected by caller's supplied authentication check function
|
|
NewSessionRejected,
|
|
|
|
/// Rekeying failed and session secret has reached its maximum usage count
|
|
MaxKeyLifetimeExceeded,
|
|
|
|
/// Attempt to send using session without established key.
|
|
SessionNotEstablished,
|
|
|
|
/// Packet ignored by rate limiter.
|
|
RateLimited,
|
|
|
|
/// Other end sent a protocol version we don't support.
|
|
UnknownProtocolVersion,
|
|
|
|
/// Supplied data buffer is too small to receive data.
|
|
DataBufferTooSmall,
|
|
|
|
/// An internal error occurred.
|
|
OtherError(Box<dyn std::error::Error>),
|
|
}
|
|
|
|
impl From<std::io::Error> for Error {
|
|
#[cold]
|
|
#[inline(never)]
|
|
fn from(e: std::io::Error) -> Self {
|
|
Self::OtherError(Box::new(e))
|
|
}
|
|
}
|
|
|
|
#[cold]
|
|
#[inline(never)]
|
|
extern "C" fn unlikely_branch() {}
|
|
|
|
impl std::fmt::Display for Error {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Self::UnknownLocalSessionId(id) => f.write_str(format!("UnknownLocalSessionId({})", id.0.get()).as_str()),
|
|
Self::InvalidPacket => f.write_str("InvalidPacket"),
|
|
Self::InvalidParameter => f.write_str("InvalidParameter"),
|
|
Self::FailedAuthentication => f.write_str("FailedAuthentication"),
|
|
Self::NewSessionRejected => f.write_str("NewSessionRejected"),
|
|
Self::MaxKeyLifetimeExceeded => f.write_str("MaxKeyLifetimeExceeded"),
|
|
Self::SessionNotEstablished => f.write_str("SessionNotEstablished"),
|
|
Self::RateLimited => f.write_str("RateLimited"),
|
|
Self::UnknownProtocolVersion => f.write_str("UnknownProtocolVersion"),
|
|
Self::DataBufferTooSmall => f.write_str("DataBufferTooSmall"),
|
|
Self::OtherError(e) => f.write_str(format!("OtherError({})", e.to_string()).as_str()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for Error {}
|
|
|
|
impl std::fmt::Debug for Error {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
std::fmt::Display::fmt(self, f)
|
|
}
|
|
}
|
|
|
|
pub enum ReceiveResult<'a, H: Host> {
|
|
/// Packet is valid, no action needs to be taken.
|
|
Ok,
|
|
|
|
/// Packet is valid and a data payload was decoded and authenticated.
|
|
///
|
|
/// The returned reference is to the filled parts of the data buffer supplied to receive.
|
|
OkData(&'a mut [u8]),
|
|
|
|
/// Packet is valid and a new session was created.
|
|
///
|
|
/// The session will have already been gated by the accept_new_session() method in the Host trait.
|
|
OkNewSession(Session<H>),
|
|
|
|
/// Packet apperas valid but was ignored e.g. as a duplicate.
|
|
Ignored,
|
|
}
|
|
|
|
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
|
#[repr(transparent)]
|
|
pub struct SessionId(NonZeroU64);
|
|
|
|
impl SessionId {
|
|
pub const MAX_BIT_MASK: u64 = 0xffffffffffff;
|
|
|
|
#[inline(always)]
|
|
pub fn new_from_u64(i: u64) -> Option<SessionId> {
|
|
debug_assert!(i <= Self::MAX_BIT_MASK);
|
|
NonZeroU64::new(i).map(|i| Self(i))
|
|
}
|
|
|
|
pub fn new_from_reader<R: Read>(r: &mut R) -> std::io::Result<Option<SessionId>> {
|
|
let mut tmp = [0_u8; 8];
|
|
r.read_exact(&mut tmp[..SESSION_ID_SIZE]).map(|_| NonZeroU64::new(u64::from_le_bytes(tmp)).map(|i| Self(i)))
|
|
}
|
|
|
|
pub fn new_random() -> Self {
|
|
Self(NonZeroU64::new((random::next_u64_secure() & Self::MAX_BIT_MASK).max(1)).unwrap())
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn to_bytes(&self) -> [u8; SESSION_ID_SIZE] {
|
|
self.0.get().to_le_bytes()[0..6].try_into().unwrap()
|
|
}
|
|
}
|
|
|
|
impl From<SessionId> for u64 {
|
|
#[inline(always)]
|
|
fn from(sid: SessionId) -> Self {
|
|
sid.0.get()
|
|
}
|
|
}
|
|
|
|
/// Trait to implement to integrate the session into an application.
|
|
pub trait Host: Sized {
|
|
/// Arbitrary object that can be associated with sessions.
|
|
type AssociatedObject;
|
|
|
|
/// Arbitrary object that dereferences to the session, such as Arc<Session<Self>>.
|
|
type SessionRef: Deref<Target = Session<Self>>;
|
|
|
|
/// A buffer containing data read from the network that can be cached.
|
|
///
|
|
/// This can be e.g. a pooled buffer that automatically returns itself to the pool when dropped.
|
|
type IncomingPacketBuffer: AsRef<[u8]>;
|
|
|
|
/// Get a reference to this host's static public key blob.
|
|
///
|
|
/// This must contain a NIST P-384 public key but can contain other information.
|
|
fn get_local_s_public(&self) -> &[u8];
|
|
|
|
/// Get SHA384(this host's static public key blob), included here so we don't have to calculate it each time.
|
|
fn get_local_s_public_hash(&self) -> &[u8; 48];
|
|
|
|
/// Get a reference to this hosts' static public key's NIST P-384 secret key pair
|
|
fn get_local_s_keypair_p384(&self) -> &P384KeyPair;
|
|
|
|
/// Extract the NIST P-384 ECC public key component from a static public key blob or return None on failure.
|
|
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey>;
|
|
|
|
/// Look up a local session by local ID.
|
|
fn session_lookup(&self, local_session_id: SessionId) -> Option<Self::SessionRef>;
|
|
|
|
/// Check whether a new session should be accepted.
|
|
///
|
|
/// On success a tuple of local session ID, static secret, and associated object is returned. The
|
|
/// static secret is whatever results from agreement between the local and remote static public
|
|
/// keys.
|
|
fn accept_new_session(&self, remote_static_public: &[u8], remote_metadata: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)>;
|
|
}
|
|
|
|
pub struct Session<H: Host> {
|
|
pub id: SessionId,
|
|
pub associated_object: H::AssociatedObject,
|
|
|
|
send_counter: Counter,
|
|
psk: Secret<64>, // Arbitrary PSK provided by external code
|
|
ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer
|
|
header_check_cipher: Aes, // Cipher used for fast 32-bit header MAC
|
|
state: RwLock<MutableState>, // Mutable parts of state (other than defrag buffers)
|
|
remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob)
|
|
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
|
|
defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 16, 4>>,
|
|
}
|
|
|
|
struct MutableState {
|
|
remote_session_id: Option<SessionId>,
|
|
keys: LinkedList<SessionKey>,
|
|
offer: Option<EphemeralOffer>,
|
|
}
|
|
|
|
/// State information to associate with receiving contexts such as sockets or remote paths/endpoints.
|
|
pub struct ReceiveContext<H: Host> {
|
|
initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 128>>,
|
|
incoming_init_header_check_cipher: Aes,
|
|
}
|
|
|
|
impl<H: Host> Session<H> {
|
|
#[inline]
|
|
pub fn new<SendFunction: FnMut(&mut [u8])>(
|
|
host: &H,
|
|
mut send: SendFunction,
|
|
local_session_id: SessionId,
|
|
remote_s_public: &[u8],
|
|
offer_metadata: &[u8],
|
|
psk: &Secret<64>,
|
|
associated_object: H::AssociatedObject,
|
|
mtu: usize,
|
|
current_time: i64,
|
|
jedi: bool,
|
|
) -> Result<Self, Error> {
|
|
if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) {
|
|
if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) {
|
|
let send_counter = Counter::new();
|
|
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
|
|
let remote_s_public_hash = SHA384::hash(remote_s_public);
|
|
let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
|
|
if let Ok(offer) = EphemeralOffer::create_alice_offer(
|
|
&mut send,
|
|
send_counter.next(),
|
|
local_session_id,
|
|
None,
|
|
host.get_local_s_public(),
|
|
offer_metadata,
|
|
&remote_s_public_p384,
|
|
&remote_s_public_hash,
|
|
&ss,
|
|
&outgoing_init_header_check_cipher,
|
|
mtu,
|
|
current_time,
|
|
jedi,
|
|
) {
|
|
return Ok(Self {
|
|
id: local_session_id,
|
|
associated_object,
|
|
send_counter,
|
|
psk: psk.clone(),
|
|
ss,
|
|
header_check_cipher,
|
|
state: RwLock::new(MutableState {
|
|
remote_session_id: None,
|
|
keys: LinkedList::new(),
|
|
offer: Some(offer),
|
|
}),
|
|
remote_s_public_hash,
|
|
remote_s_public_p384: remote_s_public_p384.as_bytes().clone(),
|
|
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
return Err(Error::InvalidParameter);
|
|
}
|
|
|
|
#[inline]
|
|
pub fn send<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction, mtu_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> {
|
|
debug_assert!(mtu_buffer.len() >= MIN_MTU);
|
|
let state = self.state.read();
|
|
if let Some(remote_session_id) = state.remote_session_id {
|
|
if let Some(key) = state.keys.front() {
|
|
let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
|
|
let counter = self.send_counter.next();
|
|
|
|
send_with_fragmentation_init_header(mtu_buffer, packet_len, mtu_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), counter);
|
|
let mut c = key.get_send_cipher(counter)?;
|
|
c.init(mtu_buffer);
|
|
|
|
if packet_len > mtu_buffer.len() {
|
|
let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap();
|
|
let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE;
|
|
let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
|
|
loop {
|
|
debug_assert!(data.len() > last_fragment_data_mtu);
|
|
c.crypt(&data[HEADER_CHECK_SIZE..fragment_data_mtu], &mut mtu_buffer[HEADER_SIZE..]);
|
|
let hc = header_check(mtu_buffer, &self.header_check_cipher);
|
|
mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
|
|
send(mtu_buffer);
|
|
data = &data[fragment_data_mtu..];
|
|
debug_assert!(header[7].wrapping_shr(2) < 63);
|
|
header[7] += 0x04; // increment fragment number
|
|
mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].copy_from_slice(&header);
|
|
if data.len() <= last_fragment_data_mtu {
|
|
break;
|
|
}
|
|
}
|
|
packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
|
|
}
|
|
|
|
let gcm_tag_idx = data.len() + HEADER_SIZE;
|
|
c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]);
|
|
let hc = header_check(mtu_buffer, &self.header_check_cipher);
|
|
mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
|
|
mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish());
|
|
send(&mut mtu_buffer[..packet_len]);
|
|
|
|
key.return_send_cipher(c);
|
|
|
|
return Ok(());
|
|
}
|
|
}
|
|
return Err(Error::SessionNotEstablished);
|
|
}
|
|
|
|
#[inline]
|
|
pub fn rekey_check<SendFunction: FnMut(&mut [u8])>(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) {
|
|
let state = self.state.upgradable_read();
|
|
if let Some(key) = state.keys.front() {
|
|
if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) {
|
|
if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) {
|
|
if let Ok(offer) = EphemeralOffer::create_alice_offer(
|
|
&mut send,
|
|
self.send_counter.next(),
|
|
self.id,
|
|
state.remote_session_id,
|
|
host.get_local_s_public(),
|
|
offer_metadata,
|
|
&remote_s_public_p384,
|
|
&self.remote_s_public_hash,
|
|
&self.ss,
|
|
&self.header_check_cipher,
|
|
mtu,
|
|
current_time,
|
|
jedi,
|
|
) {
|
|
let _ = RwLockUpgradableReadGuard::upgrade(state).offer.replace(offer);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<H: Host> ReceiveContext<H> {
|
|
#[inline]
|
|
pub fn new(host: &H) -> Self {
|
|
Self {
|
|
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
|
|
incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn receive<'a, SendFunction: FnMut(&mut [u8])>(
|
|
&self,
|
|
host: &H,
|
|
mut send: SendFunction,
|
|
data_buf: &'a mut [u8],
|
|
incoming_packet_buf: H::IncomingPacketBuffer,
|
|
mtu: usize,
|
|
jedi: bool,
|
|
current_time: i64,
|
|
) -> Result<ReceiveResult<'a, H>, Error> {
|
|
let incoming_packet = incoming_packet_buf.as_ref();
|
|
if incoming_packet.len() < MIN_PACKET_SIZE {
|
|
unlikely_branch();
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
|
|
let header_0_8 = memory::u64_from_le_bytes(&incoming_packet[HEADER_CHECK_SIZE..12]); // session ID, type, frag info
|
|
let counter = memory::u32_from_le_bytes(&incoming_packet[12..16]);
|
|
let local_session_id = SessionId::new_from_u64(header_0_8 & SessionId::MAX_BIT_MASK);
|
|
let packet_type = (header_0_8.wrapping_shr(48) as u8) & 15;
|
|
let fragment_count = ((header_0_8.wrapping_shr(52) as u8) & 63).wrapping_add(1);
|
|
let fragment_no = (header_0_8.wrapping_shr(58) as u8) & 63;
|
|
|
|
if let Some(local_session_id) = local_session_id {
|
|
if let Some(session) = host.session_lookup(local_session_id) {
|
|
if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &session.header_check_cipher) {
|
|
unlikely_branch();
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
|
|
if fragment_count > 1 {
|
|
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
|
|
let mut defrag = session.defrag.lock();
|
|
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
|
|
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
|
|
drop(defrag); // release lock
|
|
return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time);
|
|
}
|
|
} else {
|
|
unlikely_branch();
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
} else {
|
|
return self.receive_complete(host, &mut send, data_buf, &[incoming_packet_buf], packet_type, Some(session), mtu, jedi, current_time);
|
|
}
|
|
} else {
|
|
unlikely_branch();
|
|
return Err(Error::UnknownLocalSessionId(local_session_id));
|
|
}
|
|
} else {
|
|
unlikely_branch();
|
|
|
|
if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &self.incoming_init_header_check_cipher) {
|
|
unlikely_branch();
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
|
|
let mut defrag = self.initial_offer_defrag.lock();
|
|
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
|
|
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
|
|
drop(defrag); // release lock
|
|
return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
|
|
}
|
|
};
|
|
|
|
return Ok(ReceiveResult::Ok);
|
|
}
|
|
|
|
fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>(
|
|
&self,
|
|
host: &H,
|
|
send: &mut SendFunction,
|
|
data_buf: &'a mut [u8],
|
|
fragments: &[H::IncomingPacketBuffer],
|
|
packet_type: u8,
|
|
session: Option<H::SessionRef>,
|
|
mtu: usize,
|
|
jedi: bool,
|
|
current_time: i64,
|
|
) -> Result<ReceiveResult<'a, H>, Error> {
|
|
debug_assert!(fragments.len() >= 1);
|
|
|
|
debug_assert_eq!(PACKET_TYPE_DATA, 0);
|
|
debug_assert_eq!(PACKET_TYPE_NOP, 1);
|
|
if packet_type <= PACKET_TYPE_NOP {
|
|
if let Some(session) = session {
|
|
let state = session.state.read();
|
|
let key_count = state.keys.len();
|
|
for (key_index, key) in state.keys.iter().enumerate() {
|
|
let tail = fragments.last().unwrap().as_ref();
|
|
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
|
|
unlikely_branch();
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
|
|
let mut c = key.get_receive_cipher();
|
|
c.init(&fragments.first().unwrap().as_ref()[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
|
|
|
|
let mut data_len = 0;
|
|
|
|
for f in fragments[..(fragments.len() - 1)].iter() {
|
|
let f = f.as_ref();
|
|
debug_assert!(f.len() >= HEADER_SIZE);
|
|
let current_frag_data_start = data_len;
|
|
data_len += f.len() - HEADER_SIZE;
|
|
if data_len > data_buf.len() {
|
|
unlikely_branch();
|
|
key.return_receive_cipher(c);
|
|
return Err(Error::DataBufferTooSmall);
|
|
}
|
|
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
|
|
}
|
|
|
|
let current_frag_data_start = data_len;
|
|
data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
|
|
if data_len > data_buf.len() {
|
|
unlikely_branch();
|
|
key.return_receive_cipher(c);
|
|
return Err(Error::DataBufferTooSmall);
|
|
}
|
|
c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
|
|
|
|
let tag = c.finish();
|
|
key.return_receive_cipher(c);
|
|
|
|
if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) {
|
|
// Drop obsolete keys if we had to iterate past the first key to get here.
|
|
if key_index > 0 {
|
|
unlikely_branch();
|
|
drop(state);
|
|
let mut state = session.state.write();
|
|
if state.keys.len() == key_count {
|
|
for _ in 0..key_index {
|
|
let _ = state.keys.pop_front();
|
|
}
|
|
}
|
|
}
|
|
|
|
if packet_type == PACKET_TYPE_DATA {
|
|
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
|
|
} else {
|
|
unlikely_branch();
|
|
return Ok(ReceiveResult::Ok);
|
|
}
|
|
}
|
|
}
|
|
return Err(Error::FailedAuthentication);
|
|
} else {
|
|
unlikely_branch();
|
|
return Err(Error::SessionNotEstablished);
|
|
}
|
|
} else {
|
|
unlikely_branch();
|
|
|
|
let mut incoming_packet_buf = [0_u8; MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS];
|
|
let mut incoming_packet_len = 0;
|
|
for i in 0..fragments.len() {
|
|
let mut ff = fragments[i].as_ref();
|
|
debug_assert!(ff.len() >= MIN_PACKET_SIZE);
|
|
if i > 0 {
|
|
ff = &ff[HEADER_SIZE..];
|
|
}
|
|
let j = incoming_packet_len + ff.len();
|
|
if j > incoming_packet_buf.len() {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
incoming_packet_buf[incoming_packet_len..j].copy_from_slice(ff);
|
|
incoming_packet_len = j;
|
|
}
|
|
let original_ciphertext = incoming_packet_buf.clone();
|
|
let incoming_packet = &mut incoming_packet_buf[..incoming_packet_len];
|
|
|
|
if incoming_packet_len <= HEADER_SIZE {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
if incoming_packet[HEADER_SIZE] != SESSION_PROTOCOL_VERSION {
|
|
return Err(Error::UnknownProtocolVersion);
|
|
}
|
|
|
|
match packet_type {
|
|
PACKET_TYPE_KEY_OFFER => {
|
|
// alice (remote) -> bob (local)
|
|
|
|
if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE) {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE);
|
|
let aes_gcm_tag_end = incoming_packet_len - (HMAC_SIZE + HMAC_SIZE);
|
|
let hmac1_end = incoming_packet_len - HMAC_SIZE;
|
|
|
|
// Check that the sender knows this host's identity before doing anything else.
|
|
if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[HEADER_CHECK_SIZE..hmac1_end]).eq(&incoming_packet[hmac1_end..]) {
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
|
|
// Check rate limit if this session is known.
|
|
if let Some(session) = session.as_ref() {
|
|
if let Some(offer) = session.state.read().offer.as_ref() {
|
|
if (current_time - offer.creation_time) < OFFER_RATE_LIMIT_MS {
|
|
return Err(Error::RateLimited);
|
|
}
|
|
}
|
|
}
|
|
|
|
let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
|
|
.and_then(|pk| host.get_local_s_keypair_p384().agree(&pk).map(move |s| (pk, s)))
|
|
.ok_or(Error::FailedAuthentication)?;
|
|
|
|
let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes()));
|
|
|
|
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false);
|
|
c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
|
|
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
|
|
let c = c.finish();
|
|
if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) {
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
|
|
let (alice_session_id, alice_s_public, alice_metadata, alice_e1_public) = parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?;
|
|
|
|
if let Some(session) = session.as_ref() {
|
|
// Important! If there's already a session, make sure the caller is the same endpoint as that session!
|
|
if !session.remote_s_public_hash.eq(&SHA384::hash(&alice_s_public)) {
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
}
|
|
|
|
let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?;
|
|
let ss = host.get_local_s_keypair_p384().agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?;
|
|
|
|
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
|
|
|
|
if !hmac_sha384(
|
|
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
|
|
&original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end],
|
|
)
|
|
.eq(&incoming_packet[aes_gcm_tag_end..hmac1_end])
|
|
{
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
|
|
// Alice's offer has been verified and her current key state reconstructed.
|
|
|
|
let bob_e0_keypair = P384KeyPair::generate();
|
|
let e0e0 = bob_e0_keypair.agree(&alice_e0_public).ok_or(Error::FailedAuthentication)?;
|
|
let se0 = bob_e0_keypair.agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?;
|
|
|
|
let new_session = if session.is_some() {
|
|
None
|
|
} else {
|
|
if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) {
|
|
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
|
|
Some(Session::<H> {
|
|
id: new_session_id,
|
|
associated_object,
|
|
send_counter: Counter::new(),
|
|
psk,
|
|
ss,
|
|
header_check_cipher,
|
|
state: RwLock::new(MutableState {
|
|
remote_session_id: Some(alice_session_id),
|
|
keys: LinkedList::new(),
|
|
offer: None,
|
|
}),
|
|
remote_s_public_hash: SHA384::hash(&alice_s_public),
|
|
remote_s_public_p384: alice_s_public_p384.as_bytes().clone(),
|
|
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
|
|
})
|
|
} else {
|
|
return Err(Error::NewSessionRejected);
|
|
}
|
|
};
|
|
let session_ref = session;
|
|
let session = session_ref.as_ref().map_or_else(|| new_session.as_ref().unwrap(), |s| &*s);
|
|
|
|
// FIPS note: the order of HMAC parameters are flipped here from the usual Noise HMAC(key, X). That's because
|
|
// NIST/FIPS allows HKDF with HMAC(salt, key) and salt is allowed to be anything. This way if the PSK is not
|
|
// FIPS compliant the compliance of the entire key derivation is not invalidated. Both inputs are secrets of
|
|
// fixed size so this shouldn't matter cryptographically.
|
|
let key = Secret(hmac_sha512(
|
|
session.psk.as_bytes(),
|
|
&hmac_sha512(&hmac_sha512(&hmac_sha512(key.as_bytes(), bob_e0_keypair.public_key_bytes()), e0e0.as_bytes()), se0.as_bytes()),
|
|
));
|
|
|
|
// At this point we've completed Noise_IK key derivation with NIST P-384 ECDH, but see final step below...
|
|
|
|
let (bob_e1_public, e1e1) = if jedi && alice_e1_public.len() > 0 {
|
|
if let Ok((bob_e1_public, e1e1)) = pqc_kyber::encapsulate(alice_e1_public, &mut random::SecureRandom::default()) {
|
|
(Some(bob_e1_public), Secret(e1e1))
|
|
} else {
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
} else {
|
|
(None, Secret::default()) // use all zero Kyber secret if disabled
|
|
};
|
|
|
|
let mut reply_buf = [0_u8; MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS];
|
|
let reply_counter = session.send_counter.next();
|
|
let mut reply_len = {
|
|
let mut rp = &mut reply_buf[HEADER_SIZE..];
|
|
|
|
rp.write_all(&[SESSION_PROTOCOL_VERSION])?;
|
|
rp.write_all(bob_e0_keypair.public_key_bytes())?;
|
|
|
|
rp.write_all(&session.id.to_bytes())?;
|
|
varint::write(&mut rp, 0)?; // they don't need our static public; they have it
|
|
varint::write(&mut rp, 0)?; // no meta-data yet
|
|
if let Some(bob_e1_public) = bob_e1_public.as_ref() {
|
|
rp.write_all(&[E1_TYPE_KYBER1024])?;
|
|
rp.write_all(bob_e1_public)?;
|
|
} else {
|
|
rp.write_all(&[E1_TYPE_NONE])?;
|
|
}
|
|
|
|
(MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len()
|
|
};
|
|
|
|
send_with_fragmentation_init_header(&mut reply_buf, reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter);
|
|
|
|
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true);
|
|
c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
|
|
c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]);
|
|
let c = c.finish();
|
|
reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c);
|
|
reply_len += AES_GCM_TAG_SIZE;
|
|
|
|
// Normal Noise_IK is done, but we have one more step: mix in the Kyber shared secret (or all zeroes if Kyber is
|
|
// disabled). We have to wait until this point because Kyber's keys are encrypted and can't be decrypted until
|
|
// the P-384 exchange is done. We also flip the HMAC parameter order here for the same reason we do in the previous
|
|
// key derivation step.
|
|
let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes()));
|
|
|
|
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[HEADER_CHECK_SIZE..reply_len]);
|
|
reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac);
|
|
reply_len += HMAC_SIZE;
|
|
|
|
let mut state = session.state.write();
|
|
let _ = state.remote_session_id.replace(alice_session_id);
|
|
add_key(&mut state.keys, SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi));
|
|
drop(state);
|
|
|
|
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
|
|
|
|
send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &session.header_check_cipher);
|
|
|
|
if new_session.is_some() {
|
|
return Ok(ReceiveResult::OkNewSession(new_session.unwrap()));
|
|
} else {
|
|
return Ok(ReceiveResult::Ok);
|
|
}
|
|
}
|
|
|
|
PACKET_TYPE_KEY_COUNTER_OFFER => {
|
|
// bob (remote) -> alice (local)
|
|
|
|
if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE);
|
|
let aes_gcm_tag_end = incoming_packet_len - HMAC_SIZE;
|
|
|
|
if let Some(session) = session {
|
|
let state = session.state.upgradable_read();
|
|
if let Some(offer) = state.offer.as_ref() {
|
|
let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
|
|
.and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s)))
|
|
.ok_or(Error::FailedAuthentication)?;
|
|
let se0 = host.get_local_s_keypair_p384().agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?;
|
|
|
|
let key = Secret(hmac_sha512(
|
|
session.psk.as_bytes(),
|
|
&hmac_sha512(&hmac_sha512(&hmac_sha512(offer.key.as_bytes(), bob_e0_public.as_bytes()), e0e0.as_bytes()), se0.as_bytes()),
|
|
));
|
|
|
|
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false);
|
|
c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
|
|
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
|
|
let c = c.finish();
|
|
if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) {
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
|
|
// Alice has now completed Noise_IK with NIST P-384 and verified with GCM auth, but now for hybrid...
|
|
|
|
let (bob_session_id, _, _, bob_e1_public) = parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?;
|
|
|
|
let e1e1 = if jedi && bob_e1_public.len() > 0 && offer.alice_e1_keypair.is_some() {
|
|
if let Ok(e1e1) = pqc_kyber::decapsulate(bob_e1_public, &offer.alice_e1_keypair.as_ref().unwrap().secret) {
|
|
Secret(e1e1)
|
|
} else {
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
} else {
|
|
Secret::default()
|
|
};
|
|
|
|
let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes()));
|
|
|
|
if !hmac_sha384(
|
|
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
|
|
&original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end],
|
|
)
|
|
.eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()])
|
|
{
|
|
return Err(Error::FailedAuthentication);
|
|
}
|
|
|
|
// Alice has now completed and validated the full hybrid exchange.
|
|
|
|
let counter = session.send_counter.next();
|
|
let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi);
|
|
|
|
let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
|
|
send_with_fragmentation_init_header(&mut reply_buf, HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter);
|
|
|
|
let mut c = key.get_send_cipher(counter)?;
|
|
c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
|
|
reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish());
|
|
key.return_send_cipher(c);
|
|
|
|
let hc = header_check(&reply_buf, &session.header_check_cipher);
|
|
reply_buf[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
|
|
|
|
send(&mut reply_buf);
|
|
|
|
let mut state = RwLockUpgradableReadGuard::upgrade(state);
|
|
let _ = state.remote_session_id.replace(bob_session_id);
|
|
let _ = state.offer.take();
|
|
add_key(&mut state.keys, key);
|
|
|
|
return Ok(ReceiveResult::Ok);
|
|
}
|
|
}
|
|
|
|
// Just ignore counter-offers that are out of place. They probably indicate that this side
|
|
// restarted and needs to establish a new session.
|
|
return Ok(ReceiveResult::Ignored);
|
|
}
|
|
|
|
_ => return Err(Error::InvalidPacket),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
struct Counter(AtomicU64);
|
|
|
|
impl Counter {
|
|
#[inline(always)]
|
|
fn new() -> Self {
|
|
Self(AtomicU64::new(random::next_u32_secure() as u64))
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn current(&self) -> CounterValue {
|
|
CounterValue(self.0.load(Ordering::SeqCst))
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn next(&self) -> CounterValue {
|
|
CounterValue(self.0.fetch_add(1, Ordering::SeqCst))
|
|
}
|
|
}
|
|
|
|
/// A value of the outgoing packet counter.
|
|
///
|
|
/// The counter is internally 64-bit so we can more easily track usage limits without
|
|
/// confusing logic to handle 32-bit wrapping. The least significant 32 bits are the
|
|
/// actual counter put in the packet.
|
|
#[repr(transparent)]
|
|
#[derive(Copy, Clone)]
|
|
struct CounterValue(u64);
|
|
|
|
impl CounterValue {
|
|
#[inline(always)]
|
|
pub fn to_u32(&self) -> u32 {
|
|
self.0 as u32
|
|
}
|
|
}
|
|
|
|
struct KeyLifetime {
|
|
rekey_at_or_after_counter: u64,
|
|
hard_expire_at_counter: u64,
|
|
rekey_at_or_after_timestamp: i64,
|
|
}
|
|
|
|
impl KeyLifetime {
|
|
fn new(current_counter: CounterValue, current_time: i64) -> Self {
|
|
Self {
|
|
rekey_at_or_after_counter: current_counter.0 + REKEY_AFTER_USES + (random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER) as u64,
|
|
hard_expire_at_counter: current_counter.0 + EXPIRE_AFTER_USES,
|
|
rekey_at_or_after_timestamp: current_time + REKEY_AFTER_TIME_MS + (random::next_u32_secure() % REKEY_AFTER_TIME_MS_MAX_JITTER) as i64,
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn should_rekey(&self, counter: CounterValue, current_time: i64) -> bool {
|
|
counter.0 >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn expired(&self, counter: CounterValue) -> bool {
|
|
counter.0 >= self.hard_expire_at_counter
|
|
}
|
|
}
|
|
|
|
/// Ephemeral offer sent with KEY_OFFER and rememebered so state can be reconstructed on COUNTER_OFFER.
|
|
struct EphemeralOffer {
|
|
creation_time: i64,
|
|
key: Secret<64>,
|
|
alice_e0_keypair: P384KeyPair,
|
|
alice_e1_keypair: Option<pqc_kyber::Keypair>,
|
|
}
|
|
|
|
impl EphemeralOffer {
|
|
fn create_alice_offer<SendFunction: FnMut(&mut [u8])>(
|
|
send: &mut SendFunction,
|
|
counter: CounterValue,
|
|
alice_session_id: SessionId,
|
|
bob_session_id: Option<SessionId>,
|
|
alice_s_public: &[u8],
|
|
alice_metadata: &[u8],
|
|
bob_s_public_p384: &P384PublicKey,
|
|
bob_s_public_hash: &[u8],
|
|
ss: &Secret<48>,
|
|
header_check_cipher: &Aes,
|
|
mtu: usize,
|
|
current_time: i64,
|
|
jedi: bool,
|
|
) -> Result<EphemeralOffer, Error> {
|
|
let alice_e0_keypair = P384KeyPair::generate();
|
|
let e0s = alice_e0_keypair.agree(bob_s_public_p384);
|
|
if e0s.is_none() {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
|
|
let alice_e1_keypair = if jedi {
|
|
Some(pqc_kyber::keypair(&mut random::SecureRandom::get()))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
const PACKET_BUF_SIZE: usize = MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS;
|
|
let mut packet_buf = [0_u8; PACKET_BUF_SIZE];
|
|
let mut packet_len = {
|
|
let mut p = &mut packet_buf[HEADER_SIZE..];
|
|
|
|
p.write_all(&[SESSION_PROTOCOL_VERSION])?;
|
|
p.write_all(alice_e0_keypair.public_key_bytes())?;
|
|
|
|
p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?;
|
|
varint::write(&mut p, alice_s_public.len() as u64)?;
|
|
p.write_all(alice_s_public)?;
|
|
varint::write(&mut p, alice_metadata.len() as u64)?;
|
|
p.write_all(alice_metadata)?;
|
|
if let Some(e1kp) = alice_e1_keypair {
|
|
p.write_all(&[E1_TYPE_KYBER1024])?;
|
|
p.write_all(&e1kp.public)?;
|
|
} else {
|
|
p.write_all(&[E1_TYPE_NONE])?;
|
|
}
|
|
|
|
PACKET_BUF_SIZE - p.len()
|
|
};
|
|
|
|
send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter);
|
|
|
|
let key = Secret(hmac_sha512(
|
|
&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()),
|
|
e0s.unwrap().as_bytes(),
|
|
));
|
|
|
|
let gcm_tag = {
|
|
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true);
|
|
c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
|
|
c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]);
|
|
c.finish()
|
|
};
|
|
packet_buf[packet_len..(packet_len + AES_GCM_TAG_SIZE)].copy_from_slice(&gcm_tag);
|
|
packet_len += AES_GCM_TAG_SIZE;
|
|
|
|
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
|
|
|
|
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]);
|
|
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
|
|
packet_len += HMAC_SIZE;
|
|
|
|
let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]);
|
|
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
|
|
packet_len += HMAC_SIZE;
|
|
|
|
send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher);
|
|
|
|
Ok(EphemeralOffer {
|
|
creation_time: current_time,
|
|
key,
|
|
alice_e0_keypair,
|
|
alice_e1_keypair,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn send_with_fragmentation_init_header(header: &mut [u8], packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) {
|
|
let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
|
|
debug_assert!(mtu >= MIN_MTU);
|
|
debug_assert!(packet_len >= HEADER_SIZE);
|
|
debug_assert!(fragment_count <= MAX_FRAGMENTS);
|
|
debug_assert!(fragment_count > 0);
|
|
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
|
|
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
|
|
header[HEADER_CHECK_SIZE..12].copy_from_slice(&(recipient_session_id | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)).to_le_bytes());
|
|
header[12..HEADER_SIZE].copy_from_slice(&counter.to_u32().to_le_bytes());
|
|
}
|
|
|
|
/// Send a packet in fragments (used for everything but DATA which has a hand-rolled version for performance).
|
|
fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header_check_cipher: &Aes) {
|
|
let packet_len = packet.len();
|
|
let mut fragment_start = 0;
|
|
let mut fragment_end = packet_len.min(mtu);
|
|
let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = packet[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap();
|
|
loop {
|
|
let hc = header_check(&packet[fragment_start..], header_check_cipher);
|
|
packet[fragment_start..(fragment_start + HEADER_CHECK_SIZE)].copy_from_slice(&hc.to_ne_bytes());
|
|
send(&mut packet[fragment_start..fragment_end]);
|
|
if fragment_end < packet_len {
|
|
debug_assert!(header[7].wrapping_shr(2) < 63);
|
|
header[7] += 0x04; // increment fragment number
|
|
fragment_start = fragment_end - HEADER_SIZE;
|
|
fragment_end = (fragment_start + mtu).min(packet_len);
|
|
packet[(fragment_start + HEADER_CHECK_SIZE)..(fragment_start + HEADER_SIZE)].copy_from_slice(&header);
|
|
} else {
|
|
debug_assert_eq!(fragment_end, packet_len);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Compute the 32-bit header check code for a packet on receipt or right before send.
|
|
fn header_check(packet: &[u8], header_check_cipher: &Aes) -> u32 {
|
|
debug_assert!(packet.len() >= HEADER_SIZE);
|
|
let mut header_check = 0u128.to_ne_bytes();
|
|
if packet.len() >= (16 + HEADER_CHECK_SIZE) {
|
|
header_check_cipher.encrypt_block(&packet[HEADER_CHECK_SIZE..(16 + HEADER_CHECK_SIZE)], &mut header_check);
|
|
} else {
|
|
unlikely_branch();
|
|
header_check[..(packet.len() - HEADER_CHECK_SIZE)].copy_from_slice(&packet[HEADER_CHECK_SIZE..]);
|
|
header_check_cipher.encrypt_block_in_place(&mut header_check);
|
|
}
|
|
memory::u32_from_ne_bytes(&header_check)
|
|
}
|
|
|
|
/// Add a new session key to the key list, retiring older non-active keys if necessary.
|
|
fn add_key(keys: &mut LinkedList<SessionKey>, key: SessionKey) {
|
|
debug_assert!(KEY_HISTORY_SIZE_MAX >= 2);
|
|
while keys.len() >= KEY_HISTORY_SIZE_MAX {
|
|
let current = keys.pop_front().unwrap();
|
|
let _ = keys.pop_front();
|
|
keys.push_front(current);
|
|
}
|
|
keys.push_back(key);
|
|
}
|
|
|
|
fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<(SessionId, &[u8], &[u8], &[u8]), Error> {
|
|
let mut p = &incoming_packet[..];
|
|
let alice_session_id = SessionId::new_from_reader(&mut p)?;
|
|
if alice_session_id.is_none() {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
let alice_session_id = alice_session_id.unwrap();
|
|
let alice_s_public_len = varint::read(&mut p)?.0;
|
|
if (p.len() as u64) < alice_s_public_len {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
let alice_s_public = &p[..(alice_s_public_len as usize)];
|
|
p = &p[(alice_s_public_len as usize)..];
|
|
let alice_metadata_len = varint::read(&mut p)?.0;
|
|
if (p.len() as u64) < alice_metadata_len {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
let alice_metadata = &p[..(alice_metadata_len as usize)];
|
|
p = &p[(alice_metadata_len as usize)..];
|
|
if p.is_empty() {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
let alice_e1_public = match p[0] {
|
|
E1_TYPE_KYBER1024 => {
|
|
if packet_type == PACKET_TYPE_KEY_OFFER {
|
|
if p.len() < (pqc_kyber::KYBER_PUBLICKEYBYTES + 1) {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
&p[1..(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)]
|
|
} else {
|
|
if p.len() < (pqc_kyber::KYBER_CIPHERTEXTBYTES + 1) {
|
|
return Err(Error::InvalidPacket);
|
|
}
|
|
&p[1..(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)]
|
|
}
|
|
}
|
|
_ => &[],
|
|
};
|
|
|
|
Ok((alice_session_id, alice_s_public, alice_metadata, alice_e1_public))
|
|
}
|
|
|
|
enum Role {
|
|
Alice,
|
|
Bob,
|
|
}
|
|
|
|
#[allow(unused)]
|
|
struct SessionKey {
|
|
lifetime: KeyLifetime,
|
|
receive_key: Secret<32>,
|
|
send_key: Secret<32>,
|
|
receive_cipher_pool: Mutex<Vec<Box<AesGcm>>>,
|
|
send_cipher_pool: Mutex<Vec<Box<AesGcm>>>,
|
|
role: Role,
|
|
jedi: bool, // true if kyber was used
|
|
}
|
|
|
|
impl SessionKey {
|
|
/// Create a new symmetric shared session key and set its key expiration times, etc.
|
|
fn new(key: Secret<64>, role: Role, current_time: i64, current_counter: CounterValue, jedi: bool) -> Self {
|
|
let a2b: Secret<32> = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n_clone();
|
|
let b2a: Secret<32> = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n_clone();
|
|
let (receive_key, send_key) = match role {
|
|
Role::Alice => (b2a, a2b),
|
|
Role::Bob => (a2b, b2a),
|
|
};
|
|
Self {
|
|
lifetime: KeyLifetime::new(current_counter, current_time),
|
|
receive_key,
|
|
send_key,
|
|
receive_cipher_pool: Mutex::new(Vec::with_capacity(2)),
|
|
send_cipher_pool: Mutex::new(Vec::with_capacity(2)),
|
|
role,
|
|
jedi,
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn get_send_cipher(&self, counter: CounterValue) -> Result<Box<AesGcm>, Error> {
|
|
if !self.lifetime.expired(counter) {
|
|
Ok(self.send_cipher_pool.lock().pop().unwrap_or_else(|| Box::new(AesGcm::new(self.send_key.as_bytes(), true))))
|
|
} else {
|
|
Err(Error::MaxKeyLifetimeExceeded)
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn return_send_cipher(&self, c: Box<AesGcm>) {
|
|
self.send_cipher_pool.lock().push(c);
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn get_receive_cipher(&self) -> Box<AesGcm> {
|
|
self.receive_cipher_pool.lock().pop().unwrap_or_else(|| Box::new(AesGcm::new(self.receive_key.as_bytes(), false)))
|
|
}
|
|
|
|
#[inline(always)]
|
|
fn return_receive_cipher(&self, c: Box<AesGcm>) {
|
|
self.receive_cipher_pool.lock().push(c);
|
|
}
|
|
}
|
|
|
|
/// HMAC-SHA512 key derivation function modeled on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12)
|
|
fn kbkdf512(key: &[u8], label: u8) -> Secret<64> {
|
|
Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00]))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use parking_lot::Mutex;
|
|
use std::collections::LinkedList;
|
|
use std::sync::Arc;
|
|
|
|
#[allow(unused_imports)]
|
|
use super::*;
|
|
|
|
struct TestHost {
|
|
local_s: P384KeyPair,
|
|
local_s_hash: [u8; 48],
|
|
psk: Secret<64>,
|
|
session: Mutex<Option<Arc<Session<Box<TestHost>>>>>,
|
|
session_id_counter: Mutex<u64>,
|
|
pub queue: Mutex<LinkedList<Vec<u8>>>,
|
|
pub this_name: &'static str,
|
|
pub other_name: &'static str,
|
|
}
|
|
|
|
impl TestHost {
|
|
fn new(psk: Secret<64>, this_name: &'static str, other_name: &'static str) -> Self {
|
|
let local_s = P384KeyPair::generate();
|
|
let local_s_hash = SHA384::hash(local_s.public_key_bytes());
|
|
Self {
|
|
local_s,
|
|
local_s_hash,
|
|
psk,
|
|
session: Mutex::new(None),
|
|
session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1),
|
|
queue: Mutex::new(LinkedList::new()),
|
|
this_name,
|
|
other_name,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Host for Box<TestHost> {
|
|
type AssociatedObject = u32;
|
|
type SessionRef = Arc<Session<Box<TestHost>>>;
|
|
type IncomingPacketBuffer = Vec<u8>;
|
|
|
|
fn get_local_s_public(&self) -> &[u8] {
|
|
self.local_s.public_key_bytes()
|
|
}
|
|
|
|
fn get_local_s_public_hash(&self) -> &[u8; 48] {
|
|
&self.local_s_hash
|
|
}
|
|
|
|
fn get_local_s_keypair_p384(&self) -> &P384KeyPair {
|
|
&self.local_s
|
|
}
|
|
|
|
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey> {
|
|
P384PublicKey::from_bytes(static_public)
|
|
}
|
|
|
|
fn session_lookup(&self, local_session_id: SessionId) -> Option<Self::SessionRef> {
|
|
self.session.lock().as_ref().and_then(|s| {
|
|
if s.id == local_session_id {
|
|
Some(s.clone())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
}
|
|
|
|
fn accept_new_session(&self, _: &[u8], _: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)> {
|
|
loop {
|
|
let mut new_id = self.session_id_counter.lock();
|
|
*new_id += 1;
|
|
return Some((SessionId::new_from_u64(*new_id).unwrap(), self.psk.clone(), 0));
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(unused_variables)]
|
|
#[test]
|
|
fn establish_session() {
|
|
let jedi = true;
|
|
let mut psk: Secret<64> = Secret::default();
|
|
random::fill_bytes_secure(&mut psk.0);
|
|
let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob"));
|
|
let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice"));
|
|
let alice_rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new(&alice_host));
|
|
let bob_rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new(&bob_host));
|
|
let mut data_buf = [0_u8; 4096];
|
|
|
|
//println!("zssp: size of session (bytes): {}", std::mem::size_of::<Session<Box<TestHost>>>());
|
|
|
|
let _ = alice_host.session.lock().insert(Arc::new(
|
|
Session::new(
|
|
&alice_host,
|
|
|data| bob_host.queue.lock().push_front(data.to_vec()),
|
|
SessionId::new_random(),
|
|
bob_host.local_s.public_key_bytes(),
|
|
&[],
|
|
&psk,
|
|
1,
|
|
1280,
|
|
1,
|
|
jedi,
|
|
)
|
|
.unwrap(),
|
|
));
|
|
|
|
let mut ts = 0;
|
|
for _ in 0..256 {
|
|
for host in [&alice_host, &bob_host] {
|
|
let send_to_other = |data: &mut [u8]| {
|
|
if std::ptr::eq(host, &alice_host) {
|
|
bob_host.queue.lock().push_front(data.to_vec());
|
|
} else {
|
|
alice_host.queue.lock().push_front(data.to_vec());
|
|
}
|
|
};
|
|
|
|
let rc = if std::ptr::eq(host, &alice_host) {
|
|
&alice_rc
|
|
} else {
|
|
&bob_rc
|
|
};
|
|
|
|
loop {
|
|
if let Some(qi) = host.queue.lock().pop_back() {
|
|
let qi_len = qi.len();
|
|
ts += 1;
|
|
let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, jedi, ts);
|
|
if r.is_ok() {
|
|
let r = r.unwrap();
|
|
match r {
|
|
ReceiveResult::Ok => {
|
|
println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len);
|
|
}
|
|
ReceiveResult::OkData(data) => {
|
|
println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len());
|
|
}
|
|
ReceiveResult::OkNewSession(new_session) => {
|
|
println!("zssp: {} => {} ({}): OkNewSession ({})", host.other_name, host.this_name, qi_len, u64::from(new_session.id));
|
|
let mut hs = host.session.lock();
|
|
assert!(hs.is_none());
|
|
let _ = hs.insert(Arc::new(new_session));
|
|
}
|
|
ReceiveResult::Ignored => {
|
|
println!("zssp: {} => {} ({}): Ignored", host.other_name, host.this_name, qi_len);
|
|
}
|
|
}
|
|
} else {
|
|
println!("zssp: {} => {} ({}): error: {}", host.other_name, host.this_name, qi_len, r.err().unwrap().to_string());
|
|
panic!();
|
|
}
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|