diff --git a/controller/DB.cpp b/controller/DB.cpp index 530582149..d4bd34eb4 100644 --- a/controller/DB.cpp +++ b/controller/DB.cpp @@ -49,6 +49,9 @@ void DB::initNetwork(nlohmann::json &network) }}; } if (!network.count("dns")) network["dns"] = nlohmann::json::array(); + if (!network.count("ssoEnabled")) network["ssoEnabled"] = false; + if (!network.count("clientId")) network["clientId"] = ""; + if (!network.count("authorizationEndpoint")) network["authorizationEndpoint"] = ""; network["objtype"] = "network"; } @@ -136,7 +139,6 @@ bool DB::get(const uint64_t networkId,nlohmann::json &network,const uint64_t mem if (m == nw->members.end()) return false; member = m->second; - updateMemberOnLoad(networkId, memberId, member); } return true; } @@ -160,7 +162,6 @@ bool DB::get(const uint64_t networkId,nlohmann::json &network,const uint64_t mem if (m == nw->members.end()) return false; member = m->second; - updateMemberOnLoad(networkId, memberId, member); } return true; } @@ -181,7 +182,6 @@ bool DB::get(const uint64_t networkId,nlohmann::json &network,std::vectorconfig; for(auto m=nw->members.begin();m!=nw->members.end();++m) { members.push_back(m->second); - updateMemberOnLoad(networkId, m->first, members.back()); } } return true; diff --git a/controller/DB.hpp b/controller/DB.hpp index a2edb9243..fb8ac6133 100644 --- a/controller/DB.hpp +++ b/controller/DB.hpp @@ -104,7 +104,7 @@ public: virtual void eraseNetwork(const uint64_t networkId) = 0; virtual void eraseMember(const uint64_t networkId,const uint64_t memberId) = 0; virtual void nodeIsOnline(const uint64_t networkId,const uint64_t memberId,const InetAddress &physicalAddress) = 0; - virtual void updateMemberOnLoad(const uint64_t networkId, const uint64_t memberId, nlohmann::json &member) {} + virtual std::string getSSOAuthURL(const nlohmann::json &member) { return ""; } inline void addListener(DB::ChangeListener *const listener) { diff --git a/controller/DBMirrorSet.cpp b/controller/DBMirrorSet.cpp index f19741bb3..341c2d8ca 100644 --- a/controller/DBMirrorSet.cpp +++ b/controller/DBMirrorSet.cpp @@ -125,6 +125,18 @@ bool DBMirrorSet::get(const uint64_t networkId,nlohmann::json &network,std::vect return false; } +std::string DBMirrorSet::getSSOAuthURL(const nlohmann::json &member) +{ + std::lock_guard l(_dbs_l); + for(auto d=_dbs.begin();d!=_dbs.end();++d) { + std::string url = (*d)->getSSOAuthURL(member); + if (!url.empty()) { + return url; + } + } + return ""; +} + void DBMirrorSet::networks(std::set &networks) { std::lock_guard l(_dbs_l); diff --git a/controller/DBMirrorSet.hpp b/controller/DBMirrorSet.hpp index 967cd9360..bf1106e8f 100644 --- a/controller/DBMirrorSet.hpp +++ b/controller/DBMirrorSet.hpp @@ -51,6 +51,8 @@ public: virtual void onNetworkMemberUpdate(const void *db,uint64_t networkId,uint64_t memberId,const nlohmann::json &member); virtual void onNetworkMemberDeauthorize(const void *db,uint64_t networkId,uint64_t memberId); + std::string getSSOAuthURL(const nlohmann::json &member); + inline void addDB(const std::shared_ptr &db) { db->addListener(this); diff --git a/controller/EmbeddedNetworkController.cpp b/controller/EmbeddedNetworkController.cpp index 5ab77f703..06bcef2a9 100644 --- a/controller/EmbeddedNetworkController.cpp +++ b/controller/EmbeddedNetworkController.cpp @@ -1325,17 +1325,12 @@ void EmbeddedNetworkController::_request( member["lastAuthorizedCredential"] = autoAuthCredential; } - - int64_t authenticationExpiryTime = -1; - if (!member["authenticationExpiryTime"].is_null()) { - authenticationExpiryTime = member["authenticationExpiryTime"]; - } - - std::string authenticationURL = ""; - if (!member["authenticationURL"].is_null()) { - authenticationURL = member["authenticationURL"]; - } + // Should we check SSO Stuff? + // If network is configured with SSO, and the member is not marked exempt: yes + // Otherwise no, we use standard auth logic. + bool networkSSOEnabled = OSUtils::jsonBool(network["ssoEnabled"], false); + bool memberSSOExempt = OSUtils::jsonBool(member["ssoExempt"], false); if (authorized) { // Update version info and meta-data if authorized and if this is a genuine request if (requestPacketId) { @@ -1361,14 +1356,20 @@ void EmbeddedNetworkController::_request( ms.identity = identity; } } - - if ((authenticationExpiryTime >= 0)&&(authenticationExpiryTime < now)) { - Dictionary<1024> authInfo; - if (!authenticationURL.empty()) - authInfo.add("aU", authenticationURL.c_str()); - _sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, authInfo.data(), authInfo.sizeBytes()); - return; + + if (networkSSOEnabled && !memberSSOExempt) { + int64_t authenticationExpiryTime = (int64_t)OSUtils::jsonInt(member["authenticationExpiryTime"], 0); + if ((authenticationExpiryTime == 0) || (authenticationExpiryTime < now)) { + Dictionary<1024> authInfo; + std::string authenticationURL = _db.getSSOAuthURL(member); + if (!authenticationURL.empty()) { + authInfo.add("aU", authenticationURL.c_str()); + } + _sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, authInfo.data(), authInfo.sizeBytes()); + return; + } } + } else { // If they are not authorized, STOP! DB::cleanMember(member); @@ -1406,8 +1407,11 @@ void EmbeddedNetworkController::_request( Utils::scopy(nc->name,sizeof(nc->name),OSUtils::jsonString(network["name"],"").c_str()); nc->mtu = std::max(std::min((unsigned int)OSUtils::jsonInt(network["mtu"],ZT_DEFAULT_MTU),(unsigned int)ZT_MAX_MTU),(unsigned int)ZT_MIN_MTU); nc->multicastLimit = (unsigned int)OSUtils::jsonInt(network["multicastLimit"],32ULL); - Utils::scopy(nc->authenticationURL, sizeof(nc->authenticationURL), authenticationURL.c_str()); - nc->authenticationExpiryTime = authenticationExpiryTime; + + // TODO: Decide what to do with these, or if to remove them + // they don't make sense here as is. + // Utils::scopy(nc->authenticationURL, sizeof(nc->authenticationURL), authenticationURL.c_str()); + // nc->authenticationExpiryTime = authenticationExpiryTime; std::string rtt(OSUtils::jsonString(member["remoteTraceTarget"],"")); if (rtt.length() == 10) { diff --git a/controller/PostgreSQL.cpp b/controller/PostgreSQL.cpp index 2c3200bc5..ebcdd7303 100644 --- a/controller/PostgreSQL.cpp +++ b/controller/PostgreSQL.cpp @@ -30,7 +30,7 @@ using json = nlohmann::json; namespace { -static const int DB_MINIMUM_VERSION = 19; +static const int DB_MINIMUM_VERSION = 20; static const char *_timestr() { @@ -309,81 +309,94 @@ void PostgreSQL::nodeIsOnline(const uint64_t networkId, const uint64_t memberId, } } -void PostgreSQL::updateMemberOnLoad(const uint64_t networkId, const uint64_t memberId, nlohmann::json &member) +std::string PostgreSQL::getSSOAuthURL(const nlohmann::json &member) { - - const uint64_t nwid = OSUtils::jsonIntHex(member["nwid"],0ULL); - const uint64_t id = OSUtils::jsonIntHex(member["id"],0ULL); - char nwids[24],ids[24]; - OSUtils::ztsnprintf(nwids, sizeof(nwids), "%.16llx", nwid); - OSUtils::ztsnprintf(ids, sizeof(ids), "%.10llx", id); + // NONCE is just a random character string. no semantic meaning + // state = HMAC SHA384 of Nonce based on shared sso key + // + // need nonce timeout in database? make sure it's used within X time + // X is 5 minutes for now. Make configurable later? + // + // how do we tell when a nonce is used? if auth_expiration_time is set + std::string networkId = member["nwid"]; + std::string memberId = member["id"]; + char authenticationURL[4096] = {0}; - fprintf(stderr, "PostgreSQL::updateMemberOnLoad: %s-%s\n", nwids, ids); + fprintf(stderr, "PostgreSQL::updateMemberOnLoad: %s-%s\n", networkId.c_str(), memberId.c_str()); bool have_auth = false; try { auto c = _pool->borrow(); pqxx::work w(*c->c); - pqxx::result r = w.exec_params("SELECT org.client_id, org.authorization_endpoint " + std::string nonce = ""; + + // find an unused nonce, if one exists. + pqxx::result r = w.exec_params("SELECT nonce FROM ztc_sso_expiry " + "WHERE network_id = $1 AND member_id = $2 AND " + "AND authentication_expiry_time IS NULL AND ((NOW() AT TIME ZONE 'UTC') <= nonce_expiry", + networkId, memberId); + + if (r.size() == 1) { + // we have an existing nonce. Use it + nonce = r.at(0)[0].as(); + } else if (r.empty()) { + // create a nonce + char randBuf[16] = {0}; + Utils::getSecureRandom(randBuf, 16); + char nonceBuf[256] = {0}; + Utils::hex(randBuf, sizeof(randBuf), nonceBuf); + nonce = std::string(nonceBuf); + + pqxx::result ir = w.exec_params0("INSERT INTO ztc_sso_expiry " + "(nonce, nonce_expiry, network_id, member_id) VALUES " + "($1, TO_TIMESTAMP($2::double precision/1000) $3, $4)", + nonce, OSUtils::now() + 300000, networkId, memberId); + } else { + // > 1 ?!? Thats an error! + fprintf(stderr, "> 1 unused nonce!\n"); + exit(6); + } + + r = w.exec_params("SELECT org.client_id, org.authorization_endpoint " "FROM ztc_network AS nw, ztc_org AS org " - "WHERE nw.id = $1 AND nw.sso_enabled = true AND org.owner_id = nw.owner_id", nwids); + "WHERE nw.id = $1 AND nw.sso_enabled = true AND org.owner_id = nw.owner_id", networkId); std::string client_id = ""; std::string authorization_endpoint = ""; if (r.size() == 1) { - // only one should exist - pqxx::row row = r.at(0); - client_id = row[0].as(); - authorization_endpoint = row[1].as(); + client_id = r.at(0)[0].as(); + authorization_endpoint = r.at(0)[1].as(); } else if (r.size() > 1) { - fprintf(stderr, "ERROR: More than one auth endpoint for an organization?!?!? NetworkID: %s\n", nwids); + fprintf(stderr, "ERROR: More than one auth endpoint for an organization?!?!? NetworkID: %s\n", networkId.c_str()); } // no catch all else because we don't actually care if no records exist here. just continue as normal. if ((!client_id.empty())&&(!authorization_endpoint.empty())) { - pqxx::row r2 = w.exec_params1( - "SELECT e.nonce, EXTRACT(EPOCH FROM e.authentication_expiry_time AT TIME ZONE 'UTC')*1000 as authentication_expiry_time" - "FROM ztc_sso_expiry e " - "WHERE e.network_id = $1 AND e.member_id = $2 " - "ORDER BY n.authentication_expiry_time DESC LIMIT 1", nwids, ids); + have_auth = true; - std::string nonce = r2[0].as(); - int64_t authentication_expiry_time = r2[0].as(); - if ((authentication_expiry_time >= 0)&&(!nonce.empty())) { - have_auth = true; - - uint8_t state[48]; - HMACSHA384(_ssoPsk, nonce.data(), (unsigned int)nonce.length(), state); - char state_hex[256]; - Utils::hex(state, 48, state_hex); - char authenticationURL[4096]; - const char *redirect_url = "redirect_uri=http%3A%2F%2Fmy.zerotier.com%2Fapi%2Fnetwork%2Fsso-auth"; // TODO: this should be configurable - OSUtils::ztsnprintf(authenticationURL, sizeof(authenticationURL), - "%s?response_type=id_token&response_mode=form_post&scope=openid+email+profile&redriect_uri=%s&nonce=%s&state=%s&client_id=%s", - authorization_endpoint.c_str(), - redirect_url, - nonce.c_str(), - state_hex, // NOTE: should these be URL escaped? Don't think there's a risk as they are not user definable. - client_id.c_str()); - - member["authenticationExpiryTime"] = authentication_expiry_time; - member["authenticationURL"] = authenticationURL; - } - } else { - member["authenticationExpiryTime"] = -1LL; - member["authenticationURL"] = ""; - } + uint8_t state[48]; + HMACSHA384(_ssoPsk, nonce.data(), (unsigned int)nonce.length(), state); + char state_hex[256]; + Utils::hex(state, 48, state_hex); + + const char *redirect_url = "redirect_uri=https%3A%2F%2Fmy.zerotier.com%2Fapi%2Fnetwork%2Fsso-auth"; // TODO: this should be configurable + OSUtils::ztsnprintf(authenticationURL, sizeof(authenticationURL), + "%s?response_type=id_token&response_mode=form_post&scope=openid+email+profile&redriect_uri=%s&nonce=%s&state=%s&client_id=%s", + authorization_endpoint.c_str(), + redirect_url, + nonce.c_str(), + state_hex, + client_id.c_str()); + } _pool->unborrow(c); - - } catch (sw::redis::Error &e) { - fprintf(stderr, "ERROR: Error updating member on load, in Redis: %s\n", e.what()); - exit(-1); } catch (std::exception &e) { fprintf(stderr, "ERROR: Error updating member on load: %s\n", e.what()); exit(-1); } + + return std::string(authenticationURL); } void PostgreSQL::initializeNetworks() @@ -398,13 +411,15 @@ void PostgreSQL::initializeNetworks() pqxx::work w{*c->c}; pqxx::result r = w.exec_params("SELECT id, EXTRACT(EPOCH FROM creation_time AT TIME ZONE 'UTC')*1000 as creation_time, capabilities, " "enable_broadcast, EXTRACT(EPOCH FROM last_modified AT TIME ZONE 'UTC')*1000 AS last_modified, mtu, multicast_limit, name, private, remote_trace_level, " - "remote_trace_target, revision, rules, tags, v4_assign_mode, v6_assign_mode FROM ztc_network " + "remote_trace_target, revision, rules, tags, v4_assign_mode, v6_assign_mode, sso_enabled FROM ztc_network " "WHERE deleted = false AND controller_id = $1", _myAddressStr); for (auto row = r.begin(); row != r.end(); row++) { json empty; json config; + initNetwork(config); + std::string nwid = row[0].as(); networkSet.insert(nwid); @@ -458,6 +473,7 @@ void PostgreSQL::initializeNetworks() config["tags"] = json::parse(row[13].as()); config["v4AssignMode"] = json::parse(row[14].as()); config["v6AssignMode"] = json::parse(row[15].as()); + config["ssoEnabled"] = row[16].as(); config["objtype"] = "network"; config["ipAssignmentPools"] = json::array(); config["routes"] = json::array(); @@ -514,6 +530,19 @@ void PostgreSQL::initializeNetworks() config["dns"] = obj; } + r2 = w.exec_params("SELECT org.client_id, org.authorization_endpoint " + "FROM ztc_network nw " + "INNER JOIN ztc_org org " + " ON org.owner_id = nw.owner_id " + "WHERE nw.id = $1 AND nw.sso_enabled = true", nwid); + + if (r2.size() == 1) { + // only one should exist + pqxx::row row = r.at(0); + config["clientId"] = row[0].as(); + config["authorizationEndpoint"] = row[1].as(); + } + _networkChanged(empty, config, false); fprintf(stderr, "Initialized Network: %s\n", nwid.c_str()); } @@ -549,7 +578,7 @@ void PostgreSQL::initializeMembers() " (EXTRACT(EPOCH FROM m.last_authorized_time AT TIME ZONE 'UTC')*1000)::bigint, " " (EXTRACT(EPOCH FROM m.last_deauthorized_time AT TIME ZONE 'UTC')*1000)::bigint, " " m.remote_trace_level, m.remote_trace_target, m.tags, m.v_major, m.v_minor, m.v_rev, m.v_proto, " - " m.no_auto_assign_ips, m.revision " + " m.no_auto_assign_ips, m.revision, sso_exempt " "FROM ztc_member m " "INNER JOIN ztc_network n " " ON n.id = m.network_id " @@ -559,6 +588,8 @@ void PostgreSQL::initializeMembers() json empty; json config; + initMember(config); + std::string memberId = row[0].as(); std::string networkId = row[1].as(); @@ -627,13 +658,28 @@ void PostgreSQL::initializeMembers() config["revision"] = 0ULL; //fprintf(stderr, "Error updating revision (member): %s\n", PQgetvalue(res, i, 17)); } + config["ssoExempt"] = row[18].as(); + + config["authenticationExpiryTime"] = 0LL; + pqxx::result authRes = w.exec_params( + "SELECT (EXTRACT(EPOCH FROM e.authentication_expiry_time)*1000)::bigint " + "FROM ztc_sso_expiry e " + "INNER JOIN ztc_network n " + " ON n.id = e.network_id " + "WHERE e.network_id = $1 AND e.member_id = $2 AND n.sso_enabled = TRUE " + "ORDER BY e.authentication_expiry_time LIMIT 1", networkId, memberId); + + if (authRes.size() == 1) { + // there is an expiry time record + config["authenticationExpiryTime"] = authRes.at(0)[0].as(); + } + config["objtype"] = "member"; config["ipAssignments"] = json::array(); - pqxx::result r2 = w.exec("SELECT DISTINCT address " + pqxx::result r2 = w.exec_params("SELECT DISTINCT address " "FROM ztc_member_ip_assignment " - "WHERE member_id = "+w.quote(memberId)+" AND network_id = "+w.quote(networkId)); - + "WHERE member_id = $1 AND network_id = $2", memberId, networkId); for (auto row2 = r2.begin(); row2 != r2.end(); row2++) { std::string ipaddr = row2[0].as(); diff --git a/controller/PostgreSQL.hpp b/controller/PostgreSQL.hpp index 49fa9a6fa..5e4b32ddd 100644 --- a/controller/PostgreSQL.hpp +++ b/controller/PostgreSQL.hpp @@ -107,7 +107,7 @@ public: virtual void eraseNetwork(const uint64_t networkId); virtual void eraseMember(const uint64_t networkId, const uint64_t memberId); virtual void nodeIsOnline(const uint64_t networkId, const uint64_t memberId, const InetAddress &physicalAddress); - virtual void updateMemberOnLoad(const uint64_t networkId, const uint64_t memberId, nlohmann::json &member); + virtual std::string getSSOAuthURL(const nlohmann::json &member); protected: struct _PairHasher