Rewrite fileDB to be simpler in controller. Simplify controller. ZOMG IT BUILDS AND TESTS PASS.

This commit is contained in:
Adam Ierymenko 2023-03-29 12:37:27 -04:00
parent b97ed1e97a
commit a24be3dbe5
14 changed files with 544 additions and 492 deletions

View file

@ -1,14 +1,15 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently proprietary pending actual release and licensing. See LICENSE.md.
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use std::error::Error;
use std::mem::replace;
use std::ops::Bound;
use std::sync::{Mutex, RwLock};
use crate::database::Database;
use crate::model::{Member, Network};
use zerotier_network_hypervisor::vl1::Address;
use zerotier_network_hypervisor::vl1::{Address, PartialAddress};
use zerotier_network_hypervisor::vl2::NetworkId;
/// Network and member cache used by database implementations to implement change detection.
@ -18,7 +19,7 @@ use zerotier_network_hypervisor::vl2::NetworkId;
/// okay but calls out of order will result in extra updated events being generated for
/// movements forward and backward in time. Calls must be temporally ordered.
pub struct Cache {
by_nwid: RwLock<HashMap<NetworkId, (Network, Mutex<HashMap<Address, Member>>)>>,
by_nwid: RwLock<HashMap<NetworkId, (Network, Mutex<BTreeMap<PartialAddress, Member>>)>>,
}
impl Cache {
@ -34,12 +35,14 @@ impl Cache {
let networks = db.list_networks().await?;
for network_id in networks {
if let Some(network) = db.get_network(&network_id).await? {
let network_entry = by_nwid.entry(network_id.clone()).or_insert_with(|| (network, Mutex::new(HashMap::new())));
let network_entry = by_nwid
.entry(network_id.clone())
.or_insert_with(|| (network, Mutex::new(BTreeMap::new())));
let mut by_node_id = network_entry.1.lock().unwrap();
let members = db.list_members(&network_id).await?;
for node_id in members {
if let Some(member) = db.get_member(&network_id, &node_id).await? {
let _ = by_node_id.insert(node_id.clone(), member);
let _ = by_node_id.insert(node_id, member);
}
}
}
@ -49,6 +52,7 @@ impl Cache {
}
/// Update a network if changed, returning whether or not any update was made and the old version if any.
///
/// A value of (true, None) indicates that there was no network by that ID in which case it is added.
pub fn on_network_updated(&self, network: Network) -> (bool, Option<Network>) {
let mut by_nwid = self.by_nwid.write().unwrap();
@ -59,31 +63,80 @@ impl Cache {
(false, None)
}
} else {
let _ = by_nwid.insert(network.id.clone(), (network.clone(), Mutex::new(HashMap::new())));
assert!(by_nwid
.insert(network.id.clone(), (network.clone(), Mutex::new(BTreeMap::new())))
.is_none());
(true, None)
}
}
/// Update a member if changed, returning whether or not any update was made and the old version if any.
/// A value of (true, None) indicates that there was no member with that ID. If there is no network with
/// the member's network ID (false, None) is returned and no action is taken.
///
/// A value of (true, None) indicates that there was no member with that ID and that it was added. If
/// there is no network with the member's network ID (false, None) is returned and no action is taken.
pub fn on_member_updated(&self, member: Member) -> (bool, Option<Member>) {
let by_nwid = self.by_nwid.read().unwrap();
if let Some(network) = by_nwid.get(&member.network_id) {
let mut by_node_id = network.1.lock().unwrap();
if let Some(prev_member) = by_node_id.get_mut(&member.node_id) {
if !member.eq(prev_member) {
(true, Some(replace(prev_member, member)))
} else {
(false, None)
if let Some(exact_address_match) = by_node_id.get_mut(&member.node_id) {
if !member.eq(exact_address_match) {
return (true, Some(std::mem::replace(exact_address_match, member)));
}
} else {
let _ = by_node_id.insert(member.node_id.clone(), member);
(true, None)
let mut partial_address_match = None;
for m in by_node_id.range_mut::<PartialAddress, (Bound<&PartialAddress>, Bound<&PartialAddress>)>((
Bound::Included(&member.node_id),
Bound::Unbounded,
)) {
if m.0.matches_partial(&member.node_id) {
if partial_address_match.is_some() {
return (false, None);
}
let _ = partial_address_match.insert(m.1);
} else {
break;
}
}
if let Some(partial_address_match) = partial_address_match {
if !member.eq(partial_address_match) {
return (true, Some(std::mem::replace(partial_address_match, member)));
} else {
return (false, None);
}
}
let mut partial_address_match = None;
for m in by_node_id
.range_mut::<PartialAddress, (Bound<&PartialAddress>, Bound<&PartialAddress>)>((
Bound::Unbounded,
Bound::Included(&member.node_id),
))
.rev()
{
if m.0.matches_partial(&member.node_id) {
if partial_address_match.is_some() {
return (false, None);
}
let _ = partial_address_match.insert(m.1);
} else {
break;
}
}
if let Some(partial_address_match) = partial_address_match {
if !member.eq(partial_address_match) {
return (true, Some(std::mem::replace(partial_address_match, member)));
} else {
return (false, None);
}
}
assert!(by_node_id.insert(member.node_id.clone(), member).is_none());
return (true, None);
}
} else {
(false, None)
}
return (false, None);
}
/// Delete a network, returning it if it existed.
@ -91,7 +144,7 @@ impl Cache {
let mut by_nwid = self.by_nwid.write().unwrap();
let network = by_nwid.remove(&network_id)?;
let mut members = network.1.lock().unwrap();
Some((network.0, members.drain().map(|(_, v)| v).collect()))
Some((network.0, members.values().cloned().collect()))
}
/// Delete a member, returning it if it existed.
@ -99,6 +152,6 @@ impl Cache {
let by_nwid = self.by_nwid.read().unwrap();
let network = by_nwid.get(&network_id)?;
let mut members = network.1.lock().unwrap();
members.remove(&node_id)
members.remove(&node_id.to_partial())
}
}

View file

@ -15,7 +15,6 @@ zerotier-service = { path = "../service" }
async-trait = "^0"
serde = { version = "^1", features = ["derive"], default-features = false }
serde_json = { version = "^1", features = ["std"], default-features = false }
serde_yaml = "^0"
clap = { version = "^3", features = ["std", "suggestions"], default-features = false }
notify = { version = "^5", features = ["macos_fsevent"], default-features = false }
tokio-postgres = "^0"

View file

@ -15,12 +15,12 @@ use zerotier_network_hypervisor::vl2::multicastauthority::MulticastAuthority;
use zerotier_network_hypervisor::vl2::v1::networkconfig::*;
use zerotier_network_hypervisor::vl2::v1::Revocation;
use zerotier_network_hypervisor::vl2::NetworkId;
use zerotier_service::vl1::VL1Service;
use zerotier_utils::buffer::OutOfBoundsError;
use zerotier_utils::cast::cast_ref;
use zerotier_utils::reaper::Reaper;
use zerotier_utils::tokio;
use zerotier_utils::{ms_monotonic, ms_since_epoch};
use zerotier_vl1_service::VL1Service;
use crate::database::*;
use crate::model::{AuthenticationResult, Member, RequestLogItem, CREDENTIAL_WINDOW_SIZE_DEFAULT};
@ -53,7 +53,7 @@ impl Controller {
local_identity: IdentitySecret,
database: Arc<dyn Database>,
) -> Result<Arc<Self>, Box<dyn Error>> {
let c = Arc::new_cyclic(|self_ref| Self {
Ok(Arc::new_cyclic(|self_ref| Self {
self_ref: self_ref.clone(),
reaper: Reaper::new(&runtime),
runtime,
@ -62,16 +62,22 @@ impl Controller {
multicast_authority: MulticastAuthority::new(),
daemons: Mutex::new(Vec::with_capacity(2)),
recently_authorized: RwLock::new(HashMap::new()),
});
}))
}
if let Some(cw) = c.database.changes().await.map(|mut ch| {
let self2 = c.self_ref.clone();
c.runtime.spawn(async move {
/// Start this controller's background tasks.
///
/// Note that the controller only holds a Weak<VL1Service<Self>> to avoid circular references.
pub async fn start(&self, app: &Arc<VL1Service<Self>>) {
if let Some(cw) = self.database.changes().await.map(|mut ch| {
let controller_weak = self.self_ref.clone();
let app_weak = Arc::downgrade(app);
self.runtime.spawn(async move {
loop {
if let Ok(change) = ch.recv().await {
if let Some(self2) = self2.upgrade() {
self2.reaper.add(
self2.runtime.spawn(self2.clone().handle_change_notification(change)),
if let (Some(controller), Some(app)) = (controller_weak.upgrade(), app_weak.upgrade()) {
controller.reaper.add(
controller.runtime.spawn(controller.clone().handle_change_notification(app, change)),
Instant::now().checked_add(REQUEST_TIMEOUT).unwrap(),
);
} else {
@ -81,19 +87,19 @@ impl Controller {
}
})
}) {
c.daemons.lock().unwrap().push(cw);
self.daemons.lock().unwrap().push(cw);
}
let self2 = c.self_ref.clone();
c.daemons.lock().unwrap().push(c.runtime.spawn(async move {
let controller_weak = self.self_ref.clone();
self.daemons.lock().unwrap().push(self.runtime.spawn(async move {
let sleep_duration = Duration::from_millis((protocol::VL2_DEFAULT_MULTICAST_LIKE_EXPIRE / 2).min(2500) as u64);
loop {
tokio::time::sleep(sleep_duration).await;
if let Some(self2) = self2.upgrade() {
if let Some(controller) = controller_weak.upgrade() {
let time_ticks = ms_monotonic();
self2.multicast_authority.clean(time_ticks);
self2.recently_authorized.write().unwrap().retain(|_, by_network| {
controller.multicast_authority.clean(time_ticks);
controller.recently_authorized.write().unwrap().retain(|_, by_network| {
by_network.retain(|_, timeout| *timeout > time_ticks);
!by_network.is_empty()
});
@ -102,12 +108,10 @@ impl Controller {
}
}
}));
Ok(c)
}
/// Launched as a task when the DB informs us of a change.
async fn handle_change_notification(self: Arc<Self>, change: Change) {
async fn handle_change_notification(self: Arc<Self>, app: Arc<VL1Service<Self>>, change: Change) {
match change {
Change::NetworkCreated(_) => {}
Change::NetworkChanged(_, _) => {}
@ -115,10 +119,10 @@ impl Controller {
Change::MemberCreated(_) => {}
Change::MemberChanged(old_member, new_member) => {
if !new_member.authorized() && old_member.authorized() {
self.deauthorize_member(&new_member).await;
self.deauthorize_member(&app, &new_member).await;
}
}
Change::MemberDeleted(member) => self.deauthorize_member(&member).await,
Change::MemberDeleted(member) => self.deauthorize_member(&app, &member).await,
}
}
@ -201,15 +205,16 @@ impl Controller {
let time_clock = ms_since_epoch();
let mut revocations = Vec::with_capacity(1);
if let Ok(all_network_members) = self.database.list_members(&member.network_id).await {
for m in all_network_members.iter() {
if member.node_id != *m {
if let Some(peer) = app.node.peer(m) {
for other_member in all_network_members.iter() {
if member.node_id != *other_member && member.node_id.is_complete() && other_member.is_complete() {
let node_id = member.node_id.as_complete().unwrap();
if let Some(peer) = app.node.peer(node_id) {
revocations.clear();
revocations.push(Revocation::new(
&member.network_id,
time_clock,
&member.node_id,
m,
node_id,
other_member.as_complete().unwrap(),
&self.local_identity,
false,
));
@ -239,7 +244,7 @@ impl Controller {
}
let network = network.unwrap();
let mut member = self.database.get_member(&network_id, &source_identity.address).await?;
let mut member = self.database.get_member(&network_id, &source_identity.address.to_partial()).await?;
let mut member_changed = false;
let mut authentication_result = AuthenticationResult::Rejected;
@ -253,7 +258,7 @@ impl Controller {
if !member_authorized {
if member.is_none() {
if network.learn_members.unwrap_or(true) {
let _ = member.insert(Member::new(source_identity.address.clone(), network_id.clone()));
let _ = member.insert(Member::new(source_identity.clone(), network_id.clone()));
member_changed = true;
} else {
return Ok((AuthenticationResult::Rejected, None));
@ -289,6 +294,16 @@ impl Controller {
let member_authorized = member_authorized;
let authentication_result = authentication_result;
// Pin full address and full identity if these aren't pinned already.
if !member.node_id.is_complete() {
member.node_id = source_identity.address.to_partial();
member_changed = true;
}
if member.identity.is_none() {
let _ = member.identity.insert(source_identity.clone().remove_typestate());
member_changed = true;
}
// Generate network configuration if the member is authorized.
let network_config = if authentication_result.approved() {
// We should not be able to make it here if this is still false.
@ -325,8 +340,7 @@ impl Controller {
nc.rules.reserve(deauthed_members_still_in_window.len() + 1);
let mut or = false;
for dead in deauthed_members_still_in_window.iter() {
nc.rules
.push(vl2::rule::Rule::match_source_zerotier_address(false, or, dead.to_partial()));
nc.rules.push(vl2::rule::Rule::match_source_zerotier_address(false, or, dead.clone()));
or = true;
}
nc.rules.push(vl2::rule::Rule::action_drop());
@ -455,21 +469,26 @@ impl InnerProtocolLayer for Controller {
};
// Launch handler as an async background task.
let app: &VL1Service<Self> = cast_ref(app).unwrap();
let app = app.get();
let (self2, source, source_remote_endpoint) = (self.self_ref.upgrade().unwrap(), source.clone(), source_path.endpoint.clone());
let app = app.concrete_self::<VL1Service<Self>>().unwrap().get_self_arc(); // can't be a dead pointer since we're in a handler being called by it
let (controller, source, source_remote_endpoint) = (self.self_ref.upgrade().unwrap(), source.clone(), source_path.endpoint.clone());
self.reaper.add(
self.runtime.spawn(async move {
let node_id = source.identity.address.clone();
let now = ms_since_epoch();
let (result, config) = match self2.authorize(&source.identity, &network_id, now).await {
let result = match controller.authorize(&source.identity, &network_id, now).await {
Result::Ok((result, Some(config))) => {
//println!("{}", serde_yaml::to_string(&config).unwrap());
self2.send_network_config(app.as_ref(), &app.node, cast_ref(source.as_ref()).unwrap(), &config, Some(message_id));
(result, Some(config))
controller.send_network_config(
app.as_ref(),
&app.node,
cast_ref(source.as_ref()).unwrap(),
&config,
Some(message_id),
);
result
}
Result::Ok((result, None)) => (result, None),
Result::Ok((result, None)) => result,
Result::Err(e) => {
#[cfg(debug_assertions)]
debug_event!(app, "[vl2] ERROR getting network config: {}", e.to_string());
@ -477,12 +496,12 @@ impl InnerProtocolLayer for Controller {
}
};
let _ = self2
let _ = controller
.database
.log_request(RequestLogItem {
network_id,
node_id,
controller_node_id: self2.local_identity.public.address.clone(),
controller_node_id: controller.local_identity.public.address.clone(),
metadata,
peer_version: source.version(),
peer_protocol_version: source.protocol_version(),

View file

@ -1,7 +1,7 @@
use async_trait::async_trait;
use zerotier_crypto::secure_eq;
use zerotier_network_hypervisor::vl1::{Address, InetAddress};
use zerotier_network_hypervisor::vl1::{InetAddress, PartialAddress};
use zerotier_network_hypervisor::vl2::NetworkId;
use zerotier_utils::tokio::sync::broadcast::Receiver;
@ -22,22 +22,50 @@ pub enum Change {
#[async_trait]
pub trait Database: Sync + Send + 'static {
/// List networks on this controller.
async fn list_networks(&self) -> Result<Vec<NetworkId>, Error>;
/// Get a network by network ID.
async fn get_network(&self, id: &NetworkId) -> Result<Option<Network>, Error>;
/// Save a network.
///
/// Note that unlike members the network ID is not automatically promoted from legacy to full
/// ID format.
async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Error>;
async fn list_members(&self, network_id: &NetworkId) -> Result<Vec<Address>, Error>;
async fn get_member(&self, network_id: &NetworkId, node_id: &Address) -> Result<Option<Member>, Error>;
/// List members of a network.
async fn list_members(&self, network_id: &NetworkId) -> Result<Vec<PartialAddress>, Error>;
/// Get a member of network.
///
/// If node_id is not a complete address, the best unique match should be returned. None should
/// be returned not only if the member is not found but if node_id is ambiguous (would match more
/// than one member).
async fn get_member(&self, network_id: &NetworkId, node_id: &PartialAddress) -> Result<Option<Member>, Error>;
/// Save a modified member to a network.
///
/// Note that member modifications can include the automatic replacement of a less specific address
/// in node_id with a fully specific address. This happens the first time a member added with an
/// incomplete address is actually seen. In that case the implementation must correctly find the
/// best matching existing member and replace it with a member identified by the fully specified
/// address, removing and re-adding if needed.
///
/// This must also handle the (rare) case when someone may try to save a member with a less
/// specific address than the one currently in the database. In that case the "old" more specific
/// address should replace the less specific address in the node_id field. This can only happen if
/// an external user manually does this. The controller won't do this automatically.
async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Error>;
/// Save a log entry for a request this controller has handled.
async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error>;
/// Get a receiver that can be used to receive changes made to networks and members, if supported.
///
/// The receiver returned is a broadcast receiver. This can be called more than once if there are
/// multiple parts of the controller that listen.
///
/// Changes should NOT be broadcast on call to save_network() or save_member(). They should only
/// be broadcast when externally generated changes occur.
///
/// The default implementation returns None indicating that change following is not supported.
/// Change following is required for instant deauthorization with revocations and other instant
/// changes in response to modifications to network and member configuration.
@ -49,7 +77,7 @@ pub trait Database: Sync + Send + 'static {
///
/// The default trait implementation uses a brute force method. This should be reimplemented if a
/// more efficient way is available.
async fn list_members_deauthorized_after(&self, network_id: &NetworkId, cutoff: i64) -> Result<Vec<Address>, Error> {
async fn list_members_deauthorized_after(&self, network_id: &NetworkId, cutoff: i64) -> Result<Vec<PartialAddress>, Error> {
let mut v = Vec::new();
let members = self.list_members(network_id).await?;
for a in members.iter() {
@ -77,6 +105,4 @@ pub trait Database: Sync + Send + 'static {
}
return Ok(false);
}
async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error>;
}

View file

@ -1,404 +1,201 @@
use std::collections::BTreeMap;
use std::mem::replace;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::{Arc, Mutex, Weak};
use std::sync::{Arc, Weak};
use serde::{Deserialize, Serialize};
use async_trait::async_trait;
use notify::{RecursiveMode, Watcher};
use serde::de::DeserializeOwned;
use zerotier_utils::tokio::io::AsyncWriteExt;
use zerotier_network_hypervisor::vl1::Address;
use crate::database;
use crate::database::Change;
use crate::model::{Member, Network, RequestLogItem};
use zerotier_network_hypervisor::vl1::PartialAddress;
use zerotier_network_hypervisor::vl2::NetworkId;
use zerotier_utils::reaper::Reaper;
use zerotier_utils::tokio::fs;
use zerotier_utils::tokio::runtime::Handle;
use zerotier_utils::tokio::sync::broadcast::{channel, Receiver, Sender};
use zerotier_utils::tokio::task::JoinHandle;
use zerotier_utils::tokio::time::{sleep, Duration, Instant};
use zerotier_utils::tokio;
use zerotier_utils::tokio::sync::{broadcast, mpsc};
use crate::cache::Cache;
use crate::database::{Change, Database, Error};
use crate::model::*;
const EVENT_HANDLER_TASK_TIMEOUT: Duration = Duration::from_secs(10);
/// An in-filesystem database that permits live editing.
///
/// A cache is maintained that contains the actual objects. When an object is live edited,
/// once it successfully reads and loads it is merged with the cached object and saved to
/// the cache. The cache will also contain any ephemeral data, generated data, etc.
///
/// The file format is YAML instead of JSON for better human friendliness and the layout
/// is different from V1 so it'll need a converter to use with V1 FileDb controller data.
pub struct FileDatabase {
base_path: PathBuf,
controller_address: Address,
change_sender: Sender<Change>,
tasks: Reaper,
cache: Cache,
daemon: JoinHandle<()>,
db_path: PathBuf,
log: Option<tokio::sync::Mutex<tokio::fs::File>>,
data: tokio::sync::Mutex<(BTreeMap<NetworkId, FileDbNetwork>, bool)>,
change_sender: broadcast::Sender<Change>,
file_write_notify_sender: mpsc::Sender<()>,
file_writer: tokio::task::JoinHandle<()>,
}
// TODO: should cache at least hashes and detect changes in the filesystem live.
#[derive(Serialize, Deserialize)]
struct FileDbNetwork {
pub config: Network,
pub members: BTreeMap<PartialAddress, Member>,
}
impl FileDatabase {
pub async fn new<P: AsRef<Path>>(runtime: Handle, base_path: P, controller_address: Address) -> Result<Arc<Self>, Error> {
let base_path: PathBuf = base_path.as_ref().into();
pub async fn new(db_path: &Path, log_path: Option<&Path>) -> Result<Arc<Self>, Box<dyn std::error::Error + Send + Sync>> {
let data_bytes = tokio::fs::read(db_path).await;
let mut data: BTreeMap<NetworkId, FileDbNetwork> = BTreeMap::new();
if let Err(e) = data_bytes {
if !matches!(e.kind(), tokio::io::ErrorKind::NotFound) {
return Err(Box::new(e));
}
} else {
data = serde_json::from_slice(data_bytes.as_ref().unwrap().as_slice())?;
}
let (change_sender, _) = channel(256);
let db_weak_tmp: Arc<Mutex<Weak<Self>>> = Arc::new(Mutex::new(Weak::default()));
let db_weak = db_weak_tmp.clone();
let runtime2 = runtime.clone();
let log = if let Some(log_path) = log_path {
Some(tokio::sync::Mutex::new(
tokio::fs::OpenOptions::new().append(true).create(true).mode(0o600).open(log_path).await?,
))
} else {
None
};
let db = Arc::new(Self {
base_path: base_path.clone(),
controller_address: controller_address.clone(),
change_sender,
tasks: Reaper::new(&runtime2),
cache: Cache::new(),
daemon: runtime2.spawn(async move {
let mut watcher = notify::recommended_watcher(move |event: notify::Result<notify::event::Event>| {
if let Ok(event) = event {
match event.kind {
notify::EventKind::Create(_) | notify::EventKind::Modify(_) | notify::EventKind::Remove(_) => {
if let Some(db) = db_weak.lock().unwrap().upgrade() {
let controller_address2 = controller_address.clone();
db.clone().tasks.add(
runtime.spawn(async move {
if let Some(path0) = event.paths.first() {
if let Some((record_type, network_id, node_id)) =
Self::record_type_from_path(controller_address2, path0.as_path())
{
// Paths to objects that were deleted or changed. Changed includes adding new objects.
let mut deleted = None;
let mut changed = None;
match event.kind {
notify::EventKind::Create(create_kind) => match create_kind {
notify::event::CreateKind::File => {
changed = Some(path0.as_path());
}
_ => {}
},
notify::EventKind::Modify(modify_kind) => match modify_kind {
notify::event::ModifyKind::Data(_) => {
changed = Some(path0.as_path());
}
notify::event::ModifyKind::Name(rename_mode) => match rename_mode {
notify::event::RenameMode::Both => {
if event.paths.len() >= 2 {
if let Some(path1) = event.paths.last() {
deleted = Some(path0.as_path());
changed = Some(path1.as_path());
}
}
}
notify::event::RenameMode::From => {
deleted = Some(path0.as_path());
}
notify::event::RenameMode::To => {
changed = Some(path0.as_path());
}
_ => {}
},
_ => {}
},
notify::EventKind::Remove(remove_kind) => match remove_kind {
notify::event::RemoveKind::File => {
deleted = Some(path0.as_path());
}
_ => {}
},
_ => {}
}
if deleted.is_some() {
match record_type {
RecordType::Network => {
if let Some((network, members)) = db.cache.on_network_deleted(network_id) {
let _ = db.change_sender.send(Change::NetworkDeleted(network, members));
}
}
RecordType::Member => {
if let Some(node_id) = node_id {
if let Some(member) = db.cache.on_member_deleted(network_id, node_id) {
let _ = db.change_sender.send(Change::MemberDeleted(member));
}
}
}
_ => {}
}
}
if let Some(changed) = changed {
match record_type {
RecordType::Network => {
if let Ok(Some(new_network)) = Self::load_object::<Network>(changed).await {
match db.cache.on_network_updated(new_network.clone()) {
(true, Some(old_network)) => {
let _ = db
.change_sender
.send(Change::NetworkChanged(old_network, new_network));
}
(true, None) => {
let _ = db.change_sender.send(Change::NetworkCreated(new_network));
}
_ => {}
}
}
}
RecordType::Member => {
if let Ok(Some(new_member)) = Self::load_object::<Member>(changed).await {
match db.cache.on_member_updated(new_member.clone()) {
(true, Some(old_member)) => {
let _ =
db.change_sender.send(Change::MemberChanged(old_member, new_member));
}
(true, None) => {
let _ = db.change_sender.send(Change::MemberCreated(new_member));
}
_ => {}
}
}
}
_ => {}
}
}
}
}
}),
Instant::now().checked_add(EVENT_HANDLER_TASK_TIMEOUT).unwrap(),
let (file_write_notify_sender, mut file_write_notify_receiver) = mpsc::channel(16);
let db = Arc::new_cyclic(|self_weak: &Weak<FileDatabase>| {
let self_weak = self_weak.clone();
Self {
db_path: db_path.to_path_buf(),
log,
data: tokio::sync::Mutex::new((data, false)),
change_sender: broadcast::channel(16).0,
file_write_notify_sender,
file_writer: tokio::task::spawn(async move {
loop {
file_write_notify_receiver.recv().await;
if let Some(db) = self_weak.upgrade() {
let mut data = db.data.lock().await;
if data.1 {
let json = zerotier_utils::json::to_json_pretty(&data.0);
if let Err(e) = tokio::fs::write(db.db_path.as_path(), json.as_bytes()).await {
eprintln!(
"WARNING: controller changes not persisted! unable to write file database to '{}': {}",
db.db_path.to_string_lossy(),
e.to_string()
);
} else {
data.1 = false;
}
}
_ => {}
} else {
break;
}
}
})
.expect("FATAL: unable to start filesystem change listener");
let _ = watcher.configure(
notify::Config::default()
.with_compare_contents(true)
.with_poll_interval(std::time::Duration::from_secs(2)),
);
watcher
.watch(&base_path, RecursiveMode::Recursive)
.expect("FATAL: unable to watch base path");
loop {
// Any periodic background stuff can be put here. Adjust timing as needed.
sleep(Duration::from_secs(10)).await;
}
}),
}),
}
});
db.cache.load_all(db.as_ref()).await?;
*db_weak_tmp.lock().unwrap() = Arc::downgrade(&db); // this starts the daemon tasks and starts watching for file changes
Ok(db)
}
fn network_path(&self, network_id: &NetworkId) -> PathBuf {
self.base_path.join(format!("N{:06x}", network_id.network_no())).join("config.yaml")
}
fn member_path(&self, network_id: &NetworkId, member_id: &Address) -> PathBuf {
self.base_path
.join(format!("N{:06x}", network_id.network_no()))
.join(format!("M{}.yaml", member_id.to_string()))
}
async fn load_object<O: DeserializeOwned>(path: &Path) -> Result<Option<O>, Error> {
if let Ok(raw) = fs::read(path).await {
return Ok(Some(serde_yaml::from_slice::<O>(raw.as_slice())?));
} else {
return Ok(None);
}
}
/// Get record type and also the number after it: network number or address.
fn record_type_from_path(controller_address: Address, p: &Path) -> Option<(RecordType, NetworkId, Option<Address>)> {
let parent = p.parent()?.file_name()?.to_string_lossy();
if parent.len() == 7 && (parent.starts_with("N") || parent.starts_with('n')) {
let network_id = NetworkId::Full(controller_address, u32::from_str_radix(&parent[1..], 16).ok()?);
if let Some(file_name) = p.file_name().map(|p| p.to_string_lossy().to_lowercase()) {
if file_name.eq("config.yaml") {
return Some((RecordType::Network, network_id, None));
} else if file_name.len() == 16 && file_name.starts_with("m") && file_name.ends_with(".yaml") {
return Some((
RecordType::Member,
network_id,
Some(Address::from_str(&file_name.as_str()[1..file_name.len() - 5]).ok()?),
));
}
}
}
return None;
}
}
impl Drop for FileDatabase {
fn drop(&mut self) {
self.daemon.abort();
self.file_writer.abort();
}
}
#[async_trait]
impl Database for FileDatabase {
async fn list_networks(&self) -> Result<Vec<NetworkId>, Error> {
let mut networks = Vec::new();
let mut dir = fs::read_dir(&self.base_path).await?;
while let Ok(Some(ent)) = dir.next_entry().await {
if ent.file_type().await.map_or(false, |t| t.is_dir()) {
let osname = ent.file_name();
let name = osname.to_string_lossy();
if name.len() == 7 && name.starts_with("N") {
if fs::metadata(ent.path().join("config.yaml")).await.is_ok() {
if let Ok(network_no) = u32::from_str_radix(&name[1..], 16) {
networks.push(NetworkId::Full(self.controller_address.clone(), network_no));
}
}
impl database::Database for FileDatabase {
async fn list_networks(&self) -> Result<Vec<NetworkId>, database::Error> {
Ok(self.data.lock().await.0.keys().cloned().collect())
}
async fn get_network(&self, id: &NetworkId) -> Result<Option<Network>, database::Error> {
Ok(self.data.lock().await.0.get(id).map(|x| x.config.clone()))
}
async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), database::Error> {
let mut data = self.data.lock().await;
if let Some(nw) = data.0.get_mut(&obj.id) {
if !nw.config.eq(&obj) {
let old = replace(&mut nw.config, obj);
if generate_change_notification {
let _ = self.change_sender.send(Change::NetworkChanged(old, nw.config.clone()));
}
let _ = self.file_write_notify_sender.send(()).await;
}
}
Ok(networks)
}
async fn get_network(&self, id: &NetworkId) -> Result<Option<Network>, Error> {
let mut network = Self::load_object::<Network>(self.network_path(id).as_path()).await?;
if let Some(network) = network.as_mut() {
// FileDatabase stores networks by their "network number" and automatically adapts their IDs
// if the controller's identity changes. This is done to make it easy to just clone networks,
// including storing them in "git."
let network_id_should_be = NetworkId::Full(self.controller_address.clone(), network.id.network_no());
if network.id != network_id_should_be {
network.id = network_id_should_be;
let _ = self.save_network(network.clone(), false).await?;
} else {
data.0
.insert(obj.id.clone(), FileDbNetwork { config: obj.clone(), members: BTreeMap::new() });
if generate_change_notification {
let _ = self.change_sender.send(Change::NetworkCreated(obj));
}
let _ = self.file_write_notify_sender.send(()).await;
}
Ok(network)
}
async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Error> {
if !generate_change_notification {
let _ = self.cache.on_network_updated(obj.clone());
}
let base_network_path = self.network_path(&obj.id);
let _ = fs::create_dir_all(base_network_path.parent().unwrap()).await;
let _ = fs::write(base_network_path, serde_yaml::to_string(&obj)?.as_bytes()).await?;
return Ok(());
}
async fn list_members(&self, network_id: &NetworkId) -> Result<Vec<Address>, Error> {
let mut members = Vec::new();
let mut dir = fs::read_dir(self.base_path.join(format!("N{:06x}", network_id.network_no()))).await?;
while let Ok(Some(ent)) = dir.next_entry().await {
if ent.file_type().await.map_or(false, |t| t.is_file() || t.is_symlink()) {
let osname = ent.file_name();
let name = osname.to_string_lossy();
if name.starts_with("M") && name.ends_with(".yaml") {
if let Ok(member_address) = Address::from_str(&name[1..name.len() - 5]) {
members.push(member_address);
}
}
}
}
Ok(members)
async fn list_members(&self, network_id: &NetworkId) -> Result<Vec<PartialAddress>, database::Error> {
Ok(self
.data
.lock()
.await
.0
.get(network_id)
.map_or_else(|| Vec::new(), |x| x.members.keys().cloned().collect()))
}
async fn get_member(&self, network_id: &NetworkId, node_id: &Address) -> Result<Option<Member>, Error> {
let mut member = Self::load_object::<Member>(self.member_path(&network_id, node_id).as_path()).await?;
if let Some(member) = member.as_mut() {
if member.network_id.eq(network_id) {
// Also auto-update member network IDs, see get_network().
member.network_id = network_id.clone();
self.save_member(member.clone(), false).await?;
}
}
Ok(member)
async fn get_member(&self, network_id: &NetworkId, node_id: &PartialAddress) -> Result<Option<Member>, database::Error> {
Ok(self
.data
.lock()
.await
.0
.get_mut(network_id)
.and_then(|x| node_id.find_unique_match(&x.members).cloned()))
}
async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Error> {
if !generate_change_notification {
let _ = self.cache.on_member_updated(obj.clone());
}
let base_member_path = self.member_path(&obj.network_id, &obj.node_id);
let _ = fs::create_dir_all(base_member_path.parent().unwrap()).await;
let _ = fs::write(base_member_path, serde_yaml::to_string(&obj)?.as_bytes()).await?;
Ok(())
}
async fn save_member(&self, mut obj: Member, generate_change_notification: bool) -> Result<(), database::Error> {
let mut data = self.data.lock().await;
if let Some(nw) = data.0.get_mut(&obj.network_id) {
if let Some(member) = obj.node_id.find_unique_match_mut(&mut nw.members) {
if !obj.eq(member) {
if member.node_id.specificity_bytes() != obj.node_id.specificity_bytes() {
// If the specificity of the node_id has changed we have to delete and re-add the entry.
async fn changes(&self) -> Option<Receiver<Change>> {
Some(self.change_sender.subscribe())
}
let old_node_id = member.node_id.clone();
let old = nw.members.remove(&old_node_id);
async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error> {
println!("{}", obj.to_string());
Ok(())
}
}
if old_node_id.specificity_bytes() > obj.node_id.specificity_bytes() {
obj.node_id = old_node_id;
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use zerotier_network_hypervisor::vl1::identity::Identity;
nw.members.insert(obj.node_id.clone(), obj.clone());
/* TODO
#[allow(unused)]
#[test]
fn test_db() {
if let Ok(tokio_runtime) = zerotier_utils::tokio::runtime::Builder::new_current_thread().enable_all().build() {
let _ = tokio_runtime.block_on(async {
let node_id = Address::from_u64(0xdeadbeefu64).unwrap();
let network_id = NetworkId::from_u64(0xfeedbeefcafebabeu64).unwrap();
let test_dir = std::env::temp_dir().join("zt_filedatabase_test");
println!("test filedatabase is in: {}", test_dir.as_os_str().to_str().unwrap());
let _ = std::fs::remove_dir_all(&test_dir);
let controller_id = Identity::generate(false);
assert!(fs::create_dir_all(&test_dir).await.is_ok());
let db = Arc::new(
FileDatabase::new(tokio_runtime.handle().clone(), test_dir, controller_id.public.address.clone())
.await
.expect("new db"),
);
let change_count = Arc::new(AtomicUsize::new(0));
let db2 = db.clone();
let change_count2 = change_count.clone();
tokio_runtime.spawn(async move {
let mut change_receiver = db2.changes().await.unwrap();
loop {
if let Ok(change) = change_receiver.recv().await {
change_count2.fetch_add(1, Ordering::SeqCst);
//println!("[FileDatabase] {:#?}", change);
} else {
break;
if generate_change_notification {
let _ = self.change_sender.send(Change::MemberChanged(old.unwrap(), obj));
}
} else {
let old = replace(member, obj);
if generate_change_notification {
let _ = self.change_sender.send(Change::MemberChanged(old, member.clone()));
}
}
});
let mut test_network = Network::new(network_id);
db.save_network(test_network.clone(), true).await.expect("network save error");
let mut test_member = Member::new_without_identity(node_id, network_id);
for x in 0..3 {
test_member.name = x.to_string();
db.save_member(test_member.clone(), true).await.expect("member save error");
zerotier_utils::tokio::task::yield_now().await;
sleep(Duration::from_millis(100)).await;
zerotier_utils::tokio::task::yield_now().await;
let test_member2 = db.get_member(&network_id, &node_id).await.unwrap().unwrap();
assert!(test_member == test_member2);
let _ = self.file_write_notify_sender.send(()).await;
}
});
} else {
let _ = nw.members.insert(obj.node_id.clone(), obj.clone());
if generate_change_notification {
let _ = self.change_sender.send(Change::MemberCreated(obj));
}
let _ = self.file_write_notify_sender.send(()).await;
}
}
return Ok(());
}
async fn log_request(&self, obj: RequestLogItem) -> Result<(), database::Error> {
if let Some(log) = self.log.as_ref() {
let mut json_line = zerotier_utils::json::to_json(&obj);
json_line.push('\n');
let _ = log.lock().await.write_all(json_line.as_bytes()).await;
}
Ok(())
}
async fn changes(&self) -> Option<broadcast::Receiver<Change>> {
Some(self.change_sender.subscribe())
}
*/
}

View file

@ -2,8 +2,6 @@
mod controller;
pub(crate) mod cache;
pub mod database;
pub mod filedatabase;
pub mod model;

View file

@ -1,5 +1,7 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently proprietary pending actual release and licensing. See LICENSE.md.
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use zerotier_network_controller::database::Database;
@ -7,23 +9,24 @@ use zerotier_network_controller::filedatabase::FileDatabase;
use zerotier_network_controller::Controller;
use zerotier_network_hypervisor::vl1::identity::IdentitySecret;
use zerotier_network_hypervisor::{VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION};
use zerotier_service::vl1::{VL1Service, VL1Settings};
use zerotier_utils::exitcode;
use zerotier_utils::tokio;
use zerotier_utils::tokio::runtime::Runtime;
use zerotier_vl1_service::VL1Service;
async fn run(identity: IdentitySecret, runtime: &Runtime) -> i32 {
match Controller::new(database.clone(), runtime.handle().clone()).await {
async fn run(database: Arc<dyn Database>, identity: IdentitySecret, runtime: &Runtime) -> i32 {
match Controller::new(runtime.handle().clone(), identity.clone(), database.clone()).await {
Err(err) => {
eprintln!("FATAL: error initializing handler: {}", err.to_string());
exitcode::ERR_CONFIG
}
Ok(handler) => match VL1Service::new(identity, handler.clone(), zerotier_vl1_service::VL1Settings::default()) {
Ok(handler) => match VL1Service::new(identity, handler.clone(), VL1Settings::default()) {
Err(err) => {
eprintln!("FATAL: error launching service: {}", err.to_string());
exitcode::ERR_IOERR
}
Ok(svc) => {
svc.node().init_default_roots();
svc.node.init_default_roots();
handler.start(&svc).await;
zerotier_utils::wait_for_process_abort();
println!("Terminate signal received, shutting down...");
@ -36,13 +39,33 @@ async fn run(identity: IdentitySecret, runtime: &Runtime) -> i32 {
fn main() {
const REQUIRE_ONE_OF_ARGS: [&'static str; 2] = ["postgres", "filedb"];
let global_args = clap::Command::new("zerotier-controller")
.arg(
clap::Arg::new("identity")
.short('i')
.long("identity")
.takes_value(true)
.forbid_empty_values(true)
.value_name("identity")
.help(Some("Path to secret ZeroTier identity"))
.required(true),
)
.arg(
clap::Arg::new("logfile")
.short('l')
.long("logfile")
.takes_value(true)
.forbid_empty_values(true)
.value_name("logfile")
.help(Some("Path to log file"))
.required(false),
)
.arg(
clap::Arg::new("filedb")
.short('f')
.long("filedb")
.takes_value(true)
.forbid_empty_values(true)
.value_name("path")
.value_name("filedb")
.help(Some("Use filesystem database at path"))
.required_unless_present_any(&REQUIRE_ONE_OF_ARGS),
)
@ -52,8 +75,8 @@ fn main() {
.long("postgres")
.takes_value(true)
.forbid_empty_values(true)
.value_name("path")
.help(Some("Connect to postgres with parameters in YAML file"))
.value_name("postgres")
.help(Some("Connect to postgres with supplied URL"))
.required_unless_present_any(&REQUIRE_ONE_OF_ARGS),
)
.version(format!("{}.{}.{}", VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION).as_str())
@ -64,23 +87,39 @@ fn main() {
std::process::exit(exitcode::ERR_USAGE);
});
if let Ok(tokio_runtime) = zerotier_utils::tokio::runtime::Builder::new_multi_thread().enable_all().build() {
if let Ok(tokio_runtime) = tokio::runtime::Builder::new_multi_thread().enable_all().build() {
tokio_runtime.block_on(async {
if let Some(filedb_base_path) = global_args.value_of("filedb") {
let file_db = FileDatabase::new(tokio_runtime.handle().clone(), filedb_base_path).await;
let identity = if let Ok(identity_data) = tokio::fs::read(global_args.value_of("identity").unwrap()).await {
if let Ok(identity) = IdentitySecret::from_str(String::from_utf8_lossy(identity_data.as_slice()).as_ref()) {
identity
} else {
eprintln!("FATAL: invalid secret identity");
std::process::exit(exitcode::ERR_CONFIG);
}
} else {
eprintln!("FATAL: unable to read secret identity");
std::process::exit(exitcode::ERR_IOERR);
};
let db: Arc<dyn Database> = if let Some(filedb_path) = global_args.value_of("filedb") {
let file_db = FileDatabase::new(Path::new(filedb_path), global_args.value_of("logfile").map(|l| Path::new(l))).await;
if file_db.is_err() {
eprintln!(
"FATAL: unable to open filesystem database at {}: {}",
filedb_base_path,
filedb_path,
file_db.as_ref().err().unwrap().to_string()
);
std::process::exit(exitcode::ERR_IOERR)
}
std::process::exit(run(file_db.unwrap(), &tokio_runtime).await);
file_db.unwrap()
} else if let Some(_postgres_url) = global_args.value_of("postgres") {
panic!("not implemented yet");
} else {
eprintln!("FATAL: no database type selected.");
std::process::exit(exitcode::ERR_USAGE);
};
std::process::exit(run(db, identity, &tokio_runtime).await);
});
} else {
eprintln!("FATAL: can't start async runtime");

View file

@ -5,17 +5,28 @@ use std::hash::Hash;
use serde::{Deserialize, Serialize};
use zerotier_network_hypervisor::vl1::{Address, InetAddress};
use zerotier_crypto::typestate::Valid;
use zerotier_network_hypervisor::vl1::identity::Identity;
use zerotier_network_hypervisor::vl1::{InetAddress, PartialAddress};
use zerotier_network_hypervisor::vl2::NetworkId;
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
pub struct Member {
/// Member node ID
///
/// This can be a partial address if it was manually added as such by a user. As soon as a node matching
/// this partial is seen, this will be replaced by a full specificity PartialAddress from the querying
/// node's full identity. The 'identity' field will also be populated in this case.
#[serde(rename = "address")]
pub node_id: Address,
pub node_id: PartialAddress,
#[serde(rename = "networkId")]
pub network_id: NetworkId,
/// Full identity of this node, if known.
#[serde(skip_serializing_if = "Option::is_none")]
pub identity: Option<Identity>,
/// A short name that can also be used for DNS, etc.
#[serde(skip_serializing_if = "String::is_empty")]
#[serde(default)]
@ -68,11 +79,11 @@ pub struct Member {
}
impl Member {
/// Create a new network member without specifying a "pinned" identity.
pub fn new(node_id: Address, network_id: NetworkId) -> Self {
pub fn new(node_identity: Valid<Identity>, network_id: NetworkId) -> Self {
Self {
node_id,
node_id: node_identity.address.to_partial(),
network_id,
identity: Some(node_identity.remove_typestate()),
name: String::new(),
last_authorized_time: None,
last_deauthorized_time: None,
@ -85,7 +96,8 @@ impl Member {
}
}
/// Check whether this member is authorized, which is true if the last authorized time is after last deauthorized time.
/// Check whether this member is authorized.
/// This is true if the last authorized time is after last deauthorized time.
pub fn authorized(&self) -> bool {
self.last_authorized_time
.map_or(false, |la| self.last_deauthorized_time.map_or(true, |ld| la > ld))

View file

@ -6,8 +6,6 @@ mod network;
pub use member::*;
pub use network::*;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use zerotier_network_hypervisor::vl1::{Address, Endpoint};
@ -20,13 +18,6 @@ pub enum RecordType {
RequestLogItem,
}
/// A complete network with all member configuration information for import/export or blob storage.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct NetworkExport {
pub network: Network,
pub members: HashMap<Address, Member>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum AuthenticationResult {

View file

@ -1,8 +1,10 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently proprietary pending actual release and licensing. See LICENSE.md.
use std::borrow::Borrow;
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Bound;
use std::str::FromStr;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
@ -24,8 +26,11 @@ pub struct Address(pub(super) [u8; Self::SIZE_BYTES]);
/// A partial address, which is bytes and the number of bytes of specificity (similar to a CIDR IP address).
///
/// Partial addresses are looked up to get full addresses (and identities) via roots using WHOIS messages.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct PartialAddress(pub(super) Address, pub(super) u16);
#[derive(Clone, PartialEq, Eq)]
pub struct PartialAddress {
pub(super) address: Address,
pub(super) specificity: u16,
}
impl Address {
pub const SIZE_BYTES: usize = 48;
@ -60,21 +65,25 @@ impl Address {
}
/// Get a partial address object (with full specificity) for this address
#[inline(always)]
#[inline]
pub fn to_partial(&self) -> PartialAddress {
PartialAddress(Address(self.0), Self::SIZE_BYTES as u16)
PartialAddress {
address: Address(self.0),
specificity: Self::SIZE_BYTES as u16,
}
}
/// Get a partial address covering the 40-bit legacy address.
#[inline]
pub fn to_legacy_partial(&self) -> PartialAddress {
PartialAddress(
Address({
PartialAddress {
address: Address({
let mut tmp = [0u8; PartialAddress::MAX_SIZE_BYTES];
tmp[..PartialAddress::LEGACY_SIZE_BYTES].copy_from_slice(&self.0[..PartialAddress::LEGACY_SIZE_BYTES]);
tmp
}),
PartialAddress::LEGACY_SIZE_BYTES as u16,
)
specificity: PartialAddress::LEGACY_SIZE_BYTES as u16,
}
}
#[inline(always)]
@ -195,8 +204,11 @@ impl PartialAddress {
&& b[0] != Address::RESERVED_PREFIX
&& b[..Self::LEGACY_SIZE_BYTES].iter().any(|i| *i != 0)
{
let mut a = Self(Address([0u8; Address::SIZE_BYTES]), b.len() as u16);
a.0 .0[..b.len()].copy_from_slice(b);
let mut a = Self {
address: Address([0u8; Address::SIZE_BYTES]),
specificity: b.len() as u16,
};
a.address.0[..b.len()].copy_from_slice(b);
Ok(a)
} else {
Err(InvalidParameterError("invalid address"))
@ -206,14 +218,14 @@ impl PartialAddress {
#[inline]
pub(crate) fn from_legacy_address_bytes(b: &[u8; 5]) -> Result<Self, InvalidParameterError> {
if b[0] != Address::RESERVED_PREFIX && b.iter().any(|i| *i != 0) {
Ok(Self(
Address({
Ok(Self {
address: Address({
let mut tmp = [0u8; Self::MAX_SIZE_BYTES];
tmp[..5].copy_from_slice(b);
tmp
}),
Self::LEGACY_SIZE_BYTES as u16,
))
specificity: Self::LEGACY_SIZE_BYTES as u16,
})
} else {
Err(InvalidParameterError("invalid address"))
}
@ -223,14 +235,14 @@ impl PartialAddress {
pub(crate) fn from_legacy_address_u64(mut b: u64) -> Result<Self, InvalidParameterError> {
b &= 0xffffffffff;
if b.wrapping_shr(32) != (Address::RESERVED_PREFIX as u64) && b != 0 {
Ok(Self(
Address({
Ok(Self {
address: Address({
let mut tmp = [0u8; Self::MAX_SIZE_BYTES];
tmp[..5].copy_from_slice(&b.to_be_bytes()[..5]);
tmp
}),
Self::LEGACY_SIZE_BYTES as u16,
))
specificity: Self::LEGACY_SIZE_BYTES as u16,
})
} else {
Err(InvalidParameterError("invalid address"))
}
@ -238,45 +250,60 @@ impl PartialAddress {
#[inline(always)]
pub fn as_bytes(&self) -> &[u8] {
debug_assert!(self.1 >= Self::MIN_SIZE_BYTES as u16);
&self.0 .0[..self.1 as usize]
debug_assert!(self.specificity >= Self::MIN_SIZE_BYTES as u16);
&self.address.0[..self.specificity as usize]
}
#[inline(always)]
pub(crate) fn legacy_bytes(&self) -> &[u8; 5] {
debug_assert!(self.1 >= Self::MIN_SIZE_BYTES as u16);
memory::array_range::<u8, { Address::SIZE_BYTES }, 0, { PartialAddress::LEGACY_SIZE_BYTES }>(&self.0 .0)
debug_assert!(self.specificity >= Self::MIN_SIZE_BYTES as u16);
memory::array_range::<u8, { Address::SIZE_BYTES }, 0, { PartialAddress::LEGACY_SIZE_BYTES }>(&self.address.0)
}
#[inline(always)]
pub(crate) fn legacy_u64(&self) -> u64 {
u64::from_be(memory::load_raw(&self.0 .0)).wrapping_shr(24)
u64::from_be(memory::load_raw(&self.address.0)).wrapping_shr(24)
}
/// Returns true if this partial address matches a full length address up to this partial's specificity.
#[inline(always)]
pub(super) fn matches(&self, k: &Address) -> bool {
debug_assert!(self.1 >= Self::MIN_SIZE_BYTES as u16);
let l = self.1 as usize;
self.0 .0[..l].eq(&k.0[..l])
pub fn matches(&self, k: &Address) -> bool {
debug_assert!(self.specificity >= Self::MIN_SIZE_BYTES as u16);
let l = self.specificity as usize;
self.address.0[..l].eq(&k.0[..l])
}
/// Returns true if this partial address matches another up to the lower of the two addresses' specificities.
#[inline(always)]
pub fn matches_partial(&self, k: &PartialAddress) -> bool {
debug_assert!(self.specificity >= Self::MIN_SIZE_BYTES as u16);
let l = self.specificity.min(k.specificity) as usize;
self.address.0[..l].eq(&k.address.0[..l])
}
/// Get the number of bits of specificity in this address
#[inline(always)]
pub fn specificity(&self) -> usize {
(self.1 * 8) as usize
pub fn specificity_bits(&self) -> usize {
(self.specificity * 8) as usize
}
/// Get the number of bytes of specificity in this address (only 8 bit increments in specificity are allowed)
#[inline(always)]
pub fn specificity_bytes(&self) -> usize {
self.specificity as usize
}
/// Returns true if this address has legacy 40 bit specificity (V1 ZeroTier address)
#[inline(always)]
pub fn is_legacy(&self) -> bool {
self.1 == Self::LEGACY_SIZE_BYTES as u16
self.specificity == Self::LEGACY_SIZE_BYTES as u16
}
/// Get a full length address if this partial address is actually complete (384 bits of specificity)
#[inline(always)]
pub fn as_address(&self) -> Option<&Address> {
if self.1 == Self::MAX_SIZE_BYTES as u16 {
Some(&self.0)
/// Get a complete address from this partial if it is in fact complete.
#[inline]
pub fn as_complete(&self) -> Option<&Address> {
if self.specificity == Self::MAX_SIZE_BYTES as u16 {
Some(&self.address)
} else {
None
}
@ -285,14 +312,84 @@ impl PartialAddress {
/// Returns true if specificity is at the maximum value (384 bits)
#[inline(always)]
pub fn is_complete(&self) -> bool {
self.1 == Self::MAX_SIZE_BYTES as u16
self.specificity == Self::MAX_SIZE_BYTES as u16
}
/// Efficiently find an entry in a BTreeMap of partial addresses that uniquely matches this partial.
///
/// This returns None if there is no match or if this partial matches more than one entry, in which
/// case it's ambiguous and may be unsafe to use. This should be prohibited at other levels of the
/// system but is checked for here as well.
#[inline]
pub fn find_unique_match<'a, T>(&self, map: &'a BTreeMap<PartialAddress, T>) -> Option<&'a T> {
// Search for an exact or more specific match.
let mut m = None;
// First search for exact or more specific matches, which would appear later in the sorted key list.
let mut pos = map.range((Bound::Included(self), Bound::Unbounded));
while let Some(e) = pos.next() {
if self.matches_partial(e.0) {
if m.is_some() {
// Ambiguous!
return None;
}
let _ = m.insert(e.1);
} else {
break;
}
}
// Then search for less specific matches or verify that the match we found above is not ambiguous.
let mut pos = map.range((Bound::Unbounded, Bound::Excluded(self)));
while let Some(e) = pos.next_back() {
if self.matches_partial(e.0) {
if m.is_some() {
return None;
}
let _ = m.insert(e.1);
} else {
break;
}
}
return m;
}
/// Efficiently find an entry in a BTreeMap of partial addresses that uniquely matches this partial.
///
/// This returns None if there is no match or if this partial matches more than one entry, in which
/// case it's ambiguous and may be unsafe to use. This should be prohibited at other levels of the
/// system but is checked for here as well.
#[inline]
pub fn find_unique_match_mut<'a, T>(&self, map: &'a mut BTreeMap<PartialAddress, T>) -> Option<&'a mut T> {
// This not only saves some repetition but is in fact the only way to easily do this. The same code as
// find_unique_match() but with range_mut() doesn't compile because the second range_mut() would
// borrow 'map' a second time (since 'm' may have it borrowed). This is primarily due to the too-limited
// API of BTreeMap which is missing a good way to find the nearest match. This should be safe since
// we do not mutate the map and the signature of find_unique_match_mut() should properly guarantee
// that the semantics of mutable references are obeyed in the calling context.
unsafe { std::mem::transmute(self.find_unique_match::<T>(map)) }
}
}
impl Ord for PartialAddress {
#[inline(always)]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.address.cmp(&other.address).then(self.specificity.cmp(&other.specificity))
}
}
impl PartialOrd for PartialAddress {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl ToString for PartialAddress {
fn to_string(&self) -> String {
if self.is_legacy() {
hex::to_string(&self.0 .0[..Self::LEGACY_SIZE_BYTES])
hex::to_string(&self.address.0[..Self::LEGACY_SIZE_BYTES])
} else {
base24::encode(self.as_bytes())
}
@ -315,7 +412,7 @@ impl Hash for PartialAddress {
#[inline(always)]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// Since this contains a random hash, the first 64 bits should be enough for a local HashMap etc.
state.write_u64(memory::load_raw(&self.0 .0))
state.write_u64(memory::load_raw(&self.address.0))
}
}

View file

@ -97,6 +97,17 @@ pub trait ApplicationLayer: Sync + Send + 'static {
/// Called to get the current time in milliseconds since epoch from the real-time clock.
/// This needs to be accurate to about one second resolution or better.
fn time_clock(&self) -> i64;
/// Get this application implementation cast to its concrete type.
///
/// The default implementation just returns None, but this can be implemented using the cast_ref()
/// function in zerotier_utils::cast to return the concrete implementation of this type. It's exposed
/// in this interface for convenience since it's common for inner protocol or other handlers to want
/// to get 'app' as its concrete type to access internal fields and methods. Implement it if possible
/// and convenient.
fn concrete_self<T: ApplicationLayer>(&self) -> Option<&T> {
None
}
}
/// Result of a packet handler in the InnerProtocolLayer trait.

View file

@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;
use super::address::{Address, PartialAddress};
use super::api::{ApplicationLayer, InnerProtocolLayer, PacketHandlerResult};
use super::api::{ApplicationLayer, InnerProtocolLayer};
use super::debug_event;
use super::endpoint::Endpoint;
use super::event::Event;

View file

@ -39,8 +39,11 @@ impl<Application: ApplicationLayer> PeerMap<Application> {
/// Get a matching peer for a partial address of any specificity, but return None if the match is ambiguous.
pub fn get_unambiguous(&self, address: &PartialAddress) -> Option<Arc<Peer<Application>>> {
let mm = self.maps[address.0 .0[0] as usize].read().unwrap();
let matches = mm.range::<[u8; 48], (Bound<&[u8; 48]>, Bound<&[u8; 48]>)>((Bound::Included(&address.0 .0), Bound::Unbounded));
let mm = self.maps[address.address.0[0] as usize].read().unwrap();
let matches = mm.range::<[u8; Address::SIZE_BYTES], (Bound<&[u8; Address::SIZE_BYTES]>, Bound<&[u8; Address::SIZE_BYTES]>)>((
Bound::Included(&address.address.0),
Bound::Unbounded,
));
let mut r = None;
for m in matches {
if address.matches(m.0) {

View file

@ -10,6 +10,7 @@ use zerotier_crypto::random;
use zerotier_network_hypervisor::protocol::{PacketBufferFactory, PacketBufferPool};
use zerotier_network_hypervisor::vl1::identity::IdentitySecret;
use zerotier_network_hypervisor::vl1::*;
use zerotier_utils::cast::cast_ref;
use zerotier_utils::{ms_monotonic, ms_since_epoch};
use super::vl1settings::{VL1Settings, UNASSIGNED_PRIVILEGED_PORTS};
@ -67,7 +68,8 @@ impl<Inner: InnerProtocolLayer + 'static> VL1Service<Inner> {
Ok(service)
}
pub fn get(&self) -> Arc<Self> {
#[inline]
pub fn get_self_arc(&self) -> Arc<Self> {
self.self_ref.upgrade().unwrap()
}
@ -281,6 +283,11 @@ impl<Inner: InnerProtocolLayer + 'static> ApplicationLayer for VL1Service<Inner>
fn time_clock(&self) -> i64 {
ms_since_epoch()
}
#[inline(always)]
fn concrete_self<T: ApplicationLayer>(&self) -> Option<&T> {
cast_ref(self)
}
}
impl<Inner: InnerProtocolLayer + 'static> Drop for VL1Service<Inner> {