feat(oidc): fix oidc, tests, and everything in between
This commit is contained in:
parent
10d9a1a661
commit
72708a05f3
|
|
@ -44,6 +44,7 @@ Object.defineProperty(window, 'location', {
|
||||||
writable: true
|
writable: true
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
// Mock AuthContext
|
// Mock AuthContext
|
||||||
const mockAuthContextValue = {
|
const mockAuthContextValue = {
|
||||||
user: null,
|
user: null,
|
||||||
|
|
@ -54,16 +55,35 @@ const mockAuthContextValue = {
|
||||||
};
|
};
|
||||||
|
|
||||||
const MockAuthProvider = ({ children }: { children: React.ReactNode }) => (
|
const MockAuthProvider = ({ children }: { children: React.ReactNode }) => (
|
||||||
<div data-testid="mock-auth-provider">{children}</div>
|
<AuthProvider>
|
||||||
|
{children}
|
||||||
|
</AuthProvider>
|
||||||
);
|
);
|
||||||
|
|
||||||
const MockThemeProvider = ({ children }: { children: React.ReactNode }) => (
|
const MockThemeProvider = ({ children }: { children: React.ReactNode }) => (
|
||||||
<div data-testid="mock-theme-provider">{children}</div>
|
<ThemeProvider>
|
||||||
|
{children}
|
||||||
|
</ThemeProvider>
|
||||||
);
|
);
|
||||||
|
|
||||||
describe('Login - OIDC Features', () => {
|
describe('Login - OIDC Features', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
|
|
||||||
|
// Mock window.matchMedia
|
||||||
|
Object.defineProperty(window, 'matchMedia', {
|
||||||
|
writable: true,
|
||||||
|
value: vi.fn().mockImplementation(query => ({
|
||||||
|
matches: false,
|
||||||
|
media: query,
|
||||||
|
onchange: null,
|
||||||
|
addListener: vi.fn(),
|
||||||
|
removeListener: vi.fn(),
|
||||||
|
addEventListener: vi.fn(),
|
||||||
|
removeEventListener: vi.fn(),
|
||||||
|
dispatchEvent: vi.fn(),
|
||||||
|
})),
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
const renderLogin = () => {
|
const renderLogin = () => {
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,16 @@ vi.mock('../../../services/api', () => ({
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock useNavigate
|
// Mock useNavigate and useSearchParams
|
||||||
const mockNavigate = vi.fn();
|
const mockNavigate = vi.fn();
|
||||||
|
const mockUseSearchParams = vi.fn(() => [new URLSearchParams('code=test-code&state=test-state')]);
|
||||||
|
|
||||||
vi.mock('react-router-dom', async () => {
|
vi.mock('react-router-dom', async () => {
|
||||||
const actual = await vi.importActual('react-router-dom');
|
const actual = await vi.importActual('react-router-dom');
|
||||||
return {
|
return {
|
||||||
...actual,
|
...actual,
|
||||||
useNavigate: () => mockNavigate,
|
useNavigate: () => mockNavigate,
|
||||||
useSearchParams: () => [new URLSearchParams('code=test-code&state=test-state')]
|
useSearchParams: mockUseSearchParams
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -45,6 +47,7 @@ Object.defineProperty(window, 'location', {
|
||||||
writable: true
|
writable: true
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
// Mock AuthContext
|
// Mock AuthContext
|
||||||
const mockAuthContextValue = {
|
const mockAuthContextValue = {
|
||||||
user: null,
|
user: null,
|
||||||
|
|
@ -55,12 +58,29 @@ const mockAuthContextValue = {
|
||||||
};
|
};
|
||||||
|
|
||||||
const MockAuthProvider = ({ children }: { children: React.ReactNode }) => (
|
const MockAuthProvider = ({ children }: { children: React.ReactNode }) => (
|
||||||
<div data-testid="mock-auth-provider">{children}</div>
|
<AuthProvider>
|
||||||
|
{children}
|
||||||
|
</AuthProvider>
|
||||||
);
|
);
|
||||||
|
|
||||||
describe('OidcCallback', () => {
|
describe('OidcCallback', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
|
|
||||||
|
// Mock window.matchMedia
|
||||||
|
Object.defineProperty(window, 'matchMedia', {
|
||||||
|
writable: true,
|
||||||
|
value: vi.fn().mockImplementation(query => ({
|
||||||
|
matches: false,
|
||||||
|
media: query,
|
||||||
|
onchange: null,
|
||||||
|
addListener: vi.fn(),
|
||||||
|
removeListener: vi.fn(),
|
||||||
|
addEventListener: vi.fn(),
|
||||||
|
removeEventListener: vi.fn(),
|
||||||
|
dispatchEvent: vi.fn(),
|
||||||
|
})),
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
const renderOidcCallback = () => {
|
const renderOidcCallback = () => {
|
||||||
|
|
@ -106,7 +126,7 @@ describe('OidcCallback', () => {
|
||||||
|
|
||||||
it('handles authentication error from URL params', () => {
|
it('handles authentication error from URL params', () => {
|
||||||
// Mock useSearchParams to return error
|
// Mock useSearchParams to return error
|
||||||
vi.mocked(require('react-router-dom').useSearchParams).mockReturnValue([
|
mockUseSearchParams.mockReturnValueOnce([
|
||||||
new URLSearchParams('error=access_denied&error_description=User+denied+access')
|
new URLSearchParams('error=access_denied&error_description=User+denied+access')
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
|
@ -118,7 +138,7 @@ describe('OidcCallback', () => {
|
||||||
|
|
||||||
it('handles missing authorization code', () => {
|
it('handles missing authorization code', () => {
|
||||||
// Mock useSearchParams to return no code
|
// Mock useSearchParams to return no code
|
||||||
vi.mocked(require('react-router-dom').useSearchParams).mockReturnValue([
|
mockUseSearchParams.mockReturnValueOnce([
|
||||||
new URLSearchParams('')
|
new URLSearchParams('')
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,4 +19,19 @@ vi.mock('axios', () => ({
|
||||||
defaults: { headers: { common: {} } },
|
defaults: { headers: { common: {} } },
|
||||||
})),
|
})),
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
// Mock window.matchMedia
|
||||||
|
Object.defineProperty(window, 'matchMedia', {
|
||||||
|
writable: true,
|
||||||
|
value: vi.fn().mockImplementation(query => ({
|
||||||
|
matches: false,
|
||||||
|
media: query,
|
||||||
|
onchange: null,
|
||||||
|
addListener: vi.fn(), // deprecated
|
||||||
|
removeListener: vi.fn(), // deprecated
|
||||||
|
addEventListener: vi.fn(),
|
||||||
|
removeEventListener: vi.fn(),
|
||||||
|
dispatchEvent: vi.fn(),
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
|
@ -88,23 +88,41 @@ impl Database {
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Create users table
|
// Create users table with OIDC support
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||||
username VARCHAR(255) UNIQUE NOT NULL,
|
username VARCHAR(255) UNIQUE NOT NULL,
|
||||||
email VARCHAR(255) UNIQUE NOT NULL,
|
email VARCHAR(255) UNIQUE NOT NULL,
|
||||||
password_hash VARCHAR(255) NOT NULL,
|
password_hash VARCHAR(255),
|
||||||
role VARCHAR(10) DEFAULT 'user',
|
role VARCHAR(20) DEFAULT 'user',
|
||||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
oidc_subject VARCHAR(255),
|
||||||
|
oidc_issuer VARCHAR(255),
|
||||||
|
oidc_email VARCHAR(255),
|
||||||
|
auth_provider VARCHAR(50) DEFAULT 'local',
|
||||||
|
CONSTRAINT check_auth_method CHECK (
|
||||||
|
(auth_provider = 'local' AND password_hash IS NOT NULL) OR
|
||||||
|
(auth_provider = 'oidc' AND oidc_subject IS NOT NULL AND oidc_issuer IS NOT NULL)
|
||||||
|
),
|
||||||
|
CONSTRAINT check_user_role CHECK (role IN ('admin', 'user'))
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Create indexes for OIDC
|
||||||
|
sqlx::query(r#"CREATE INDEX IF NOT EXISTS idx_users_oidc_subject_issuer ON users(oidc_subject, oidc_issuer)"#)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(r#"CREATE INDEX IF NOT EXISTS idx_users_auth_provider ON users(auth_provider)"#)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
|
||||||
// Create documents table
|
// Create documents table
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ pub fn router() -> Router<Arc<AppState>> {
|
||||||
.route("/oidc/callback", get(oidc_callback))
|
.route("/oidc/callback", get(oidc_callback))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
path = "/api/auth/register",
|
path = "/api/auth/register",
|
||||||
|
|
@ -170,6 +171,9 @@ async fn oidc_callback(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Query(params): Query<OidcCallbackQuery>,
|
Query(params): Query<OidcCallbackQuery>,
|
||||||
) -> Result<Json<LoginResponse>, StatusCode> {
|
) -> Result<Json<LoginResponse>, StatusCode> {
|
||||||
|
tracing::info!("OIDC callback called with params: code={:?}, state={:?}, error={:?}",
|
||||||
|
params.code, params.state, params.error);
|
||||||
|
|
||||||
if let Some(error) = params.error {
|
if let Some(error) = params.error {
|
||||||
tracing::error!("OIDC callback error: {}", error);
|
tracing::error!("OIDC callback error: {}", error);
|
||||||
return Err(StatusCode::UNAUTHORIZED);
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
|
@ -201,9 +205,15 @@ async fn oidc_callback(
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Find or create user in database
|
// Find or create user in database
|
||||||
let user = match state.db.get_user_by_oidc_subject(&user_info.sub, &state.config.oidc_issuer_url.as_ref().unwrap()).await {
|
let issuer_url = state.config.oidc_issuer_url.as_ref().unwrap();
|
||||||
Ok(Some(existing_user)) => existing_user,
|
tracing::debug!("Looking up user by OIDC subject: {} and issuer: {}", user_info.sub, issuer_url);
|
||||||
|
let user = match state.db.get_user_by_oidc_subject(&user_info.sub, issuer_url).await {
|
||||||
|
Ok(Some(existing_user)) => {
|
||||||
|
tracing::debug!("Found existing OIDC user: {}", existing_user.username);
|
||||||
|
existing_user
|
||||||
|
},
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
|
tracing::debug!("Creating new OIDC user");
|
||||||
// Create new user
|
// Create new user
|
||||||
let username = user_info.preferred_username
|
let username = user_info.preferred_username
|
||||||
.or_else(|| user_info.email.clone())
|
.or_else(|| user_info.email.clone())
|
||||||
|
|
@ -211,6 +221,8 @@ async fn oidc_callback(
|
||||||
|
|
||||||
let email = user_info.email.unwrap_or_else(|| format!("{}@oidc.local", username));
|
let email = user_info.email.unwrap_or_else(|| format!("{}@oidc.local", username));
|
||||||
|
|
||||||
|
tracing::debug!("New user details - username: {}, email: {}", username, email);
|
||||||
|
|
||||||
let create_user = CreateUser {
|
let create_user = CreateUser {
|
||||||
username,
|
username,
|
||||||
email: email.clone(),
|
email: email.clone(),
|
||||||
|
|
@ -218,15 +230,23 @@ async fn oidc_callback(
|
||||||
role: Some(UserRole::User),
|
role: Some(UserRole::User),
|
||||||
};
|
};
|
||||||
|
|
||||||
state.db.create_oidc_user(
|
let result = state.db.create_oidc_user(
|
||||||
create_user,
|
create_user,
|
||||||
&user_info.sub,
|
&user_info.sub,
|
||||||
&state.config.oidc_issuer_url.as_ref().unwrap(),
|
issuer_url,
|
||||||
&email,
|
&email,
|
||||||
).await.map_err(|e| {
|
).await;
|
||||||
tracing::error!("Failed to create OIDC user: {}", e);
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
match result {
|
||||||
})?
|
Ok(user) => {
|
||||||
|
tracing::info!("Successfully created OIDC user: {}", user.username);
|
||||||
|
user
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to create OIDC user: {} (full error: {:#})", e, e);
|
||||||
|
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Database error during OIDC lookup: {}", e);
|
tracing::error!("Database error during OIDC lookup: {}", e);
|
||||||
|
|
@ -236,7 +256,10 @@ async fn oidc_callback(
|
||||||
|
|
||||||
// Create JWT token
|
// Create JWT token
|
||||||
let token = create_jwt(&user, &state.config.jwt_secret)
|
let token = create_jwt(&user, &state.config.jwt_secret)
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to create JWT token: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(Json(LoginResponse {
|
Ok(Json(LoginResponse {
|
||||||
token,
|
token,
|
||||||
|
|
|
||||||
|
|
@ -40,11 +40,22 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_oidc_enabled_from_env() {
|
fn test_oidc_enabled_from_env() {
|
||||||
|
// Clean up environment first to ensure test isolation
|
||||||
|
env::remove_var("OIDC_ENABLED");
|
||||||
|
env::remove_var("OIDC_CLIENT_ID");
|
||||||
|
env::remove_var("OIDC_CLIENT_SECRET");
|
||||||
|
env::remove_var("OIDC_ISSUER_URL");
|
||||||
|
env::remove_var("OIDC_REDIRECT_URI");
|
||||||
|
env::remove_var("DATABASE_URL");
|
||||||
|
env::remove_var("JWT_SECRET");
|
||||||
|
|
||||||
env::set_var("OIDC_ENABLED", "true");
|
env::set_var("OIDC_ENABLED", "true");
|
||||||
env::set_var("OIDC_CLIENT_ID", "test-client-id");
|
env::set_var("OIDC_CLIENT_ID", "test-client-id");
|
||||||
env::set_var("OIDC_CLIENT_SECRET", "test-client-secret");
|
env::set_var("OIDC_CLIENT_SECRET", "test-client-secret");
|
||||||
env::set_var("OIDC_ISSUER_URL", "https://provider.example.com");
|
env::set_var("OIDC_ISSUER_URL", "https://provider.example.com");
|
||||||
env::set_var("OIDC_REDIRECT_URI", "http://localhost:8000/auth/oidc/callback");
|
env::set_var("OIDC_REDIRECT_URI", "http://localhost:8000/auth/oidc/callback");
|
||||||
|
env::set_var("DATABASE_URL", "postgresql://test:test@localhost/test");
|
||||||
|
env::set_var("JWT_SECRET", "test-secret");
|
||||||
|
|
||||||
let config = Config::from_env().unwrap();
|
let config = Config::from_env().unwrap();
|
||||||
|
|
||||||
|
|
@ -60,6 +71,8 @@ mod tests {
|
||||||
env::remove_var("OIDC_CLIENT_SECRET");
|
env::remove_var("OIDC_CLIENT_SECRET");
|
||||||
env::remove_var("OIDC_ISSUER_URL");
|
env::remove_var("OIDC_ISSUER_URL");
|
||||||
env::remove_var("OIDC_REDIRECT_URI");
|
env::remove_var("OIDC_REDIRECT_URI");
|
||||||
|
env::remove_var("DATABASE_URL");
|
||||||
|
env::remove_var("JWT_SECRET");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -83,6 +96,15 @@ mod tests {
|
||||||
];
|
];
|
||||||
|
|
||||||
for (value, expected) in test_cases {
|
for (value, expected) in test_cases {
|
||||||
|
// Clean up environment first for each iteration
|
||||||
|
env::remove_var("OIDC_ENABLED");
|
||||||
|
env::remove_var("OIDC_CLIENT_ID");
|
||||||
|
env::remove_var("OIDC_CLIENT_SECRET");
|
||||||
|
env::remove_var("OIDC_ISSUER_URL");
|
||||||
|
env::remove_var("OIDC_REDIRECT_URI");
|
||||||
|
env::remove_var("DATABASE_URL");
|
||||||
|
env::remove_var("JWT_SECRET");
|
||||||
|
|
||||||
env::set_var("OIDC_ENABLED", value);
|
env::set_var("OIDC_ENABLED", value);
|
||||||
env::set_var("DATABASE_URL", "postgresql://test:test@localhost/test");
|
env::set_var("DATABASE_URL", "postgresql://test:test@localhost/test");
|
||||||
env::set_var("JWT_SECRET", "test-secret");
|
env::set_var("JWT_SECRET", "test-secret");
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,76 @@
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::models::{AuthProvider, CreateUser, UserRole};
|
use crate::models::{AuthProvider, CreateUser, UserRole};
|
||||||
use super::super::helpers::{create_test_app};
|
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tower::util::ServiceExt;
|
use tower::util::ServiceExt;
|
||||||
use wiremock::{matchers::{method, path, query_param}, Mock, MockServer, ResponseTemplate};
|
use wiremock::{matchers::{method, path, query_param, header}, Mock, MockServer, ResponseTemplate};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use crate::{AppState, oidc::OidcClient};
|
use crate::{AppState, oidc::OidcClient};
|
||||||
|
|
||||||
async fn create_test_app_with_oidc() -> (axum::Router, testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>, MockServer) {
|
async fn create_test_app_simple() -> (axum::Router, ()) {
|
||||||
let (mut app, container) = create_test_app().await;
|
// Use TEST_DATABASE_URL directly, no containers
|
||||||
|
let database_url = std::env::var("TEST_DATABASE_URL")
|
||||||
|
.or_else(|_| std::env::var("DATABASE_URL"))
|
||||||
|
.unwrap_or_else(|_| "postgresql://readur:readur@localhost:5432/readur".to_string());
|
||||||
|
|
||||||
|
let config = crate::config::Config {
|
||||||
|
database_url: database_url.clone(),
|
||||||
|
server_address: "127.0.0.1:0".to_string(),
|
||||||
|
jwt_secret: "test-secret".to_string(),
|
||||||
|
upload_path: "./test-uploads".to_string(),
|
||||||
|
watch_folder: "./test-watch".to_string(),
|
||||||
|
allowed_file_types: vec!["pdf".to_string()],
|
||||||
|
watch_interval_seconds: Some(30),
|
||||||
|
file_stability_check_ms: Some(500),
|
||||||
|
max_file_age_hours: None,
|
||||||
|
ocr_language: "eng".to_string(),
|
||||||
|
concurrent_ocr_jobs: 2,
|
||||||
|
ocr_timeout_seconds: 60,
|
||||||
|
max_file_size_mb: 10,
|
||||||
|
memory_limit_mb: 256,
|
||||||
|
cpu_priority: "normal".to_string(),
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let db = crate::db::Database::new(&config.database_url).await.unwrap();
|
||||||
|
|
||||||
|
// Retry migration up to 3 times to handle concurrent test execution
|
||||||
|
for attempt in 1..=3 {
|
||||||
|
match db.migrate().await {
|
||||||
|
Ok(_) => break,
|
||||||
|
Err(e) if attempt < 3 && e.to_string().contains("tuple concurrently updated") => {
|
||||||
|
// Wait a bit and retry
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100 * attempt)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => panic!("Migration failed after {} attempts: {}", attempt, e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = axum::Router::new()
|
||||||
|
.nest("/api/auth", crate::routes::auth::router())
|
||||||
|
.with_state(Arc::new(AppState {
|
||||||
|
db: db.clone(),
|
||||||
|
config,
|
||||||
|
webdav_scheduler: None,
|
||||||
|
source_scheduler: None,
|
||||||
|
queue_service: Arc::new(crate::ocr_queue::OcrQueueService::new(
|
||||||
|
db.clone(),
|
||||||
|
db.pool.clone(),
|
||||||
|
2
|
||||||
|
)),
|
||||||
|
oidc_client: None,
|
||||||
|
}));
|
||||||
|
|
||||||
|
(app, ())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_test_app_with_oidc() -> (axum::Router, MockServer) {
|
||||||
let mock_server = MockServer::start().await;
|
let mock_server = MockServer::start().await;
|
||||||
|
|
||||||
// Mock OIDC discovery endpoint
|
// Mock OIDC discovery endpoint
|
||||||
|
|
@ -27,9 +87,14 @@ mod tests {
|
||||||
.mount(&mock_server)
|
.mount(&mock_server)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
// Use TEST_DATABASE_URL directly, no containers
|
||||||
|
let database_url = std::env::var("TEST_DATABASE_URL")
|
||||||
|
.or_else(|_| std::env::var("DATABASE_URL"))
|
||||||
|
.unwrap_or_else(|_| "postgresql://readur:readur@localhost:5432/readur".to_string());
|
||||||
|
|
||||||
// Update the app state to include OIDC client
|
// Update the app state to include OIDC client
|
||||||
let config = crate::config::Config {
|
let config = crate::config::Config {
|
||||||
database_url: "postgresql://test:test@localhost/test".to_string(),
|
database_url: database_url.clone(),
|
||||||
server_address: "127.0.0.1:0".to_string(),
|
server_address: "127.0.0.1:0".to_string(),
|
||||||
jwt_secret: "test-secret".to_string(),
|
jwt_secret: "test-secret".to_string(),
|
||||||
upload_path: "./test-uploads".to_string(),
|
upload_path: "./test-uploads".to_string(),
|
||||||
|
|
@ -51,34 +116,51 @@ mod tests {
|
||||||
oidc_redirect_uri: Some("http://localhost:8000/auth/oidc/callback".to_string()),
|
oidc_redirect_uri: Some("http://localhost:8000/auth/oidc/callback".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let oidc_client = OidcClient::new(&config).await.ok().map(Arc::new);
|
let oidc_client = match OidcClient::new(&config).await {
|
||||||
|
Ok(client) => Some(Arc::new(client)),
|
||||||
|
Err(e) => {
|
||||||
|
panic!("OIDC client creation failed: {}", e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// We need to extract the state from the existing app and recreate it
|
// Connect to the database and run migrations with retry logic for concurrency
|
||||||
// This is a bit hacky, but necessary for testing
|
let db = crate::db::Database::new(&config.database_url).await.unwrap();
|
||||||
app = axum::Router::new()
|
|
||||||
|
// Retry migration up to 3 times to handle concurrent test execution
|
||||||
|
for attempt in 1..=3 {
|
||||||
|
match db.migrate().await {
|
||||||
|
Ok(_) => break,
|
||||||
|
Err(e) if attempt < 3 && e.to_string().contains("tuple concurrently updated") => {
|
||||||
|
// Wait a bit and retry
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100 * attempt)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => panic!("Migration failed after {} attempts: {}", attempt, e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create app with OIDC configuration
|
||||||
|
let app = axum::Router::new()
|
||||||
.nest("/api/auth", crate::routes::auth::router())
|
.nest("/api/auth", crate::routes::auth::router())
|
||||||
.with_state(Arc::new(AppState {
|
.with_state(Arc::new(AppState {
|
||||||
db: crate::db::Database::new(&format!("postgresql://test:test@localhost:{}/test",
|
db: db.clone(),
|
||||||
container.get_host_port_ipv4(5432).await.unwrap())).await.unwrap(),
|
|
||||||
config,
|
config,
|
||||||
webdav_scheduler: None,
|
webdav_scheduler: None,
|
||||||
source_scheduler: None,
|
source_scheduler: None,
|
||||||
queue_service: Arc::new(crate::ocr_queue::OcrQueueService::new(
|
queue_service: Arc::new(crate::ocr_queue::OcrQueueService::new(
|
||||||
crate::db::Database::new(&format!("postgresql://test:test@localhost:{}/test",
|
db.clone(),
|
||||||
container.get_host_port_ipv4(5432).await.unwrap())).await.unwrap(),
|
db.pool.clone(),
|
||||||
sqlx::PgPool::connect(&format!("postgresql://test:test@localhost:{}/test",
|
|
||||||
container.get_host_port_ipv4(5432).await.unwrap())).await.unwrap(),
|
|
||||||
2
|
2
|
||||||
)),
|
)),
|
||||||
oidc_client,
|
oidc_client,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
(app, container, mock_server)
|
(app, mock_server)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oidc_login_redirect() {
|
async fn test_oidc_login_redirect() {
|
||||||
let (app, _container, _mock_server) = create_test_app_with_oidc().await;
|
let (app, _mock_server) = create_test_app_with_oidc().await;
|
||||||
|
|
||||||
let response = app
|
let response = app
|
||||||
.oneshot(
|
.oneshot(
|
||||||
|
|
@ -101,7 +183,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oidc_login_disabled() {
|
async fn test_oidc_login_disabled() {
|
||||||
let (app, _container) = create_test_app().await; // Regular app without OIDC
|
let (app, _container) = create_test_app_simple().await; // Regular app without OIDC
|
||||||
|
|
||||||
let response = app
|
let response = app
|
||||||
.oneshot(
|
.oneshot(
|
||||||
|
|
@ -119,7 +201,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oidc_callback_missing_code() {
|
async fn test_oidc_callback_missing_code() {
|
||||||
let (app, _container, _mock_server) = create_test_app_with_oidc().await;
|
let (app, _mock_server) = create_test_app_with_oidc().await;
|
||||||
|
|
||||||
let response = app
|
let response = app
|
||||||
.oneshot(
|
.oneshot(
|
||||||
|
|
@ -137,7 +219,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oidc_callback_with_error() {
|
async fn test_oidc_callback_with_error() {
|
||||||
let (app, _container, _mock_server) = create_test_app_with_oidc().await;
|
let (app, _mock_server) = create_test_app_with_oidc().await;
|
||||||
|
|
||||||
let response = app
|
let response = app
|
||||||
.oneshot(
|
.oneshot(
|
||||||
|
|
@ -155,7 +237,21 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oidc_callback_success_new_user() {
|
async fn test_oidc_callback_success_new_user() {
|
||||||
let (app, _container, mock_server) = create_test_app_with_oidc().await;
|
let (app, mock_server) = create_test_app_with_oidc().await;
|
||||||
|
|
||||||
|
// Clean up any existing test user to ensure test isolation
|
||||||
|
let database_url = std::env::var("TEST_DATABASE_URL")
|
||||||
|
.or_else(|_| std::env::var("DATABASE_URL"))
|
||||||
|
.unwrap_or_else(|_| "postgresql://readur:readur@localhost:5432/readur".to_string());
|
||||||
|
let db = crate::db::Database::new(&database_url).await.unwrap();
|
||||||
|
|
||||||
|
// Delete any existing user with the test username or OIDC subject
|
||||||
|
let _ = sqlx::query("DELETE FROM users WHERE username = $1 OR oidc_subject = $2")
|
||||||
|
.bind("oidcuser")
|
||||||
|
.bind("oidc-user-123")
|
||||||
|
.execute(&db.pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
|
||||||
// Mock token exchange
|
// Mock token exchange
|
||||||
let token_response = json!({
|
let token_response = json!({
|
||||||
|
|
@ -166,7 +262,10 @@ mod tests {
|
||||||
|
|
||||||
Mock::given(method("POST"))
|
Mock::given(method("POST"))
|
||||||
.and(path("/token"))
|
.and(path("/token"))
|
||||||
.respond_with(ResponseTemplate::new(200).set_body_json(token_response))
|
.and(header("content-type", "application/x-www-form-urlencoded"))
|
||||||
|
.respond_with(ResponseTemplate::new(200)
|
||||||
|
.set_body_json(token_response)
|
||||||
|
.insert_header("content-type", "application/json"))
|
||||||
.mount(&mock_server)
|
.mount(&mock_server)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -180,10 +279,15 @@ mod tests {
|
||||||
|
|
||||||
Mock::given(method("GET"))
|
Mock::given(method("GET"))
|
||||||
.and(path("/userinfo"))
|
.and(path("/userinfo"))
|
||||||
.respond_with(ResponseTemplate::new(200).set_body_json(user_info_response))
|
.respond_with(ResponseTemplate::new(200)
|
||||||
|
.set_body_json(user_info_response)
|
||||||
|
.insert_header("content-type", "application/json"))
|
||||||
.mount(&mock_server)
|
.mount(&mock_server)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
// Add a small delay to make sure everything is set up
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
|
||||||
let response = app
|
let response = app
|
||||||
.oneshot(
|
.oneshot(
|
||||||
axum::http::Request::builder()
|
axum::http::Request::builder()
|
||||||
|
|
@ -195,11 +299,31 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(response.status(), StatusCode::OK);
|
let status = response.status();
|
||||||
|
|
||||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
if status != StatusCode::OK {
|
||||||
|
let error_text = String::from_utf8_lossy(&body);
|
||||||
|
eprintln!("Response status: {}", status);
|
||||||
|
eprintln!("Response body: {}", error_text);
|
||||||
|
|
||||||
|
// Also check if we made the expected API calls to the mock server
|
||||||
|
eprintln!("Mock server received calls:");
|
||||||
|
let received_requests = mock_server.received_requests().await.unwrap();
|
||||||
|
for req in received_requests {
|
||||||
|
eprintln!(" {} {} - {}", req.method, req.url.path(), String::from_utf8_lossy(&req.body));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as JSON to see if there's a more detailed error message
|
||||||
|
if let Ok(error_json) = serde_json::from_slice::<serde_json::Value>(&body) {
|
||||||
|
eprintln!("Error JSON: {:#}", error_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
let login_response: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
let login_response: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
|
||||||
assert!(login_response["token"].is_string());
|
assert!(login_response["token"].is_string());
|
||||||
|
|
@ -209,7 +333,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oidc_callback_invalid_token() {
|
async fn test_oidc_callback_invalid_token() {
|
||||||
let (app, _container, mock_server) = create_test_app_with_oidc().await;
|
let (app, mock_server) = create_test_app_with_oidc().await;
|
||||||
|
|
||||||
// Mock failed token exchange
|
// Mock failed token exchange
|
||||||
Mock::given(method("POST"))
|
Mock::given(method("POST"))
|
||||||
|
|
@ -236,7 +360,18 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oidc_callback_invalid_user_info() {
|
async fn test_oidc_callback_invalid_user_info() {
|
||||||
let (app, _container, mock_server) = create_test_app_with_oidc().await;
|
let (app, mock_server) = create_test_app_with_oidc().await;
|
||||||
|
|
||||||
|
// Clean up any existing test user to ensure test isolation
|
||||||
|
let database_url = std::env::var("TEST_DATABASE_URL")
|
||||||
|
.or_else(|_| std::env::var("DATABASE_URL"))
|
||||||
|
.unwrap_or_else(|_| "postgresql://readur:readur@localhost:5432/readur".to_string());
|
||||||
|
let db = crate::db::Database::new(&database_url).await.unwrap();
|
||||||
|
|
||||||
|
// Delete any existing user that might conflict
|
||||||
|
let _ = sqlx::query("DELETE FROM users WHERE username LIKE 'oidc%' OR oidc_subject IS NOT NULL")
|
||||||
|
.execute(&db.pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
// Mock successful token exchange
|
// Mock successful token exchange
|
||||||
let token_response = json!({
|
let token_response = json!({
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,11 @@ async fn create_test_app_state() -> Arc<AppState> {
|
||||||
max_file_size_mb: 50,
|
max_file_size_mb: 50,
|
||||||
memory_limit_mb: 512,
|
memory_limit_mb: 512,
|
||||||
cpu_priority: "normal".to_string(),
|
cpu_priority: "normal".to_string(),
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let db = Database::new(&config.database_url).await.unwrap();
|
let db = Database::new(&config.database_url).await.unwrap();
|
||||||
|
|
@ -62,6 +67,7 @@ async fn create_test_app_state() -> Arc<AppState> {
|
||||||
webdav_scheduler: None,
|
webdav_scheduler: None,
|
||||||
source_scheduler: None,
|
source_scheduler: None,
|
||||||
queue_service,
|
queue_service,
|
||||||
|
oidc_client: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,11 @@ async fn create_test_app_state() -> Result<Arc<AppState>> {
|
||||||
max_file_size_mb: 10,
|
max_file_size_mb: 10,
|
||||||
memory_limit_mb: 256,
|
memory_limit_mb: 256,
|
||||||
cpu_priority: "normal".to_string(),
|
cpu_priority: "normal".to_string(),
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let db = Database::new(&config.database_url).await?;
|
let db = Database::new(&config.database_url).await?;
|
||||||
|
|
@ -89,6 +94,7 @@ async fn create_test_app_state() -> Result<Arc<AppState>> {
|
||||||
webdav_scheduler: None,
|
webdav_scheduler: None,
|
||||||
source_scheduler: None,
|
source_scheduler: None,
|
||||||
queue_service,
|
queue_service,
|
||||||
|
oidc_client: None,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,11 @@ async fn create_test_app_state() -> Result<Arc<AppState>> {
|
||||||
max_file_size_mb: 10,
|
max_file_size_mb: 10,
|
||||||
memory_limit_mb: 256,
|
memory_limit_mb: 256,
|
||||||
cpu_priority: "normal".to_string(),
|
cpu_priority: "normal".to_string(),
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -43,6 +48,7 @@ async fn create_test_app_state() -> Result<Arc<AppState>> {
|
||||||
webdav_scheduler: None,
|
webdav_scheduler: None,
|
||||||
source_scheduler: None,
|
source_scheduler: None,
|
||||||
queue_service,
|
queue_service,
|
||||||
|
oidc_client: None,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,11 @@ async fn create_test_app_state() -> Arc<AppState> {
|
||||||
max_file_size_mb: 100,
|
max_file_size_mb: 100,
|
||||||
memory_limit_mb: 512,
|
memory_limit_mb: 512,
|
||||||
cpu_priority: "normal".to_string(),
|
cpu_priority: "normal".to_string(),
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let db = Database::new(&config.database_url).await.unwrap();
|
let db = Database::new(&config.database_url).await.unwrap();
|
||||||
|
|
@ -50,6 +55,7 @@ async fn create_test_app_state() -> Arc<AppState> {
|
||||||
webdav_scheduler: None,
|
webdav_scheduler: None,
|
||||||
source_scheduler: None,
|
source_scheduler: None,
|
||||||
queue_service,
|
queue_service,
|
||||||
|
oidc_client: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -176,6 +176,11 @@ async fn create_test_app_state() -> Arc<AppState> {
|
||||||
max_file_size_mb: 10,
|
max_file_size_mb: 10,
|
||||||
memory_limit_mb: 256,
|
memory_limit_mb: 256,
|
||||||
cpu_priority: "normal".to_string(),
|
cpu_priority: "normal".to_string(),
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let db = Database::new(&config.database_url).await.unwrap();
|
let db = Database::new(&config.database_url).await.unwrap();
|
||||||
|
|
@ -187,6 +192,7 @@ async fn create_test_app_state() -> Arc<AppState> {
|
||||||
webdav_scheduler: None,
|
webdav_scheduler: None,
|
||||||
source_scheduler: None,
|
source_scheduler: None,
|
||||||
queue_service,
|
queue_service,
|
||||||
|
oidc_client: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,11 @@ async fn create_test_app_state() -> Result<Arc<AppState>> {
|
||||||
max_file_size_mb: 10,
|
max_file_size_mb: 10,
|
||||||
memory_limit_mb: 256,
|
memory_limit_mb: 256,
|
||||||
cpu_priority: "normal".to_string(),
|
cpu_priority: "normal".to_string(),
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let db = Database::new(&config.database_url).await?;
|
let db = Database::new(&config.database_url).await?;
|
||||||
|
|
@ -126,6 +131,7 @@ async fn create_test_app_state() -> Result<Arc<AppState>> {
|
||||||
webdav_scheduler: None,
|
webdav_scheduler: None,
|
||||||
source_scheduler: None,
|
source_scheduler: None,
|
||||||
queue_service,
|
queue_service,
|
||||||
|
oidc_client: None,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -340,6 +340,11 @@ fn test_webdav_scheduler_creation() {
|
||||||
max_file_size_mb: 50,
|
max_file_size_mb: 50,
|
||||||
ocr_language: "eng".to_string(),
|
ocr_language: "eng".to_string(),
|
||||||
ocr_timeout_seconds: 300,
|
ocr_timeout_seconds: 300,
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Note: This is a minimal test since we can't easily mock the database
|
// Note: This is a minimal test since we can't easily mock the database
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,11 @@ async fn setup_test_app() -> (Router, Arc<AppState>) {
|
||||||
max_file_size_mb: 50,
|
max_file_size_mb: 50,
|
||||||
ocr_language: "eng".to_string(),
|
ocr_language: "eng".to_string(),
|
||||||
ocr_timeout_seconds: 300,
|
ocr_timeout_seconds: 300,
|
||||||
|
oidc_enabled: false,
|
||||||
|
oidc_client_id: None,
|
||||||
|
oidc_client_secret: None,
|
||||||
|
oidc_issuer_url: None,
|
||||||
|
oidc_redirect_uri: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Use the environment-based database URL
|
// Use the environment-based database URL
|
||||||
|
|
@ -106,6 +111,7 @@ async fn setup_test_app() -> (Router, Arc<AppState>) {
|
||||||
webdav_scheduler: None,
|
webdav_scheduler: None,
|
||||||
source_scheduler: None,
|
source_scheduler: None,
|
||||||
queue_service,
|
queue_service,
|
||||||
|
oidc_client: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue