diff --git a/controller/EmbeddedNetworkController.cpp b/controller/EmbeddedNetworkController.cpp index e04268102..a9ebe91fb 100644 --- a/controller/EmbeddedNetworkController.cpp +++ b/controller/EmbeddedNetworkController.cpp @@ -1338,17 +1338,16 @@ void EmbeddedNetworkController::_request( int64_t authenticationExpiryTime = (int64_t)OSUtils::jsonInt(member["authenticationExpiryTime"], 0); fprintf(stderr, "authExpiryTime: %lld\n", authenticationExpiryTime); if ((authenticationExpiryTime == 0) || (authenticationExpiryTime < now)) { - - Dictionary<1024> authInfo; std::string authenticationURL = _db.getSSOAuthURL(member); if (!authenticationURL.empty()) { + Dictionary<1024> authInfo; authInfo.add("aU", authenticationURL.c_str()); + fprintf(stderr, "sending auth URL: %s\n", authenticationURL.c_str()); + DB::cleanMember(member); + _db.save(member,true); + _sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, authInfo.data(), authInfo.sizeBytes()); + return; } - fprintf(stderr, "sending auth URL: %s\n", authenticationURL.c_str()); - DB::cleanMember(member); - _db.save(member,true); - _sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, authInfo.data(), authInfo.sizeBytes()); - return; } } diff --git a/controller/PostgreSQL.cpp b/controller/PostgreSQL.cpp index aac3c007c..3828aa3c0 100644 --- a/controller/PostgreSQL.cpp +++ b/controller/PostgreSQL.cpp @@ -330,70 +330,73 @@ std::string PostgreSQL::getSSOAuthURL(const nlohmann::json &member) std::string nonce = ""; - // find an unused nonce, if one exists. - pqxx::result r = w.exec_params("SELECT nonce FROM ztc_sso_expiry " - "WHERE network_id = $1 AND member_id = $2 " - "AND authentication_expiry_time IS NULL AND ((NOW() AT TIME ZONE 'UTC') <= nonce_expiration)", - networkId, memberId); + // check if the member exists first. + pqxx::row count = w.exec_params1("SELECT count(id) FROM ztc_member WHERE id = $1 AND network_id = $2", memberId, networkId); + if (count[0].as() == 1) { + // find an unused nonce, if one exists. + pqxx::result r = w.exec_params("SELECT nonce FROM ztc_sso_expiry " + "WHERE network_id = $1 AND member_id = $2 " + "AND authentication_expiry_time IS NULL AND ((NOW() AT TIME ZONE 'UTC') <= nonce_expiration)", + networkId, memberId); - if (r.size() == 1) { - // we have an existing nonce. Use it - nonce = r.at(0)[0].as(); - } else if (r.empty()) { - // create a nonce - char randBuf[16] = {0}; - Utils::getSecureRandom(randBuf, 16); - char nonceBuf[256] = {0}; - Utils::hex(randBuf, sizeof(randBuf), nonceBuf); - nonce = std::string(nonceBuf); + if (r.size() == 1) { + // we have an existing nonce. Use it + nonce = r.at(0)[0].as(); + } else if (r.empty()) { + // create a nonce + char randBuf[16] = {0}; + Utils::getSecureRandom(randBuf, 16); + char nonceBuf[256] = {0}; + Utils::hex(randBuf, sizeof(randBuf), nonceBuf); + nonce = std::string(nonceBuf); - pqxx::result ir = w.exec_params0("INSERT INTO ztc_sso_expiry " - "(nonce, nonce_expiration, network_id, member_id) VALUES " - "($1, TO_TIMESTAMP($2::double precision/1000), $3, $4)", - nonce, OSUtils::now() + 300000, networkId, memberId); - } else { - // > 1 ?!? Thats an error! - fprintf(stderr, "> 1 unused nonce!\n"); - exit(6); + pqxx::result ir = w.exec_params0("INSERT INTO ztc_sso_expiry " + "(nonce, nonce_expiration, network_id, member_id) VALUES " + "($1, TO_TIMESTAMP($2::double precision/1000), $3, $4)", + nonce, OSUtils::now() + 300000, networkId, memberId); + } else { + // > 1 ?!? Thats an error! + fprintf(stderr, "> 1 unused nonce!\n"); + exit(6); + } + + r = w.exec_params("SELECT org.client_id, org.authorization_endpoint " + "FROM ztc_network AS nw, ztc_org AS org " + "WHERE nw.id = $1 AND nw.sso_enabled = true AND org.owner_id = nw.owner_id", networkId); + + std::string client_id = ""; + std::string authorization_endpoint = ""; + + if (r.size() == 1) { + client_id = r.at(0)[0].as(); + authorization_endpoint = r.at(0)[1].as(); + } else if (r.size() > 1) { + fprintf(stderr, "ERROR: More than one auth endpoint for an organization?!?!? NetworkID: %s\n", networkId.c_str()); + } + // no catch all else because we don't actually care if no records exist here. just continue as normal. + + if ((!client_id.empty())&&(!authorization_endpoint.empty())) { + have_auth = true; + + uint8_t state[48]; + HMACSHA384(_ssoPsk, nonce.data(), (unsigned int)nonce.length(), state); + char state_hex[256]; + Utils::hex(state, 48, state_hex); + + const char *redirect_url = "redirect_uri=https%3A%2F%2Fmy.zerotier.com%2Fapi%2Fnetwork%2Fsso-auth"; // TODO: this should be configurable + OSUtils::ztsnprintf(authenticationURL, sizeof(authenticationURL), + "%s?response_type=id_token&response_mode=form_post&scope=openid+email+profile&redriect_uri=%s&nonce=%s&state=%s&client_id=%s", + authorization_endpoint.c_str(), + redirect_url, + nonce.c_str(), + state_hex, + client_id.c_str()); + } } - r = w.exec_params("SELECT org.client_id, org.authorization_endpoint " - "FROM ztc_network AS nw, ztc_org AS org " - "WHERE nw.id = $1 AND nw.sso_enabled = true AND org.owner_id = nw.owner_id", networkId); - - std::string client_id = ""; - std::string authorization_endpoint = ""; - - if (r.size() == 1) { - client_id = r.at(0)[0].as(); - authorization_endpoint = r.at(0)[1].as(); - } else if (r.size() > 1) { - fprintf(stderr, "ERROR: More than one auth endpoint for an organization?!?!? NetworkID: %s\n", networkId.c_str()); - } - // no catch all else because we don't actually care if no records exist here. just continue as normal. - - if ((!client_id.empty())&&(!authorization_endpoint.empty())) { - have_auth = true; - - uint8_t state[48]; - HMACSHA384(_ssoPsk, nonce.data(), (unsigned int)nonce.length(), state); - char state_hex[256]; - Utils::hex(state, 48, state_hex); - - const char *redirect_url = "redirect_uri=https%3A%2F%2Fmy.zerotier.com%2Fapi%2Fnetwork%2Fsso-auth"; // TODO: this should be configurable - OSUtils::ztsnprintf(authenticationURL, sizeof(authenticationURL), - "%s?response_type=id_token&response_mode=form_post&scope=openid+email+profile&redriect_uri=%s&nonce=%s&state=%s&client_id=%s", - authorization_endpoint.c_str(), - redirect_url, - nonce.c_str(), - state_hex, - client_id.c_str()); - } - _pool->unborrow(c); } catch (std::exception &e) { fprintf(stderr, "ERROR: Error updating member on load: %s\n", e.what()); - exit(-1); } return std::string(authenticationURL);