fix(OIDC): handle confedential client providers
This commit is contained in:
parent
a92c028b74
commit
a5edcfdd1d
79
src/oidc.rs
79
src/oidc.rs
|
|
@ -1,10 +1,13 @@
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use oauth2::{
|
use oauth2::{
|
||||||
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
||||||
ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse, TokenUrl,
|
ClientSecret, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, TokenUrl,
|
||||||
};
|
};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
|
@ -25,11 +28,16 @@ pub struct OidcUserInfo {
|
||||||
pub preferred_username: Option<String>,
|
pub preferred_username: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Storage for PKCE verifiers (csrf_token -> (verifier, expiry))
|
||||||
|
type PkceStore = Mutex<HashMap<String, (PkceCodeVerifier, Instant)>>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct OidcClient {
|
pub struct OidcClient {
|
||||||
oauth_client: BasicClient,
|
oauth_client: BasicClient,
|
||||||
discovery: OidcDiscovery,
|
discovery: OidcDiscovery,
|
||||||
http_client: Client,
|
http_client: Client,
|
||||||
|
is_public_client: bool,
|
||||||
|
pkce_store: PkceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OidcClient {
|
impl OidcClient {
|
||||||
|
|
@ -42,10 +50,11 @@ impl OidcClient {
|
||||||
.oidc_client_id
|
.oidc_client_id
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| anyhow!("OIDC client ID not configured"))?;
|
.ok_or_else(|| anyhow!("OIDC client ID not configured"))?;
|
||||||
let client_secret = config
|
|
||||||
.oidc_client_secret
|
// Client secret is optional - if not provided, this is a public client
|
||||||
.as_ref()
|
let client_secret_opt = config.oidc_client_secret.as_ref();
|
||||||
.ok_or_else(|| anyhow!("OIDC client secret not configured"))?;
|
let is_public_client = client_secret_opt.is_none();
|
||||||
|
|
||||||
let issuer_url = config
|
let issuer_url = config
|
||||||
.oidc_issuer_url
|
.oidc_issuer_url
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
@ -63,7 +72,7 @@ impl OidcClient {
|
||||||
// Create OAuth2 client
|
// Create OAuth2 client
|
||||||
let oauth_client = BasicClient::new(
|
let oauth_client = BasicClient::new(
|
||||||
ClientId::new(client_id.clone()),
|
ClientId::new(client_id.clone()),
|
||||||
Some(ClientSecret::new(client_secret.clone())),
|
client_secret_opt.map(|s| ClientSecret::new(s.clone())),
|
||||||
AuthUrl::new(discovery.authorization_endpoint.clone())?,
|
AuthUrl::new(discovery.authorization_endpoint.clone())?,
|
||||||
Some(TokenUrl::new(discovery.token_endpoint.clone())?),
|
Some(TokenUrl::new(discovery.token_endpoint.clone())?),
|
||||||
)
|
)
|
||||||
|
|
@ -73,6 +82,8 @@ impl OidcClient {
|
||||||
oauth_client,
|
oauth_client,
|
||||||
discovery,
|
discovery,
|
||||||
http_client,
|
http_client,
|
||||||
|
is_public_client,
|
||||||
|
pkce_store: Mutex::new(HashMap::new()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -101,21 +112,61 @@ impl OidcClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_authorization_url(&self) -> (Url, CsrfToken) {
|
pub fn get_authorization_url(&self) -> (Url, CsrfToken) {
|
||||||
let (pkce_challenge, _pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
// Clean up expired PKCE verifiers (older than 10 minutes)
|
||||||
|
self.cleanup_expired_verifiers();
|
||||||
|
|
||||||
self.oauth_client
|
let mut auth_request = self.oauth_client
|
||||||
.authorize_url(CsrfToken::new_random)
|
.authorize_url(CsrfToken::new_random)
|
||||||
.add_scope(Scope::new("openid".to_string()))
|
.add_scope(Scope::new("openid".to_string()))
|
||||||
.add_scope(Scope::new("email".to_string()))
|
.add_scope(Scope::new("email".to_string()))
|
||||||
.add_scope(Scope::new("profile".to_string()))
|
.add_scope(Scope::new("profile".to_string()));
|
||||||
.set_pkce_challenge(pkce_challenge)
|
|
||||||
.url()
|
// For public clients (no client_secret), PKCE is required for security
|
||||||
|
// For confidential clients, PKCE is optional but we don't use it to avoid state management
|
||||||
|
if self.is_public_client {
|
||||||
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
auth_request = auth_request.set_pkce_challenge(pkce_challenge);
|
||||||
|
|
||||||
|
// Store the verifier for later use in token exchange
|
||||||
|
let (url, csrf_token) = auth_request.url();
|
||||||
|
let mut store = self.pkce_store.lock().unwrap();
|
||||||
|
store.insert(
|
||||||
|
csrf_token.secret().clone(),
|
||||||
|
(pkce_verifier, Instant::now() + Duration::from_secs(600)), // 10 minute expiry
|
||||||
|
);
|
||||||
|
(url, csrf_token)
|
||||||
|
} else {
|
||||||
|
// Confidential client - no PKCE needed
|
||||||
|
auth_request.url()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn exchange_code(&self, code: &str) -> Result<String> {
|
fn cleanup_expired_verifiers(&self) {
|
||||||
let token_result = self
|
let mut store = self.pkce_store.lock().unwrap();
|
||||||
|
let now = Instant::now();
|
||||||
|
store.retain(|_, (_, expiry)| *expiry > now);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn exchange_code(&self, code: &str, state: Option<&str>) -> Result<String> {
|
||||||
|
let mut token_request = self
|
||||||
.oauth_client
|
.oauth_client
|
||||||
.exchange_code(AuthorizationCode::new(code.to_string()))
|
.exchange_code(AuthorizationCode::new(code.to_string()));
|
||||||
|
|
||||||
|
// For public clients, retrieve and use the PKCE verifier
|
||||||
|
if self.is_public_client {
|
||||||
|
if let Some(state_token) = state {
|
||||||
|
let mut store = self.pkce_store.lock().unwrap();
|
||||||
|
if let Some((verifier, _)) = store.remove(state_token) {
|
||||||
|
token_request = token_request.set_pkce_verifier(verifier);
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!("PKCE verifier not found for state token (expired or invalid)"));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!("State parameter required for public client PKCE flow"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_result = token_request
|
||||||
.request_async(async_http_client)
|
.request_async(async_http_client)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow!("Failed to exchange authorization code: {}", e))?;
|
.map_err(|e| anyhow!("Failed to exchange authorization code: {}", e))?;
|
||||||
|
|
|
||||||
|
|
@ -207,7 +207,7 @@ async fn oidc_callback(
|
||||||
|
|
||||||
// Exchange authorization code for access token
|
// Exchange authorization code for access token
|
||||||
let access_token = oidc_client
|
let access_token = oidc_client
|
||||||
.exchange_code(&code)
|
.exchange_code(&code, params.state.as_deref())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
tracing::error!("Failed to exchange code: {}", e);
|
tracing::error!("Failed to exchange code: {}", e);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue