From 2851a9577c88af465e269988a4f35cb23111bb0f Mon Sep 17 00:00:00 2001
From: Grant Limberg <grant.limberg@zerotier.com>
Date: Wed, 21 Oct 2020 14:18:04 -0700
Subject: [PATCH] JNI for dns configuration

---
 java/jni/Android.mk                           |   1 +
 java/jni/ZT_jniarray.cpp                      | 111 ++++++++++++++++++
 java/jni/ZT_jniarray.h                        |  60 ++++++++++
 java/jni/ZT_jniutils.cpp                      |  79 +++++++++++++
 java/jni/ZT_jniutils.h                        |   2 +
 .../zerotier/sdk/VirtualNetworkConfig.java    |   4 +
 .../com/zerotier/sdk/VirtualNetworkDNS.java   |  25 ++++
 .../com/zerotier/sdk/VirtualNetworkRoute.java |   3 +-
 8 files changed, 283 insertions(+), 2 deletions(-)
 create mode 100644 java/jni/ZT_jniarray.cpp
 create mode 100644 java/jni/ZT_jniarray.h
 create mode 100644 java/src/com/zerotier/sdk/VirtualNetworkDNS.java

diff --git a/java/jni/Android.mk b/java/jni/Android.mk
index 7aa9f41b7..4b430548a 100644
--- a/java/jni/Android.mk
+++ b/java/jni/Android.mk
@@ -58,6 +58,7 @@ LOCAL_SRC_FILES := \
 # JNI Files
 LOCAL_SRC_FILES += \
 	com_zerotierone_sdk_Node.cpp \
+	ZT_jniarray.cpp \
 	ZT_jniutils.cpp \
 	ZT_jnilookup.cpp
 
