diff --git a/controller/DB.hpp b/controller/DB.hpp index 10adbec10..64bd83af0 100644 --- a/controller/DB.hpp +++ b/controller/DB.hpp @@ -53,6 +53,7 @@ public: , ssoNonce() , ssoState() , ssoClientID() + , ssoProvider("default") {} bool enabled; @@ -64,6 +65,7 @@ public: std::string ssoNonce; std::string ssoState; std::string ssoClientID; + std::string ssoProvider; }; /** diff --git a/controller/PostgreSQL.cpp b/controller/PostgreSQL.cpp index ef1ba0ff9..183cb3c4b 100644 --- a/controller/PostgreSQL.cpp +++ b/controller/PostgreSQL.cpp @@ -34,7 +34,7 @@ using json = nlohmann::json; namespace { -static const int DB_MINIMUM_VERSION = 20; +static const int DB_MINIMUM_VERSION = 38; static const char *_timestr() { @@ -442,20 +442,29 @@ AuthInfo PostgreSQL::getSSOAuthInfo(const nlohmann::json &member, const std::str exit(7); } - r = w.exec_params("SELECT org.client_id, org.authorization_endpoint, org.issuer, org.sso_impl_version " - "FROM ztc_network AS nw, ztc_org AS org " - "WHERE nw.id = $1 AND nw.sso_enabled = true AND org.owner_id = nw.owner_id", networkId); + r = w.exec_params( + "SELECT oc.client_id, oc.authorization_endpoint, oc.issuer, oc.provider, oc.sso_impl_version " + "FROM ztc_network AS n " + "INNER JOIN ztc_org o " + " ON o.owner_id = n.owner_id " + "LEFT OUTER JOIN ztc_network_oidc_config noc " + " ON noc.network_id = n.id " + "LEFT OUTER JOIN ztc_oidc_config oc " + " ON noc.client_id = oc.client_id AND noc.org_id = o.org_id " + "WHERE n.id = $1 AND n.sso_enabled = true", networkId); std::string client_id = ""; std::string authorization_endpoint = ""; std::string issuer = ""; + std::string provider = ""; uint64_t sso_version = 0; if (r.size() == 1) { client_id = r.at(0)[0].as(); authorization_endpoint = r.at(0)[1].as(); issuer = r.at(0)[2].as(); - sso_version = r.at(0)[3].as(); + provider = r.at(0)[3].as(); + sso_version = r.at(0)[4].as(); } else if (r.size() > 1) { fprintf(stderr, "ERROR: More than one auth endpoint for an organization?!?!? NetworkID: %s\n", networkId.c_str()); } else { @@ -485,18 +494,20 @@ AuthInfo PostgreSQL::getSSOAuthInfo(const nlohmann::json &member, const std::str } else if (info.version == 1) { info.ssoClientID = client_id; info.issuerURL = issuer; + info.ssoProvider = provider; info.ssoNonce = nonce; info.ssoState = std::string(state_hex) + "_" +networkId; info.centralAuthURL = redirectURL; #ifdef ZT_DEBUG fprintf( stderr, - "ssoClientID: %s\nissuerURL: %s\nssoNonce: %s\nssoState: %s\ncentralAuthURL: %s\n", + "ssoClientID: %s\nissuerURL: %s\nssoNonce: %s\nssoState: %s\ncentralAuthURL: %s\nprovider: %s\n", info.ssoClientID.c_str(), info.issuerURL.c_str(), info.ssoNonce.c_str(), info.ssoState.c_str(), - info.centralAuthURL.c_str()); + info.centralAuthURL.c_str(), + provider.c_str()); #endif } } else { @@ -535,15 +546,21 @@ void PostgreSQL::initializeNetworks() std::unordered_set networkSet; char qbuf[2048] = {0}; - sprintf(qbuf, "SELECT n.id, (EXTRACT(EPOCH FROM n.creation_time AT TIME ZONE 'UTC')*1000)::bigint as creation_time, n.capabilities, " + sprintf(qbuf, + "SELECT n.id, (EXTRACT(EPOCH FROM n.creation_time AT TIME ZONE 'UTC')*1000)::bigint as creation_time, n.capabilities, " "n.enable_broadcast, (EXTRACT(EPOCH FROM n.last_modified AT TIME ZONE 'UTC')*1000)::bigint AS last_modified, n.mtu, n.multicast_limit, n.name, n.private, n.remote_trace_level, " - "n.remote_trace_target, n.revision, n.rules, n.tags, n.v4_assign_mode, n.v6_assign_mode, n.sso_enabled, (CASE WHEN n.sso_enabled THEN o.client_id ELSE NULL END) as client_id, " - "(CASE WHEN n.sso_enabled THEN o.authorization_endpoint ELSE NULL END) as authorization_endpoint, d.domain, d.servers, " + "n.remote_trace_target, n.revision, n.rules, n.tags, n.v4_assign_mode, n.v6_assign_mode, n.sso_enabled, (CASE WHEN n.sso_enabled THEN noc.client_id ELSE NULL END) as client_id, " + "(CASE WHEN n.sso_enabled THEN oc.authorization_endpoint ELSE NULL END) as authorization_endpoint, " + "(CASE WHEN n.sso_enabled THEN oc.provider ELSE NULL END) as provider, d.domain, d.servers, " "ARRAY(SELECT CONCAT(host(ip_range_start),'|', host(ip_range_end)) FROM ztc_network_assignment_pool WHERE network_id = n.id) AS assignment_pool, " "ARRAY(SELECT CONCAT(host(address),'/',bits::text,'|',COALESCE(host(via), 'NULL'))FROM ztc_network_route WHERE network_id = n.id) AS routes " "FROM ztc_network n " "LEFT OUTER JOIN ztc_org o " - " ON o.owner_id = n.owner_id " + " ON o.owner_id = n.owner_id " + "LEFT OUTER JOIN ztc_network_oidc_config noc " + " ON noc.network_id = n.id " + "LEFT OUTER JOIN ztc_oidc_config oc " + " ON noc.client_id = oc.client_id AND oc.org_id = o.org_id " "LEFT OUTER JOIN ztc_network_dns d " " ON d.network_id = n.id " "WHERE deleted = false AND controller_id = '%s'", _myAddressStr.c_str()); @@ -574,6 +591,7 @@ void PostgreSQL::initializeNetworks() , std::optional // ssoEnabled , std::optional // clientId , std::optional // authorizationEndpoint + , std::optional // ssoProvider , std::optional // domain , std::optional // servers , std::string // assignmentPoolString @@ -610,10 +628,11 @@ void PostgreSQL::initializeNetworks() std::optional ssoEnabled = std::get<16>(row); std::optional clientId = std::get<17>(row); std::optional authorizationEndpoint = std::get<18>(row); - std::optional dnsDomain = std::get<19>(row); - std::optional dnsServers = std::get<20>(row); - std::string assignmentPoolString = std::get<21>(row); - std::string routesString = std::get<22>(row); + std::optional ssoProvider = std::get<19>(row); + std::optional dnsDomain = std::get<20>(row); + std::optional dnsServers = std::get<21>(row); + std::string assignmentPoolString = std::get<22>(row); + std::string routesString = std::get<23>(row); config["id"] = nwid; config["nwid"] = nwid; @@ -638,6 +657,7 @@ void PostgreSQL::initializeNetworks() config["routes"] = json::array(); config["clientId"] = clientId.value_or(""); config["authorizationEndpoint"] = authorizationEndpoint.value_or(""); + config["provider"] = ssoProvider.value_or(""); networkSet.insert(nwid); diff --git a/include/ZeroTierOne.h b/include/ZeroTierOne.h index a12b3291b..a60116bd9 100644 --- a/include/ZeroTierOne.h +++ b/include/ZeroTierOne.h @@ -1246,6 +1246,11 @@ typedef struct * oidc client id */ char ssoClientID[256]; + + /** + * sso provider + **/ + char ssoProvider[64]; } ZT_VirtualNetworkConfig; /** diff --git a/node/IncomingPacket.cpp b/node/IncomingPacket.cpp index 9080128b6..a120a208a 100644 --- a/node/IncomingPacket.cpp +++ b/node/IncomingPacket.cpp @@ -217,6 +217,7 @@ bool IncomingPacket::_doERROR(const RuntimeEnvironment *RR,void *tPtr,const Shar char ssoNonce[64] = { 0 }; char ssoState[128] = {0}; char ssoClientID[256] = { 0 }; + char ssoProvider[64] = { 0 }; if (authInfo.get(ZT_AUTHINFO_DICT_KEY_ISSUER_URL, issuerURL, sizeof(issuerURL)) > 0) { issuerURL[sizeof(issuerURL) - 1] = 0; @@ -233,8 +234,13 @@ bool IncomingPacket::_doERROR(const RuntimeEnvironment *RR,void *tPtr,const Shar if (authInfo.get(ZT_AUTHINFO_DICT_KEY_CLIENT_ID, ssoClientID, sizeof(ssoClientID)) > 0) { ssoClientID[sizeof(ssoClientID) - 1] = 0; } + if (authInfo.get(ZT_AUTHINFO_DICT_KEY_SSO_PROVIDER, ssoProvider, sizeof(ssoProvider)) > 0 ) { + ssoProvider[sizeof(ssoProvider) - 1] = 0; + } else { + strncpy(ssoProvider, "default", sizeof(ssoProvider)); + } - network->setAuthenticationRequired(tPtr, issuerURL, centralAuthURL, ssoClientID, ssoNonce, ssoState); + network->setAuthenticationRequired(tPtr, issuerURL, centralAuthURL, ssoClientID, ssoProvider, ssoNonce, ssoState); } } } else { diff --git a/node/Network.cpp b/node/Network.cpp index a3810162b..b03f4b3d0 100644 --- a/node/Network.cpp +++ b/node/Network.cpp @@ -1450,6 +1450,7 @@ void Network::_externalConfig(ZT_VirtualNetworkConfig *ec) const Utils::scopy(ec->ssoNonce, sizeof(ec->ssoNonce), _config.ssoNonce); Utils::scopy(ec->ssoState, sizeof(ec->ssoState), _config.ssoState); Utils::scopy(ec->ssoClientID, sizeof(ec->ssoClientID), _config.ssoClientID); + Utils::scopy(ec->ssoProvider, sizeof(ec->ssoProvider), _config.ssoProvider); } void Network::_sendUpdatesToMembers(void *tPtr,const MulticastGroup *const newMulticastGroup) @@ -1556,7 +1557,7 @@ Membership &Network::_membership(const Address &a) return _memberships[a]; } -void Network::setAuthenticationRequired(void *tPtr, const char* issuerURL, const char* centralEndpoint, const char* clientID, const char* nonce, const char* state) +void Network::setAuthenticationRequired(void *tPtr, const char* issuerURL, const char* centralEndpoint, const char* clientID, const char *ssoProvider, const char* nonce, const char* state) { Mutex::Lock _l(_lock); _netconfFailure = NETCONF_FAILURE_AUTHENTICATION_REQUIRED; @@ -1568,6 +1569,7 @@ void Network::setAuthenticationRequired(void *tPtr, const char* issuerURL, const Utils::scopy(_config.ssoClientID, sizeof(_config.ssoClientID), clientID); Utils::scopy(_config.ssoNonce, sizeof(_config.ssoNonce), nonce); Utils::scopy(_config.ssoState, sizeof(_config.ssoState), state); + Utils::scopy(_config.ssoProvider, sizeof(_config.ssoProvider), ssoProvider); _sendUpdateEvent(tPtr); } diff --git a/node/Network.hpp b/node/Network.hpp index b427a83d6..275e82f02 100644 --- a/node/Network.hpp +++ b/node/Network.hpp @@ -241,7 +241,7 @@ public: * set netconf failure to 'authentication required' along with info needed * for sso full flow authentication. */ - void setAuthenticationRequired(void *tPtr, const char* issuerURL, const char* centralEndpoint, const char* clientID, const char* nonce, const char* state); + void setAuthenticationRequired(void *tPtr, const char* issuerURL, const char* centralEndpoint, const char* clientID, const char *ssoProvider, const char* nonce, const char* state); /** * Causes this network to request an updated configuration from its master node now diff --git a/node/NetworkConfig.cpp b/node/NetworkConfig.cpp index 13a9313aa..3dc3b36d6 100644 --- a/node/NetworkConfig.cpp +++ b/node/NetworkConfig.cpp @@ -201,6 +201,7 @@ bool NetworkConfig::toDictionary(Dictionary &d,b if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_NONCE, this->ssoNonce)) return false; if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_STATE, this->ssoState)) return false; if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_CLIENT_ID, this->ssoClientID)) return false; + if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_SSO_PROVIDER, this->ssoProvider)) return false; } delete tmp; @@ -424,6 +425,12 @@ bool NetworkConfig::fromDictionary(const DictionaryssoClientID, (unsigned int)sizeof(this->ssoClientID)) > 0) { this->ssoClientID[sizeof(this->ssoClientID) - 1] = 0; } + if (d.get(ZT_NETWORKCONFIG_DICT_KEY_SSO_PROVIDER, this->ssoProvider, (unsigned int)(sizeof(this->ssoProvider))) > 0) { + this->ssoProvider[sizeof(this->ssoProvider) - 1] = 0; + } else { + strncpy(this->ssoProvider, "default", sizeof(this->ssoProvider)); + this->ssoProvider[sizeof(this->ssoProvider) - 1] = 0; + } } else { this->authenticationURL[0] = 0; this->authenticationExpiryTime = 0; @@ -432,6 +439,7 @@ bool NetworkConfig::fromDictionary(const DictionaryssoState[0] = 0; this->ssoClientID[0] = 0; this->issuerURL[0] = 0; + this->ssoProvider[0] = 0; } } } diff --git a/node/NetworkConfig.hpp b/node/NetworkConfig.hpp index 0161b4fa9..cd713dde8 100644 --- a/node/NetworkConfig.hpp +++ b/node/NetworkConfig.hpp @@ -195,6 +195,8 @@ namespace ZeroTier { #define ZT_NETWORKCONFIG_DICT_KEY_STATE "ssos" // client ID #define ZT_NETWORKCONFIG_DICT_KEY_CLIENT_ID "ssocid" +// SSO Provider +#define ZT_NETWORKCONFIG_DICT_KEY_SSO_PROVIDER "ssop" // AuthInfo fields -- used by ncSendError for sso @@ -212,6 +214,8 @@ namespace ZeroTier { #define ZT_AUTHINFO_DICT_KEY_STATE "aS" // Client ID #define ZT_AUTHINFO_DICT_KEY_CLIENT_ID "aCID" +// SSO Provider +#define ZT_AUTHINFO_DICT_KEY_SSO_PROVIDER "aSSOp" // Legacy fields -- these are obsoleted but are included when older clients query @@ -289,6 +293,7 @@ public: memset(ssoNonce, 0, sizeof(ssoNonce)); memset(ssoState, 0, sizeof(ssoState)); memset(ssoClientID, 0, sizeof(ssoClientID)); + strncpy(ssoProvider, "default", sizeof(ssoProvider)); } /** @@ -699,6 +704,15 @@ public: * oidc client id */ char ssoClientID[256]; + + /** + * oidc provider + * + * because certain providers require specific scopes to be requested + * and others to be not requested in order to make everything work + * correctly + **/ + char ssoProvider[64]; }; } // namespace ZeroTier diff --git a/service/OneService.cpp b/service/OneService.cpp index 429ae20e4..292f47e97 100644 --- a/service/OneService.cpp +++ b/service/OneService.cpp @@ -303,11 +303,13 @@ public: assert(_config.issuerURL != nullptr); assert(_config.ssoClientID != nullptr); assert(_config.centralAuthURL != nullptr); + assert(_config.ssoProvider != nullptr); _idc = zeroidc::zeroidc_new( _config.issuerURL, _config.ssoClientID, _config.centralAuthURL, + _config.ssoProvider, _webPort ); diff --git a/zeroidc/src/ext.rs b/zeroidc/src/ext.rs index dfb25bd1a..d87724a78 100644 --- a/zeroidc/src/ext.rs +++ b/zeroidc/src/ext.rs @@ -28,6 +28,7 @@ pub extern "C" fn zeroidc_new( issuer: *const c_char, client_id: *const c_char, auth_endpoint: *const c_char, + provider: *const c_char, web_listen_port: u16, ) -> *mut ZeroIDC { if issuer.is_null() { @@ -40,6 +41,11 @@ pub extern "C" fn zeroidc_new( return std::ptr::null_mut(); } + if provider.is_null() { + println!("provider is null"); + return std::ptr::null_mut(); + } + if auth_endpoint.is_null() { println!("auth_endpoint is null"); return std::ptr::null_mut(); @@ -47,10 +53,12 @@ pub extern "C" fn zeroidc_new( let issuer = unsafe { CStr::from_ptr(issuer) }; let client_id = unsafe { CStr::from_ptr(client_id) }; + let provider = unsafe { CStr::from_ptr(provider) }; let auth_endpoint = unsafe { CStr::from_ptr(auth_endpoint) }; match ZeroIDC::new( issuer.to_str().unwrap(), client_id.to_str().unwrap(), + provider.to_str().unwrap(), auth_endpoint.to_str().unwrap(), web_listen_port, ) { diff --git a/zeroidc/src/lib.rs b/zeroidc/src/lib.rs index 52ab56f1c..21319481f 100644 --- a/zeroidc/src/lib.rs +++ b/zeroidc/src/lib.rs @@ -59,7 +59,9 @@ pub struct ZeroIDC { ))] struct Inner { running: bool, + issuer: String, auth_endpoint: String, + provider: String, oidc_thread: Option>, oidc_client: Option, access_token: Option, @@ -114,12 +116,15 @@ impl ZeroIDC { pub fn new( issuer: &str, client_id: &str, + provider: &str, auth_ep: &str, local_web_port: u16, ) -> Result { let idc = ZeroIDC { inner: Arc::new(Mutex::new(Inner { running: false, + issuer: issuer.to_string(), + provider: provider.to_string(), auth_endpoint: auth_ep.to_string(), oidc_thread: None, oidc_client: None, @@ -148,7 +153,7 @@ impl ZeroIDC { let redirect = RedirectUrl::new(redir_url.to_string())?; - (*idc.inner.lock().unwrap()).oidc_client = Some( + idc.inner.lock().unwrap().oidc_client = Some( CoreClient::from_provider_metadata( provider_meta, ClientId::new(client_id.to_string()), @@ -163,25 +168,25 @@ impl ZeroIDC { fn kick_refresh_thread(&mut self) { let local = Arc::clone(&self.inner); - (*local.lock().unwrap()).kick = true; + local.lock().unwrap().kick = true; } fn start(&mut self) { let local = Arc::clone(&self.inner); - if !(*local.lock().unwrap()).running { + if !local.lock().unwrap().running { let inner_local = Arc::clone(&self.inner); - (*local.lock().unwrap()).oidc_thread = Some(spawn(move || { - (*inner_local.lock().unwrap()).running = true; + local.lock().unwrap().oidc_thread = Some(spawn(move || { + inner_local.lock().unwrap().running = true; let mut running = true; // Keep a copy of the initial nonce used to get the tokens // Will be needed later when verifying the responses from refresh tokens - let nonce = (*inner_local.lock().unwrap()).nonce.clone(); + let nonce = inner_local.lock().unwrap().nonce.clone(); while running { let exp = - UNIX_EPOCH + Duration::from_secs((*inner_local.lock().unwrap()).exp_time); + UNIX_EPOCH + Duration::from_secs(inner_local.lock().unwrap().exp_time); let now = SystemTime::now(); #[cfg(debug_assertions)] @@ -198,17 +203,17 @@ impl ZeroIDC { ) ); } - let refresh_token = (*inner_local.lock().unwrap()).refresh_token.clone(); + let refresh_token = inner_local.lock().unwrap().refresh_token.clone(); if let Some(refresh_token) = refresh_token { - let should_kick = (*inner_local.lock().unwrap()).kick; + let should_kick = inner_local.lock().unwrap().kick; if now >= (exp - Duration::from_secs(30)) || should_kick { if should_kick { #[cfg(debug_assertions)] { println!("refresh thread kicked"); } - (*inner_local.lock().unwrap()).kick = false; + inner_local.lock().unwrap().kick = false; } #[cfg(debug_assertions)] @@ -216,10 +221,8 @@ impl ZeroIDC { println!("Refresh Token: {}", refresh_token.secret()); } - let token_response = (*inner_local.lock().unwrap()) - .oidc_client - .as_ref() - .map(|c| { + let token_response = + inner_local.lock().unwrap().oidc_client.as_ref().map(|c| { let res = c .exchange_refresh_token(&refresh_token) .request(http_client); @@ -252,7 +255,9 @@ impl ZeroIDC { let client = reqwest::blocking::Client::new(); let r = client .post( - (*inner_local.lock().unwrap()) + inner_local + .lock() + .unwrap() .auth_endpoint .clone(), ) @@ -289,10 +294,10 @@ impl ZeroIDC { match claims.expiration { Some(exp) => { println!("exp: {}", exp); - (*inner_local + inner_local .lock() - .unwrap()) - .exp_time = exp; + .unwrap() + .exp_time = exp; } None => { panic!("expiration is None. This shouldn't happen") @@ -302,12 +307,16 @@ impl ZeroIDC { panic!("error parsing claims"); } - (*inner_local.lock().unwrap()) + inner_local + .lock() + .unwrap() .access_token = Some(access_token.clone()); if let Some(t) = res.refresh_token() { // println!("New Refresh Token: {}", t.secret()); - (*inner_local.lock().unwrap()) + inner_local + .lock() + .unwrap() .refresh_token = Some(t.clone()); } @@ -333,10 +342,10 @@ impl ZeroIDC { } } - (*inner_local.lock().unwrap()) - .exp_time = 0; - (*inner_local.lock().unwrap()) - .running = false; + inner_local.lock().unwrap().exp_time = + 0; + inner_local.lock().unwrap().running = + false; } } Err(e) => { @@ -346,29 +355,28 @@ impl ZeroIDC { e.url().unwrap().as_str() ); println!("Status: {}", e.status().unwrap()); - (*inner_local.lock().unwrap()).exp_time = 0; - (*inner_local.lock().unwrap()).running = - false; + inner_local.lock().unwrap().exp_time = 0; + inner_local.lock().unwrap().running = false; } } } None => { println!("no id token?!?"); - (*inner_local.lock().unwrap()).exp_time = 0; - (*inner_local.lock().unwrap()).running = false; + inner_local.lock().unwrap().exp_time = 0; + inner_local.lock().unwrap().running = false; } } } Err(e) => { println!("token error: {}", e); - (*inner_local.lock().unwrap()).exp_time = 0; - (*inner_local.lock().unwrap()).running = false; + inner_local.lock().unwrap().exp_time = 0; + inner_local.lock().unwrap().running = false; } } } else { println!("token response??"); - (*inner_local.lock().unwrap()).exp_time = 0; - (*inner_local.lock().unwrap()).running = false; + inner_local.lock().unwrap().exp_time = 0; + inner_local.lock().unwrap().running = false; } } else { #[cfg(debug_assertions)] @@ -376,19 +384,19 @@ impl ZeroIDC { } } else { println!("no refresh token?"); - (*inner_local.lock().unwrap()).exp_time = 0; - (*inner_local.lock().unwrap()).running = false; + inner_local.lock().unwrap().exp_time = 0; + inner_local.lock().unwrap().running = false; } sleep(Duration::from_secs(1)); { - running = (*inner_local.lock().unwrap()).running; + running = inner_local.lock().unwrap().running; } } // end run loop println!("thread done!"); - (*inner_local.lock().unwrap()).running = false; + inner_local.lock().unwrap().running = false; println!("set idc thread running flag to false"); })); } @@ -397,19 +405,19 @@ impl ZeroIDC { pub fn stop(&mut self) { let local = self.inner.clone(); if self.is_running() { - (*local.lock().unwrap()).running = false; + local.lock().unwrap().running = false; } } pub fn is_running(&mut self) -> bool { let local = Arc::clone(&self.inner); - let running = (*local.lock().unwrap()).running; + let running = local.lock().unwrap().running; running } pub fn get_exp_time(&mut self) -> u64 { - return (*self.inner.lock().unwrap()).exp_time; + return self.inner.lock().unwrap().exp_time; } pub fn set_nonce_and_csrf(&mut self, csrf_token: String, nonce: String) { @@ -439,20 +447,53 @@ impl ZeroIDC { if need_verifier || csrf_diff || nonce_diff { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let r = i.oidc_client.as_ref().map(|c| { - let (auth_url, csrf_token, nonce) = c + let mut auth_builder = c .authorize_url( AuthenticationFlow::::AuthorizationCode, csrf_func(csrf_token), nonce_func(nonce), ) - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("offline_access".to_string())) - .add_scope(Scope::new("openid".to_string())) - .set_pkce_challenge(pkce_challenge) - .url(); + .set_pkce_challenge(pkce_challenge); + match i.provider.as_str() { + "auth0" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + "okta" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("groups".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + "keycloak" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())); + } + "onelogin" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("groups".to_string())) + } + "default" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + _ => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + } - (auth_url, csrf_token, nonce) + auth_builder.url() }); if let Some(r) = r { @@ -653,8 +694,7 @@ impl ZeroIDC { } Err(res) => { println!("error result: {}", res); - println!("hit url: {}", res.url().unwrap().as_str()); - println!("Status: {}", res.status().unwrap()); + println!("hit url: {}", i.auth_endpoint.clone()); println!("Post error: {}", res); i.exp_time = 0; i.running = false;