fix and enable auth handler

This commit is contained in:
Grant Limberg 2023-04-27 09:01:18 -07:00
parent 9316c21631
commit a16d9abb79
No known key found for this signature in database
GPG key ID: 8F2F97D3BE8D7735

View file

@ -11,6 +11,7 @@
*/ */
/****/ /****/
#include <exception>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
@ -1447,21 +1448,21 @@ public:
std::vector<std::string> noAuthEndpoints { "/sso", "/health" }; std::vector<std::string> noAuthEndpoints { "/sso", "/health" };
auto authCheck = [=] (const httplib::Request &req, httplib::Response &res) { auto authCheck = [=] (const httplib::Request &req, httplib::Response &res) {
std::string r = req.remote_addr + "/32"; char buf[64];
std::string r = req.remote_addr;
InetAddress remoteAddr(r.c_str()); InetAddress remoteAddr(r.c_str());
bool ipAllowed = false; bool ipAllowed = false;
bool isAuth = false; bool isAuth = false;
// If localhost, allow // If localhost, allow
if (remoteAddr.ipScope() != InetAddress::IP_SCOPE_LOOPBACK) { if (remoteAddr.ipScope() == InetAddress::IP_SCOPE_LOOPBACK) {
fprintf(stderr, "loopback address\n");
ipAllowed = true; ipAllowed = true;
} }
if (!ipAllowed) { if (!ipAllowed) {
for (auto i = _allowManagementFrom.begin(); i != _allowManagementFrom.end(); ++i) { for (auto i = _allowManagementFrom.begin(); i != _allowManagementFrom.end(); ++i) {
if (i->containsAddress(remoteAddr)) { if (i->containsAddress(remoteAddr)) {
fprintf(stderr, "ip in allowed range\n");
ipAllowed = true; ipAllowed = true;
break; break;
} }
@ -1470,7 +1471,6 @@ public:
if (ipAllowed) { if (ipAllowed) {
fprintf(stderr, "ip allowed\n");
// auto-pass endpoints in `noAuthEndpoints`. No auth token required // auto-pass endpoints in `noAuthEndpoints`. No auth token required
if (std::find(noAuthEndpoints.begin(), noAuthEndpoints.end(), req.path) != noAuthEndpoints.end()) { if (std::find(noAuthEndpoints.begin(), noAuthEndpoints.end(), req.path) != noAuthEndpoints.end()) {
isAuth = true; isAuth = true;
@ -1481,17 +1481,13 @@ public:
if (req.has_header("x-zt1-auth")) { if (req.has_header("x-zt1-auth")) {
std::string token = req.get_header_value("x-zt1-auth"); std::string token = req.get_header_value("x-zt1-auth");
if (token == _authToken) { if (token == _authToken) {
fprintf(stderr, "auth via header\n");
isAuth = true; isAuth = true;
} }
} else if (req.has_param("auth")) { } else if (req.has_param("auth")) {
std::string token = req.get_param_value("auth"); std::string token = req.get_param_value("auth");
if (token == _authToken) { if (token == _authToken) {
fprintf(stderr, "auth via param\n");
isAuth = true; isAuth = true;
} }
} else {
fprintf(stderr, "no auth header or parameter\n");
} }
} }
} }
@ -1504,7 +1500,7 @@ public:
return httplib::Server::HandlerResponse::Handled; return httplib::Server::HandlerResponse::Handled;
}; };
//_controlPlane.set_pre_routing_handler(authCheck);
_controlPlane.Get("/bond/show/([0-9a-fA-F]{10})", [this](const httplib::Request &req, httplib::Response &res) { _controlPlane.Get("/bond/show/([0-9a-fA-F]{10})", [this](const httplib::Request &req, httplib::Response &res) {
if (!_node->bondController()->inUse()) { if (!_node->bondController()->inUse()) {
@ -1959,10 +1955,26 @@ public:
} }
}); });
_controlPlane.set_exception_handler([](const httplib::Request &req, httplib::Response &res, std::exception_ptr ep) {
char buf[1024];
auto fmt = "{\"error\": %d, \"description\": \"%\"}";
try {
std::rethrow_exception(ep);
} catch (std::exception &e) {
snprintf(buf, sizeof(buf), fmt, 500, e.what());
} catch (...) {
snprintf(buf, sizeof(buf), fmt, 500, "Unknown Exception");
}
res.set_content(buf, "application/json");
res.status = 500;
});
if (_controller) { if (_controller) {
// TODO: Wire up controller // TODO: Wire up controller
} }
_controlPlane.set_pre_routing_handler(authCheck);
_controlPlane.set_logger([](const httplib::Request &req, const httplib::Response &res) { _controlPlane.set_logger([](const httplib::Request &req, const httplib::Response &res) {
fprintf(stderr, "%s", http_log(req, res).c_str()); fprintf(stderr, "%s", http_log(req, res).c_str());
}); });