Make database in controller dynamic to prep for having multiple implementations.

This commit is contained in:
Adam Ierymenko 2022-10-23 20:13:03 -04:00
parent 862710f553
commit bc5861c539
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
8 changed files with 97 additions and 84 deletions

View file

@ -1,4 +1,5 @@
use async_trait::async_trait;
use std::error::Error;
use zerotier_network_hypervisor::vl1::{Address, InetAddress, NodeStorage};
use zerotier_network_hypervisor::vl2::NetworkId;
@ -17,14 +18,12 @@ pub enum Change {
#[async_trait]
pub trait Database: Sync + Send + NodeStorage + 'static {
type Error: std::error::Error + Send + 'static;
async fn get_network(&self, id: NetworkId) -> Result<Option<Network>, Box<dyn Error + Send + Sync>>;
async fn save_network(&self, obj: Network) -> Result<(), Box<dyn Error + Send + Sync>>;
async fn get_network(&self, id: NetworkId) -> Result<Option<Network>, Self::Error>;
async fn save_network(&self, obj: Network) -> Result<(), Self::Error>;
async fn list_members(&self, network_id: NetworkId) -> Result<Vec<Address>, Self::Error>;
async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Self::Error>;
async fn save_member(&self, obj: Member) -> Result<(), Self::Error>;
async fn list_members(&self, network_id: NetworkId) -> Result<Vec<Address>, Box<dyn Error + Send + Sync>>;
async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Box<dyn Error + Send + Sync>>;
async fn save_member(&self, obj: Member) -> Result<(), Box<dyn Error + Send + Sync>>;
/// Get a receiver that can be used to receive changes made to networks and members, if supported.
///
@ -45,7 +44,11 @@ pub trait Database: Sync + Send + NodeStorage + 'static {
///
/// The default trait implementation uses a brute force method. This should be reimplemented if a
/// more efficient way is available.
async fn list_members_deauthorized_after(&self, network_id: NetworkId, cutoff: i64) -> Result<Vec<Address>, Self::Error> {
async fn list_members_deauthorized_after(
&self,
network_id: NetworkId,
cutoff: i64,
) -> Result<Vec<Address>, Box<dyn Error + Send + Sync>> {
let mut v = Vec::new();
let members = self.list_members(network_id).await?;
for a in members.iter() {
@ -62,7 +65,7 @@ pub trait Database: Sync + Send + NodeStorage + 'static {
///
/// The default trait implementation uses a brute force method. This should be reimplemented if a
/// more efficient way is available.
async fn is_ip_assigned(&self, network_id: NetworkId, ip: &InetAddress) -> Result<bool, Self::Error> {
async fn is_ip_assigned(&self, network_id: NetworkId, ip: &InetAddress) -> Result<bool, Box<dyn Error + Send + Sync>> {
let members = self.list_members(network_id).await?;
for a in members.iter() {
if let Some(m) = self.get_member(network_id, *a).await? {
@ -74,5 +77,5 @@ pub trait Database: Sync + Send + NodeStorage + 'static {
return Ok(false);
}
async fn log_request(&self, obj: RequestLogItem) -> Result<(), Self::Error>;
async fn log_request(&self, obj: RequestLogItem) -> Result<(), Box<dyn Error + Send + Sync>>;
}

View file

@ -1,5 +1,4 @@
use std::error::Error;
use std::fmt::Display;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
@ -21,35 +20,6 @@ use crate::model::*;
const IDENTITY_SECRET_FILENAME: &'static str = "identity.secret";
#[derive(Debug)]
pub enum FileDatabaseError {
InvalidYaml(String),
IoError(String),
}
impl Display for FileDatabaseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidYaml(e) => f.write_str(format!("invalid YAML ({})", e).as_str()),
Self::IoError(e) => f.write_str(format!("I/O error ({})", e).as_str()),
}
}
}
impl Error for FileDatabaseError {}
impl From<serde_yaml::Error> for FileDatabaseError {
fn from(e: serde_yaml::Error) -> Self {
Self::InvalidYaml(e.to_string())
}
}
impl From<zerotier_utils::tokio::io::Error> for FileDatabaseError {
fn from(e: zerotier_utils::tokio::io::Error) -> Self {
Self::IoError(e.to_string())
}
}
/// An in-filesystem database that permits live editing.
///
/// A cache is maintained that contains the actual objects. When an object is live edited,
@ -65,7 +35,7 @@ pub struct FileDatabase {
// TODO: should cache at least hashes and detect changes in the filesystem live.
impl FileDatabase {
pub async fn new<P: AsRef<Path>>(base_path: P) -> Result<Self, Box<dyn Error>> {
pub async fn new<P: AsRef<Path>>(base_path: P) -> Result<Self, Box<dyn Error + Send + Sync>> {
let base_path: PathBuf = base_path.as_ref().into();
let _ = fs::create_dir_all(&base_path).await?;
@ -86,6 +56,11 @@ impl FileDatabase {
}
}
})?);
let _ = watcher.configure(
notify::Config::default()
.with_compare_contents(true)
.with_poll_interval(std::time::Duration::from_secs(2)),
);
watcher.watch(&base_path, RecursiveMode::Recursive)?;
Ok(Self {
@ -151,9 +126,7 @@ impl NodeStorage for FileDatabase {
#[async_trait]
impl Database for FileDatabase {
type Error = FileDatabaseError;
async fn get_network(&self, id: NetworkId) -> Result<Option<Network>, Self::Error> {
async fn get_network(&self, id: NetworkId) -> Result<Option<Network>, Box<dyn Error + Send + Sync>> {
let r = fs::read(self.network_path(id)).await;
if let Ok(raw) = r {
let mut network = serde_yaml::from_slice::<Network>(raw.as_slice())?;
@ -166,7 +139,7 @@ impl Database for FileDatabase {
}
}
async fn save_network(&self, obj: Network) -> Result<(), Self::Error> {
async fn save_network(&self, obj: Network) -> Result<(), Box<dyn Error + Send + Sync>> {
let base_network_path = self.network_path(obj.id);
let _ = fs::create_dir_all(base_network_path.parent().unwrap()).await;
//let _ = fs::write(base_network_path, to_json_pretty(&obj).as_bytes()).await?;
@ -174,7 +147,7 @@ impl Database for FileDatabase {
return Ok(());
}
async fn list_members(&self, network_id: NetworkId) -> Result<Vec<Address>, Self::Error> {
async fn list_members(&self, network_id: NetworkId) -> Result<Vec<Address>, Box<dyn Error + Send + Sync>> {
let mut members = Vec::new();
let mut dir = fs::read_dir(self.base_path.join(network_id.to_string())).await?;
while let Ok(Some(ent)) = dir.next_entry().await {
@ -194,7 +167,7 @@ impl Database for FileDatabase {
Ok(members)
}
async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Self::Error> {
async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Box<dyn Error + Send + Sync>> {
let r = fs::read(self.member_path(network_id, node_id)).await;
if let Ok(raw) = r {
let mut member = serde_yaml::from_slice::<Member>(raw.as_slice())?;
@ -207,7 +180,7 @@ impl Database for FileDatabase {
}
}
async fn save_member(&self, obj: Member) -> Result<(), Self::Error> {
async fn save_member(&self, obj: Member) -> Result<(), Box<dyn Error + Send + Sync>> {
let base_member_path = self.member_path(obj.network_id, obj.node_id);
let _ = fs::create_dir_all(base_member_path.parent().unwrap()).await;
//let _ = fs::write(base_member_path, to_json_pretty(&obj).as_bytes()).await?;
@ -219,7 +192,7 @@ impl Database for FileDatabase {
Some(self.change_sender.subscribe())
}
async fn log_request(&self, obj: RequestLogItem) -> Result<(), Self::Error> {
async fn log_request(&self, obj: RequestLogItem) -> Result<(), Box<dyn Error + Send + Sync>> {
println!("{}", obj.to_string());
Ok(())
}

View file

@ -25,26 +25,26 @@ use crate::model::{AuthorizationResult, Member, RequestLogItem, CREDENTIAL_WINDO
const REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
/// ZeroTier VL2 network controller packet handler, answers VL2 netconf queries.
pub struct Handler<DatabaseImpl: Database> {
inner: Arc<Inner<DatabaseImpl>>,
pub struct Handler {
inner: Arc<Inner>,
}
struct Inner<DatabaseImpl: Database> {
service: RwLock<Weak<VL1Service<DatabaseImpl, Handler<DatabaseImpl>, Handler<DatabaseImpl>>>>,
struct Inner {
service: RwLock<Weak<VL1Service<dyn Database, Handler, Handler>>>,
reaper: Reaper,
daemons: Mutex<Vec<tokio::task::JoinHandle<()>>>, // drop() aborts these
runtime: tokio::runtime::Handle,
database: Arc<DatabaseImpl>,
database: Arc<dyn Database>,
local_identity: Identity,
}
impl<DatabaseImpl: Database> Handler<DatabaseImpl> {
impl Handler {
/// Start an inner protocol handler answer ZeroTier VL2 network controller queries.
pub async fn new(database: Arc<DatabaseImpl>, runtime: tokio::runtime::Handle) -> Result<Arc<Self>, Box<dyn Error>> {
pub async fn new(database: Arc<dyn Database>, runtime: tokio::runtime::Handle) -> Result<Arc<Self>, Box<dyn Error>> {
if let Some(local_identity) = database.load_node_identity() {
assert!(local_identity.secret.is_some());
let inner = Arc::new(Inner::<DatabaseImpl> {
let inner = Arc::new(Inner {
service: RwLock::new(Weak::default()),
reaper: Reaper::new(&runtime),
daemons: Mutex::new(Vec::with_capacity(1)),
@ -69,7 +69,7 @@ impl<DatabaseImpl: Database> Handler<DatabaseImpl> {
/// won't actually do anything. The reference the handler holds is weak to prevent
/// a circular reference, so if the VL1Service is dropped this must be called again to
/// tell the controller handler about a new instance.
pub fn set_service(&self, service: &Arc<VL1Service<DatabaseImpl, Self, Self>>) {
pub fn set_service(&self, service: &Arc<VL1Service<dyn Database, Handler, Handler>>) {
*self.inner.service.write().unwrap() = Arc::downgrade(service);
}
@ -98,9 +98,9 @@ impl<DatabaseImpl: Database> Handler<DatabaseImpl> {
}
// Default PathFilter implementations permit anything.
impl<DatabaseImpl: Database> PathFilter for Handler<DatabaseImpl> {}
impl PathFilter for Handler {}
impl<DatabaseImpl: Database> InnerProtocol for Handler<DatabaseImpl> {
impl InnerProtocol for Handler {
fn handle_packet<HostSystemImpl: HostSystem + ?Sized>(
&self,
_host_system: &HostSystemImpl,
@ -245,7 +245,7 @@ impl<DatabaseImpl: Database> InnerProtocol for Handler<DatabaseImpl> {
}
}
impl<DatabaseImpl: Database> Inner<DatabaseImpl> {
impl Inner {
fn send_network_config(
&self,
peer: &Peer,
@ -299,14 +299,14 @@ impl<DatabaseImpl: Database> Inner<DatabaseImpl> {
}
}
async fn handle_change_notification(self: Arc<Self>, change: Change) {}
async fn handle_change_notification(self: Arc<Self>, _change: Change) {}
async fn handle_network_config_request<HostSystemImpl: HostSystem + ?Sized>(
self: &Arc<Self>,
source_identity: &Identity,
network_id: NetworkId,
now: i64,
) -> Result<(AuthorizationResult, Option<NetworkConfig>), DatabaseImpl::Error> {
) -> Result<(AuthorizationResult, Option<NetworkConfig>), Box<dyn Error + Send + Sync>> {
let network = self.database.get_network(network_id).await?;
if network.is_none() {
// TODO: send error
@ -424,7 +424,7 @@ impl<DatabaseImpl: Database> Inner<DatabaseImpl> {
}
}
impl<DatabaseImpl: Database> Drop for Inner<DatabaseImpl> {
impl Drop for Inner {
fn drop(&mut self) {
for h in self.daemons.lock().unwrap().drain(..) {
h.abort();

View file

@ -15,7 +15,7 @@ use zerotier_utils::exitcode;
use zerotier_utils::tokio::runtime::Runtime;
use zerotier_vl1_service::VL1Service;
async fn run<DatabaseImpl: Database>(database: Arc<DatabaseImpl>, runtime: &Runtime) -> i32 {
async fn run(database: Arc<dyn Database>, runtime: &Runtime) -> i32 {
let handler = Handler::new(database.clone(), runtime.handle().clone()).await;
if handler.is_err() {
eprintln!("FATAL: error initializing handler: {}", handler.err().unwrap().to_string());

View file

@ -141,7 +141,7 @@ fn troo() -> bool {
impl Network {
/// Check member IP assignments and return 'true' if IP assignments were created or modified.
pub async fn check_zt_ip_assignments<DatabaseImpl: Database>(&self, database: &DatabaseImpl, member: &mut Member) -> bool {
pub async fn check_zt_ip_assignments<DatabaseImpl: Database + ?Sized>(&self, database: &DatabaseImpl, member: &mut Member) -> bool {
let mut modified = false;
if self.v4_assign_mode.zt {

View file

@ -32,6 +32,12 @@ use zerotier_utils::thing::Thing;
/// These methods are basically callbacks that the core calls to request or transmit things. They are called
/// during calls to things like wire_recieve() and do_background_tasks().
pub trait HostSystem: Sync + Send + 'static {
/// Type for implementation of NodeStorage.
type Storage: NodeStorage + ?Sized;
/// Path filter implementation for this host.
type PathFilter: PathFilter + ?Sized;
/// Type for local system sockets.
type LocalSocket: Sync + Send + Hash + PartialEq + Eq + Clone + ToString + Sized + 'static;
@ -41,6 +47,12 @@ pub trait HostSystem: Sync + Send + 'static {
/// A VL1 level event occurred.
fn event(&self, event: Event);
/// Get a reference to the local storage implementation at this host.
fn storage(&self) -> &Self::Storage;
/// Get the path filter implementation for this host.
fn path_filter(&self) -> &Self::PathFilter;
/// Get a pooled packet buffer for internal use.
fn get_buffer(&self) -> PooledPacketBuffer;
@ -265,21 +277,20 @@ pub struct Node {
}
impl Node {
pub fn new<HostSystemImpl: HostSystem + ?Sized, NodeStorageImpl: NodeStorage + ?Sized>(
pub fn new<HostSystemImpl: HostSystem + ?Sized>(
host_system: &HostSystemImpl,
storage: &NodeStorageImpl,
auto_generate_identity: bool,
auto_upgrade_identity: bool,
) -> Result<Self, InvalidParameterError> {
let mut id = {
let id = storage.load_node_identity();
let id = host_system.storage().load_node_identity();
if id.is_none() {
if !auto_generate_identity {
return Err(InvalidParameterError("no identity found and auto-generate not enabled"));
} else {
let id = Identity::generate();
host_system.event(Event::IdentityAutoGenerated(id.clone()));
storage.save_node_identity(&id);
host_system.storage().save_node_identity(&id);
id
}
} else {
@ -290,7 +301,7 @@ impl Node {
if auto_upgrade_identity {
let old = id.clone();
if id.upgrade()? {
storage.save_node_identity(&id);
host_system.storage().save_node_identity(&id);
host_system.event(Event::IdentityAutoUpgraded(old, id.clone()));
}
}

View file

@ -138,6 +138,8 @@ impl Peer {
fn learn_path<HostSystemImpl: HostSystem + ?Sized>(&self, host_system: &HostSystemImpl, new_path: &Arc<Path>, time_ticks: i64) {
let mut paths = self.paths.lock().unwrap();
// TODO: check path filter
match &new_path.endpoint {
Endpoint::IpUdp(new_ip) => {
// If this is an IpUdp endpoint, scan the existing paths and replace any that come from

View file

@ -27,9 +27,9 @@ const UPDATE_UDP_BINDINGS_EVERY_SECS: usize = 10;
/// whatever inner protocol implementation is using it. This would typically be VL2 but could be
/// a test harness or just the controller for a controller that runs stand-alone.
pub struct VL1Service<
NodeStorageImpl: NodeStorage + 'static,
PathFilterImpl: PathFilter + 'static,
InnerProtocolImpl: InnerProtocol + 'static,
NodeStorageImpl: NodeStorage + ?Sized + 'static,
PathFilterImpl: PathFilter + ?Sized + 'static,
InnerProtocolImpl: InnerProtocol + ?Sized + 'static,
> {
state: RwLock<VL1ServiceMutableState>,
storage: Arc<NodeStorageImpl>,
@ -46,8 +46,11 @@ struct VL1ServiceMutableState {
running: bool,
}
impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'static, InnerProtocolImpl: InnerProtocol + 'static>
VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
impl<
NodeStorageImpl: NodeStorage + ?Sized + 'static,
PathFilterImpl: PathFilter + ?Sized + 'static,
InnerProtocolImpl: InnerProtocol + ?Sized + 'static,
> VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
{
pub fn new(
storage: Arc<NodeStorageImpl>,
@ -55,7 +58,7 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
path_filter: Arc<PathFilterImpl>,
settings: VL1Settings,
) -> Result<Arc<Self>, Box<dyn Error>> {
let mut service = VL1Service {
let mut service = Self {
state: RwLock::new(VL1ServiceMutableState {
daemons: Vec::with_capacity(2),
udp_sockets: HashMap::with_capacity(8),
@ -72,7 +75,7 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
node_container: None,
};
service.node_container.replace(Node::new(&service, &*service.storage, true, false)?);
service.node_container.replace(Node::new(&service, true, false)?);
let service = Arc::new(service);
let mut daemons = Vec::new();
@ -186,8 +189,11 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
}
}
impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl: InnerProtocol> UdpPacketHandler
for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
impl<
NodeStorageImpl: NodeStorage + ?Sized + 'static,
PathFilterImpl: PathFilter + ?Sized + 'static,
InnerProtocolImpl: InnerProtocol + ?Sized + 'static,
> UdpPacketHandler for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
{
#[inline(always)]
fn incoming_udp_packet(
@ -209,9 +215,14 @@ impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl
}
}
impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl: InnerProtocol> HostSystem
for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
impl<
NodeStorageImpl: NodeStorage + ?Sized + 'static,
PathFilterImpl: PathFilter + ?Sized + 'static,
InnerProtocolImpl: InnerProtocol + ?Sized + 'static,
> HostSystem for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
{
type Storage = NodeStorageImpl;
type PathFilter = PathFilterImpl;
type LocalSocket = crate::LocalSocket;
type LocalInterface = crate::LocalInterface;
@ -227,6 +238,16 @@ impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl
socket.is_valid()
}
#[inline(always)]
fn storage(&self) -> &Self::Storage {
self.storage.as_ref()
}
#[inline(always)]
fn path_filter(&self) -> &Self::PathFilter {
self.path_filter.as_ref()
}
#[inline]
fn get_buffer(&self) -> zerotier_network_hypervisor::protocol::PooledPacketBuffer {
self.buffer_pool.get()
@ -306,8 +327,11 @@ impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl
}
}
impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl: InnerProtocol> Drop
for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
impl<
NodeStorageImpl: NodeStorage + ?Sized + 'static,
PathFilterImpl: PathFilter + ?Sized + 'static,
InnerProtocolImpl: InnerProtocol + ?Sized + 'static,
> Drop for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
{
fn drop(&mut self) {
let mut state = self.state.write().unwrap();