diff --git a/java/jni/ZT_jniarray.cpp b/java/jni/ZT_jniarray.cpp
new file mode 100644
index 000000000..24ae97c71
--- /dev/null
+++ b/java/jni/ZT_jniarray.cpp
@@ -0,0 +1,111 @@
+//
+// Created by Grant Limberg on 10/21/20.
+//
+
+#include "ZT_jniarray.h"
+#include <vector>
+#include <string>
+
+jclass java_util_ArrayList;
+jmethodID java_util_ArrayList_;
+jmethodID java_util_ArrayList_size;
+jmethodID java_util_ArrayList_get;
+jmethodID java_util_ArrayList_add;
+
+void InitListJNI(JNIEnv* env) {
+    java_util_ArrayList      = static_cast<jclass>(env->NewGlobalRef(env->FindClass("java/util/ArrayList")));
+    java_util_ArrayList_     = env->GetMethodID(java_util_ArrayList, "<init>", "(I)V");
+    java_util_ArrayList_size = env->GetMethodID (java_util_ArrayList, "size", "()I");
+    java_util_ArrayList_get  = env->GetMethodID(java_util_ArrayList, "get", "(I)Ljava/lang/Object;");
+    java_util_ArrayList_add  = env->GetMethodID(java_util_ArrayList, "add", "(Ljava/lang/Object;)Z");
+}
+
+jclass ListJNI::getListClass(JNIEnv* env) {
+    jclass jclazz = env->FindClass("java/util/List");
+    assert(jclazz != nullptr);
+    return jclazz;
+}
+
+jclass ListJNI::getArrayListClass(JNIEnv* env) {
+    jclass jclazz = env->FindClass("java/util/ArrayList");
+    assert(jclazz != nullptr);
+    return jclazz;
+}
+
+jclass ListJNI::getIteratorClass(JNIEnv* env) {
+    jclass jclazz = env->FindClass("java/util/Iterator");
+    assert(jclazz != nullptr);
+    return jclazz;
+}
+
+jmethodID ListJNI::getIteratorMethod(JNIEnv* env) {
+    static jmethodID mid = env->GetMethodID(
+            getListClass(env), "iterator", "()Ljava/util/Iterator;");
+    assert(mid != nullptr);
+    return mid;
+}
+
+jmethodID ListJNI::getHasNextMethod(JNIEnv* env) {
+    static jmethodID mid = env->GetMethodID(
+            getIteratorClass(env), "hasNext", "()Z");
+    assert(mid != nullptr);
+    return mid;
+}
+
+jmethodID ListJNI::getNextMethod(JNIEnv* env) {
+    static jmethodID mid = env->GetMethodID(
+            getIteratorClass(env), "next", "()Ljava/lang/Object;");
+    assert(mid != nullptr);
+    return mid;
+}
+
+jmethodID ListJNI::getArrayListConstructorMethodId(JNIEnv* env, jclass jclazz) {
+    static jmethodID mid = env->GetMethodID(
+            jclazz, "<init>", "(I)V");
+    assert(mid != nullptr);
+    return mid;
+}
+
+jmethodID ListJNI::getListAddMethodId(JNIEnv* env) {
+    static jmethodID mid = env->GetMethodID(
+            getListClass(env), "add", "(Ljava/lang/Object;)Z");
+    assert(mid != nullptr);
+    return mid;
+}
+
+jclass ByteJNI::getByteClass(JNIEnv* env) {
+    jclass jclazz = env->FindClass("java/lang/Byte");
+    assert(jclazz != nullptr);
+    return jclazz;
+}
+
+jmethodID ByteJNI::getByteValueMethod(JNIEnv* env) {
+    static jmethodID mid = env->GetMethodID(
+            getByteClass(env), "byteValue", "()B");
+    assert(mid != nullptr);
+    return mid;
+}
+
+jobject cppToJava(JNIEnv* env, std::vector<std::string> vector) {
+    jobject result = env->NewObject(java_util_ArrayList, java_util_ArrayList_, vector.size());
+    for (std::string s: vector) {
+        jstring element = env->NewStringUTF(s.c_str());
+        env->CallBooleanMethod(result, java_util_ArrayList_add, element);
+        env->DeleteLocalRef(element);
+    }
+    return result;
+}
+
+std::vector<std::string> javaToCpp(JNIEnv* env, jobject arrayList) {
+    jint len = env->CallIntMethod(arrayList, java_util_ArrayList_size);
+    std::vector<std::string> result;
+    result.reserve(len);
+    for (jint i=0; i<len; i++) {
+        jstring element = static_cast<jstring>(env->CallObjectMethod(arrayList, java_util_ArrayList_get, i));
+        const char* pchars = env->GetStringUTFChars(element, nullptr);
+        result.emplace_back(pchars);
+        env->ReleaseStringUTFChars(element, pchars);
+        env->DeleteLocalRef(element);
+    }
+    return result;
+}
diff --git a/java/jni/ZT_jniarray.h b/java/jni/ZT_jniarray.h
new file mode 100644
index 000000000..d93c87b9c
--- /dev/null
+++ b/java/jni/ZT_jniarray.h
@@ -0,0 +1,60 @@
+//
+// Created by Grant Limberg on 10/21/20.
+//
+
+#ifndef ZEROTIERANDROID_ZT_JNIARRAY_H
+#define ZEROTIERANDROID_ZT_JNIARRAY_H
+
+#include <jni.h>
+#include <vector>
+#include <string>
+
+extern jclass java_util_ArrayList;
+extern jmethodID java_util_ArrayList_;
+extern jmethodID java_util_ArrayList_size;
+extern jmethodID java_util_ArrayList_get;
+extern jmethodID java_util_ArrayList_add;
+
+void InitListJNI(JNIEnv* env);
+
+class ListJNI {
+public:
+    // Get the java class id of java.util.List.
+    static jclass getListClass(JNIEnv* env);
+
+    // Get the java class id of java.util.ArrayList.
+    static jclass getArrayListClass(JNIEnv* env);
+
+    // Get the java class id of java.util.Iterator.
+    static jclass getIteratorClass(JNIEnv* env);
+
+    // Get the java method id of java.util.List.iterator().
+    static jmethodID getIteratorMethod(JNIEnv* env);
+
+    // Get the java method id of java.util.Iterator.hasNext().
+    static jmethodID getHasNextMethod(JNIEnv* env);
+
+    // Get the java method id of java.util.Iterator.next().
+    static jmethodID getNextMethod(JNIEnv* env);
+
+    // Get the java method id of arrayList constructor.
+    static jmethodID getArrayListConstructorMethodId(JNIEnv* env, jclass jclazz);
+
+    // Get the java method id of java.util.List.add().
+    static jmethodID getListAddMethodId(JNIEnv* env);
+};
+
+class ByteJNI {
+public:
+    // Get the java class id of java.lang.Byte.
+    static jclass getByteClass(JNIEnv* env);
+
+    // Get the java method id of java.lang.Byte.byteValue.
+    static jmethodID getByteValueMethod(JNIEnv* env);
+};
+
+jobject cppToJava(JNIEnv* env, std::vector<std::string> vector);
+
+std::vector<std::string> javaToCpp(JNIEnv* env, jobject arrayList);
+
+#endif //ZEROTIERANDROID_ZT_JNIARRAY_H
diff --git a/java/jni/ZT_jniutils.cpp b/java/jni/ZT_jniutils.cpp
index 3f7047795..c050c07d9 100644
--- a/java/jni/ZT_jniutils.cpp
+++ b/java/jni/ZT_jniutils.cpp
@@ -18,9 +18,16 @@
 
 #include "ZT_jniutils.h"
 #include "ZT_jnilookup.h"
