diff --git a/CMakeLists.txt b/CMakeLists.txt index fdc46973d..3dd2d8034 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -248,11 +248,11 @@ if(NOT PACKAGE_STATIC) if(WIN32) set(GO_EXE_NAME "zerotier.exe") - set(GO_SERVICE_TESTS_EXE_NAME "go_service_tests.exe") + set(GO_SERVICE_TESTS_EXE_NAME "zt_service_tests.exe") set(GO_EXTRA_LIBRARIES "-lstdc++ -lwsock32 -lws2_32 -liphlpapi -lole32 -loleaut32 -lrpcrt4 -luuid") else(WIN32) set(GO_EXE_NAME "zerotier") - set(GO_SERVICE_TESTS_EXE_NAME "go_service_tests") + set(GO_SERVICE_TESTS_EXE_NAME "zt_service_tests") if(CMAKE_SYSTEM_NAME MATCHES "Linux") set(GO_EXTRA_LIBRARIES "-lstdc++") if(BUILD_ARM_V5) diff --git a/cmd/zt_service_tests/certificate.go b/cmd/zt_service_tests/certificate.go index 56efd1fe0..b5c273b82 100644 --- a/cmd/zt_service_tests/certificate.go +++ b/cmd/zt_service_tests/certificate.go @@ -107,5 +107,28 @@ func TestCertificate() bool { return false } + fmt.Printf("Checking certificate marshal/unmarshal... ") + cb, err := c.Marshal() + if err != nil { + fmt.Printf("marshal FAILED (%s)\n", err.Error()) + return false + } + fmt.Printf("marshal: %d bytes ", len(cb)) + c2, err = zerotier.NewCertificateFromBytes(cb, false) + if err != nil { + fmt.Printf("unmarshal FAILED (%s)\n", err.Error()) + return false + } + cb2, err := c2.Marshal() + if err != nil { + fmt.Printf("second marshal FAILED (%s)\n", err.Error()) + return false + } + if !bytes.Equal(cb, cb2) { + fmt.Printf("FAILED (results not equal)\n") + return false + } + fmt.Println("OK") + return true } diff --git a/core/Certificate.cpp b/core/Certificate.cpp index 024eb1522..f4bb399e7 100644 --- a/core/Certificate.cpp +++ b/core/Certificate.cpp @@ -643,7 +643,7 @@ int ZT_Certificate_sign( } enum ZT_CertificateError ZT_Certificate_decode( - ZT_Certificate **decodedCert, + const ZT_Certificate **decodedCert, const void *cert, int certSize, int verify) @@ -698,10 +698,21 @@ enum ZT_CertificateError ZT_Certificate_verify(const ZT_Certificate *cert) } } -void ZT_Certificate_delete(ZT_Certificate *cert) +const ZT_Certificate *ZT_Certificate_clone(const ZT_Certificate *cert) +{ + try { + if (!cert) + return nullptr; + return (const ZT_Certificate *)(new ZeroTier::Certificate(*cert)); + } catch ( ... ) { + return nullptr; + } +} + +void ZT_Certificate_delete(const ZT_Certificate *cert) { if (cert) - delete reinterpret_cast(cert); + delete (const ZeroTier::Certificate *)(cert); } } diff --git a/core/Utils.cpp b/core/Utils.cpp index 01e54e25a..0ef7f6a4e 100644 --- a/core/Utils.cpp +++ b/core/Utils.cpp @@ -79,7 +79,7 @@ CPUIDRegisters::CPUIDRegisters() noexcept avx2 = avx && ((ebx & (1U << 5U)) != 0); avx512f = avx && ((ebx & (1U << 16U)) != 0); sha = ((ebx & (1U << 29U)) != 0); - fsrm = sha = ((edx & (1U << 4U)) != 0); + fsrm = ((edx & (1U << 4U)) != 0); } const CPUIDRegisters CPUID; diff --git a/core/zerotier.h b/core/zerotier.h index 62a70a769..ef18a89de 100644 --- a/core/zerotier.h +++ b/core/zerotier.h @@ -2794,7 +2794,7 @@ ZT_SDK_API int ZT_Certificate_sign( * @return Certificate error, if any */ ZT_SDK_API enum ZT_CertificateError ZT_Certificate_decode( - ZT_Certificate **decodedCert, + const ZT_Certificate **decodedCert, const void *cert, int certSize, int verify); @@ -2820,12 +2820,24 @@ ZT_SDK_API int ZT_Certificate_encode( */ ZT_SDK_API enum ZT_CertificateError ZT_Certificate_verify(const ZT_Certificate *cert); +/** + * Deep clone a certificate, returning one allocated C-side. + * + * The returned certificate must be freed with ZT_Certificate_delete(). This is + * primarily to make copies of certificates that may contain pointers to objects + * on the stack, etc., before actually returning them. + * + * @param cert Certificate to deep clone + * @return New certificate with copies of all objects + */ +ZT_SDK_API const ZT_Certificate *ZT_Certificate_clone(const ZT_Certificate *cert); + /** * Free a certificate created with ZT_Certificate_decode() * * @param cert Certificate to free */ -ZT_SDK_API void ZT_Certificate_delete(ZT_Certificate *cert); +ZT_SDK_API void ZT_Certificate_delete(const ZT_Certificate *cert); /* ---------------------------------------------------------------------------------------------------------------- */ diff --git a/pkg/zerotier/certificate.go b/pkg/zerotier/certificate.go index 2a5acb2f3..be0f5642b 100644 --- a/pkg/zerotier/certificate.go +++ b/pkg/zerotier/certificate.go @@ -14,10 +14,12 @@ package zerotier // #include "../../serviceiocore/GoGlue.h" +// static inline void *_ZT_Certificate_clone2(uintptr_t p) { return (void *)ZT_Certificate_clone((const ZT_Certificate *)p); } import "C" import ( "fmt" + "runtime" "unsafe" ) @@ -25,7 +27,9 @@ const ( CertificateSerialNoSize = 48 CertificateMaxStringLength = int(C.ZT_CERTIFICATE_MAX_STRING_LENGTH) - CertificateUniqueIdTypeNistP384 = int(C.ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384) + CertificateUniqueIdTypeNistP384 = int(C.ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384) + CertificateUniqueIdTypeNistP384Size = int(C.ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_SIZE) + CertificateUniqueIdTypeNistP384PrivateSize = int(C.ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_PRIVATE_SIZE) ) // CertificateName identifies a real-world entity that owns a subject or has signed a certificate. @@ -82,16 +86,12 @@ type Certificate struct { Signature []byte `json:"signature,omitempty"` } -// CCertificate wraps a pointer to a C ZT_Certificate with any related allocated memory. -// Only the 'C' field should be used directly, and only this field is exported. +// CCertificate just wraps a C pointer so a Go finalizer can be attached to it. +// This allows CCertificate() to be used without requiring the caller to +// explicitly free memory. Ensure that a pointer to this structure is held until +// the underlying C memory is no longer needed. type CCertificate struct { - C unsafe.Pointer - internalCertificate C.ZT_Certificate - internalSubjectIdentities []C.ZT_Certificate_Identity - internalSubjectNetworks []C.ZT_Certificate_Network - internalSubjectCertificates []uintptr - internalSubjectUpdateURLs []uintptr - internalSubjectUpdateURLsData [][]byte + C unsafe.Pointer } func certificateErrorToError(cerr int) error { @@ -279,103 +279,94 @@ func NewCertificateFromCCertificate(ccptr unsafe.Pointer) *Certificate { // // This will return nil if an error occurs, which would indicate an invalid C // structure or one with invalid values. -// -// The returned Go structure bundles this with some objects that have -// to be created to set their pointers in ZT_Certificate. It's easier to -// manage allocation of these in Go and bundle them so Go's GC will clean -// them up automatically when CCertificate is released. Only the 'C' field -// in CCertificate should be directly used. The rest are internal and are -// hidden outside the package. -// -// Ensure that Certificate is not modified until the generated C certificate -// is no longer in use. func (c *Certificate) CCertificate() *CCertificate { - var cc CCertificate - var ccC *C.ZT_Certificate - - cc.C = unsafe.Pointer(&cc.internalCertificate) - ccC = &cc.internalCertificate + var cc C.ZT_Certificate + var subjectIdentities []C.ZT_Certificate_Identity + var subjectNetworks []C.ZT_Certificate_Network + var subjectCertificates []uintptr + var subjectUpdateURLs []uintptr + var subjectUpdateURLsData [][]byte if len(c.SerialNo) == 48 { - copy((*[48]byte)(unsafe.Pointer(&ccC.serialNo[0]))[:], c.SerialNo) + copy((*[48]byte)(unsafe.Pointer(&cc.serialNo[0]))[:], c.SerialNo) } - ccC.flags = C.uint64_t(c.Flags) - ccC.timestamp = C.int64_t(c.Timestamp) - ccC.validity[0] = C.int64_t(c.Validity[0]) - ccC.validity[1] = C.int64_t(c.Validity[1]) + cc.flags = C.uint64_t(c.Flags) + cc.timestamp = C.int64_t(c.Timestamp) + cc.validity[0] = C.int64_t(c.Validity[0]) + cc.validity[1] = C.int64_t(c.Validity[1]) - ccC.subject.timestamp = C.int64_t(c.Subject.Timestamp) + cc.subject.timestamp = C.int64_t(c.Subject.Timestamp) if len(c.Subject.Identities) > 0 { - cc.internalSubjectIdentities = make([]C.ZT_Certificate_Identity, len(c.Subject.Identities)) + subjectIdentities = make([]C.ZT_Certificate_Identity, len(c.Subject.Identities)) for i, id := range c.Subject.Identities { if id.Identity == nil || !id.Identity.initCIdentityPtr() { return nil } - cc.internalSubjectIdentities[i].identity = id.Identity.cid + subjectIdentities[i].identity = id.Identity.cid if id.Locator != nil { - cc.internalSubjectIdentities[i].locator = id.Locator.cl + subjectIdentities[i].locator = id.Locator.cl } } - ccC.subject.identities = &cc.internalSubjectIdentities[0] - ccC.subject.identityCount = C.uint(len(c.Subject.Identities)) + cc.subject.identities = &subjectIdentities[0] + cc.subject.identityCount = C.uint(len(c.Subject.Identities)) } if len(c.Subject.Networks) > 0 { - cc.internalSubjectNetworks = make([]C.ZT_Certificate_Network, len(c.Subject.Networks)) + subjectNetworks = make([]C.ZT_Certificate_Network, len(c.Subject.Networks)) for i, n := range c.Subject.Networks { - cc.internalSubjectNetworks[i].id = C.uint64_t(n.ID) - cc.internalSubjectNetworks[i].controller.address = C.uint64_t(n.Controller.Address) + subjectNetworks[i].id = C.uint64_t(n.ID) + subjectNetworks[i].controller.address = C.uint64_t(n.Controller.Address) if len(n.Controller.Hash) == 48 { - copy((*[48]byte)(unsafe.Pointer(&cc.internalSubjectNetworks[i].controller.hash[0]))[:], n.Controller.Hash) + copy((*[48]byte)(unsafe.Pointer(&subjectNetworks[i].controller.hash[0]))[:], n.Controller.Hash) } } - ccC.subject.networks = &cc.internalSubjectNetworks[0] - ccC.subject.networkCount = C.uint(len(c.Subject.Networks)) + cc.subject.networks = &subjectNetworks[0] + cc.subject.networkCount = C.uint(len(c.Subject.Networks)) } if len(c.Subject.Certificates) > 0 { - cc.internalSubjectCertificates = make([]uintptr, len(c.Subject.Certificates)) + subjectCertificates = make([]uintptr, len(c.Subject.Certificates)) for i, cert := range c.Subject.Certificates { if len(cert) != 48 { return nil } - cc.internalSubjectCertificates[i] = uintptr(unsafe.Pointer(&cert[0])) + subjectCertificates[i] = uintptr(unsafe.Pointer(&cert[0])) } - ccC.subject.certificates = (**C.uint8_t)(unsafe.Pointer(&cc.internalSubjectCertificates[0])) - ccC.subject.certificateCount = C.uint(len(c.Subject.Certificates)) + cc.subject.certificates = (**C.uint8_t)(unsafe.Pointer(&subjectCertificates[0])) + cc.subject.certificateCount = C.uint(len(c.Subject.Certificates)) } if len(c.Subject.UpdateURLs) > 0 { - cc.internalSubjectUpdateURLs = make([]uintptr, len(c.Subject.UpdateURLs)) - cc.internalSubjectUpdateURLsData = make([][]byte, len(c.Subject.UpdateURLs)) + subjectUpdateURLs = make([]uintptr, len(c.Subject.UpdateURLs)) + subjectUpdateURLsData = make([][]byte, len(c.Subject.UpdateURLs)) for i, u := range c.Subject.UpdateURLs { - cc.internalSubjectUpdateURLsData[i] = stringAsZeroTerminatedBytes(u) - cc.internalSubjectUpdateURLs[i] = uintptr(unsafe.Pointer(&cc.internalSubjectUpdateURLsData[0][0])) + subjectUpdateURLsData[i] = stringAsZeroTerminatedBytes(u) + subjectUpdateURLs[i] = uintptr(unsafe.Pointer(&subjectUpdateURLsData[0][0])) } - ccC.subject.updateURLs = (**C.char)(unsafe.Pointer(&cc.internalSubjectUpdateURLs[0])) - ccC.subject.updateURLCount = C.uint(len(c.Subject.UpdateURLs)) + cc.subject.updateURLs = (**C.char)(unsafe.Pointer(&subjectUpdateURLs[0])) + cc.subject.updateURLCount = C.uint(len(c.Subject.UpdateURLs)) } - cStrCopy(unsafe.Pointer(&ccC.subject.name.serialNo[0]), CertificateMaxStringLength+1, c.Subject.Name.SerialNo) - cStrCopy(unsafe.Pointer(&ccC.subject.name.commonName[0]), CertificateMaxStringLength+1, c.Subject.Name.CommonName) - cStrCopy(unsafe.Pointer(&ccC.subject.name.country[0]), CertificateMaxStringLength+1, c.Subject.Name.Country) - cStrCopy(unsafe.Pointer(&ccC.subject.name.organization[0]), CertificateMaxStringLength+1, c.Subject.Name.Organization) - cStrCopy(unsafe.Pointer(&ccC.subject.name.unit[0]), CertificateMaxStringLength+1, c.Subject.Name.Unit) - cStrCopy(unsafe.Pointer(&ccC.subject.name.locality[0]), CertificateMaxStringLength+1, c.Subject.Name.Locality) - cStrCopy(unsafe.Pointer(&ccC.subject.name.province[0]), CertificateMaxStringLength+1, c.Subject.Name.Province) - cStrCopy(unsafe.Pointer(&ccC.subject.name.streetAddress[0]), CertificateMaxStringLength+1, c.Subject.Name.StreetAddress) - cStrCopy(unsafe.Pointer(&ccC.subject.name.postalCode[0]), CertificateMaxStringLength+1, c.Subject.Name.PostalCode) - cStrCopy(unsafe.Pointer(&ccC.subject.name.email[0]), CertificateMaxStringLength+1, c.Subject.Name.Email) - cStrCopy(unsafe.Pointer(&ccC.subject.name.url[0]), CertificateMaxStringLength+1, c.Subject.Name.URL) - cStrCopy(unsafe.Pointer(&ccC.subject.name.host[0]), CertificateMaxStringLength+1, c.Subject.Name.Host) + cStrCopy(unsafe.Pointer(&cc.subject.name.serialNo[0]), CertificateMaxStringLength+1, c.Subject.Name.SerialNo) + cStrCopy(unsafe.Pointer(&cc.subject.name.commonName[0]), CertificateMaxStringLength+1, c.Subject.Name.CommonName) + cStrCopy(unsafe.Pointer(&cc.subject.name.country[0]), CertificateMaxStringLength+1, c.Subject.Name.Country) + cStrCopy(unsafe.Pointer(&cc.subject.name.organization[0]), CertificateMaxStringLength+1, c.Subject.Name.Organization) + cStrCopy(unsafe.Pointer(&cc.subject.name.unit[0]), CertificateMaxStringLength+1, c.Subject.Name.Unit) + cStrCopy(unsafe.Pointer(&cc.subject.name.locality[0]), CertificateMaxStringLength+1, c.Subject.Name.Locality) + cStrCopy(unsafe.Pointer(&cc.subject.name.province[0]), CertificateMaxStringLength+1, c.Subject.Name.Province) + cStrCopy(unsafe.Pointer(&cc.subject.name.streetAddress[0]), CertificateMaxStringLength+1, c.Subject.Name.StreetAddress) + cStrCopy(unsafe.Pointer(&cc.subject.name.postalCode[0]), CertificateMaxStringLength+1, c.Subject.Name.PostalCode) + cStrCopy(unsafe.Pointer(&cc.subject.name.email[0]), CertificateMaxStringLength+1, c.Subject.Name.Email) + cStrCopy(unsafe.Pointer(&cc.subject.name.url[0]), CertificateMaxStringLength+1, c.Subject.Name.URL) + cStrCopy(unsafe.Pointer(&cc.subject.name.host[0]), CertificateMaxStringLength+1, c.Subject.Name.Host) if len(c.Subject.UniqueID) > 0 { - ccC.subject.uniqueId = (*C.uint8_t)(unsafe.Pointer(&c.Subject.UniqueID[0])) - ccC.subject.uniqueIdSize = C.uint(len(c.Subject.UniqueID)) + cc.subject.uniqueId = (*C.uint8_t)(unsafe.Pointer(&c.Subject.UniqueID[0])) + cc.subject.uniqueIdSize = C.uint(len(c.Subject.UniqueID)) if len(c.Subject.UniqueIDProofSignature) > 0 { - ccC.subject.uniqueIdProofSignature = (*C.uint8_t)(unsafe.Pointer(&c.Subject.UniqueIDProofSignature[0])) - ccC.subject.uniqueIdProofSignatureSize = C.uint(len(c.Subject.UniqueIDProofSignature)) + cc.subject.uniqueIdProofSignature = (*C.uint8_t)(unsafe.Pointer(&c.Subject.UniqueIDProofSignature[0])) + cc.subject.uniqueIdProofSignatureSize = C.uint(len(c.Subject.UniqueIDProofSignature)) } } @@ -383,35 +374,42 @@ func (c *Certificate) CCertificate() *CCertificate { if !c.Issuer.initCIdentityPtr() { return nil } - ccC.issuer = c.Issuer.cid + cc.issuer = c.Issuer.cid } - cStrCopy(unsafe.Pointer(&ccC.issuerName.serialNo[0]), CertificateMaxStringLength+1, c.IssuerName.SerialNo) - cStrCopy(unsafe.Pointer(&ccC.issuerName.commonName[0]), CertificateMaxStringLength+1, c.IssuerName.CommonName) - cStrCopy(unsafe.Pointer(&ccC.issuerName.country[0]), CertificateMaxStringLength+1, c.IssuerName.Country) - cStrCopy(unsafe.Pointer(&ccC.issuerName.organization[0]), CertificateMaxStringLength+1, c.IssuerName.Organization) - cStrCopy(unsafe.Pointer(&ccC.issuerName.unit[0]), CertificateMaxStringLength+1, c.IssuerName.Unit) - cStrCopy(unsafe.Pointer(&ccC.issuerName.locality[0]), CertificateMaxStringLength+1, c.IssuerName.Locality) - cStrCopy(unsafe.Pointer(&ccC.issuerName.province[0]), CertificateMaxStringLength+1, c.IssuerName.Province) - cStrCopy(unsafe.Pointer(&ccC.issuerName.streetAddress[0]), CertificateMaxStringLength+1, c.IssuerName.StreetAddress) - cStrCopy(unsafe.Pointer(&ccC.issuerName.postalCode[0]), CertificateMaxStringLength+1, c.IssuerName.PostalCode) - cStrCopy(unsafe.Pointer(&ccC.issuerName.email[0]), CertificateMaxStringLength+1, c.IssuerName.Email) - cStrCopy(unsafe.Pointer(&ccC.issuerName.url[0]), CertificateMaxStringLength+1, c.IssuerName.URL) - cStrCopy(unsafe.Pointer(&ccC.issuerName.host[0]), CertificateMaxStringLength+1, c.IssuerName.Host) + cStrCopy(unsafe.Pointer(&cc.issuerName.serialNo[0]), CertificateMaxStringLength+1, c.IssuerName.SerialNo) + cStrCopy(unsafe.Pointer(&cc.issuerName.commonName[0]), CertificateMaxStringLength+1, c.IssuerName.CommonName) + cStrCopy(unsafe.Pointer(&cc.issuerName.country[0]), CertificateMaxStringLength+1, c.IssuerName.Country) + cStrCopy(unsafe.Pointer(&cc.issuerName.organization[0]), CertificateMaxStringLength+1, c.IssuerName.Organization) + cStrCopy(unsafe.Pointer(&cc.issuerName.unit[0]), CertificateMaxStringLength+1, c.IssuerName.Unit) + cStrCopy(unsafe.Pointer(&cc.issuerName.locality[0]), CertificateMaxStringLength+1, c.IssuerName.Locality) + cStrCopy(unsafe.Pointer(&cc.issuerName.province[0]), CertificateMaxStringLength+1, c.IssuerName.Province) + cStrCopy(unsafe.Pointer(&cc.issuerName.streetAddress[0]), CertificateMaxStringLength+1, c.IssuerName.StreetAddress) + cStrCopy(unsafe.Pointer(&cc.issuerName.postalCode[0]), CertificateMaxStringLength+1, c.IssuerName.PostalCode) + cStrCopy(unsafe.Pointer(&cc.issuerName.email[0]), CertificateMaxStringLength+1, c.IssuerName.Email) + cStrCopy(unsafe.Pointer(&cc.issuerName.url[0]), CertificateMaxStringLength+1, c.IssuerName.URL) + cStrCopy(unsafe.Pointer(&cc.issuerName.host[0]), CertificateMaxStringLength+1, c.IssuerName.Host) if len(c.ExtendedAttributes) > 0 { - ccC.extendedAttributes = (*C.uint8_t)(unsafe.Pointer(&c.ExtendedAttributes[0])) - ccC.extendedAttributesSize = C.uint(len(c.ExtendedAttributes)) + cc.extendedAttributes = (*C.uint8_t)(unsafe.Pointer(&c.ExtendedAttributes[0])) + cc.extendedAttributesSize = C.uint(len(c.ExtendedAttributes)) } - ccC.maxPathLength = C.uint(c.MaxPathLength) + cc.maxPathLength = C.uint(c.MaxPathLength) if len(c.Signature) > 0 { - ccC.signature = (*C.uint8_t)(unsafe.Pointer(&c.Signature[0])) - ccC.signatureSize = C.uint(len(c.Signature)) + cc.signature = (*C.uint8_t)(unsafe.Pointer(&c.Signature[0])) + cc.signatureSize = C.uint(len(c.Signature)) } - return &cc + // HACK: pass pointer to cc as uintptr to disable Go's protection against go pointers to + // go pointers, as the C function called here will make a deep clone and then we are going + // to throw away 'cc' and its components. + cc2 := &CCertificate{C: unsafe.Pointer(C._ZT_Certificate_clone2(C.uintptr_t(uintptr(unsafe.Pointer(&cc)))))} + runtime.SetFinalizer(cc2, func(obj interface{}) { + C.ZT_Certificate_delete((*C.ZT_Certificate)(obj.(*CCertificate).C)) + }) + return cc2 } // Marshal encodes this certificae as a byte array.