From 1fc4dce835627d3caf96959bae9768db03d54635 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Fri, 31 Jul 2020 13:27:27 -0700 Subject: [PATCH] A bunch of cleanup, make Location serialization format saner, reduce core memory use. --- cmd/zerotier/cli/cert.go | 67 ++++++++---- cmd/zerotier/cli/help.go | 13 +-- cmd/zerotier/cli/misc.go | 8 ++ cmd/zerotier/zerotier.go | 10 +- core/Dictionary.hpp | 2 +- core/Expect.hpp | 12 ++- core/Locator.cpp | 35 ++++--- core/Locator.hpp | 44 +++++--- core/Meter.hpp | 5 +- core/Peer.cpp | 2 +- core/Tests.cpp | 4 +- core/zerotier.h | 6 +- pkg/zerotier/api.go | 4 +- pkg/zerotier/certificate.go | 5 +- pkg/zerotier/locator.go | 28 ++--- pkg/zerotier/node.go | 201 ++++++++++++++++-------------------- 16 files changed, 251 insertions(+), 195 deletions(-) diff --git a/cmd/zerotier/cli/cert.go b/cmd/zerotier/cli/cert.go index 84c80b96a..67af8abdb 100644 --- a/cmd/zerotier/cli/cert.go +++ b/cmd/zerotier/cli/cert.go @@ -17,11 +17,9 @@ import ( "encoding/json" "fmt" "io/ioutil" - "os" "zerotier/pkg/zerotier" ) - func Cert(basePath, authToken string, args []string, jsonOutput bool) int { if len(args) < 1 { Help() @@ -52,28 +50,18 @@ func Cert(basePath, authToken string, args []string, jsonOutput bool) int { } case "newcsr": - if len(args) < 3 { + if len(args) != 4 { Help() return 1 } var cs zerotier.CertificateSubject - csb, err := ioutil.ReadFile(args[1]) - if err != nil { - fmt.Printf("ERROR: unable to read subject from %s: %s\n", args[1], err.Error()) - return 1 - } - err = json.Unmarshal(csb, &cs) + err := readJSONFile(args[1], &cs) if err != nil { fmt.Printf("ERROR: unable to read subject from %s: %s\n", args[1], err.Error()) return 1 } var subj zerotier.CertificateSubjectUniqueIDSecret - subjb, err := ioutil.ReadFile(args[2]) - if err != nil { - fmt.Printf("ERROR: unable to read unique ID secret from %s: %s\n", args[2], err.Error()) - return 1 - } - err = json.Unmarshal(subjb, &subj) + err = readJSONFile(args[2], &subj) if err != nil { fmt.Printf("ERROR: unable to read unique ID secret from %s: %s\n", args[2], err.Error()) return 1 @@ -83,13 +71,56 @@ func Cert(basePath, authToken string, args []string, jsonOutput bool) int { fmt.Printf("ERROR: problem creating CSR: %s\n", err.Error()) return 1 } - if len(args) == 3 { - _, _ = os.Stdout.Write(csr) + err = ioutil.WriteFile(args[3], csr, 0644) + if err == nil { + fmt.Printf("Wrote CSR to %s\n", args[3]) } else { - _ = ioutil.WriteFile(args[3], csr, 0644) + fmt.Printf("ERROR: unable to write CSR to %s: %s\n", args[3], err.Error()) + return 1 } case "sign": + if len(args) != 4 { + Help() + return 1 + } + var csr zerotier.Certificate + csrBytes, err := ioutil.ReadFile(args[1]) + if err != nil { + fmt.Printf("ERROR: unable to read CSR from %s: %s\n", args[1], err.Error()) + return 1 + } + c, err := zerotier.NewCertificateFromBytes(csrBytes, false) + if err != nil { + fmt.Printf("ERROR: CSR in %s is invalid: %s\n", args[1], err.Error()) + return 1 + } + id := readIdentity(args[2]) + if id == nil { + fmt.Printf("ERROR: unable to read identity from %s\n", args[2]) + return 1 + } + if !id.HasPrivate() { + fmt.Printf("ERROR: signing identity in %s lacks private key\n", args[2]) + return 1 + } + c, err = csr.Sign(id) + if err != nil { + fmt.Printf("ERROR: error signing CSR or generating certificate: %s\n", err.Error()) + return 1 + } + cb, err := c.Marshal() + if err != nil { + fmt.Printf("ERROR: error marshaling signed certificate: %s\n", err.Error()) + return 1 + } + err = ioutil.WriteFile(args[3], cb, 0644) + if err == nil { + fmt.Printf("Wrote signed certificate to %s\n", args[3]) + } else { + fmt.Printf("ERROR: unable to write signed certificate to %s: %s\n", args[3], err.Error()) + return 1 + } case "verify": diff --git a/cmd/zerotier/cli/help.go b/cmd/zerotier/cli/help.go index 25938b1ce..7b7081945 100644 --- a/cmd/zerotier/cli/help.go +++ b/cmd/zerotier/cli/help.go @@ -75,15 +75,16 @@ Commands: validate Locally validate an identity sign Sign a file with an identity's key verify Verify a signature + certs List certificates cert [args] - Certificate commands show [serial] List or show details of a certificate - newsid [secret] Create a new subject unique ID - newcsr [csr] Create a subject CSR - sign [certificate] Sign a CSR to create a certificate - verify Verify a certificate - import [trust,[trust]] Import certificate into this node + newsid Create a new subject unique ID + newcsr Create a subject CSR + sign Sign a CSR to create a certificate + verify Verify a certificate + import [trust,[trust]] Import certificate into this node rootca Certificate is a root CA - rootset ZeroTier root node set + ztrootset ZeroTier root node set restore Re-import default certificates export [path] Export a certificate from this node delete Delete certificate from this node diff --git a/cmd/zerotier/cli/misc.go b/cmd/zerotier/cli/misc.go index c9c62e933..066f42b26 100644 --- a/cmd/zerotier/cli/misc.go +++ b/cmd/zerotier/cli/misc.go @@ -203,3 +203,11 @@ func networkStatusStr(status int) string { } return "???" } + +func readJSONFile(p string, obj interface{}) error { + b, err := ioutil.ReadFile(p) + if err != nil { + return err + } + return json.Unmarshal(b, obj) +} diff --git a/cmd/zerotier/zerotier.go b/cmd/zerotier/zerotier.go index f41992038..c6f00629d 100644 --- a/cmd/zerotier/zerotier.go +++ b/cmd/zerotier/zerotier.go @@ -81,8 +81,12 @@ func main() { // Reduce Go's thread and memory footprint. This would slow things down if the Go code // were doing a lot, but it's not. It just manages the core and is not directly involved // in pushing a lot of packets around. If that ever changes this should be adjusted. - runtime.GOMAXPROCS(1) - debug.SetGCPercent(15) + if runtime.NumCPU() > 1 { + runtime.GOMAXPROCS(2) + } else { + runtime.GOMAXPROCS(1) + } + debug.SetGCPercent(10) globalOpts := flag.NewFlagSet("global", flag.ContinueOnError) hflag := globalOpts.Bool("h", false, "") // support -h to be canonical with other Unix utilities @@ -142,6 +146,8 @@ func main() { exitCode = cli.Set(basePath, authToken(basePath, *tflag, *tTflag), cmdArgs) case "identity": exitCode = cli.Identity(cmdArgs) + case "certs", "listcerts", "lscerts": // same as "cert show" with no specific serial to show + exitCode = cli.Cert(basePath, authToken(basePath, *tflag, *tTflag), []string{"show"}, *jflag) case "cert": exitCode = cli.Cert(basePath, authToken(basePath, *tflag, *tTflag), cmdArgs, *jflag) } diff --git a/core/Dictionary.hpp b/core/Dictionary.hpp index 167e09b43..393e5dcba 100644 --- a/core/Dictionary.hpp +++ b/core/Dictionary.hpp @@ -211,7 +211,7 @@ public: * @return Number of entries */ ZT_INLINE unsigned int size() const noexcept - { return m_entries.size(); } + { return (unsigned int)m_entries.size(); } /** * @return True if dictionary is not empty diff --git a/core/Expect.hpp b/core/Expect.hpp index 14c8c08f2..940378c74 100644 --- a/core/Expect.hpp +++ b/core/Expect.hpp @@ -20,14 +20,14 @@ /** * Number of buckets to use to maintain a list of expected replies. * - * Making this a power of two improves efficiency a little by allowing bit shift division. + * This must be a power of two. Memory consumed will be about this*4 bytes. */ #define ZT_EXPECT_BUCKETS 32768 /** * 1/2 the TTL for expected replies in milliseconds * - * Making this a power of two improves efficiency a little by allowing bit shift division. + * This must be a power of two. */ #define ZT_EXPECT_TTL 4096LL @@ -39,7 +39,8 @@ namespace ZeroTier { class Expect { public: - ZT_INLINE Expect() + ZT_INLINE Expect() : + m_packetIdSent() {} /** @@ -49,13 +50,14 @@ public: * @param now Current time */ ZT_INLINE void sending(const uint64_t packetId, const int64_t now) noexcept - { m_packetIdSent[Utils::hash64(packetId ^ Utils::s_mapNonce) % ZT_EXPECT_BUCKETS].store((uint32_t)(now / ZT_EXPECT_TTL)); } + { m_packetIdSent[Utils::hash64(packetId ^ Utils::s_mapNonce) % ZT_EXPECT_BUCKETS] = (uint32_t)(now / ZT_EXPECT_TTL); } /** * Check if an OK is expected and if so reset the corresponding bucket. * * This means this call mutates the state. If it returns true, it will - * subsequently return false. This is for replay protection for OKs. + * subsequently return false. This is to filter OKs against replays or + * responses to queries we did not send. * * @param inRePacketId In-re packet ID we're expecting * @param now Current time diff --git a/core/Locator.cpp b/core/Locator.cpp index 606657fef..3b05ba504 100644 --- a/core/Locator.cpp +++ b/core/Locator.cpp @@ -18,7 +18,10 @@ namespace ZeroTier { -Locator::Locator(const char *const str) noexcept +const SharedPtr< const Locator::EndpointAttributes > Locator::EndpointAttributes::DEFAULT(new Locator::EndpointAttributes()); + +Locator::Locator(const char *const str) noexcept : + __refCount(0) { if (!fromString(str)) { m_ts = 0; @@ -28,16 +31,16 @@ Locator::Locator(const char *const str) noexcept } } -bool Locator::add(const Endpoint &ep, const EndpointAttributes &a) +bool Locator::add(const Endpoint &ep, const SharedPtr< const EndpointAttributes > &a) { - for (Vector< std::pair< Endpoint, EndpointAttributes > >::iterator i(m_endpoints.begin());i!=m_endpoints.end();++i) { + for (Vector< std::pair< Endpoint, SharedPtr< const EndpointAttributes > > >::iterator i(m_endpoints.begin());i!=m_endpoints.end();++i) { if (i->first == ep) { - i->second = a; + i->second = (a) ? a : EndpointAttributes::DEFAULT; return true; } } if (m_endpoints.size() < ZT_LOCATOR_MAX_ENDPOINTS) { - m_endpoints.push_back(std::pair(ep, a)); + m_endpoints.push_back(std::pair >(ep, (a) ? a : EndpointAttributes::DEFAULT)); return true; } return false; @@ -47,7 +50,7 @@ struct p_SortByEndpoint { // There can't be more than one of the same endpoint, so only need to sort // by endpoint. - ZT_INLINE bool operator()(const std::pair< Endpoint, Locator::EndpointAttributes > &a,const std::pair< Endpoint, Locator::EndpointAttributes > &b) const noexcept + ZT_INLINE bool operator()(const std::pair< Endpoint, SharedPtr< const Locator::EndpointAttributes > > &a,const std::pair< Endpoint, SharedPtr< const Locator::EndpointAttributes > > &b) const noexcept { return a.first < b.first; } }; @@ -116,14 +119,14 @@ int Locator::marshal(uint8_t data[ZT_LOCATOR_MARSHAL_SIZE_MAX], const bool exclu Utils::storeBigEndian(data + p, (uint16_t) m_endpoints.size()); p += 2; - for (Vector< std::pair< Endpoint, EndpointAttributes> >::const_iterator e(m_endpoints.begin());e != m_endpoints.end();++e) { + for (Vector< std::pair< Endpoint, SharedPtr< const EndpointAttributes > > >::const_iterator e(m_endpoints.begin());e != m_endpoints.end();++e) { l = e->first.marshal(data + p); if (l <= 0) return -1; p += l; - l = (int)e->second.data[0] + 1; - Utils::copy(data + p, e->second.data, (unsigned int)l); + l = (int)e->second->data[0] + 1; + Utils::copy(data + p, e->second->data, (unsigned int)l); p += l; } @@ -167,7 +170,12 @@ int Locator::unmarshal(const uint8_t *data, const int len) noexcept p += l; l = (int)data[p] + 1; - Utils::copy(m_endpoints[i].second.data, data + p, (unsigned int)l); + if (l <= 1) { + m_endpoints[i].second = EndpointAttributes::DEFAULT; + } else { + m_endpoints[i].second.set(new EndpointAttributes()); + Utils::copy(const_cast< uint8_t * >(m_endpoints[i].second->data), data + p, (unsigned int)l); + } p += l; } @@ -183,7 +191,7 @@ int Locator::unmarshal(const uint8_t *data, const int len) noexcept return -1; m_signature.unsafeSetSize(siglen); Utils::copy(m_signature.data(), data + p, siglen); - p += siglen; + p += (int)siglen; if (unlikely(p > len)) return -1; @@ -204,11 +212,10 @@ ZT_Locator *ZT_Locator_create( try { if ((ts <= 0) || (!endpoints) || (endpointCount == 0) || (!signer)) return nullptr; - ZeroTier::Locator::EndpointAttributes emptyAttributes; ZeroTier::Locator *loc = new ZeroTier::Locator(); for (unsigned int i = 0;i < endpointCount;++i) - loc->add(reinterpret_cast(endpoints)[i], emptyAttributes); - if (!loc->sign(ts, *reinterpret_cast(signer))) { + loc->add(reinterpret_cast< const ZeroTier::Endpoint * >(endpoints)[i], ZeroTier::Locator::EndpointAttributes::DEFAULT); + if (!loc->sign(ts, *reinterpret_cast< const ZeroTier::Identity * >(signer))) { delete loc; return nullptr; } diff --git a/core/Locator.hpp b/core/Locator.hpp index f190526c5..f1c88517f 100644 --- a/core/Locator.hpp +++ b/core/Locator.hpp @@ -53,7 +53,6 @@ namespace ZeroTier { class Locator { friend class SharedPtr< Locator >; - friend class SharedPtr< const Locator >; public: @@ -65,6 +64,14 @@ public: */ struct EndpointAttributes { + friend class SharedPtr< Locator::EndpointAttributes >; + friend class SharedPtr< const Locator::EndpointAttributes >; + + /** + * Default endpoint attributes + */ + static const SharedPtr< const Locator::EndpointAttributes > DEFAULT; + /** * Raw attributes data in the form of a dictionary prefixed by its size. * @@ -93,6 +100,9 @@ public: ZT_INLINE bool operator>=(const EndpointAttributes &a) const noexcept { return !(*this < a); } + + private: + std::atomic< int > __refCount; }; ZT_INLINE Locator() noexcept: @@ -124,7 +134,7 @@ public: /** * @return Endpoints specified in locator */ - ZT_INLINE const Vector< std::pair< Endpoint, EndpointAttributes > > &endpoints() const noexcept + ZT_INLINE const Vector< std::pair< Endpoint, SharedPtr< const EndpointAttributes > > > &endpoints() const noexcept { return m_endpoints; } /** @@ -140,10 +150,10 @@ public: * care not to add duplicates. * * @param ep Endpoint to add - * @param a Endpoint attributes + * @param a Endpoint attributes or NULL to use default * @return True if endpoint was added (or already present), false if locator is full */ - bool add(const Endpoint &ep, const EndpointAttributes &a); + bool add(const Endpoint &ep, const SharedPtr< const EndpointAttributes > &a); /** * Sign this locator @@ -191,27 +201,35 @@ public: static constexpr int marshalSizeMax() noexcept { return ZT_LOCATOR_MARSHAL_SIZE_MAX; } - int marshal(uint8_t data[ZT_LOCATOR_MARSHAL_SIZE_MAX], bool excludeSignature = false) const noexcept; - int unmarshal(const uint8_t *data, int len) noexcept; ZT_INLINE bool operator==(const Locator &l) const noexcept { - return ( - (m_ts == l.m_ts) && - (m_signer == l.m_signer) && - (m_endpoints == l.m_endpoints) && - (m_signature == l.m_signature)); + const unsigned long es = (unsigned long)m_endpoints.size(); + if ((m_ts == l.m_ts) && (m_signer == l.m_signer) && (es == (unsigned long)l.m_endpoints.size()) && (m_signature == l.m_signature)) { + for(unsigned long i=0;i > m_endpoints; + Vector< std::pair< Endpoint, SharedPtr< const EndpointAttributes > > > m_endpoints; FCV< uint8_t, ZT_SIGNATURE_BUFFER_SIZE > m_signature; std::atomic< int > __refCount; }; diff --git a/core/Meter.hpp b/core/Meter.hpp index ef03205a8..65e0cf3a9 100644 --- a/core/Meter.hpp +++ b/core/Meter.hpp @@ -43,7 +43,10 @@ public: * * @param now Start time */ - ZT_INLINE Meter() noexcept + ZT_INLINE Meter() noexcept : + m_counts(), + m_totalExclCounts(0), + m_bucket(0) {} /** diff --git a/core/Peer.cpp b/core/Peer.cpp index 4b895282d..a5094f9ce 100644 --- a/core/Peer.cpp +++ b/core/Peer.cpp @@ -254,7 +254,7 @@ void Peer::pulse(void *const tPtr, const int64_t now, const bool isRoot) // callback (if one was supplied). if (m_locator) { - for (Vector< std::pair >::const_iterator ep(m_locator->endpoints().begin()); ep != m_locator->endpoints().end(); ++ep) { + for (Vector< std::pair > >::const_iterator ep(m_locator->endpoints().begin()); ep != m_locator->endpoints().end(); ++ep) { if (ep->first.type == ZT_ENDPOINT_TYPE_IP_UDP) { if (RR->node->shouldUsePathForZeroTierTraffic(tPtr, m_id, -1, ep->first.ip())) { int64_t < = m_lastTried[ep->first]; diff --git a/core/Tests.cpp b/core/Tests.cpp index ad01c5c29..6afd50a41 100644 --- a/core/Tests.cpp +++ b/core/Tests.cpp @@ -907,8 +907,8 @@ extern "C" const char *ZTT_general() Endpoint ep0(InetAddress::LO4); Endpoint ep1(InetAddress::LO6); Locator loc; - loc.add(ep0, Locator::EndpointAttributes()); - loc.add(ep1, Locator::EndpointAttributes()); + loc.add(ep0, Locator::EndpointAttributes::DEFAULT); + loc.add(ep1, Locator::EndpointAttributes::DEFAULT); loc.sign(now(), v1id); String locStr(loc.toString()); //ZT_T_PRINTF("%s %s %s ",locStr.c_str(),loc.endpoints()[0].toString().c_str(),loc.endpoints()[1].toString().c_str()); diff --git a/core/zerotier.h b/core/zerotier.h index b2ef498d8..aeaef1722 100644 --- a/core/zerotier.h +++ b/core/zerotier.h @@ -51,7 +51,7 @@ extern "C" { /** * IP protocol number for naked IP encapsulation (this is not currently used) */ -#define ZT_DEFAULT_IP_PROTOCOL 193 +#define ZT_DEFAULT_RAW_IP_PROTOCOL 193 /** * Ethernet type for naked Ethernet encapsulation (this is not currently used) @@ -260,6 +260,8 @@ extern "C" { /** * Identity type codes (must be the same as Identity.hpp). + * + * Do not change these integer values. They're protocol constants. */ enum ZT_IdentityType { @@ -2726,7 +2728,7 @@ ZT_SDK_API unsigned int ZT_Locator_endpointCount(const ZT_Locator *loc); */ ZT_SDK_API const ZT_Endpoint *ZT_Locator_endpoint( const ZT_Locator *loc, - const unsigned int ep); + unsigned int ep); /** * Verify this locator's signature diff --git a/pkg/zerotier/api.go b/pkg/zerotier/api.go index be20c5b15..c1afc81c0 100644 --- a/pkg/zerotier/api.go +++ b/pkg/zerotier/api.go @@ -131,7 +131,7 @@ type APIStatus struct { PeerCount int `json:"peerCount"` PathCount int `json:"pathCount"` Identity *Identity `json:"identity"` - InterfaceAddresses []net.IP `json:"interfaceAddresses,omitempty"` + InterfaceAddresses []net.IP `json:"localInterfaceAddresses,omitempty"` MappedExternalAddresses []*InetAddress `json:"mappedExternalAddresses,omitempty"` Version string `json:"version"` VersionMajor int `json:"versionMajor"` @@ -280,7 +280,7 @@ func createAPIServer(basePath string, node *Node) (*http.Server, *http.Server, e PeerCount: len(peers), PathCount: pathCount, Identity: node.Identity(), - InterfaceAddresses: node.InterfaceAddresses(), + InterfaceAddresses: node.LocalInterfaceAddresses(), MappedExternalAddresses: nil, Version: fmt.Sprintf("%d.%d.%d", CoreVersionMajor, CoreVersionMinor, CoreVersionRevision), VersionMajor: CoreVersionMajor, diff --git a/pkg/zerotier/certificate.go b/pkg/zerotier/certificate.go index d2dc23a48..b70f4147c 100644 --- a/pkg/zerotier/certificate.go +++ b/pkg/zerotier/certificate.go @@ -27,6 +27,9 @@ const ( CertificateSerialNoSize = 48 CertificateMaxStringLength = int(C.ZT_CERTIFICATE_MAX_STRING_LENGTH) + CertificateLocalTrustFlagRootCA = int(C.ZT_CERTIFICATE_LOCAL_TRUST_FLAG_ROOT_CA) + CertificateLocalTrustFlagZeroTierRootSet = int(C.ZT_CERTIFICATE_LOCAL_TRUST_FLAG_ZEROTIER_ROOT_SET) + CertificateUniqueIdTypeNistP384 = int(C.ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384) CertificateUniqueIdTypeNistP384Size = int(C.ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_SIZE) CertificateUniqueIdTypeNistP384PrivateSize = int(C.ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_PRIVATE_SIZE) @@ -425,7 +428,7 @@ func (c *Certificate) cCertificate() unsafe.Pointer { return unsafe.Pointer(C._ZT_Certificate_clone2(C.uintptr_t(uintptr(unsafe.Pointer(&cc))))) } -// Marshal encodes this certificate as a byte array. +// Marshal encodes this certificate as a byte array (binary format). func (c *Certificate) Marshal() ([]byte, error) { cc := c.cCertificate() if cc == nil { diff --git a/pkg/zerotier/locator.go b/pkg/zerotier/locator.go index d0fbaf051..bfc908f5c 100644 --- a/pkg/zerotier/locator.go +++ b/pkg/zerotier/locator.go @@ -26,7 +26,6 @@ type Locator struct { Timestamp int64 `json:"timestamp"` Fingerprint *Fingerprint `json:"fingerprint"` Endpoints []Endpoint `json:"endpoints"` - String string `json:"string"` cl unsafe.Pointer } @@ -96,6 +95,9 @@ func (loc *Locator) Validate(id *Identity) bool { } func (loc *Locator) Bytes() []byte { + if loc.cl == nil { + return nil + } var buf [4096]byte bl := C.ZT_Locator_marshal(loc.cl, unsafe.Pointer(&buf[0]), 4096) if bl <= 0 { @@ -104,26 +106,28 @@ func (loc *Locator) Bytes() []byte { return buf[0:int(bl)] } +func (loc *Locator) String() string { + if loc.cl == nil { + return "" + } + var buf [4096]C.char + return C.GoString(C.ZT_Locator_toString(loc.cl, &buf[0], 4096)) +} + func (loc *Locator) MarshalJSON() ([]byte, error) { return json.Marshal(loc) } func (loc *Locator) UnmarshalJSON(j []byte) error { - C.ZT_Locator_delete(loc.cl) - loc.cl = unsafe.Pointer(nil) + if loc.cl != nil { + C.ZT_Locator_delete(loc.cl) + loc.cl = unsafe.Pointer(nil) + } err := json.Unmarshal(j, loc) if err != nil { return err } - - sb := []byte(loc.String) - sb = append(sb, 0) - cl := C.ZT_Locator_fromString((*C.char)(unsafe.Pointer(&sb[0]))) - if cl == nil { - return ErrInvalidParameter - } - loc.cl = cl return loc.init(true) } @@ -148,8 +152,6 @@ func (loc *Locator) init(needFinalizer bool) error { for i := 0; i < epc; i++ { loc.Endpoints[i].cep = *C.ZT_Locator_endpoint(loc.cl, C.uint(i)) } - var buf [4096]byte - loc.String = C.GoString(C.ZT_Locator_toString(loc.cl, (*C.char)(unsafe.Pointer(&buf[0])), 4096)) if needFinalizer { runtime.SetFinalizer(loc, locatorFinalizer) } diff --git a/pkg/zerotier/node.go b/pkg/zerotier/node.go index 0583ac7b5..05232517d 100644 --- a/pkg/zerotier/node.go +++ b/pkg/zerotier/node.go @@ -44,27 +44,27 @@ import ( var nullLogger = log.New(ioutil.Discard, "", 0) const ( - NetworkIDStringLength = 16 - NetworkIDLength = 8 - AddressStringLength = 10 - AddressLength = 5 + NetworkIDStringLength = 16 + NetworkIDLength = 8 + AddressStringLength = 10 + AddressLength = 5 + DefaultPort = int(C.ZT_DEFAULT_PORT) + DefaultRawIPProto = int(C.ZT_DEFAULT_RAW_IP_PROTOCOL) + DefaultEthernetProto = int(C.ZT_DEFAULT_ETHERNET_PROTOCOL) + NetworkMaxShortNameLength = int(C.ZT_MAX_NETWORK_SHORT_NAME_LENGTH) - NetworkStatusRequestingConfiguration int = C.ZT_NETWORK_STATUS_REQUESTING_CONFIGURATION - NetworkStatusOK int = C.ZT_NETWORK_STATUS_OK - NetworkStatusAccessDenied int = C.ZT_NETWORK_STATUS_ACCESS_DENIED - NetworkStatusNotFound int = C.ZT_NETWORK_STATUS_NOT_FOUND + NetworkStatusRequestingConfiguration = int(C.ZT_NETWORK_STATUS_REQUESTING_CONFIGURATION) + NetworkStatusOK = int(C.ZT_NETWORK_STATUS_OK) + NetworkStatusAccessDenied = int(C.ZT_NETWORK_STATUS_ACCESS_DENIED) + NetworkStatusNotFound = int(C.ZT_NETWORK_STATUS_NOT_FOUND) - NetworkTypePrivate int = C.ZT_NETWORK_TYPE_PRIVATE - NetworkTypePublic int = C.ZT_NETWORK_TYPE_PUBLIC + NetworkTypePrivate = int(C.ZT_NETWORK_TYPE_PRIVATE) + NetworkTypePublic = int(C.ZT_NETWORK_TYPE_PUBLIC) - networkConfigOpUp int = C.ZT_VIRTUAL_NETWORK_CONFIG_OPERATION_UP - networkConfigOpUpdate int = C.ZT_VIRTUAL_NETWORK_CONFIG_OPERATION_CONFIG_UPDATE - - defaultVirtualNetworkMTU = C.ZT_DEFAULT_MTU - - // maxCNodeRefs is the maximum number of Node instances that can be created in this process. - // This is perfectly fine to increase. - maxCNodeRefs = 8 + networkConfigOpUp = int(C.ZT_VIRTUAL_NETWORK_CONFIG_OPERATION_UP) + networkConfigOpUpdate = int(C.ZT_VIRTUAL_NETWORK_CONFIG_OPERATION_CONFIG_UPDATE) + defaultVirtualNetworkMTU = int(C.ZT_DEFAULT_MTU) + maxCNodeRefs = 8 // perfectly fine to increase this ) var ( @@ -77,8 +77,9 @@ var ( // cNodeRefs maps an index to a *Node cNodeRefs [maxCNodeRefs]*Node - // cNodeRefsUsed maps an index to whether or not the corresponding cNodeRefs[] entry is used. - cNodeRefUsed [maxCNodeRefs]uint32 + // cNodeRefUsed maps an index to whether or not the corresponding cNodeRefs[] entry is used. + // This is accessed atomically to provide a really fast way to gate cNodeRefs. + cNodeRefUsed [maxCNodeRefs]uintptr ) func init() { @@ -96,46 +97,37 @@ type Node struct { // Time this node was created startupTime int64 - // an arbitrary uintptr given to the core as its pointer back to Go's Node instance. - // This is an index in the cNodeRefs array, which is synchronized by way of a set of - // used/free booleans accessed atomically. + // cPtr is an arbitrary pseudo-pointer given to the core to map back to our Go object. + // This is an index into the cNodeRefs array. cPtr uintptr - // networks contains networks we have joined, and networksByMAC by their local MAC address - networks map[NetworkID]*Network - networksByMAC map[MAC]*Network // locked by networksLock - networksLock sync.RWMutex + networks map[NetworkID]*Network + networksByMAC map[MAC]*Network // locked by networksLock + networksLock sync.RWMutex + localInterfaceAddresses map[string]net.IP + localInterfaceAddressesLock sync.Mutex + running uintptr // atomic flag + online uintptr // atomic flag + basePath string + peersPath string + certsPath string + networksPath string + localConfigPath string + infoLogPath string + errorLogPath string + localConfig *LocalConfig + previousLocalConfig *LocalConfig + localConfigLock sync.RWMutex + infoLogW *sizeLimitWriter + errLogW *sizeLimitWriter + traceLogW io.Writer + infoLog *log.Logger + errLog *log.Logger + traceLog *log.Logger + namedSocketAPIServer *http.Server + tcpAPIServer *http.Server - // interfaceAddresses are physical IPs assigned to the local machine. - // These are the detected IPs, not those configured explicitly. They include - // both private and global IPs. - interfaceAddresses map[string]net.IP - interfaceAddressesLock sync.Mutex - - running uint32 - online uint32 - - basePath string - peersPath string - networksPath string - localConfigPath string - infoLogPath string - errorLogPath string - - // localConfig is the current state of local.conf. - localConfig *LocalConfig - previousLocalConfig *LocalConfig - localConfigLock sync.RWMutex - - infoLogW *sizeLimitWriter - errLogW *sizeLimitWriter - traceLogW io.Writer - - infoLog *log.Logger - errLog *log.Logger - traceLog *log.Logger - - // gn is the GoNode instance, see go/native/GoNode.hpp + // gn is the GoNode instance, see serviceiocore/GoNode.hpp gn *C.ZT_GoNode // zn is the underlying ZT_Node (ZeroTier::Node) instance @@ -144,9 +136,6 @@ type Node struct { // id is the identity of this node (includes private key) id *Identity - namedSocketAPIServer *http.Server - tcpAPIServer *http.Server - // runWaitGroup is used to wait for all node goroutines on shutdown. // Any new goroutine is tracked via this wait group so node shutdown can // itself wait until all goroutines have exited. @@ -163,7 +152,7 @@ func NewNode(basePath string) (n *Node, err error) { // returning an error. cPtr := -1 for i := 0; i < maxCNodeRefs; i++ { - if atomic.CompareAndSwapUint32(&cNodeRefUsed[i], 0, 1) { + if atomic.CompareAndSwapUintptr(&cNodeRefUsed[i], 0, 1) { cNodeRefs[i] = n cPtr = i n.cPtr = uintptr(i) @@ -173,20 +162,16 @@ func NewNode(basePath string) (n *Node, err error) { if cPtr < 0 { return nil, ErrInternal } - - // Check and delete node reference pointer if it's non-negative. This helps - // with error handling cleanup. At the end we set cPtr to -1 to disable. defer func() { if cPtr >= 0 { - atomic.StoreUint32(&cNodeRefUsed[cPtr], 0) + atomic.StoreUintptr(&cNodeRefUsed[cPtr], 0) cNodeRefs[cPtr] = nil } }() n.networks = make(map[NetworkID]*Network) n.networksByMAC = make(map[MAC]*Network) - n.interfaceAddresses = make(map[string]net.IP) - + n.localInterfaceAddresses = make(map[string]net.IP) n.running = 1 _ = os.MkdirAll(basePath, 0755) @@ -200,16 +185,12 @@ func NewNode(basePath string) (n *Node, err error) { if _, err = os.Stat(n.peersPath); err != nil { return } + n.certsPath = path.Join(basePath, "certs.d") + _ = os.MkdirAll(n.certsPath, 0755) n.networksPath = path.Join(basePath, "networks.d") _ = os.MkdirAll(n.networksPath, 0755) - if _, err = os.Stat(n.networksPath); err != nil { - return - } n.localConfigPath = path.Join(basePath, "local.conf") - // Read local configuration, initializing with defaults if not found. We - // check for identity.secret's existence to determine if this is a new - // node or one that already existed. This influences some of the defaults. _, isTotallyNewNode := os.Stat(path.Join(basePath, "identity.secret")) n.localConfig = new(LocalConfig) err = n.localConfig.Read(n.localConfigPath, true, isTotallyNewNode != nil) @@ -282,9 +263,8 @@ func NewNode(basePath string) (n *Node, err error) { return nil, err } - cPath := C.CString(basePath) - n.gn = C.ZT_GoNode_new(cPath, C.uintptr_t(n.cPtr)) - C.free(unsafe.Pointer(cPath)) + cPath := cStr(basePath) + n.gn = C.ZT_GoNode_new((*C.char)(unsafe.Pointer(&cPath[0])), C.uintptr_t(n.cPtr)) if n.gn == nil { n.infoLog.Println("FATAL: node initialization failed") return nil, ErrNodeInitFailed @@ -297,13 +277,11 @@ func NewNode(basePath string) (n *Node, err error) { return nil, err } - // Background maintenance goroutine that handles polling for local network changes, cleaning internal data - // structures, syncing local config changes, and numerous other things that must happen from time to time. n.runWaitGroup.Add(1) go func() { defer n.runWaitGroup.Done() lastMaintenanceRun := int64(0) - for atomic.LoadUint32(&n.running) != 0 { + for atomic.LoadUintptr(&n.running) != 0 { time.Sleep(250 * time.Millisecond) nowS := time.Now().Unix() if (nowS - lastMaintenanceRun) >= 30 { @@ -322,7 +300,7 @@ func NewNode(basePath string) (n *Node, err error) { // Close closes this Node and frees its underlying C++ Node structures func (n *Node) Close() { - if atomic.SwapUint32(&n.running, 0) != 0 { + if atomic.SwapUintptr(&n.running, 0) != 0 { if n.namedSocketAPIServer != nil { _ = n.namedSocketAPIServer.Close() } @@ -335,7 +313,7 @@ func (n *Node) Close() { n.runWaitGroup.Wait() cNodeRefs[n.cPtr] = nil - atomic.StoreUint32(&cNodeRefUsed[n.cPtr], 0) + atomic.StoreUintptr(&cNodeRefUsed[n.cPtr], 0) } } @@ -346,16 +324,16 @@ func (n *Node) Address() Address { return n.id.address } func (n *Node) Identity() *Identity { return n.id } // Online returns true if this node can reach something -func (n *Node) Online() bool { return atomic.LoadUint32(&n.online) != 0 } +func (n *Node) Online() bool { return atomic.LoadUintptr(&n.online) != 0 } -// InterfaceAddresses are external IPs belonging to physical interfaces on this machine -func (n *Node) InterfaceAddresses() []net.IP { +// LocalInterfaceAddresses are external IPs belonging to physical interfaces on this machine +func (n *Node) LocalInterfaceAddresses() []net.IP { + n.localInterfaceAddressesLock.Lock() + defer n.localInterfaceAddressesLock.Unlock() var ea []net.IP - n.interfaceAddressesLock.Lock() - for _, a := range n.interfaceAddresses { + for _, a := range n.localInterfaceAddresses { ea = append(ea, a) } - n.interfaceAddressesLock.Unlock() sort.Slice(ea, func(a, b int) bool { return bytes.Compare(ea[a], ea[b]) < 0 }) return ea } @@ -395,14 +373,15 @@ func (n *Node) SetLocalConfig(lc *LocalConfig) (restartRequired bool, err error) // If tap is nil, the default system tap for this OS/platform is used (if available). func (n *Node) Join(nwid NetworkID, controllerFingerprint *Fingerprint, settings *NetworkLocalSettings, tap Tap) (*Network, error) { n.networksLock.RLock() + defer n.networksLock.RUnlock() + if nw, have := n.networks[nwid]; have { n.infoLog.Printf("join network %.16x ignored: already a member", nwid) if settings != nil { - nw.SetLocalSettings(settings) + go nw.SetLocalSettings(settings) // "go" this to avoid possible deadlocks } return nw, nil } - n.networksLock.RUnlock() if tap != nil { panic("non-native taps not yet implemented") @@ -423,11 +402,9 @@ func (n *Node) Join(nwid NetworkID, controllerFingerprint *Fingerprint, settings C.ZT_GoNode_leave(n.gn, C.uint64_t(nwid)) return nil, err } - n.networksLock.Lock() n.networks[nwid] = nw - n.networksLock.Unlock() if settings != nil { - nw.SetLocalSettings(settings) + go nw.SetLocalSettings(settings) } return nw, nil @@ -457,12 +434,12 @@ func (n *Node) Network(nwid NetworkID) *Network { // Networks returns a list of networks that this node has joined func (n *Node) Networks() []*Network { - var nws []*Network n.networksLock.RLock() + defer n.networksLock.RUnlock() + var nws []*Network for _, nw := range n.networks { nws = append(nws, nw) } - n.networksLock.RUnlock() return nws } @@ -471,13 +448,13 @@ func (n *Node) Peers() []*Peer { var peers []*Peer pl := C.ZT_Node_peers(n.zn) if pl != nil { + defer C.ZT_freeQueryResult(unsafe.Pointer(pl)) for i := uintptr(0); i < uintptr(pl.peerCount); i++ { p, _ := newPeerFromCPeer((*C.ZT_Peer)(unsafe.Pointer(uintptr(unsafe.Pointer(pl.peers)) + (i * C.sizeof_ZT_Peer)))) if p != nil { peers = append(peers, p) } } - C.ZT_freeQueryResult(unsafe.Pointer(pl)) } sort.Slice(peers, func(a, b int) bool { return peers[a].Address < peers[b].Address @@ -499,14 +476,13 @@ func (n *Node) Peer(fpOrAddress interface{}) *Peer { } pl := C.ZT_Node_peers(n.zn) if pl != nil { + defer C.ZT_freeQueryResult(unsafe.Pointer(pl)) for i := uintptr(0); i < uintptr(pl.peerCount); i++ { p, _ := newPeerFromCPeer((*C.ZT_Peer)(unsafe.Pointer(uintptr(unsafe.Pointer(pl.peers)) + (i * C.sizeof_ZT_Peer)))) if p != nil && p.Identity.Fingerprint().BestSpecificityEquals(fp) { - C.ZT_freeQueryResult(unsafe.Pointer(pl)) return p } } - C.ZT_freeQueryResult(unsafe.Pointer(pl)) } return nil } @@ -545,6 +521,7 @@ func (n *Node) TryPeer(fpOrAddress interface{}, ep *Endpoint, retries int) bool func (n *Node) ListCertificates() (certs []*Certificate, localTrust []uint, err error) { cl := C.ZT_Node_listCertificates(n.zn) if cl != nil { + defer C.ZT_freeQueryResult(unsafe.Pointer(cl)) for i := uintptr(0); i < uintptr(cl.certCount); i++ { c := newCertificateFromCCertificate(unsafe.Pointer(uintptr(unsafe.Pointer(cl.certs)) + (i * pointerSize))) if c != nil { @@ -553,7 +530,6 @@ func (n *Node) ListCertificates() (certs []*Certificate, localTrust []uint, err localTrust = append(localTrust, uint(lt)) } } - C.ZT_freeQueryResult(unsafe.Pointer(cl)) } return } @@ -615,9 +591,9 @@ func (n *Node) runMaintenance() { if n.localConfig.Settings.SecondaryPort > 0 && n.localConfig.Settings.SecondaryPort < 65536 { ports = append(ports, n.localConfig.Settings.SecondaryPort) } - n.interfaceAddressesLock.Lock() + n.localInterfaceAddressesLock.Lock() for astr, ipn := range interfaceAddresses { - if _, alreadyKnown := n.interfaceAddresses[astr]; !alreadyKnown { + if _, alreadyKnown := n.localInterfaceAddresses[astr]; !alreadyKnown { interfaceAddressesChanged = true ipCstr := C.CString(ipn.String()) for pn, p := range ports { @@ -631,7 +607,7 @@ func (n *Node) runMaintenance() { C.free(unsafe.Pointer(ipCstr)) } } - for astr, ipn := range n.interfaceAddresses { + for astr, ipn := range n.localInterfaceAddresses { if _, stillPresent := interfaceAddresses[astr]; !stillPresent { interfaceAddressesChanged = true ipCstr := C.CString(ipn.String()) @@ -642,8 +618,8 @@ func (n *Node) runMaintenance() { C.free(unsafe.Pointer(ipCstr)) } } - n.interfaceAddresses = interfaceAddresses - n.interfaceAddressesLock.Unlock() + n.localInterfaceAddresses = interfaceAddresses + n.localInterfaceAddressesLock.Unlock() // Update node's interface address list if detected or configured addresses have changed. if interfaceAddressesChanged || n.previousLocalConfig == nil || !reflect.DeepEqual(n.localConfig.Settings.ExplicitAddresses, n.previousLocalConfig.Settings.ExplicitAddresses) { @@ -729,20 +705,17 @@ func (n *Node) makeStateObjectPath(objType int, id []uint64) (string, bool) { case C.ZT_STATE_OBJECT_LOCATOR: fp = path.Join(n.basePath, "locator") case C.ZT_STATE_OBJECT_PEER: - fp = path.Join(n.basePath, "peers.d") - _ = os.Mkdir(fp, 0700) - fp = path.Join(fp, fmt.Sprintf("%.10x.peer", id[0])) + _ = os.Mkdir(n.peersPath, 0700) + fp = path.Join(n.peersPath, fmt.Sprintf("%.10x.peer", id[0])) secret = true case C.ZT_STATE_OBJECT_NETWORK_CONFIG: - fp = path.Join(n.basePath, "networks.d") - _ = os.Mkdir(fp, 0755) - fp = path.Join(fp, fmt.Sprintf("%.16x.conf", id[0])) + _ = os.Mkdir(n.networksPath, 0755) + fp = path.Join(n.networksPath, fmt.Sprintf("%.16x.conf", id[0])) case C.ZT_STATE_OBJECT_TRUST_STORE: fp = path.Join(n.basePath, "truststore") case C.ZT_STATE_OBJECT_CERT: - fp = path.Join(n.basePath, "certs.d") - _ = os.Mkdir(fp, 0755) - fp = path.Join(fp, Base32StdLowerCase.EncodeToString((*[48]byte)(unsafe.Pointer(&id[0]))[:])) + _ = os.Mkdir(n.certsPath, 0755) + fp = path.Join(n.certsPath, Base32StdLowerCase.EncodeToString((*[48]byte)(unsafe.Pointer(&id[0]))[:])) } return fp, secret } @@ -956,8 +929,8 @@ func goZtEvent(gn unsafe.Pointer, eventType C.int, data unsafe.Pointer) { switch eventType { case C.ZT_EVENT_OFFLINE: - atomic.StoreUint32(&node.online, 0) + atomic.StoreUintptr(&node.online, 0) case C.ZT_EVENT_ONLINE: - atomic.StoreUint32(&node.online, 1) + atomic.StoreUintptr(&node.online, 1) } }