From de4b9e9a16eecab6b731c7b51cb2d08e112a3044 Mon Sep 17 00:00:00 2001
From: Grant Limberg <grant.limberg@zerotier.com>
Date: Wed, 29 Mar 2017 12:52:29 -0700
Subject: [PATCH] Added path checking interface for Java

---
 java/jni/com_zerotierone_sdk_Node.cpp      | 178 +++++++++++++++++++++
 java/src/com/zerotier/sdk/Node.java        |   5 +-
 java/src/com/zerotier/sdk/PathChecker.java |  45 ++++++
 3 files changed, 227 insertions(+), 1 deletion(-)
 create mode 100644 java/src/com/zerotier/sdk/PathChecker.java

diff --git a/java/jni/com_zerotierone_sdk_Node.cpp b/java/jni/com_zerotierone_sdk_Node.cpp
index 47bbc40e2..307add395 100644
--- a/java/jni/com_zerotierone_sdk_Node.cpp
+++ b/java/jni/com_zerotierone_sdk_Node.cpp
@@ -56,6 +56,7 @@ namespace {
             , eventListener(NULL)
             , frameListener(NULL)
             , configListener(NULL)
+            , pathChecker(NULL)
             , callbacks(NULL)
         {
             callbacks = (ZT_Node_Callbacks*)malloc(sizeof(ZT_Node_Callbacks));
@@ -73,6 +74,7 @@ namespace {
             env->DeleteGlobalRef(eventListener);
             env->DeleteGlobalRef(frameListener);
             env->DeleteGlobalRef(configListener);
+            env->DeleteGlobalRef(pathChecker);
 
             free(callbacks);
             callbacks = NULL;
@@ -90,6 +92,7 @@ namespace {
         jobject eventListener;
         jobject frameListener;
         jobject configListener;
+        jobject pathChecker;
 
         ZT_Node_Callbacks *callbacks;
     };
@@ -487,6 +490,165 @@ namespace {
         return retval;
     }
 
+    int PathCheckFunction(ZT_Node *node,
+        void *userPtr,
+        void *threadPtr,
+        uint64_t address,
+        const struct sockaddr_storage *localAddress,
+        const struct sockaddr_storage *remoteAddress)
+    {
+        JniRef *ref = (JniRef*)userPtr;
+        assert(ref->node == node);
+
+        if(ref->pathChecker == NULL) {
+            return true;
+        }
+
+        JNIEnv *env = NULL;
+        ref->jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
+
+        jclass pathCheckerClass = env->GetObjectClass(ref->pathChecker);
+        if(pathCheckerClass == NULL)
+        {
+            LOGE("Couldn't find class for PathChecker instance");
+            return true;
+        }
+
+        jmethodID pathCheckCallbackMethod = lookup.findMethod(pathCheckerClass,
+            "onPathCheck", "(JLjava/net/InetSocketAddress;Ljava/net/InetSocketAddress;)Z");
+        if(pathCheckCallbackMethod == NULL)
+        {
+            LOGE("Couldn't find onPathCheck method implementation");
+            return true;
+        }
+
+        jobject localAddressObj = NULL;
+        jobject remoteAddressObj = NULL;
+
+        if(memcmp(localAddress, &ZT_SOCKADDR_NULL, sizeof(sockaddr_storage)) != 0)
+        {
+            localAddressObj = newInetSocketAddress(env, *localAddress);
+        }
+        if(memcmp(remoteAddress, &ZT_SOCKADDR_NULL, sizeof(sockaddr_storage)) != 0)
+        {
+            remoteAddressObj = newInetSocketAddress(env, *remoteAddress);
+        }
+
+        return env->CallBooleanMethod(ref->pathChecker, pathCheckCallbackMethod, address, localAddressObj, remoteAddressObj);
+    }
+
+    int PathLookupFunction(ZT_Node *node,
+        void *userPtr,
+        void *threadPtr,
+        uint64_t address,
+        int ss_family,
+        struct sockaddr_storage *result)
+    {
+        JniRef *ref = (JniRef*)userPtr;
+        assert(ref->node == node);
+
+        if(ref->pathChecker == NULL) {
+            return false;
+        }
+
+        JNIEnv *env = NULL;
+        ref->jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
+
+        jclass pathCheckerClass = env->GetObjectClass(ref->pathChecker);
+        if(pathCheckerClass == NULL)
+        {
+            LOGE("Couldn't find class for PathChecker instance");
+            return false;
+        }
+
+        jmethodID pathLookupMethod = lookup.findMethod(pathCheckerClass,
+            "onPathLookup", "(JI)Ljava/net/InetSocketAddress;");
+        if(pathLookupMethod == NULL) {
+            return false;
+        }
+
+        jobject sockAddressObject = env->CallObjectMethod(ref->pathChecker, pathLookupMethod, address, ss_family);
+        if(sockAddressObject == NULL)
+        {
+            LOGE("Unable to call onPathLookup implementation");
+            return false;
+        }
+
+        jclass inetSockAddressClass = env->GetObjectClass(sockAddressObject);
+        if(inetSockAddressClass == NULL)
+        {
+            LOGE("Unable to find InetSocketAddress class");
+            return false;
+        }
+
+        jmethodID getAddressMethod = lookup.findMethod(inetSockAddressClass, "getAddress", "()Ljava/net/InetAddress;");
+        if(getAddressMethod == NULL)
+        {
+            LOGE("Unable to find InetSocketAddress.getAddress() method");
+            return false;
+        }
+
+        jmethodID getPortMethod = lookup.findMethod(inetSockAddressClass, "getPort", "()I");
+        if(getPortMethod == NULL)
+        {
+            LOGE("Unable to find InetSocketAddress.getPort() method");
+            return false;
+        }
+
+        jint port = env->CallIntMethod(sockAddressObject, getPortMethod);
+        jobject addressObject = env->CallObjectMethod(sockAddressObject, getAddressMethod);
+        
+        jclass inetAddressClass = lookup.findClass("java/net/InetAddress");
+        if(inetAddressClass == NULL)
+        {
+            LOGE("Unable to find InetAddress class");
+            return false;
+        }
+
+        getAddressMethod = lookup.findMethod(inetAddressClass, "getAddress", "()[B");
+        if(getAddressMethod == NULL)
+        {
+            LOGE("Unable to find InetAddress.getAddress() method");
+            return false;
+        }
+
+        jbyteArray addressBytes = (jbyteArray)env->CallObjectMethod(addressObject, getAddressMethod);
+        if(addressBytes == NULL)
+        {
+            LOGE("Unable to call InetAddress.getBytes()");
+            return false;
+        }
+
+        int addressSize = env->GetArrayLength(addressBytes);
+        if(addressSize == 4)
+        {
+            // IPV4
+            sockaddr_in *addr = (sockaddr_in*)result;
+            addr->sin_family = AF_INET;
+            addr->sin_port = htons(port);
+            
+            void *data = env->GetPrimitiveArrayCritical(addressBytes, NULL);
+            memcpy(&addr->sin_addr, data, 4);
+            env->ReleasePrimitiveArrayCritical(addressBytes, data, 0);
+        }
+        else if (addressSize == 16)
+        {
+            // IPV6
+            sockaddr_in6 *addr = (sockaddr_in6*)result;
+            addr->sin6_family = AF_INET6;
+            addr->sin6_port = htons(port);
+            void *data = env->GetPrimitiveArrayCritical(addressBytes, NULL);
+            memcpy(&addr->sin6_addr, data, 16);
+            env->ReleasePrimitiveArrayCritical(addressBytes, data, 0);
+        }
+        else
+        {
+            return false;
+        }
+
+        return true;
+    }
+
     typedef std::map<uint64_t, JniRef*> NodeMap;
     static NodeMap nodeMap;
     ZeroTier::Mutex nodeMapMutex;
@@ -619,12 +781,28 @@ JNIEXPORT jobject JNICALL Java_com_zerotier_sdk_Node_node_1init(
     }
     ref->eventListener = env->NewGlobalRef(tmp);
 
+    fid = lookup.findField(
+        cls, "pathChecker", "Lcom/zerotier/sdk/PathChecker;");
+    if(fid == NULL)
+    {
+        return NULL;
+    }
+
+    tmp = env->GetObjectField(obj, fid);
+    if(tmp == NULL)
+    {
+        return NULL;
+    }
+    ref->pathChecker = env->NewGlobalRef(tmp);
+
     ref->callbacks->dataStoreGetFunction = &DataStoreGetFunction;
     ref->callbacks->dataStorePutFunction = &DataStorePutFunction;
     ref->callbacks->wirePacketSendFunction = &WirePacketSendFunction;
     ref->callbacks->virtualNetworkFrameFunction = &VirtualNetworkFrameFunctionCallback;
     ref->callbacks->virtualNetworkConfigFunction = &VirtualNetworkConfigFunctionCallback;
     ref->callbacks->eventCallback = &EventCallback;
+    ref->callbacks->pathCheckFunction = &PathCheckFunction;
+    ref->callbacks->pathLookupFunction = &PathLookupFunction;
 
     ZT_ResultCode rc = ZT_Node_new(
         &node,
diff --git a/java/src/com/zerotier/sdk/Node.java b/java/src/com/zerotier/sdk/Node.java
index 4bc6e1846..7b111f740 100644
--- a/java/src/com/zerotier/sdk/Node.java
+++ b/java/src/com/zerotier/sdk/Node.java
@@ -74,6 +74,7 @@ public class Node {
     private final EventListener eventListener;
     private final VirtualNetworkFrameListener frameListener;
     private final VirtualNetworkConfigListener configListener;
+    private final PathChecker pathChecker;
     
     /**
      * Create a new ZeroTier One node
@@ -95,7 +96,8 @@ public class Node {
                 PacketSender sender,
                 EventListener eventListener,
                 VirtualNetworkFrameListener frameListener,
-                VirtualNetworkConfigListener configListener) throws NodeException
+                VirtualNetworkConfigListener configListener,
+                PathChecker pathChecker) throws NodeException
 	{
         this.nodeId = now;
 
@@ -105,6 +107,7 @@ public class Node {
         this.eventListener = eventListener;
         this.frameListener = frameListener;
         this.configListener = configListener;
+        this.pathChecker = pathChecker;
 
         ResultCode rc = node_init(now);
         if(rc != ResultCode.RESULT_OK)
diff --git a/java/src/com/zerotier/sdk/PathChecker.java b/java/src/com/zerotier/sdk/PathChecker.java
new file mode 100644
index 000000000..3e02f1124
--- /dev/null
+++ b/java/src/com/zerotier/sdk/PathChecker.java
@@ -0,0 +1,45 @@
+/*
+ * ZeroTier One - Network Virtualization Everywhere
+ * Copyright (C) 2011-2017  ZeroTier, Inc.  https://www.zerotier.com/
+ */
+
+package com.zerotier.sdk;
+
+import java.net.InetSocketAddress;
+
+public interface PathChecker {
+    /**
+     * Callback to check whether a path should be used for ZeroTier traffic
+     *
+     * This function must return true if the path should be used.
+     *
+     * If no path check function is specified, ZeroTier will still exclude paths
+     * that overlap with ZeroTier-assigned and managed IP address blocks. But the
+     * use of a path check function is recommended to ensure that recursion does
+     * not occur in cases where addresses are assigned by the OS or managed by
+     * an out of band mechanism like DHCP. The path check function should examine
+     * all configured ZeroTier interfaces and check to ensure that the supplied
+     * addresses will not result in ZeroTier traffic being sent over a ZeroTier
+     * interface (recursion).
+     *
+     * Obviously this is not required in configurations where this can't happen,
+     * such as network containers or embedded.
+     *
+     * @param ztAddress ZeroTier address or 0 for none/any
+     * @param localAddress Local interface address
+     * @param remoteAddress remote address
+     */
+    boolean onPathCheck(long ztAddress, InetSocketAddress localAddress, InetSocketAddress remoteAddress);
+
+    /**
+     * Function to get physical addresses for ZeroTier peers
+     *
+     * If provided this function will be occasionally called to get physical
+     * addresses that might be tried to reach a ZeroTier address.
+     *
+     * @param ztAddress ZeroTier address (least significant 40 bits)
+     * @param ss_family desired address family or -1 for any
+     * @return address and port of ztAddress or null
+     */
+    InetSocketAddress onPathLookup(long ztAddress, int ss_family);
+}