Some more renaming to make code more readable.

This commit is contained in:
Adam Ierymenko 2022-12-16 10:13:16 -05:00
parent 45bf978dcd
commit 826f0d3ab5
2 changed files with 83 additions and 78 deletions

View file

@ -52,10 +52,10 @@ pub(crate) const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 10; // 10 min
pub(crate) const SESSION_PROTOCOL_VERSION: u8 = 0x00; pub(crate) const SESSION_PROTOCOL_VERSION: u8 = 0x00;
/// Secondary key type: none, use only P-384 for forward secrecy. /// Secondary key type: none, use only P-384 for forward secrecy.
pub(crate) const E1_TYPE_NONE: u8 = 0; pub(crate) const HYBRID_KEY_TYPE_NONE: u8 = 0;
/// Secondary key type: Kyber1024, PQ forward secrecy enabled. /// Secondary key type: Kyber1024, PQ forward secrecy enabled.
pub(crate) const E1_TYPE_KYBER1024: u8 = 1; pub(crate) const HYBRID_KEY_TYPE_KYBER1024: u8 = 1;
/// Size of packet header /// Size of packet header
pub(crate) const HEADER_SIZE: usize = 16; pub(crate) const HEADER_SIZE: usize = 16;

View file

@ -35,7 +35,7 @@ pub enum Error {
/// Packet failed one or more authentication (MAC) checks /// Packet failed one or more authentication (MAC) checks
FailedAuthentication, FailedAuthentication,
/// New session was rejected via Host::check_new_session_attempt or Host::accept_new_session. /// New session was rejected by the application layer.
NewSessionRejected, NewSessionRejected,
/// Rekeying failed and session secret has reached its hard usage count limit /// Rekeying failed and session secret has reached its hard usage count limit
@ -56,7 +56,10 @@ pub enum Error {
/// Data object is too large to send, even with fragmentation /// Data object is too large to send, even with fragmentation
DataTooLarge, DataTooLarge,
/// An unexpected buffer overrun occured while attempting to encode or decode a packet, this can only ever happen if exceptionally large key blobs or metadata are being used, or as the result of an internal encoding bug. /// An unexpected buffer overrun occured while attempting to encode or decode a packet.
///
/// This can only ever happen if exceptionally large key blobs or metadata are being used,
/// or as the result of an internal encoding bug.
UnexpectedBufferOverrun, UnexpectedBufferOverrun,
} }
@ -99,12 +102,12 @@ pub struct ReceiveContext<H: ApplicationLayer> {
} }
/// ZSSP bi-directional packet transport channel. /// ZSSP bi-directional packet transport channel.
pub struct Session<Layer: ApplicationLayer> { pub struct Session<Application: ApplicationLayer> {
/// This side's session ID (unique on this side) /// This side's session ID (unique on this side)
pub id: SessionId, pub id: SessionId,
/// An arbitrary object associated with session (type defined in Host trait) /// An arbitrary object associated with session (type defined in Host trait)
pub user_data: Layer::SessionUserData, pub user_data: Application::SessionUserData,
send_counter: Counter, // Outgoing packet counter and nonce state send_counter: Counter, // Outgoing packet counter and nonce state
psk: Secret<64>, // Arbitrary PSK provided by external code psk: Secret<64>, // Arbitrary PSK provided by external code
@ -114,7 +117,7 @@ pub struct Session<Layer: ApplicationLayer> {
remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_raw: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key remote_s_public_raw: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
defrag: Mutex<RingBufferMap<u32, GatherArray<Layer::IncomingPacketBuffer, MAX_FRAGMENTS>, 8, 8>>, defrag: Mutex<RingBufferMap<u32, GatherArray<Application::IncomingPacketBuffer, MAX_FRAGMENTS>, 8, 8>>,
} }
struct SessionMutableState { struct SessionMutableState {
@ -208,7 +211,6 @@ fn safe_write_all(buffer: &mut [u8], idx: usize, src: &[u8]) -> Result<usize, Er
} }
} }
/// Write a variable length integer, which can consume up to 10 bytes. Uses safe_write_all to do so. /// Write a variable length integer, which can consume up to 10 bytes. Uses safe_write_all to do so.
#[inline(always)]
fn varint_safe_write(buffer: &mut [u8], idx: usize, v: u64) -> Result<usize, Error> { fn varint_safe_write(buffer: &mut [u8], idx: usize, v: u64) -> Result<usize, Error> {
let mut b = [0_u8; varint::VARINT_MAX_SIZE_BYTES]; let mut b = [0_u8; varint::VARINT_MAX_SIZE_BYTES];
let i = varint::encode(&mut b, v); let i = varint::encode(&mut b, v);
@ -227,7 +229,6 @@ fn safe_read_exact<'a>(src: &mut &'a [u8], amt: usize) -> Result<&'a [u8], Error
} }
} }
/// Read a variable length integer, which can consume up to 10 bytes. Uses varint_safe_read to do so. /// Read a variable length integer, which can consume up to 10 bytes. Uses varint_safe_read to do so.
#[inline(always)]
fn varint_safe_read(src: &mut &[u8]) -> Result<u64, Error> { fn varint_safe_read(src: &mut &[u8]) -> Result<u64, Error> {
let (v, amt) = varint::decode(*src).ok_or(Error::InvalidPacket)?; let (v, amt) = varint::decode(*src).ok_or(Error::InvalidPacket)?;
let (_, b) = src.split_at(amt); let (_, b) = src.split_at(amt);
@ -235,7 +236,7 @@ fn varint_safe_read(src: &mut &[u8]) -> Result<u64, Error> {
Ok(v) Ok(v)
} }
impl<Layer: ApplicationLayer> Session<Layer> { impl<Application: ApplicationLayer> Session<Application> {
/// Create a new session and send an initial key offer message to the other end. /// Create a new session and send an initial key offer message to the other end.
/// ///
/// * `host` - Interface to application using ZSSP /// * `host` - Interface to application using ZSSP
@ -247,19 +248,19 @@ impl<Layer: ApplicationLayer> Session<Layer> {
/// * `mtu` - Physical wire maximum transmition unit /// * `mtu` - Physical wire maximum transmition unit
/// * `current_time` - Current monotonic time in milliseconds /// * `current_time` - Current monotonic time in milliseconds
pub fn start_new<SendFunction: FnMut(&mut [u8])>( pub fn start_new<SendFunction: FnMut(&mut [u8])>(
host: &Layer, app: &Application,
mut send: SendFunction, mut send: SendFunction,
local_session_id: SessionId, local_session_id: SessionId,
remote_s_public_blob: &[u8], remote_s_public_blob: &[u8],
offer_metadata: &[u8], offer_metadata: &[u8],
psk: &Secret<64>, psk: &Secret<64>,
user_data: Layer::SessionUserData, user_data: Application::SessionUserData,
mtu: usize, mtu: usize,
current_time: i64, current_time: i64,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let bob_s_public_blob = remote_s_public_blob; let bob_s_public_blob = remote_s_public_blob;
if let Some(bob_s_public) = Layer::extract_s_public_from_raw(bob_s_public_blob) { if let Some(bob_s_public) = Application::extract_s_public_from_raw(bob_s_public_blob) {
if let Some(noise_ss) = host.get_local_s_keypair().agree(&bob_s_public) { if let Some(noise_ss) = app.get_local_s_keypair().agree(&bob_s_public) {
let send_counter = Counter::new(); let send_counter = Counter::new();
let bob_s_public_blob_hash = SHA384::hash(bob_s_public_blob); let bob_s_public_blob_hash = SHA384::hash(bob_s_public_blob);
let header_check_cipher = let header_check_cipher =
@ -270,7 +271,7 @@ impl<Layer: ApplicationLayer> Session<Layer> {
send_counter.next(), send_counter.next(),
local_session_id, local_session_id,
None, None,
host.get_local_s_public_blob(), app.get_local_s_public_blob(),
offer_metadata, offer_metadata,
&bob_s_public, &bob_s_public,
&bob_s_public_blob_hash, &bob_s_public_blob_hash,
@ -310,16 +311,16 @@ impl<Layer: ApplicationLayer> Session<Layer> {
/// Send data over the session. /// Send data over the session.
/// ///
/// * `send` - Function to call to send physical packet(s) /// * `send` - Function to call to send physical packet(s)
/// * `mtu_buffer` - A writable work buffer whose size also specifies the physical MTU /// * `mtu_sized_buffer` - A writable work buffer whose size also specifies the physical MTU
/// * `data` - Data to send /// * `data` - Data to send
#[inline] #[inline]
pub fn send<SendFunction: FnMut(&mut [u8])>( pub fn send<SendFunction: FnMut(&mut [u8])>(
&self, &self,
mut send: SendFunction, mut send: SendFunction,
mtu_buffer: &mut [u8], mtu_sized_buffer: &mut [u8],
mut data: &[u8], mut data: &[u8],
) -> Result<(), Error> { ) -> Result<(), Error> {
debug_assert!(mtu_buffer.len() >= MIN_TRANSPORT_MTU); debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU);
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id { if let Some(remote_session_id) = state.remote_session_id {
if let Some(session_key) = state.session_keys[state.cur_session_key_idx].as_ref() { if let Some(session_key) = state.session_keys[state.cur_session_key_idx].as_ref() {
@ -335,9 +336,9 @@ impl<Layer: ApplicationLayer> Session<Layer> {
// Create initial header for first fragment of packet and place in first HEADER_SIZE bytes of buffer. // Create initial header for first fragment of packet and place in first HEADER_SIZE bytes of buffer.
create_packet_header( create_packet_header(
mtu_buffer, mtu_sized_buffer,
packet_len, packet_len,
mtu_buffer.len(), mtu_sized_buffer.len(),
PACKET_TYPE_DATA, PACKET_TYPE_DATA,
remote_session_id.into(), remote_session_id.into(),
counter, counter,
@ -350,21 +351,21 @@ impl<Layer: ApplicationLayer> Session<Layer> {
// Send first N-1 fragments of N total fragments. // Send first N-1 fragments of N total fragments.
let last_fragment_size; let last_fragment_size;
if packet_len > mtu_buffer.len() { if packet_len > mtu_sized_buffer.len() {
let mut header: [u8; 16] = mtu_buffer[..HEADER_SIZE].try_into().unwrap(); let mut header: [u8; 16] = mtu_sized_buffer[..HEADER_SIZE].try_into().unwrap();
let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE; let fragment_data_mtu = mtu_sized_buffer.len() - HEADER_SIZE;
let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); let last_fragment_data_mtu = mtu_sized_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
loop { loop {
let fragment_data_size = fragment_data_mtu.min(data.len()); let fragment_data_size = fragment_data_mtu.min(data.len());
let fragment_size = fragment_data_size + HEADER_SIZE; let fragment_size = fragment_data_size + HEADER_SIZE;
c.crypt(&data[..fragment_data_size], &mut mtu_buffer[HEADER_SIZE..fragment_size]); c.crypt(&data[..fragment_data_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]);
data = &data[fragment_data_size..]; data = &data[fragment_data_size..];
set_header_check_code(mtu_buffer, &self.header_check_cipher); set_header_check_code(mtu_sized_buffer, &self.header_check_cipher);
send(&mut mtu_buffer[..fragment_size]); send(&mut mtu_sized_buffer[..fragment_size]);
debug_assert!(header[15].wrapping_shr(2) < 63); debug_assert!(header[15].wrapping_shr(2) < 63);
header[15] += 0x04; // increment fragment number header[15] += 0x04; // increment fragment number
mtu_buffer[..HEADER_SIZE].copy_from_slice(&header); mtu_sized_buffer[..HEADER_SIZE].copy_from_slice(&header);
if data.len() <= last_fragment_data_mtu { if data.len() <= last_fragment_data_mtu {
break; break;
@ -377,11 +378,11 @@ impl<Layer: ApplicationLayer> Session<Layer> {
// Send final fragment (or only fragment if no fragmentation was needed) // Send final fragment (or only fragment if no fragmentation was needed)
let payload_end = data.len() + HEADER_SIZE; let payload_end = data.len() + HEADER_SIZE;
c.crypt(data, &mut mtu_buffer[HEADER_SIZE..payload_end]); c.crypt(data, &mut mtu_sized_buffer[HEADER_SIZE..payload_end]);
let gcm_tag = c.finish_encrypt(); let gcm_tag = c.finish_encrypt();
mtu_buffer[payload_end..last_fragment_size].copy_from_slice(&gcm_tag); mtu_sized_buffer[payload_end..last_fragment_size].copy_from_slice(&gcm_tag);
set_header_check_code(mtu_buffer, &self.header_check_cipher); set_header_check_code(mtu_sized_buffer, &self.header_check_cipher);
send(&mut mtu_buffer[..last_fragment_size]); send(&mut mtu_sized_buffer[..last_fragment_size]);
// Check reusable AES-GCM instance back into pool. // Check reusable AES-GCM instance back into pool.
session_key.return_send_cipher(c); session_key.return_send_cipher(c);
@ -425,7 +426,7 @@ impl<Layer: ApplicationLayer> Session<Layer> {
/// * `force_rekey` - Re-key the session now regardless of key aging (still subject to rate limiting) /// * `force_rekey` - Re-key the session now regardless of key aging (still subject to rate limiting)
pub fn service<SendFunction: FnMut(&mut [u8])>( pub fn service<SendFunction: FnMut(&mut [u8])>(
&self, &self,
host: &Layer, host: &Application,
mut send: SendFunction, mut send: SendFunction,
offer_metadata: &[u8], offer_metadata: &[u8],
mtu: usize, mtu: usize,
@ -440,7 +441,7 @@ impl<Layer: ApplicationLayer> Session<Layer> {
&& state && state
.offer .offer
.as_ref() .as_ref()
.map_or(true, |o| (current_time - o.creation_time) > Layer::REKEY_RATE_LIMIT_MS) .map_or(true, |o| (current_time - o.creation_time) > Application::REKEY_RATE_LIMIT_MS)
{ {
if let Some(remote_s_public) = P384PublicKey::from_bytes(&self.remote_s_public_raw) { if let Some(remote_s_public) = P384PublicKey::from_bytes(&self.remote_s_public_raw) {
let mut offer = None; let mut offer = None;
@ -474,8 +475,8 @@ impl<Layer: ApplicationLayer> Session<Layer> {
} }
} }
impl<Layer: ApplicationLayer> ReceiveContext<Layer> { impl<Application: ApplicationLayer> ReceiveContext<Application> {
pub fn new(host: &Layer) -> Self { pub fn new(host: &Application) -> Self {
Self { Self {
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
incoming_init_header_check_cipher: Aes::new( incoming_init_header_check_cipher: Aes::new(
@ -495,14 +496,14 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
#[inline] #[inline]
pub fn receive<'a, SendFunction: FnMut(&mut [u8])>( pub fn receive<'a, SendFunction: FnMut(&mut [u8])>(
&self, &self,
host: &Layer, app: &Application,
remote_address: &Layer::RemoteAddress, remote_address: &Application::RemoteAddress,
mut send: SendFunction, mut send: SendFunction,
data_buf: &'a mut [u8], data_buf: &'a mut [u8],
incoming_packet_buf: Layer::IncomingPacketBuffer, incoming_packet_buf: Application::IncomingPacketBuffer,
mtu: usize, mtu: usize,
current_time: i64, current_time: i64,
) -> Result<ReceiveResult<'a, Layer>, Error> { ) -> Result<ReceiveResult<'a, Application>, Error> {
let incoming_packet = incoming_packet_buf.as_ref(); let incoming_packet = incoming_packet_buf.as_ref();
if incoming_packet.len() < MIN_PACKET_SIZE { if incoming_packet.len() < MIN_PACKET_SIZE {
unlikely_branch(); unlikely_branch();
@ -517,7 +518,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64) if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64)
{ {
if let Some(session) = host.lookup_session(local_session_id) { if let Some(session) = app.lookup_session(local_session_id) {
if verify_header_check_code(incoming_packet, &session.header_check_cipher) { if verify_header_check_code(incoming_packet, &session.header_check_cipher) {
let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter);
if fragment_count > 1 { if fragment_count > 1 {
@ -527,7 +528,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock drop(defrag); // release lock
return self.receive_complete( return self.receive_complete(
host, app,
remote_address, remote_address,
&mut send, &mut send,
data_buf, data_buf,
@ -545,7 +546,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
} }
} else { } else {
return self.receive_complete( return self.receive_complete(
host, app,
remote_address, remote_address,
&mut send, &mut send,
data_buf, data_buf,
@ -576,7 +577,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock drop(defrag); // release lock
return self.receive_complete( return self.receive_complete(
host, app,
remote_address, remote_address,
&mut send, &mut send,
data_buf, data_buf,
@ -590,7 +591,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
} }
} else { } else {
return self.receive_complete( return self.receive_complete(
host, app,
remote_address, remote_address,
&mut send, &mut send,
data_buf, data_buf,
@ -618,17 +619,17 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
#[inline] #[inline]
fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>( fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>(
&self, &self,
host: &Layer, app: &Application,
remote_address: &Layer::RemoteAddress, remote_address: &Application::RemoteAddress,
send: &mut SendFunction, send: &mut SendFunction,
data_buf: &'a mut [u8], data_buf: &'a mut [u8],
canonical_header_bytes: &[u8; 12], canonical_header_bytes: &[u8; AES_GCM_TAG_SIZE],
fragments: &[Layer::IncomingPacketBuffer], fragments: &[Application::IncomingPacketBuffer],
packet_type: u8, packet_type: u8,
session: Option<Layer::SessionRef>, session: Option<Application::SessionRef>,
mtu: usize, mtu: usize,
current_time: i64, current_time: i64,
) -> Result<ReceiveResult<'a, Layer>, Error> { ) -> Result<ReceiveResult<'a, Application>, Error> {
debug_assert!(fragments.len() >= 1); debug_assert!(fragments.len() >= 1);
// The first 'if' below should capture both DATA and NOP but not other types. Sanity check this. // The first 'if' below should capture both DATA and NOP but not other types. Sanity check this.
@ -769,7 +770,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
// Check the second HMAC first, which proves that the sender knows the recipient's full static identity. // Check the second HMAC first, which proves that the sender knows the recipient's full static identity.
let hmac2 = &kex_packet[hmac1_end..kex_packet_len]; let hmac2 = &kex_packet[hmac1_end..kex_packet_len];
if !hmac_sha384_2( if !hmac_sha384_2(
host.get_local_s_public_blob_hash(), app.get_local_s_public_blob_hash(),
canonical_header_bytes, canonical_header_bytes,
&kex_packet[HEADER_SIZE..hmac1_end], &kex_packet[HEADER_SIZE..hmac1_end],
) )
@ -780,11 +781,11 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
// Check rate limits. // Check rate limits.
if let Some(session) = session.as_ref() { if let Some(session) = session.as_ref() {
if (current_time - session.state.read().unwrap().last_remote_offer) < Layer::REKEY_RATE_LIMIT_MS { if (current_time - session.state.read().unwrap().last_remote_offer) < Application::REKEY_RATE_LIMIT_MS {
return Err(Error::RateLimited); return Err(Error::RateLimited);
} }
} else { } else {
if !host.check_new_session(self, remote_address) { if !app.check_new_session(self, remote_address) {
return Err(Error::RateLimited); return Err(Error::RateLimited);
} }
} }
@ -792,7 +793,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
// Key agreement: alice (remote) ephemeral NIST P-384 <> local static NIST P-384 // Key agreement: alice (remote) ephemeral NIST P-384 <> local static NIST P-384
let alice_e_public = let alice_e_public =
P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..plaintext_end]).ok_or(Error::FailedAuthentication)?; P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..plaintext_end]).ok_or(Error::FailedAuthentication)?;
let noise_es = host let noise_es = app
.get_local_s_keypair() .get_local_s_keypair()
.agree(&alice_e_public) .agree(&alice_e_public)
.ok_or(Error::FailedAuthentication)?; .ok_or(Error::FailedAuthentication)?;
@ -832,10 +833,10 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
} }
// Extract alice's static NIST P-384 public key from her public blob. // Extract alice's static NIST P-384 public key from her public blob.
let alice_s_public = Layer::extract_s_public_from_raw(alice_s_public_blob).ok_or(Error::InvalidPacket)?; let alice_s_public = Application::extract_s_public_from_raw(alice_s_public_blob).ok_or(Error::InvalidPacket)?;
// Key agreement: both sides' static P-384 keys. // Key agreement: both sides' static P-384 keys.
let noise_ss = host let noise_ss = app
.get_local_s_keypair() .get_local_s_keypair()
.agree(&alice_s_public) .agree(&alice_s_public)
.ok_or(Error::FailedAuthentication)?; .ok_or(Error::FailedAuthentication)?;
@ -874,7 +875,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
let state = session.state.read().unwrap(); let state = session.state.read().unwrap();
for k in state.session_keys.iter() { for k in state.session_keys.iter() {
if let Some(k) = k.as_ref() { if let Some(k) = k.as_ref() {
if secret_fingerprint(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) {
ratchet_key = Some(k.ratchet_key.clone()); ratchet_key = Some(k.ratchet_key.clone());
ratchet_count = k.ratchet_count; ratchet_count = k.ratchet_count;
break; break;
@ -888,13 +889,13 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
(None, ratchet_key, ratchet_count) (None, ratchet_key, ratchet_count)
} else { } else {
if let Some((new_session_id, psk, associated_object)) = if let Some((new_session_id, psk, associated_object)) =
host.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata) app.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata)
{ {
let header_check_cipher = Aes::new( let header_check_cipher = Aes::new(
kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(), kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(),
); );
( (
Some(Session::<Layer> { Some(Session::<Application> {
id: new_session_id, id: new_session_id,
user_data: associated_object, user_data: associated_object,
send_counter: Counter::new(), send_counter: Counter::new(),
@ -983,10 +984,10 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
idx = varint_safe_write(&mut reply_buf, idx, 0)?; // they don't need our static public; they have it idx = varint_safe_write(&mut reply_buf, idx, 0)?; // they don't need our static public; they have it
idx = varint_safe_write(&mut reply_buf, idx, 0)?; // no meta-data in counter-offers (could be used in the future) idx = varint_safe_write(&mut reply_buf, idx, 0)?; // no meta-data in counter-offers (could be used in the future)
if let Some(bob_hk_public) = bob_hk_public.as_ref() { if let Some(bob_hk_public) = bob_hk_public.as_ref() {
idx = safe_write_all(&mut reply_buf, idx, &[E1_TYPE_KYBER1024])?; idx = safe_write_all(&mut reply_buf, idx, &[HYBRID_KEY_TYPE_KYBER1024])?;
idx = safe_write_all(&mut reply_buf, idx, bob_hk_public)?; idx = safe_write_all(&mut reply_buf, idx, bob_hk_public)?;
} else { } else {
idx = safe_write_all(&mut reply_buf, idx, &[E1_TYPE_NONE])?; idx = safe_write_all(&mut reply_buf, idx, &[HYBRID_KEY_TYPE_NONE])?;
} }
if ratchet_key.is_some() { if ratchet_key.is_some() {
idx = safe_write_all(&mut reply_buf, idx, &[0x01])?; idx = safe_write_all(&mut reply_buf, idx, &[0x01])?;
@ -1086,7 +1087,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
let bob_e_public = P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..plaintext_end]) let bob_e_public = P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..plaintext_end])
.ok_or(Error::FailedAuthentication)?; .ok_or(Error::FailedAuthentication)?;
let noise_ee = offer.alice_e_keypair.agree(&bob_e_public).ok_or(Error::FailedAuthentication)?; let noise_ee = offer.alice_e_keypair.agree(&bob_e_public).ok_or(Error::FailedAuthentication)?;
let noise_se = host.get_local_s_keypair().agree(&bob_e_public).ok_or(Error::FailedAuthentication)?; let noise_se = app.get_local_s_keypair().agree(&bob_e_public).ok_or(Error::FailedAuthentication)?;
let noise_ik_key = Secret(hmac_sha512( let noise_ik_key = Secret(hmac_sha512(
session.psk.as_bytes(), session.psk.as_bytes(),
@ -1229,7 +1230,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
// Perform key agreement with the other side's static P-384 public key. // Perform key agreement with the other side's static P-384 public key.
let noise_es = alice_e_keypair.agree(bob_s_public).ok_or(Error::InvalidPacket)?; let noise_es = alice_e_keypair.agree(bob_s_public).ok_or(Error::InvalidPacket)?;
// Generate a Kyber1024 pair if enabled. // Generate a Kyber1024 (hybrid PQ crypto) pair if enabled.
let alice_hk_keypair = if JEDI { let alice_hk_keypair = if JEDI {
Some(pqc_kyber::keypair(&mut random::SecureRandom::get())) Some(pqc_kyber::keypair(&mut random::SecureRandom::get()))
} else { } else {
@ -1268,14 +1269,14 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
idx = varint_safe_write(&mut packet_buf, idx, alice_metadata.len() as u64)?; idx = varint_safe_write(&mut packet_buf, idx, alice_metadata.len() as u64)?;
idx = safe_write_all(&mut packet_buf, idx, alice_metadata)?; idx = safe_write_all(&mut packet_buf, idx, alice_metadata)?;
if let Some(hkp) = alice_hk_keypair { if let Some(hkp) = alice_hk_keypair {
idx = safe_write_all(&mut packet_buf, idx, &[E1_TYPE_KYBER1024])?; idx = safe_write_all(&mut packet_buf, idx, &[HYBRID_KEY_TYPE_KYBER1024])?;
idx = safe_write_all(&mut packet_buf, idx, &hkp.public)?; idx = safe_write_all(&mut packet_buf, idx, &hkp.public)?;
} else { } else {
idx = safe_write_all(&mut packet_buf, idx, &[E1_TYPE_NONE])?; idx = safe_write_all(&mut packet_buf, idx, &[HYBRID_KEY_TYPE_NONE])?;
} }
if let Some(ratchet_key) = ratchet_key.as_ref() { if let Some(ratchet_key) = ratchet_key.as_ref() {
idx = safe_write_all(&mut packet_buf, idx, &[0x01])?; idx = safe_write_all(&mut packet_buf, idx, &[0x01])?;
idx = safe_write_all(&mut packet_buf, idx, &secret_fingerprint(ratchet_key.as_bytes())[..16])?; idx = safe_write_all(&mut packet_buf, idx, &public_fingerprint_of_secret(ratchet_key.as_bytes())[..16])?;
} else { } else {
idx = safe_write_all(&mut packet_buf, idx, &[0x00])?; idx = safe_write_all(&mut packet_buf, idx, &[0x00])?;
} }
@ -1362,7 +1363,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
/// Populate all but the header check code in the first 16 bytes of a packet or fragment. /// Populate all but the header check code in the first 16 bytes of a packet or fragment.
#[inline(always)] #[inline(always)]
fn create_packet_header( fn create_packet_header(
header: &mut [u8], header_destination_buffer: &mut [u8],
packet_len: usize, packet_len: usize,
mtu: usize, mtu: usize,
packet_type: u8, packet_type: u8,
@ -1371,7 +1372,7 @@ fn create_packet_header(
) -> Result<(), Error> { ) -> Result<(), Error> {
let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize; let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize;
debug_assert!(header.len() >= HEADER_SIZE); debug_assert!(header_destination_buffer.len() >= HEADER_SIZE);
debug_assert!(mtu >= MIN_TRANSPORT_MTU); debug_assert!(mtu >= MIN_TRANSPORT_MTU);
debug_assert!(packet_len >= MIN_PACKET_SIZE); debug_assert!(packet_len >= MIN_PACKET_SIZE);
debug_assert!(fragment_count > 0); debug_assert!(fragment_count > 0);
@ -1386,11 +1387,11 @@ fn create_packet_header(
// [112-115] packet type (0-15) // [112-115] packet type (0-15)
// [116-121] number of fragments (0..63 for 1..64 fragments total) // [116-121] number of fragments (0..63 for 1..64 fragments total)
// [122-127] fragment number (0, 1, 2, ...) // [122-127] fragment number (0, 1, 2, ...)
memory::store_raw((counter.to_u32() as u64).to_le(), header); memory::store_raw((counter.to_u32() as u64).to_le(), header_destination_buffer);
memory::store_raw( memory::store_raw(
(u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)) (u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52))
.to_le(), .to_le(),
&mut header[8..], &mut header_destination_buffer[8..],
); );
Ok(()) Ok(())
} else { } else {
@ -1464,7 +1465,7 @@ fn parse_dec_key_offer_after_header(
let alice_metadata = safe_read_exact(&mut p, alice_metadata_len as usize)?; let alice_metadata = safe_read_exact(&mut p, alice_metadata_len as usize)?;
let alice_hk_public_raw = match safe_read_exact(&mut p, 1)?[0] { let alice_hk_public_raw = match safe_read_exact(&mut p, 1)?[0] {
E1_TYPE_KYBER1024 => { HYBRID_KEY_TYPE_KYBER1024 => {
if packet_type == PACKET_TYPE_INITIAL_KEY_OFFER { if packet_type == PACKET_TYPE_INITIAL_KEY_OFFER {
safe_read_exact(&mut p, pqc_kyber::KYBER_PUBLICKEYBYTES)? safe_read_exact(&mut p, pqc_kyber::KYBER_PUBLICKEYBYTES)?
} else { } else {
@ -1473,6 +1474,7 @@ fn parse_dec_key_offer_after_header(
} }
_ => &[], _ => &[],
}; };
if p.is_empty() { if p.is_empty() {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
@ -1481,6 +1483,7 @@ fn parse_dec_key_offer_after_header(
} else { } else {
None None
}; };
Ok(( Ok((
offer_id, //always 16 bytes offer_id, //always 16 bytes
alice_session_id, alice_session_id,
@ -1501,7 +1504,7 @@ impl SessionKey {
Role::Bob => (a2b, b2a), Role::Bob => (a2b, b2a),
}; };
Self { Self {
secret_fingerprint: secret_fingerprint(key.as_bytes())[..16].try_into().unwrap(), secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(),
establish_time: current_time, establish_time: current_time,
establish_counter: current_counter, establish_counter: current_counter,
lifetime: KeyLifetime::new(current_counter, current_time), lifetime: KeyLifetime::new(current_counter, current_time),
@ -1599,16 +1602,18 @@ fn hmac_sha384_2(key: &[u8], a: &[u8], b: &[u8]) -> [u8; 48] {
hmac.finish() hmac.finish()
} }
/// HMAC-SHA512 key derivation function modeled on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12) /// HMAC-SHA512 key derivation based on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12)
/// Cryptographically this isn't really different from HMAC(key, [label]) with just one byte. ///
/// Cryptographically this isn't meaningfully different from HMAC(key, [label]),
/// but NIST seems to like it this way.
fn kbkdf512(key: &[u8], label: u8) -> Secret<64> { fn kbkdf512(key: &[u8], label: u8) -> Secret<64> {
Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00])) Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00]))
} }
/// Get a hash of a secret key that can be used as a public fingerprint. /// Get a hash of a secret that can be used as a public key fingerprint to check ratcheting during key exchange.
fn secret_fingerprint(key: &[u8]) -> [u8; 48] { fn public_fingerprint_of_secret(key: &[u8]) -> [u8; 48] {
let mut tmp = SHA384::new(); let mut tmp = SHA384::new();
tmp.update("fp".as_bytes()); tmp.update(&[0xf0, 0x0d]); // arbitrary salt
tmp.update(key); tmp.update(key);
tmp.finish() tmp.finish()
} }