+#include "ZT_jniarray.h"
+
 #include <string>
 #include <assert.h>
 
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+
 extern JniLookup lookup;
 
 #ifdef __cplusplus
@@ -623,6 +630,7 @@ jobject newNetworkConfig(JNIEnv *env, const ZT_VirtualNetworkConfig &vnetConfig)
     jfieldID netconfRevisionField = NULL;
     jfieldID assignedAddressesField = NULL;
     jfieldID routesField = NULL;
+    jfieldID dnsField = NULL;
 
     vnetConfigClass = lookup.findClass("com/zerotier/sdk/VirtualNetworkConfig");
     if(vnetConfigClass == NULL)
@@ -739,6 +747,13 @@ jobject newNetworkConfig(JNIEnv *env, const ZT_VirtualNetworkConfig &vnetConfig)
         return NULL;
     }
 
+    dnsField = lookup.findField(vnetConfigClass, "dns", "Lcom/zerotier/sdk/VirtualNetworkDNS");
+    if(env->ExceptionCheck() || dnsField == NULL)
+    {
+        LOGE("Error getting DNS field");
+        return NULL;
+    }
+
     env->SetLongField(vnetConfigObj, nwidField, vnetConfig.nwid);
     env->SetLongField(vnetConfigObj, macField, vnetConfig.mac);
     jstring nameStr = env->NewStringUTF(vnetConfig.name);
@@ -824,6 +839,10 @@ jobject newNetworkConfig(JNIEnv *env, const ZT_VirtualNetworkConfig &vnetConfig)
 
     env->SetObjectField(vnetConfigObj, routesField, routesArrayObj);
 
+    jobject dnsObj = newVirtualNetworkDNS(env, vnetConfig.dns);
+    if (dnsObj != NULL) {
+        env->SetObjectField(vnetConfigObj, dnsField, dnsObj);
+    }
     return vnetConfigObj;
 }
 
@@ -947,6 +966,66 @@ jobject newVirtualNetworkRoute(JNIEnv *env, const ZT_VirtualNetworkRoute &route)
     return routeObj;
 }
 
+jobject newVirtualNetworkDNS(JNIEnv *env, const ZT_VirtualNetworkDNS &dns)
+{
+    jclass virtualNetworkDNSClass = NULL;
+    jmethodID dnsConstructor = NULL;
+
+    virtualNetworkDNSClass = lookup.findClass("com/zerotier/sdk/VirtualNetworkDNS");
+    if (env->ExceptionCheck() || virtualNetworkDNSClass == NULL) {
+        return NULL;
+    }
+
+    dnsConstructor = lookup.findMethod(virtualNetworkDNSClass, "<init>", "()V");
+    if(env->ExceptionCheck() || dnsConstructor == NULL) {
+        return NULL;
+    }
+
+    jobject dnsObj = env->NewObject(virtualNetworkDNSClass, dnsConstructor);
+    if(env->ExceptionCheck() || dnsObj == NULL) {
+        return NULL;
+    }
+
+    jfieldID domainField = NULL;
+    jfieldID serversField = NULL;
+
+    domainField = lookup.findField(virtualNetworkDNSClass, "domain", "Ljava/lang/String;");
+    if(env->ExceptionCheck() || domainField == NULL)
+    {
+        return NULL;
+    }
+
+    serversField = lookup.findField(virtualNetworkDNSClass, "servers", "[Ljava/net/InetSocketAddress;");
+    if(env->ExceptionCheck() || serversField == NULL) {
+        return NULL;
+    }
+
+    if (strlen(dns.domain) > 0) {
+        InitListJNI(env);
+        jstring domain = env->NewStringUTF(dns.domain);
+
+        jobject addrArray = env->NewObject(java_util_ArrayList, java_util_ArrayList_, 0);
+
+        struct sockaddr_storage nullAddr;
+        memset(&nullAddr, 0, sizeof(struct sockaddr_storage));
+        for(int i = 0; i < ZT_MAX_DNS_SERVERS; ++i) {
+            struct sockaddr_storage tmp = dns.server_addr[i];
+
+            if (memcmp(&tmp, &nullAddr, sizeof(struct sockaddr_storage)) != 0) {
+                jobject addr = newInetSocketAddress(env, tmp);
+                env->CallBooleanMethod(addrArray, java_util_ArrayList_add, addr);
+                env->DeleteLocalRef(addr);
+            }
+        }
+
+        env->SetObjectField(dnsObj, domainField, domain);
+        env->SetObjectField(dnsObj, serversField, addrArray);
+
+        return dnsObj;
+    }
+    return NULL;
+}
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/java/jni/ZT_jniutils.h b/java/jni/ZT_jniutils.h
index 56b63179e..3e81b934d 100644
--- a/java/jni/ZT_jniutils.h
+++ b/java/jni/ZT_jniutils.h
@@ -76,6 +76,8 @@ jobject newVersion(JNIEnv *env, int major, int minor, int rev);
 
 jobject newVirtualNetworkRoute(JNIEnv *env, const ZT_VirtualNetworkRoute &route);
 
