From 4a60ae5736eebe150772a53517d89a625d2c1bb6 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Fri, 26 Mar 2021 18:43:50 -0400 Subject: [PATCH] Some efficiency improvements. --- rust-zerotier-core/src/node.rs | 103 +++++++++++--------- service/Cargo.toml | 2 +- service/src/api.rs | 13 ++- service/src/httpclient.rs | 11 ++- service/src/httplistener.rs | 32 ++++--- service/src/service.rs | 167 +++++++++++++++------------------ 6 files changed, 165 insertions(+), 163 deletions(-) diff --git a/rust-zerotier-core/src/node.rs b/rust-zerotier-core/src/node.rs index 9de041811..bbace933a 100644 --- a/rust-zerotier-core/src/node.rs +++ b/rust-zerotier-core/src/node.rs @@ -24,6 +24,7 @@ use serde::{Deserialize, Serialize}; use crate::*; use crate::capi as ztcore; +use std::marker::PhantomData; pub const NODE_BACKGROUND_TASKS_MAX_INTERVAL: i64 = 200; @@ -47,7 +48,7 @@ pub enum StateObjectType { Peer = ztcore::ZT_StateObjectType_ZT_STATE_OBJECT_PEER as isize, NetworkConfig = ztcore::ZT_StateObjectType_ZT_STATE_OBJECT_NETWORK_CONFIG as isize, TrustStore = ztcore::ZT_StateObjectType_ZT_STATE_OBJECT_TRUST_STORE as isize, - Certificate = ztcore::ZT_StateObjectType_ZT_STATE_OBJECT_CERT as isize + Certificate = ztcore::ZT_StateObjectType_ZT_STATE_OBJECT_CERT as isize, } /// The status of a ZeroTier node. @@ -91,24 +92,32 @@ pub trait NodeEventHandler { fn path_lookup(&self, address: Address, id: &Identity, desired_family: InetAddressFamily) -> Option; } -pub struct NodeIntl + Sync + Send + Clone + 'static, N: Sync + Send + 'static> { +pub struct NodeIntl + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler> { event_handler: T, capi: *mut ztcore::ZT_Node, now: PortableAtomicI64, networks_by_id: Mutex>>>, + event_handler_placeholder: PhantomData, } /// An instance of the ZeroTier core. -/// This is templated on the actual implementations of things to avoid "dyn" indirect -/// call overhead. This is a high performance networking thingy so we care about cycles -/// where possible. -pub struct Node + Sync + Send + Clone + 'static, N: Sync + Send + 'static> { - intl: Pin>>, +/// +/// The event handler is templated as AsRef where H is the concrete type of the actual +/// handler. This allows the handler to be an Arc<>, Box<>, or similar. We do this instead +/// of templating it on "dyn NodeEventHandler" because we want the types to all be concrete +/// to avoid dynamic call overhead. Unfortunately it makes the types here a tad more +/// verbose. +/// +/// In most cases you will want the handler to be an Arc<> anyway since most uses will be +/// multithreaded or async. +pub struct Node + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler> { + intl: Pin>>, + event_handler_placeholder: PhantomData, } /********************************************************************************************************************/ -extern "C" fn zt_virtual_network_config_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static>( +extern "C" fn zt_virtual_network_config_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler>( _: *mut ztcore::ZT_Node, uptr: *mut c_void, _: *mut c_void, @@ -118,17 +127,17 @@ extern "C" fn zt_virtual_network_config_function + Sync + conf: *const ztcore::ZT_VirtualNetworkConfig, ) { let _ = VirtualNetworkConfigOperation::from_i32(op as i32).map(|op| { - let n = unsafe { &*(uptr.cast::>()) }; + let n = unsafe { &*(uptr.cast::>()) }; if conf.is_null() { - n.event_handler.virtual_network_config(NetworkId(nwid), unsafe { &*(nptr.cast::()) }, op, None); + n.event_handler.as_ref().virtual_network_config(NetworkId(nwid), unsafe { &*(nptr.cast::()) }, op, None); } else { let conf2 = unsafe { VirtualNetworkConfig::new_from_capi(&*conf) }; - n.event_handler.virtual_network_config(NetworkId(nwid), unsafe { &*(nptr.cast::()) }, op, Some(&conf2)); + n.event_handler.as_ref().virtual_network_config(NetworkId(nwid), unsafe { &*(nptr.cast::()) }, op, Some(&conf2)); } }); } -extern "C" fn zt_virtual_network_frame_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static>( +extern "C" fn zt_virtual_network_frame_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler>( _: *mut ztcore::ZT_Node, uptr: *mut c_void, _: *mut c_void, @@ -142,7 +151,7 @@ extern "C" fn zt_virtual_network_frame_function + Sync + data_size: c_uint, ) { if !nptr.is_null() { - unsafe { &*(uptr.cast::>()) }.event_handler.virtual_network_frame( + unsafe { &*(uptr.cast::>()) }.event_handler.as_ref().virtual_network_frame( NetworkId(nwid), unsafe { &*(nptr.cast::()) }, MAC(source_mac), @@ -153,25 +162,25 @@ extern "C" fn zt_virtual_network_frame_function + Sync + } } -extern "C" fn zt_event_callback + Sync + Send + Clone + 'static, N: Sync + Send + 'static>( +extern "C" fn zt_event_callback + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler>( _: *mut ztcore::ZT_Node, uptr: *mut c_void, _: *mut c_void, ev: ztcore::ZT_Event, data: *const c_void, - data_size: c_uint + data_size: c_uint, ) { let _ = Event::from_i32(ev as i32).map(|ev: Event| { - let n = unsafe { &*(uptr.cast::>()) }; + let n = unsafe { &*(uptr.cast::>()) }; if data.is_null() { - n.event_handler.event(ev, &EMPTY_BYTE_ARRAY); + n.event_handler.as_ref().event(ev, &EMPTY_BYTE_ARRAY); } else { - n.event_handler.event(ev, unsafe { &*slice_from_raw_parts(data.cast::(), data_size as usize) }); + n.event_handler.as_ref().event(ev, unsafe { &*slice_from_raw_parts(data.cast::(), data_size as usize) }); } }); } -extern "C" fn zt_state_put_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static>( +extern "C" fn zt_state_put_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler>( _: *mut ztcore::ZT_Node, uptr: *mut c_void, _: *mut c_void, @@ -182,12 +191,12 @@ extern "C" fn zt_state_put_function + Sync + Send + Clone obj_data_len: c_int, ) { let _ = StateObjectType::from_i32(obj_type as i32).map(|obj_type| { - let n = unsafe { &*(uptr.cast::>()) }; - let _ = n.event_handler.state_put(obj_type, unsafe { &*slice_from_raw_parts(obj_id, obj_id_len as usize) }, unsafe { &*slice_from_raw_parts(obj_data.cast::(), obj_data_len as usize) }); + let n = unsafe { &*(uptr.cast::>()) }; + let _ = n.event_handler.as_ref().state_put(obj_type, unsafe { &*slice_from_raw_parts(obj_id, obj_id_len as usize) }, unsafe { &*slice_from_raw_parts(obj_data.cast::(), obj_data_len as usize) }); }); } -extern "C" fn zt_state_get_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static>( +extern "C" fn zt_state_get_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler>( _: *mut ztcore::ZT_Node, uptr: *mut c_void, _: *mut c_void, @@ -207,7 +216,7 @@ extern "C" fn zt_state_get_function + Sync + Send + Clone StateObjectType::from_i32(obj_type as i32).map_or_else(|| { -1 as c_int }, |obj_type| { - unsafe { &*(uptr.cast::>()) }.event_handler.state_get(obj_type, unsafe { &*slice_from_raw_parts(obj_id, obj_id_len as usize) }).map_or_else(|_| { + unsafe { &*(uptr.cast::>()) }.event_handler.as_ref().state_get(obj_type, unsafe { &*slice_from_raw_parts(obj_id, obj_id_len as usize) }).map_or_else(|_| { -1 as c_int }, |obj_data_result| { let obj_data_len = obj_data_result.len() as c_int; @@ -230,7 +239,7 @@ extern "C" fn zt_state_get_function + Sync + Send + Clone } } -extern "C" fn zt_wire_packet_send_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static>( +extern "C" fn zt_wire_packet_send_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler>( _: *mut ztcore::ZT_Node, uptr: *mut c_void, _: *mut c_void, @@ -240,10 +249,10 @@ extern "C" fn zt_wire_packet_send_function + Sync + Send data_size: c_uint, packet_ttl: c_uint, ) -> c_int { - unsafe { &*(uptr.cast::>()) }.event_handler.wire_packet_send(local_socket, InetAddress::transmute_capi(unsafe { &*sock_addr }), unsafe { &*slice_from_raw_parts(data.cast::(), data_size as usize) }, packet_ttl as u32) as c_int + unsafe { &*(uptr.cast::>()) }.event_handler.as_ref().wire_packet_send(local_socket, InetAddress::transmute_capi(unsafe { &*sock_addr }), unsafe { &*slice_from_raw_parts(data.cast::(), data_size as usize) }, packet_ttl as u32) as c_int } -extern "C" fn zt_path_check_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static>( +extern "C" fn zt_path_check_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler>( _: *mut ztcore::ZT_Node, uptr: *mut c_void, _: *mut c_void, @@ -253,10 +262,10 @@ extern "C" fn zt_path_check_function + Sync + Send + Clon sock_addr: *const ztcore::ZT_InetAddress, ) -> c_int { let id = Identity::new_from_capi(identity, false); - unsafe { &*(uptr.cast::>()) }.event_handler.path_check(Address(address), &id, local_socket, InetAddress::transmute_capi(unsafe{ &*sock_addr })) as c_int + unsafe { &*(uptr.cast::>()) }.event_handler.as_ref().path_check(Address(address), &id, local_socket, InetAddress::transmute_capi(unsafe { &*sock_addr })) as c_int } -extern "C" fn zt_path_lookup_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static>( +extern "C" fn zt_path_lookup_function + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler>( _: *mut ztcore::ZT_Node, uptr: *mut c_void, _: *mut c_void, @@ -280,7 +289,7 @@ extern "C" fn zt_path_lookup_function + Sync + Send + Clo } let id = Identity::new_from_capi(identity, false); - unsafe { &*(uptr.cast::>()) }.event_handler.path_lookup(Address(address), &id, sock_family2).map_or_else(|| { + unsafe { &*(uptr.cast::>()) }.event_handler.as_ref().path_lookup(Address(address), &id, sock_family2).map_or_else(|| { 0 as c_int }, |result| { let result_ptr = &result as *const InetAddress; @@ -293,31 +302,33 @@ extern "C" fn zt_path_lookup_function + Sync + Send + Clo /********************************************************************************************************************/ -impl + Sync + Send + Clone + 'static, N: Sync + Send + 'static> Node { +impl + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler> Node { /// Create a new Node with a given event handler. #[allow(unused_mut)] - pub fn new(event_handler: T, now: i64) -> Result, ResultCode> { + pub fn new(event_handler: T, now: i64) -> Result, ResultCode> { let mut n = Node { intl: Box::pin(NodeIntl { event_handler: event_handler.clone(), capi: null_mut(), now: PortableAtomicI64::new(now), - networks_by_id: Mutex::new(HashMap::new()) + networks_by_id: Mutex::new(HashMap::new()), + event_handler_placeholder: PhantomData::default(), }), + event_handler_placeholder: PhantomData::default(), }; let rc = unsafe { let callbacks = ztcore::ZT_Node_Callbacks { - statePutFunction: transmute(zt_state_put_function:: as *const ()), - stateGetFunction: transmute(zt_state_get_function:: as *const ()), - wirePacketSendFunction: transmute(zt_wire_packet_send_function:: as *const ()), - virtualNetworkFrameFunction: transmute(zt_virtual_network_frame_function:: as *const ()), - virtualNetworkConfigFunction: transmute(zt_virtual_network_config_function:: as *const ()), - eventCallback: transmute(zt_event_callback:: as *const ()), - pathCheckFunction: transmute(zt_path_check_function:: as *const ()), - pathLookupFunction: transmute(zt_path_lookup_function:: as *const ()), + statePutFunction: transmute(zt_state_put_function:: as *const ()), + stateGetFunction: transmute(zt_state_get_function:: as *const ()), + wirePacketSendFunction: transmute(zt_wire_packet_send_function:: as *const ()), + virtualNetworkFrameFunction: transmute(zt_virtual_network_frame_function:: as *const ()), + virtualNetworkConfigFunction: transmute(zt_virtual_network_config_function:: as *const ()), + eventCallback: transmute(zt_event_callback:: as *const ()), + pathCheckFunction: transmute(zt_path_check_function:: as *const ()), + pathLookupFunction: transmute(zt_path_lookup_function:: as *const ()), }; - ztcore::ZT_Node_new(transmute(&(n.intl.capi) as *const *mut ztcore::ZT_Node), transmute(&*n.intl as *const NodeIntl), null_mut(), &callbacks, now) + ztcore::ZT_Node_new(transmute(&(n.intl.capi) as *const *mut ztcore::ZT_Node), transmute(&*n.intl as *const NodeIntl), null_mut(), &callbacks, now) }; if rc == 0 { @@ -505,19 +516,19 @@ impl + Sync + Send + Clone + 'static, N: Sync + Send + 's } } -impl + Sync + Send + Clone + 'static, N: Sync + Send + Clone + 'static> Node { +impl + Sync + Send + Clone + 'static, N: Sync + Send + Clone + 'static, H: NodeEventHandler> Node { /// Get a copy of this network's associated object. /// This is only available if N implements Clone. pub fn network(&self, nwid: NetworkId) -> Option { - self.intl.networks_by_id.lock().unwrap().get(&nwid.0).map_or(None, |nw| { Some(nw.as_ref().get_ref().clone()) }) + self.intl.networks_by_id.lock().unwrap().get(&nwid.0).map_or(None, |nw| Some((**nw).clone())) } } -unsafe impl + Sync + Send + Clone + 'static, N: Sync + Send + 'static> Sync for Node {} +unsafe impl + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler> Sync for Node {} -unsafe impl + Sync + Send + Clone + 'static, N: Sync + Send + 'static> Send for Node {} +unsafe impl + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler> Send for Node {} -impl + Sync + Send + Clone + 'static, N: Sync + Send + 'static> Drop for Node { +impl + Sync + Send + Clone + 'static, N: Sync + Send + 'static, H: NodeEventHandler> Drop for Node { fn drop(&mut self) { unsafe { ztcore::ZT_Node_delete(self.intl.capi, null_mut()); diff --git a/service/Cargo.toml b/service/Cargo.toml index 4f37c8612..680fd1613 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -23,7 +23,7 @@ num-derive = "0" hyper = { version = "0", features = ["http1", "runtime", "server", "client", "tcp", "stream"] } socket2 = { version = "0", features = ["reuseport", "unix", "pair"] } dialoguer = "0" -digest_auth = "0" +digest_auth = "0.2.4" colored = "2" [target."cfg(windows)".dependencies] diff --git a/service/src/api.rs b/service/src/api.rs index 472614eeb..43f75d8e6 100644 --- a/service/src/api.rs +++ b/service/src/api.rs @@ -11,10 +11,13 @@ */ /****/ -use crate::service::Service; +use std::sync::Arc; + use hyper::{Request, Body, StatusCode, Method}; -pub(crate) fn status(service: Service, req: Request) -> (StatusCode, Body) { +use crate::service::Service; + +pub(crate) fn status(service: Arc, req: Request) -> (StatusCode, Body) { if req.method() == Method::GET { service.status().map_or_else(|| { (StatusCode::SERVICE_UNAVAILABLE, Body::from("node shutdown in progress")) @@ -26,7 +29,7 @@ pub(crate) fn status(service: Service, req: Request) -> (StatusCode, Body) } } -pub(crate) fn config(service: Service, req: Request) -> (StatusCode, Body) { +pub(crate) fn config(service: Arc, req: Request) -> (StatusCode, Body) { let config = service.local_config(); if req.method() == Method::POST || req.method() == Method::PUT { // TODO: diff config @@ -34,10 +37,10 @@ pub(crate) fn config(service: Service, req: Request) -> (StatusCode, Body) (StatusCode::OK, Body::from(serde_json::to_string(config.as_ref()).unwrap())) } -pub(crate) fn peer(service: Service, req: Request) -> (StatusCode, Body) { +pub(crate) fn peer(service: Arc, req: Request) -> (StatusCode, Body) { (StatusCode::NOT_IMPLEMENTED, Body::from("")) } -pub(crate) fn network(service: Service, req: Request) -> (StatusCode, Body) { +pub(crate) fn network(service: Arc, req: Request) -> (StatusCode, Body) { (StatusCode::NOT_IMPLEMENTED, Body::from("")) } diff --git a/service/src/httpclient.rs b/service/src/httpclient.rs index 91a7e5dde..d63543b21 100644 --- a/service/src/httpclient.rs +++ b/service/src/httpclient.rs @@ -93,10 +93,11 @@ pub(crate) fn run_command< } /// Send a request to the API with support for HTTP digest authentication. -/// The data option is for PUT and POST requests. For GET it is ignored. Errors indicate total -/// failure such as connection refused. A returned result must still have its status checked. -/// Note that if authorization is required and the auth token doesn't work, IncorrectAuthTokenError -/// is returned as an error instead of a 401 response object. +/// The data option is for PUT and POST requests. For GET it is ignored. This will try to +/// authenticate if a WWW-Authorized header is sent in an unauthorized response. If authentication +/// with auth_token fails, IncorrectAuthTokenError is returned as an error. If the request is +/// unauthorizred and no WWW-Authorired header is present, a normal response is returned. The +/// caller must always check the response status code. pub(crate) async fn request(client: &HttpClient, method: Method, uri: Uri, data: Option<&[u8]>, auth_token: &str) -> Result, Box> { let body: Vec = data.map_or_else(|| Vec::new(), |data| data.to_vec()); @@ -113,7 +114,7 @@ pub(crate) async fn request(client: &HttpClient, method: Method, uri: Uri, data: if res.status() == StatusCode::UNAUTHORIZED { let auth = res.headers().get(hyper::header::WWW_AUTHENTICATE); if auth.is_none() { - return Err(Box::new(UnexpectedStatusCodeError(StatusCode::UNAUTHORIZED, "host returned 401 but no WWW-Authenticate header found"))) + return Ok(res); } let auth = auth.unwrap().to_str(); if auth.is_err() { diff --git a/service/src/httplistener.rs b/service/src/httplistener.rs index e0a2ab051..37da33536 100644 --- a/service/src/httplistener.rs +++ b/service/src/httplistener.rs @@ -11,14 +11,16 @@ */ /****/ -use std::cell::RefCell; +use std::cell::Cell; use std::convert::Infallible; +use std::sync::Arc; use std::net::SocketAddr; use hyper::{Body, Request, Response, StatusCode, Method}; use hyper::server::Server; use hyper::service::{make_service_fn, service_fn}; use tokio::task::JoinHandle; +use digest_auth::{AuthContext, AuthorizationHeader, Charset, WwwAuthenticateHeader}; use crate::service::Service; use crate::api; @@ -26,7 +28,6 @@ use crate::utils::{decrypt_http_auth_nonce, ms_since_epoch, create_http_auth_non #[cfg(target_os = "linux")] use std::os::unix::io::AsRawFd; -use digest_auth::{AuthContext, AuthorizationHeader, Charset, WwwAuthenticateHeader}; const HTTP_MAX_NONCE_AGE_MS: i64 = 30000; @@ -35,17 +36,19 @@ const HTTP_MAX_NONCE_AGE_MS: i64 = 30000; /// but it might not shut down instantly as this occurs asynchronously. pub(crate) struct HttpListener { pub address: SocketAddr, - shutdown_tx: RefCell>>, + shutdown_tx: Cell>>, server: JoinHandle>, } -async fn http_handler(service: Service, req: Request) -> Result, Infallible> { +async fn http_handler(service: Arc, req: Request) -> Result, Infallible> { + let req_path = req.uri().path(); + let mut authorized = false; let mut stale = false; let auth_token = service.store().auth_token(false); if auth_token.is_err() { - return Ok::, Infallible>(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::from("authorization token unreadable")).unwrap()); + return Ok(Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::from("authorization token unreadable")).unwrap()); } let auth_context = AuthContext::new_with_method("", auth_token.unwrap(), req.uri().to_string(), None::<&[u8]>, match *req.method() { Method::GET => digest_auth::HttpMethod::GET, @@ -53,14 +56,16 @@ async fn http_handler(service: Service, req: Request) -> Result digest_auth::HttpMethod::HEAD, Method::PUT => digest_auth::HttpMethod::OTHER("PUT"), Method::DELETE => digest_auth::HttpMethod::OTHER("DELETE"), - _ => digest_auth::HttpMethod::OTHER(""), + _ => { + return Ok(Response::builder().status(StatusCode::METHOD_NOT_ALLOWED).body(Body::from("unrecognized method")).unwrap()); + } }); let auth_header = req.headers().get(hyper::header::AUTHORIZATION); if auth_header.is_some() { let auth_header = AuthorizationHeader::parse(auth_header.unwrap().to_str().unwrap_or("")); if auth_header.is_err() { - return Ok::, Infallible>(Response::builder().status(StatusCode::BAD_REQUEST).body(Body::from(format!("invalid authorization header: {}", auth_header.err().unwrap().to_string()))).unwrap()); + return Ok(Response::builder().status(StatusCode::BAD_REQUEST).body(Body::from(format!("invalid authorization header: {}", auth_header.err().unwrap().to_string()))).unwrap()); } let auth_header = auth_header.unwrap(); @@ -88,11 +93,8 @@ async fn http_handler(service: Service, req: Request) -> Result) -> Result, Infallible>(Response::builder().header("Content-Type", "application/json").status(status).body(body).unwrap()) + Ok(Response::builder().header("Content-Type", "application/json").status(status).body(body).unwrap()) } else { - Ok::, Infallible>(Response::builder().header(hyper::header::WWW_AUTHENTICATE, WwwAuthenticateHeader { + Ok(Response::builder().header(hyper::header::WWW_AUTHENTICATE, WwwAuthenticateHeader { domain: None, realm: "zerotier-service-api".to_owned(), nonce: create_http_auth_nonce(ms_since_epoch()), @@ -126,7 +128,7 @@ async fn http_handler(service: Service, req: Request) -> Result Result> { + pub async fn new(_device_name: &str, address: SocketAddr, service: &Arc) -> Result> { let listener = if address.is_ipv4() { let listener = socket2::Socket::new(socket2::Domain::ipv4(), socket2::Type::stream(), Some(socket2::Protocol::tcp())); if listener.is_err() { @@ -188,7 +190,7 @@ impl HttpListener { Ok(HttpListener { address, - shutdown_tx: RefCell::new(Some(shutdown_tx)), + shutdown_tx: Cell::new(Some(shutdown_tx)), server, }) } diff --git a/service/src/service.rs b/service/src/service.rs index eecff69a8..417d2b92c 100644 --- a/service/src/service.rs +++ b/service/src/service.rs @@ -11,6 +11,7 @@ */ /****/ +use std::cell::Cell; use std::collections::BTreeMap; use std::net::{SocketAddr, Ipv4Addr, IpAddr, Ipv6Addr}; use std::sync::{Arc, Mutex, Weak}; @@ -64,7 +65,10 @@ pub struct ServiceStatus { pub http_local_endpoints: Vec, } -struct ServiceIntl { +/// Core ZeroTier service, which is sort of just a container for all the things. +pub(crate) struct Service { + pub(crate) log: Log, + _node: Cell, Network, Service>>>, // never modified after node is created udp_local_endpoints: Mutex>, http_local_endpoints: Mutex>, interrupt: Mutex>, @@ -76,18 +80,6 @@ struct ServiceIntl { online: AtomicBool, } -unsafe impl Send for ServiceIntl {} - -unsafe impl Sync for ServiceIntl {} - -/// Core ZeroTier service, which is sort of just a container for all the things. -#[derive(Clone)] -pub(crate) struct Service { - pub(crate) log: Arc, - _node: Weak>, - intl: Arc, -} - impl NodeEventHandler for Service { #[inline(always)] fn virtual_network_config(&self, network_id: NetworkId, network_obj: &Network, config_op: VirtualNetworkConfigOperation, config: Option<&VirtualNetworkConfig>) {} @@ -104,17 +96,17 @@ impl NodeEventHandler for Service { Event::Down => { d!(self.log, "node shutdown event received."); - self.intl.online.store(false, Ordering::Relaxed); + self.online.store(false, Ordering::Relaxed); } Event::Online => { d!(self.log, "node is online."); - self.intl.online.store(true, Ordering::Relaxed); + self.online.store(true, Ordering::Relaxed); } Event::Offline => { d!(self.log, "node is offline."); - self.intl.online.store(false, Ordering::Relaxed); + self.online.store(false, Ordering::Relaxed); } Event::Trace => { @@ -144,16 +136,16 @@ impl NodeEventHandler for Service { #[inline(always)] fn state_put(&self, obj_type: StateObjectType, obj_id: &[u64], obj_data: &[u8]) -> std::io::Result<()> { if !obj_data.is_empty() { - self.intl.store.store_object(&obj_type, obj_id, obj_data) + self.store.store_object(&obj_type, obj_id, obj_data) } else { - self.intl.store.erase_object(&obj_type, obj_id); + self.store.erase_object(&obj_type, obj_id); Ok(()) } } #[inline(always)] fn state_get(&self, obj_type: StateObjectType, obj_id: &[u64]) -> std::io::Result> { - self.intl.store.load_object(&obj_type, obj_id) + self.store.load_object(&obj_type, obj_id) } #[inline(always)] @@ -185,33 +177,34 @@ impl NodeEventHandler for Service { impl Service { pub fn local_config(&self) -> Arc { - self.intl.local_config.lock().unwrap().clone() + self.local_config.lock().unwrap().clone() } pub fn set_local_config(&self, new_lc: LocalConfig) { - *(self.intl.local_config.lock().unwrap()) = Arc::new(new_lc); + *(self.local_config.lock().unwrap()) = Arc::new(new_lc); } /// Get the node running with this service. - /// This can return None if we are in the midst of shutdown. In this case - /// whatever operation is in progress should abort. None will never be - /// returned during normal operation. - pub fn node(&self) -> Option>> { - self._node.upgrade() + /// This can return None during shutdown because Service holds a weak + /// reference to Node to avoid circular Arc<> pointers. This will only + /// return None during shutdown, in which case whatever is happening + /// should abort as quietly as possible. + pub fn node(&self) -> Option, Network, Service>>> { + unsafe { &*self._node.as_ptr() }.upgrade() } #[inline(always)] pub fn store(&self) -> &Arc { - &self.intl.store + &self.store } pub fn online(&self) -> bool { - self.intl.online.load(Ordering::Relaxed) + self.online.load(Ordering::Relaxed) } pub fn shutdown(&self) { - self.intl.run.store(false, Ordering::Relaxed); - let _ = self.intl.interrupt.lock().unwrap().try_send(()); + self.run.store(false, Ordering::Relaxed); + let _ = self.interrupt.lock().unwrap().try_send(()); } /// Get service status for API, or None if a shutdown is in progress. @@ -222,8 +215,8 @@ impl Service { object_type: "status".to_owned(), address: node.address(), clock: ms_since_epoch(), - start_time: self.intl.startup_time, - uptime: ms_monotonic() - self.intl.startup_time_monotonic, + start_time: self.startup_time, + uptime: ms_monotonic() - self.startup_time_monotonic, config: (*self.local_config()).clone(), online: self.online(), public_identity: node.identity(), @@ -232,8 +225,8 @@ impl Service { version_minor: ver.1, version_revision: ver.2, version_build: ver.3, - udp_local_endpoints: self.intl.udp_local_endpoints.lock().unwrap().clone(), - http_local_endpoints: self.intl.http_local_endpoints.lock().unwrap().clone(), + udp_local_endpoints: self.udp_local_endpoints.lock().unwrap().clone(), + http_local_endpoints: self.http_local_endpoints.lock().unwrap().clone(), } }) } @@ -243,7 +236,7 @@ unsafe impl Send for Service {} unsafe impl Sync for Service {} -async fn run_async(store: Arc, log: Arc, local_config: Arc) -> i32 { +async fn run_async(store: Arc, local_config: Arc) -> i32 { let mut process_exit_value: i32 = 0; let mut udp_sockets: BTreeMap = BTreeMap::new(); @@ -251,58 +244,64 @@ async fn run_async(store: Arc, log: Arc, local_config: Arc, Option) = (None, None); // 127.0.0.1, ::1 let (interrupt_tx, mut interrupt_rx) = futures::channel::mpsc::channel::<()>(1); - let mut service = Service { - log: log.clone(), - _node: Weak::new(), - intl: Arc::new(ServiceIntl { - udp_local_endpoints: Mutex::new(Vec::new()), - http_local_endpoints: Mutex::new(Vec::new()), - interrupt: Mutex::new(interrupt_tx), - local_config: Mutex::new(local_config), - store: store.clone(), - startup_time: ms_since_epoch(), - startup_time_monotonic: ms_monotonic(), - run: AtomicBool::new(true), - online: AtomicBool::new(false), - }), - }; + let service = Arc::new(Service { + log: Log::new( + if local_config.settings.log.path.as_ref().is_some() { + local_config.settings.log.path.as_ref().unwrap().as_str() + } else { + store.default_log_path.to_str().unwrap() + }, + local_config.settings.log.max_size, + local_config.settings.log.stderr, + local_config.settings.log.debug, + "", + ), + _node: Cell::new(Weak::new()), + udp_local_endpoints: Mutex::new(Vec::new()), + http_local_endpoints: Mutex::new(Vec::new()), + interrupt: Mutex::new(interrupt_tx), + local_config: Mutex::new(local_config), + store: store.clone(), + startup_time: ms_since_epoch(), + startup_time_monotonic: ms_monotonic(), + run: AtomicBool::new(true), + online: AtomicBool::new(false), + }); let node = Node::new(service.clone(), ms_monotonic()); if node.is_err() { - log.fatal(format!("error initializing node: {}", node.err().unwrap().to_str())); + service.log.fatal(format!("error initializing node: {}", node.err().unwrap().to_str())); return 1; } let node = Arc::new(node.ok().unwrap()); - - service._node = Arc::downgrade(&node); - let service = service; // make immutable after setting node + service._node.replace(Arc::downgrade(&node)); let mut local_config = service.local_config(); let mut now: i64 = ms_monotonic(); let mut loop_delay = zerotier_core::NODE_BACKGROUND_TASKS_MAX_INTERVAL; let mut last_checked_config: i64 = 0; - while service.intl.run.load(Ordering::Relaxed) { + while service.run.load(Ordering::Relaxed) { let loop_delay_start = ms_monotonic(); tokio::select! { _ = tokio::time::sleep(Duration::from_millis(loop_delay as u64)) => { now = ms_monotonic(); let actual_delay = now - loop_delay_start; if actual_delay > ((loop_delay as i64) * 4_i64) { - l!(log, "likely sleep/wake detected due to excessive loop delay, cycling links..."); + l!(service.log, "likely sleep/wake detected due to excessive loop delay, cycling links..."); // TODO: handle likely sleep/wake or other system interruption } }, _ = interrupt_rx.next() => { - d!(log, "inner loop delay interrupted!"); - if !service.intl.run.load(Ordering::Relaxed) { + d!(service.log, "inner loop delay interrupted!"); + if !service.run.load(Ordering::Relaxed) { break; } now = ms_monotonic(); }, _ = tokio::signal::ctrl_c() => { - l!(log, "exit signal received, shutting down..."); - service.intl.run.store(false, Ordering::Relaxed); + l!(service.log, "exit signal received, shutting down..."); + service.run.store(false, Ordering::Relaxed); break; }, } @@ -313,7 +312,7 @@ async fn run_async(store: Arc, log: Arc, local_config: Arc, log: Arc, local_config: Arc, log: Arc, local_config: Arc>().iter() { - l!(log, "unbinding UDP socket at {} (address no longer exists on system or port has changed)", k.to_string()); + l!(service.log, "unbinding UDP socket at {} (address no longer exists on system or port has changed)", k.to_string()); udp_sockets.remove(k); bindings_changed = true; } @@ -371,9 +370,9 @@ async fn run_async(store: Arc, log: Arc, local_config: Arc, log: Arc, local_config: Arc>().iter() { - l!(log, "closing HTTP listener at {} (address no longer exists on system or port has changed)", k.to_string()); + l!(service.log, "closing HTTP listener at {} (address no longer exists on system or port has changed)", k.to_string()); http_listeners.remove(k); bindings_changed = true; } @@ -421,9 +420,9 @@ async fn run_async(store: Arc, log: Arc, local_config: Arc, log: Arc, local_config: Arc, log: Arc, local_config: Arc, log: Arc, local_config: Arc, log: Arc, local_config: Arc) -> i32 { let local_config = Arc::new(store.read_local_conf_or_default()); - let log = Arc::new(Log::new( - if local_config.settings.log.path.as_ref().is_some() { - local_config.settings.log.path.as_ref().unwrap().as_str() - } else { - store.default_log_path.to_str().unwrap() - }, - local_config.settings.log.max_size, - local_config.settings.log.stderr, - local_config.settings.log.debug, - "", - )); - if store.auth_token(true).is_err() { eprintln!("FATAL: error writing new web API authorization token (likely permission problem)."); return 1; @@ -521,7 +506,7 @@ pub(crate) fn run(store: Arc) -> i32 { let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); let store2 = store.clone(); - let process_exit_value = rt.block_on(async move { run_async(store2, log, local_config).await }); + let process_exit_value = rt.block_on(async move { run_async(store2, local_config).await }); rt.shutdown_timeout(Duration::from_millis(500)); store.erase_pid();