diff --git a/rust-zerotier-core/src/address.rs b/rust-zerotier-core/src/address.rs index dd3b37e81..dc0cdac71 100644 --- a/rust-zerotier-core/src/address.rs +++ b/rust-zerotier-core/src/address.rs @@ -32,6 +32,15 @@ impl From<&str> for Address { } } +impl PartialEq for Address { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for Address {} + impl serde::Serialize for Address { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.to_string().as_str()) diff --git a/rust-zerotier-core/src/certificate.rs b/rust-zerotier-core/src/certificate.rs index 29da2f4f2..2b3265134 100644 --- a/rust-zerotier-core/src/certificate.rs +++ b/rust-zerotier-core/src/certificate.rs @@ -25,7 +25,6 @@ use serde::{Deserialize, Serialize}; use crate::*; use crate::bindings::capi as ztcore; -use crate::bindings::capi::ZT_CertificateError; /// Maximum length of a string in a certificate (mostly for the certificate name fields). pub const CERTIFICATE_MAX_STRING_LENGTH: isize = ztcore::ZT_CERTIFICATE_MAX_STRING_LENGTH as isize; @@ -44,6 +43,7 @@ pub const CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_PRIVATE_SIZE: u32 = ztcore::ZT_C ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +#[derive(PartialEq, Eq)] pub struct CertificateSerialNo(pub [u8; 48]); impl CertificateSerialNo { @@ -128,7 +128,7 @@ impl<'de> serde::Deserialize<'de> for CertificateSerialNo { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// /// Type of certificate subject unique ID -#[derive(FromPrimitive, ToPrimitive)] +#[derive(FromPrimitive, ToPrimitive, PartialEq, Eq)] pub enum CertificateUniqueIdType { NistP384 = ztcore::ZT_CertificateUniqueIdType_ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384 as isize } @@ -182,7 +182,7 @@ impl<'de> serde::Deserialize<'de> for CertificateUniqueIdType { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] pub struct CertificateSubjectUniqueIdSecret { pub public: Vec, pub private: Vec, @@ -293,7 +293,7 @@ impl<'de> serde::Deserialize<'de> for CertificateError { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] pub struct CertificateName { #[serde(rename = "serialNo")] pub serial_no: String, @@ -383,7 +383,7 @@ impl CertificateName { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] pub struct CertificateNetwork { pub id: NetworkId, pub controller: Fingerprint, @@ -413,7 +413,7 @@ impl CertificateNetwork { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] pub struct CertificateIdentity { pub identity: Identity, pub locator: Option, @@ -440,7 +440,7 @@ impl CertificateIdentity { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] pub struct CertificateSubject { pub timestamp: i64, pub identities: Vec, @@ -614,7 +614,7 @@ impl CertificateSubject { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] pub struct Certificate { #[serde(rename = "serialNo")] pub serial_no: CertificateSerialNo, @@ -745,6 +745,8 @@ impl Certificate { } } +implement_to_from_json!(Certificate); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #[cfg(test)] @@ -765,9 +767,11 @@ mod tests { println!("certificate unique ID private: {}", hex::encode(uid.private).as_str()); } - /* #[test] - fn cert_encoding_decoding() { + fn cert_encode_decode() { + let uid = CertificateSubjectUniqueIdSecret::new(CertificateUniqueIdType::NistP384); + let id0 = Identity::new_generate(IdentityType::NistP384).ok().unwrap(); + let mut cert = Certificate{ serial_no: CertificateSerialNo::new(), flags: 1, @@ -780,8 +784,46 @@ mod tests { max_path_length: 123, signature: Vec::new() }; - cert.serial_no.0[1] = 2; + cert.serial_no.0[1] = 99; + cert.subject.timestamp = 5; + cert.subject.identities.push(CertificateIdentity{ + identity: id0.clone(), + locator: None + }); + cert.subject.networks.push(CertificateNetwork{ + id: NetworkId(0xdeadbeef), + controller: id0.fingerprint() + }); cert.subject.certificates.push(CertificateSerialNo::new()); + cert.subject.update_urls.push(String::from("http://foo.bar")); + cert.subject.name = CertificateName{ + serial_no: String::from("12345"), + common_name: String::from("foo"), + country: String::from("bar"), + organization: String::from("baz"), + unit: String::from("asdf"), + locality: String::from("qwerty"), + province: String::from("province"), + street_address: String::from("street address"), + postal_code: String::from("postal code"), + email: String::from("nobody@nowhere.org"), + url: String::from("https://www.zerotier.com/"), + host: String::from("zerotier.com") + }; + + //println!("{}", cert.to_json().as_str()); + + unsafe { + let cert_capi = cert.to_capi(); + let cert2 = Certificate::new_from_capi(&cert_capi.certificate); + assert!(cert == cert2); + //println!("{}", cert2.to_json().as_str()); + } + + { + let cert2 = Certificate::new_from_json(cert.to_json().as_str()); + assert!(cert2.is_ok()); + assert!(cert2.ok().unwrap() == cert); + } } - */ } diff --git a/rust-zerotier-core/src/endpoint.rs b/rust-zerotier-core/src/endpoint.rs index 8d3e3fdf5..2fafbe141 100644 --- a/rust-zerotier-core/src/endpoint.rs +++ b/rust-zerotier-core/src/endpoint.rs @@ -92,6 +92,14 @@ impl ToString for Endpoint { } } +impl PartialEq for Endpoint { + fn eq(&self, other: &Endpoint) -> bool { + self.to_string() == other.to_string() + } +} + +impl Eq for Endpoint {} + impl serde::Serialize for Endpoint { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.to_string().as_str()) diff --git a/rust-zerotier-core/src/fingerprint.rs b/rust-zerotier-core/src/fingerprint.rs index 747f4bbec..18f8de4b4 100644 --- a/rust-zerotier-core/src/fingerprint.rs +++ b/rust-zerotier-core/src/fingerprint.rs @@ -18,6 +18,7 @@ use std::os::raw::{c_char, c_int}; use crate::*; use crate::bindings::capi as ztcore; +#[derive(PartialEq, Eq)] pub struct Fingerprint { pub address: Address, pub hash: [u8; 48] diff --git a/rust-zerotier-core/src/identity.rs b/rust-zerotier-core/src/identity.rs index 36c6475a0..f246cf871 100644 --- a/rust-zerotier-core/src/identity.rs +++ b/rust-zerotier-core/src/identity.rs @@ -87,7 +87,7 @@ impl Identity { /// Convert to a string and include the private key if present. /// If the private key is not present this is the same as to_string(). - #[inline] + #[inline(always)] pub fn to_secret_string(&self) -> String { self.intl_to_string(true) } @@ -150,6 +150,14 @@ impl Identity { } } +impl PartialEq for Identity { + fn eq(&self, other: &Self) -> bool { + self.intl_to_string(false) == other.intl_to_string(false) + } +} + +impl Eq for Identity {} + impl Clone for Identity { fn clone(&self) -> Identity { unsafe { @@ -169,7 +177,7 @@ impl Drop for Identity { } impl ToString for Identity { - #[inline] + #[inline(always)] fn to_string(&self) -> String { self.intl_to_string(false) } @@ -208,12 +216,14 @@ impl<'de> serde::Deserialize<'de> for Identity { #[cfg(test)] mod tests { use crate::*; + use crate::StateObjectType::IdentitySecret; #[test] fn identity() { let test1 = Identity::new_generate(IdentityType::Curve25519); assert!(test1.is_ok()); let test1 = test1.ok().unwrap(); + assert!(test1.has_private()); let test2 = Identity::new_generate(IdentityType::NistP384); assert!(test2.is_ok()); @@ -221,5 +231,40 @@ mod tests { println!("test type 0: {}", test1.to_secret_string()); println!("test type 1: {}", test2.to_secret_string()); + + assert!(test1.clone() == test1); + + let test12 = Identity::new_from_string(test1.to_string().as_str()); + assert!(test12.is_ok()); + let test12 = test12.ok().unwrap(); + assert!(!test12.has_private()); + let test22 = Identity::new_from_string(test2.to_string().as_str()); + assert!(test22.is_ok()); + let test22 = test22.ok().unwrap(); + assert!(test1 == test12); + assert!(test2 == test22); + + println!("test type 0, from string: {}", test12.to_string()); + println!("test type 1, from string: {}", test22.to_string()); + + let from_str_fail = Identity::new_from_string("asdf:foo:invalid"); + assert!(from_str_fail.is_err()); + + let mut to_sign: [u8; 4] = [ 1,2,3,4 ]; + + let signed = test1.sign(&to_sign); + assert!(signed.is_ok()); + let signed = signed.ok().unwrap(); + assert!(test1.verify(&to_sign, signed.as_ref())); + to_sign[0] = 2; + assert!(!test1.verify(&to_sign, signed.as_ref())); + to_sign[0] = 1; + + let signed = test2.sign(&to_sign); + assert!(signed.is_ok()); + let signed = signed.ok().unwrap(); + assert!(test2.verify(&to_sign, signed.as_ref())); + to_sign[0] = 2; + assert!(!test2.verify(&to_sign, signed.as_ref())); } } diff --git a/rust-zerotier-core/src/lib.rs b/rust-zerotier-core/src/lib.rs index eedc104b8..690e07660 100644 --- a/rust-zerotier-core/src/lib.rs +++ b/rust-zerotier-core/src/lib.rs @@ -183,7 +183,6 @@ pub unsafe fn cstr_to_string(cstr: *const c_char, max_len: isize) -> String { String::new() } -/* #[macro_export(crate)] macro_rules! implement_to_from_json { ($struct_name:ident) => { @@ -206,4 +205,3 @@ macro_rules! implement_to_from_json { } }; } -*/ diff --git a/rust-zerotier-core/src/locator.rs b/rust-zerotier-core/src/locator.rs index 1ab47b51b..912d2279e 100644 --- a/rust-zerotier-core/src/locator.rs +++ b/rust-zerotier-core/src/locator.rs @@ -96,6 +96,14 @@ impl ToString for Locator { } } +impl PartialEq for Locator { + fn eq(&self, other: &Locator) -> bool { + self.to_string() == other.to_string() + } +} + +impl Eq for Locator {} + impl serde::Serialize for Locator { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.to_string().as_str()) diff --git a/rust-zerotier-core/src/mac.rs b/rust-zerotier-core/src/mac.rs index 4ba140131..e712db1bb 100644 --- a/rust-zerotier-core/src/mac.rs +++ b/rust-zerotier-core/src/mac.rs @@ -38,6 +38,15 @@ impl serde::Serialize for MAC { } } +impl PartialEq for MAC { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for MAC {} + struct AddressVisitor; impl<'de> serde::de::Visitor<'de> for AddressVisitor { diff --git a/rust-zerotier-core/src/networkid.rs b/rust-zerotier-core/src/networkid.rs index 330ffdb57..e66c20397 100644 --- a/rust-zerotier-core/src/networkid.rs +++ b/rust-zerotier-core/src/networkid.rs @@ -41,6 +41,15 @@ impl From<&str> for NetworkId { } } +impl PartialEq for NetworkId { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for NetworkId {} + impl serde::Serialize for NetworkId { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.to_string().as_str())