+jobject newVirtualNetworkDNS(JNIEnv *env, const ZT_VirtualNetworkDNS &dns);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/java/src/com/zerotier/sdk/VirtualNetworkConfig.java b/java/src/com/zerotier/sdk/VirtualNetworkConfig.java
index bb4e07110..b5e9041d7 100644
--- a/java/src/com/zerotier/sdk/VirtualNetworkConfig.java
+++ b/java/src/com/zerotier/sdk/VirtualNetworkConfig.java
@@ -56,6 +56,7 @@ public final class VirtualNetworkConfig implements Comparable<VirtualNetworkConf
     private long netconfRevision;
     private InetSocketAddress[] assignedAddresses;
     private VirtualNetworkRoute[] routes;
+    private VirtualNetworkDNS dns;
 
     private VirtualNetworkConfig() {
 
@@ -161,6 +162,7 @@ public final class VirtualNetworkConfig implements Comparable<VirtualNetworkConf
                this.broadcastEnabled == cfg.broadcastEnabled &&
                this.portError == cfg.portError &&
                this.enabled == cfg.enabled &&
+               this.dns.equals(cfg.dns) &&
                aaEqual && routesEqual;
     }
 
@@ -278,4 +280,6 @@ public final class VirtualNetworkConfig implements Comparable<VirtualNetworkConf
      * @return
      */
     public final VirtualNetworkRoute[] routes() { return routes; }
+
+    public final VirtualNetworkDNS dns() { return dns; }
 }
diff --git a/java/src/com/zerotier/sdk/VirtualNetworkDNS.java b/java/src/com/zerotier/sdk/VirtualNetworkDNS.java
new file mode 100644
index 000000000..f700218b9
--- /dev/null
+++ b/java/src/com/zerotier/sdk/VirtualNetworkDNS.java
@@ -0,0 +1,25 @@
+/*
+ * ZeroTier One - Network Virtualization Everywhere
+ * Copyright (C) 2011-2020  ZeroTier, Inc.  https://www.zerotier.com/
+ */
+
+package com.zerotier.sdk;
+
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+
+public class VirtualNetworkDNS implements Comparable<VirtualNetworkDNS> {
+    private String domain;
+    private ArrayList<InetSocketAddress> servers;
+
+    public VirtualNetworkDNS() {}
+
+    public boolean equals(VirtualNetworkDNS o) {
+        return domain.equals(o.domain) && servers.equals(o.servers);
+    }
+
+    @Override
+    public int compareTo(VirtualNetworkDNS o) {
+        return domain.compareTo(o.domain);
+    }
+}
diff --git a/java/src/com/zerotier/sdk/VirtualNetworkRoute.java b/java/src/com/zerotier/sdk/VirtualNetworkRoute.java
index 51bdfef33..8dd700c09 100644
--- a/java/src/com/zerotier/sdk/VirtualNetworkRoute.java
+++ b/java/src/com/zerotier/sdk/VirtualNetworkRoute.java
@@ -103,7 +103,6 @@ public final class VirtualNetworkRoute implements Comparable<VirtualNetworkRoute
             viaEquals = via.toString().equals(other.via.toString());
         }
 
-        return viaEquals &&
-                viaEquals;
+        return viaEquals && targetEquals;
     }
 }