Use 0xffffffffffff as NIL session ID because magic backward compatibility check for ZT protocol.

This commit is contained in:
Adam Ierymenko 2022-09-20 12:09:54 -04:00
parent 30d3f6e176
commit 7724092551
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3

View file

@ -4,7 +4,6 @@
// FIPS compliant Noise_IK with Jedi powers and built-in attack-resistant large payload (fragmentation) support.
use std::io::{Read, Write};
use std::num::NonZeroU64;
use std::ops::Deref;
use std::sync::atomic::{AtomicU64, Ordering};
@ -17,6 +16,7 @@ use crate::secret::Secret;
use zerotier_utils::gatherarray::GatherArray;
use zerotier_utils::memory;
use zerotier_utils::ringbuffermap::RingBufferMap;
use zerotier_utils::unlikely_branch;
use zerotier_utils::varint;
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard};
@ -165,14 +165,10 @@ impl From<std::io::Error> for Error {
}
}
#[cold]
#[inline(never)]
extern "C" fn unlikely_branch() {}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnknownLocalSessionId(id) => f.write_str(format!("UnknownLocalSessionId({})", id.0.get()).as_str()),
Self::UnknownLocalSessionId(id) => f.write_str(format!("UnknownLocalSessionId({})", id.0).as_str()),
Self::InvalidPacket => f.write_str("InvalidPacket"),
Self::InvalidParameter => f.write_str("InvalidParameter"),
Self::FailedAuthentication => f.write_str("FailedAuthentication"),
@ -217,28 +213,37 @@ pub enum ReceiveResult<'a, H: Host> {
/// 48-bit session ID (most significant 24 bits of u64 are unused)
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(transparent)]
pub struct SessionId(NonZeroU64);
pub struct SessionId(u64);
impl SessionId {
pub const MAX_BIT_MASK: u64 = 0xffffffffffff;
pub const NIL: SessionId = SessionId(0xffffffffffff);
#[inline(always)]
#[inline]
pub fn new_from_u64(i: u64) -> Option<SessionId> {
debug_assert!(i <= Self::MAX_BIT_MASK);
NonZeroU64::new(i).map(|i| Self(i))
if i < Self::NIL.0 {
Some(Self(i))
} else {
None
}
}
#[inline]
pub fn new_from_reader<R: Read>(r: &mut R) -> std::io::Result<Option<SessionId>> {
let mut tmp = [0_u8; 8];
r.read_exact(&mut tmp[..SESSION_ID_SIZE])
.map(|_| NonZeroU64::new(u64::from_le_bytes(tmp)).map(|i| Self(i)))
r.read_exact(&mut tmp[..SESSION_ID_SIZE])?;
Ok(Self::new_from_u64(u64::from_le_bytes(tmp)))
}
#[inline]
pub fn new_random() -> Self {
Self(random::xorshift64_random() % (Self::NIL.0 - 1))
}
}
impl From<SessionId> for u64 {
#[inline(always)]
fn from(sid: SessionId) -> Self {
sid.0.get()
sid.0
}
}
@ -578,8 +583,7 @@ impl<H: Host> ReceiveContext<H> {
let fragment_count = ((packet_type_fragment_info.wrapping_shr(4) + 1) as u8) & 63;
let fragment_no = packet_type_fragment_info.wrapping_shr(10) as u8;
if let Some(local_session_id) =
SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & SessionId::MAX_BIT_MASK)
if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64)
{
if let Some(session) = host.session_lookup(local_session_id) {
if check_header_mac(incoming_packet, &session.header_check_cipher) {
@ -632,7 +636,7 @@ impl<H: Host> ReceiveContext<H> {
} else {
unlikely_branch();
if check_header_mac(incoming_packet, &self.incoming_init_header_check_cipher) {
let pseudoheader = Pseudoheader::make(0, packet_type, counter);
let pseudoheader = Pseudoheader::make(SessionId::NIL.0, packet_type, counter);
if fragment_count > 1 {
let mut defrag = self.initial_offer_defrag.lock();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
@ -993,7 +997,7 @@ impl<H: Host> ReceiveContext<H> {
rp.write_all(bob_e0_keypair.public_key_bytes())?;
rp.write_all(&offer_id)?;
rp.write_all(&session.id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?;
rp.write_all(&session.id.0.to_le_bytes()[..SESSION_ID_SIZE])?;
varint::write(&mut rp, 0)?; // they don't need our static public; they have it
varint::write(&mut rp, 0)?; // no meta-data in counter-offers (could be used in the future)
if let Some(bob_e1_public) = bob_e1_public.as_ref() {
@ -1300,7 +1304,7 @@ fn create_initial_offer<SendFunction: FnMut(&mut [u8])>(
p.write_all(alice_e0_keypair.public_key_bytes())?;
p.write_all(&id)?;
p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?;
p.write_all(&alice_session_id.0.to_le_bytes()[..SESSION_ID_SIZE])?;
varint::write(&mut p, alice_s_public.len() as u64)?;
p.write_all(alice_s_public)?;
varint::write(&mut p, alice_metadata.len() as u64)?;
@ -1321,7 +1325,7 @@ fn create_initial_offer<SendFunction: FnMut(&mut [u8])>(
PACKET_BUF_SIZE - p.len()
};
let bob_session_id: u64 = bob_session_id.map_or(0_u64, |i| i.into());
let bob_session_id: u64 = bob_session_id.map_or(SessionId::NIL.0, |i| i.into());
create_packet_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id, counter)?;
let pseudoheader = Pseudoheader::make(bob_session_id, PACKET_TYPE_KEY_OFFER, counter.to_u32());
@ -1684,7 +1688,7 @@ mod tests {
local_s_hash,
psk,
session: Mutex::new(None),
session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1),
session_id_counter: Mutex::new(1),
queue: Mutex::new(LinkedList::new()),
key_id: Mutex::new([0; 16]),
this_name,
@ -1765,7 +1769,7 @@ mod tests {
Session::new(
&alice_host,
|data| bob_host.queue.lock().push_front(data.to_vec()),
SessionId::new_from_u64(random::xorshift64_random().wrapping_shr(16)).unwrap(),
SessionId::new_random(),
bob_host.local_s.public_key_bytes(),
&[],
&psk,