feat(server): implement queue system for ocr process as well, to fight resource exhaustion

This commit is contained in:
perf3ct 2025-06-16 01:20:13 +00:00
parent 6dba46b021
commit e33240a811
No known key found for this signature in database
GPG Key ID: 569C4EEC436F5232
14 changed files with 1146 additions and 62 deletions

View File

@ -14,6 +14,7 @@ pub mod ocr_error;
pub mod ocr_health;
pub mod ocr_queue;
pub mod ocr_tests;
pub mod request_throttler;
pub mod routes;
pub mod s3_service;
pub mod seed;
@ -38,6 +39,7 @@ pub struct AppState {
pub config: Config,
pub webdav_scheduler: Option<std::sync::Arc<webdav_scheduler::WebDAVScheduler>>,
pub source_scheduler: Option<std::sync::Arc<source_scheduler::SourceScheduler>>,
pub queue_service: std::sync::Arc<ocr_queue::OcrQueueService>,
}
/// Health check endpoint for monitoring

View File

@ -130,21 +130,31 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}
// Create web-facing state with dedicated web DB pool
// Create shared OCR queue service for both web and background operations
let concurrent_jobs = 15; // Limit concurrent OCR jobs to prevent DB pool exhaustion
let shared_queue_service = Arc::new(readur::ocr_queue::OcrQueueService::new(
background_db.clone(),
background_db.get_pool().clone(),
concurrent_jobs
));
// Create web-facing state with shared queue service
let web_state = AppState {
db: web_db,
config: config.clone(),
webdav_scheduler: None, // Will be set after creating scheduler
source_scheduler: None, // Will be set after creating scheduler
queue_service: shared_queue_service.clone(),
};
let web_state = Arc::new(web_state);
// Create background state with dedicated background DB pool
// Create background state with shared queue service
let background_state = AppState {
db: background_db,
config: config.clone(),
webdav_scheduler: None,
source_scheduler: None,
queue_service: shared_queue_service.clone(),
};
let background_state = Arc::new(background_state);
@ -177,15 +187,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.enable_all()
.build()?;
// Start OCR queue worker on dedicated OCR runtime using background DB pool
let concurrent_jobs = 4; // TODO: Get from config/settings
let queue_service = Arc::new(readur::ocr_queue::OcrQueueService::new(
background_state.db.clone(),
background_state.db.get_pool().clone(),
concurrent_jobs
));
let queue_worker = queue_service.clone();
// Start OCR queue worker on dedicated OCR runtime using shared queue service
let queue_worker = shared_queue_service.clone();
ocr_runtime.spawn(async move {
if let Err(e) = queue_worker.start_worker().await {
error!("OCR queue worker error: {}", e);
@ -193,7 +196,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});
// Start OCR maintenance tasks on dedicated OCR runtime
let queue_maintenance = queue_service.clone();
let queue_maintenance = shared_queue_service.clone();
ocr_runtime.spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); // Every 5 minutes
loop {
@ -223,6 +226,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
config: web_state.config.clone(),
webdav_scheduler: Some(webdav_scheduler.clone()),
source_scheduler: Some(source_scheduler.clone()),
queue_service: shared_queue_service.clone(),
};
let web_state = Arc::new(updated_web_state);

View File

@ -8,7 +8,7 @@ use tokio::time::{sleep, Duration};
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{db::Database, enhanced_ocr::EnhancedOcrService, db_guardrails_simple::DocumentTransactionManager};
use crate::{db::Database, enhanced_ocr::EnhancedOcrService, db_guardrails_simple::DocumentTransactionManager, request_throttler::RequestThrottler};
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct OcrQueueItem {
@ -44,18 +44,29 @@ pub struct OcrQueueService {
max_concurrent_jobs: usize,
worker_id: String,
transaction_manager: DocumentTransactionManager,
processing_throttler: Arc<RequestThrottler>,
}
impl OcrQueueService {
pub fn new(db: Database, pool: PgPool, max_concurrent_jobs: usize) -> Self {
let worker_id = format!("worker-{}-{}", hostname::get().unwrap_or_default().to_string_lossy(), Uuid::new_v4());
let transaction_manager = DocumentTransactionManager::new(pool.clone());
// Create a processing throttler to limit concurrent OCR operations
// This prevents overwhelming the database connection pool
let processing_throttler = Arc::new(RequestThrottler::new(
max_concurrent_jobs.min(15), // Don't exceed 15 concurrent OCR processes
60, // 60 second max wait time for OCR processing
format!("ocr-processing-{}", worker_id),
));
Self {
db,
pool,
max_concurrent_jobs,
worker_id,
transaction_manager,
processing_throttler,
}
}
@ -260,7 +271,7 @@ impl OcrQueueService {
}
/// Process a single queue item
async fn process_item(&self, item: OcrQueueItem, ocr_service: &EnhancedOcrService) -> Result<()> {
pub async fn process_item(&self, item: OcrQueueItem, ocr_service: &EnhancedOcrService) -> Result<()> {
let start_time = std::time::Instant::now();
info!("Processing OCR job {} for document {}", item.id, item.document_id);
@ -408,10 +419,24 @@ impl OcrQueueService {
let self_clone = self.clone();
let ocr_service_clone = ocr_service.clone();
// Spawn task to process item
// Spawn task to process item with throttling
tokio::spawn(async move {
if let Err(e) = self_clone.process_item(item, &ocr_service_clone).await {
error!("Error processing OCR item: {}", e);
// Acquire throttling permit to prevent overwhelming the database
match self_clone.processing_throttler.acquire_permit().await {
Ok(_throttle_permit) => {
// Process the item with both semaphore and throttle permits held
if let Err(e) = self_clone.process_item(item, &ocr_service_clone).await {
error!("Error processing OCR item: {}", e);
}
// Permits are automatically released when dropped
}
Err(e) => {
error!("Failed to acquire throttling permit for OCR processing: {}", e);
// Mark the item as failed due to throttling
if let Err(mark_err) = self_clone.mark_failed(item.id, &format!("Throttling error: {}", e)).await {
error!("Failed to mark item as failed after throttling error: {}", mark_err);
}
}
}
drop(permit);
});

184
src/request_throttler.rs Normal file
View File

@ -0,0 +1,184 @@
/*!
* Request Throttling for High-Concurrency Scenarios
*
* This module provides throttling mechanisms to prevent resource exhaustion
* when processing large numbers of concurrent requests.
*/
use std::sync::Arc;
use tokio::sync::Semaphore;
use tokio::time::{Duration, Instant};
use tracing::{warn, info};
/// Request throttler to limit concurrent operations
#[derive(Clone)]
pub struct RequestThrottler {
/// Semaphore to limit concurrent operations
semaphore: Arc<Semaphore>,
/// Maximum wait time for acquiring a permit
max_wait_time: Duration,
/// Name for logging purposes
name: String,
}
impl RequestThrottler {
/// Create a new request throttler
pub fn new(max_concurrent: usize, max_wait_seconds: u64, name: String) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
max_wait_time: Duration::from_secs(max_wait_seconds),
name,
}
}
/// Acquire a permit for processing, with timeout
pub async fn acquire_permit(&self) -> Result<ThrottlePermit, ThrottleError> {
let start = Instant::now();
// Try to acquire permit with timeout
let permit = tokio::time::timeout(self.max_wait_time, self.semaphore.clone().acquire_owned())
.await
.map_err(|_| ThrottleError::Timeout)?
.map_err(|_| ThrottleError::Cancelled)?;
let wait_time = start.elapsed();
if wait_time > Duration::from_millis(100) {
info!("Throttler '{}': Acquired permit after {:?} wait", self.name, wait_time);
}
Ok(ThrottlePermit {
_permit: permit,
throttler_name: self.name.clone(),
})
}
/// Get current available permits
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
/// Check if throttling is active
pub fn is_throttling(&self) -> bool {
self.semaphore.available_permits() == 0
}
}
/// A permit that must be held while processing
pub struct ThrottlePermit {
_permit: tokio::sync::OwnedSemaphorePermit,
throttler_name: String,
}
impl Drop for ThrottlePermit {
fn drop(&mut self) {
// Permit is automatically released when dropped
}
}
/// Throttling errors
#[derive(Debug)]
pub enum ThrottleError {
/// Timeout waiting for permit
Timeout,
/// Operation was cancelled
Cancelled,
}
impl std::fmt::Display for ThrottleError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ThrottleError::Timeout => write!(f, "Timeout waiting for throttling permit"),
ThrottleError::Cancelled => write!(f, "Throttling operation was cancelled"),
}
}
}
impl std::error::Error for ThrottleError {}
/// Batch processor for handling high-volume operations
pub struct BatchProcessor<T> {
batch_size: usize,
flush_interval: Duration,
processor: Box<dyn Fn(Vec<T>) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>,
}
impl<T: Send + Clone + 'static> BatchProcessor<T> {
/// Create a new batch processor
pub fn new<F, Fut>(
batch_size: usize,
flush_interval_seconds: u64,
processor: F,
) -> Self
where
F: Fn(Vec<T>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
Self {
batch_size,
flush_interval: Duration::from_secs(flush_interval_seconds),
processor: Box::new(move |items| Box::pin(processor(items))),
}
}
/// Process items in batches
pub async fn process_batch(&self, items: Vec<T>) {
if items.is_empty() {
return;
}
// Split into batches
for chunk in items.chunks(self.batch_size) {
let batch = chunk.to_vec();
info!("Processing batch of {} items", batch.len());
(self.processor)(batch).await;
// Small delay between batches to prevent overwhelming the system
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[tokio::test]
async fn test_throttler_basic() {
let throttler = RequestThrottler::new(2, 5, "test".to_string());
// Should be able to acquire 2 permits
let _permit1 = throttler.acquire_permit().await.unwrap();
let _permit2 = throttler.acquire_permit().await.unwrap();
// Third permit should be throttled
assert_eq!(throttler.available_permits(), 0);
assert!(throttler.is_throttling());
}
#[tokio::test]
async fn test_throttler_timeout() {
let throttler = RequestThrottler::new(1, 1, "test".to_string());
let _permit = throttler.acquire_permit().await.unwrap();
// This should timeout
let result = throttler.acquire_permit().await;
assert!(matches!(result, Err(ThrottleError::Timeout)));
}
#[tokio::test]
async fn test_permit_release() {
let throttler = RequestThrottler::new(1, 5, "test".to_string());
{
let _permit = throttler.acquire_permit().await.unwrap();
assert_eq!(throttler.available_permits(), 0);
} // permit dropped here
// Should be available again
assert_eq!(throttler.available_permits(), 1);
let _permit2 = throttler.acquire_permit().await.unwrap();
}
}

View File

@ -14,7 +14,6 @@ use crate::{
auth::AuthUser,
file_service::FileService,
models::DocumentResponse,
ocr_queue::OcrQueueService,
AppState,
};
@ -137,8 +136,7 @@ async fn upload_document(
let enable_background_ocr = settings.enable_background_ocr;
if enable_background_ocr {
let queue_service = OcrQueueService::new(state.db.clone(), state.db.pool.clone(), 1);
// Use the shared queue service from AppState instead of creating a new one
// Calculate priority based on file size
let priority = match file_size {
0..=1048576 => 10, // <= 1MB: highest priority
@ -148,7 +146,7 @@ async fn upload_document(
_ => 2, // > 50MB: lowest priority
};
queue_service.enqueue_document(document_id, priority, file_size).await
state.queue_service.enqueue_document(document_id, priority, file_size).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}

View File

@ -2,7 +2,7 @@ use crate::{AppState, models::UserResponse};
use axum::Router;
use serde_json::json;
use std::sync::Arc;
use testcontainers::{core::WaitFor, runners::AsyncRunner, ContainerAsync, GenericImage};
use testcontainers::{core::WaitFor, runners::AsyncRunner, ContainerAsync, GenericImage, ImageExt};
use testcontainers_modules::postgres::Postgres;
use tower::util::ServiceExt;

View File

@ -14,7 +14,7 @@
*/
use reqwest::Client;
use serde_json::{json, Value};
use serde_json::Value;
use std::time::{Duration, Instant};
use tokio::time::sleep;
use uuid::Uuid;
@ -177,8 +177,13 @@ impl FileProcessingTestClient {
original_filename: doc.original_filename.clone(),
file_size: doc.file_size,
mime_type: doc.mime_type.clone(),
tags: doc.tags.clone(),
created_at: doc.created_at,
has_ocr_text: doc.has_ocr_text,
ocr_confidence: doc.ocr_confidence,
ocr_word_count: doc.ocr_word_count,
ocr_processing_time_ms: doc.ocr_processing_time_ms,
ocr_status: doc.ocr_status.clone(),
upload_date: doc.upload_date,
};
return Ok(doc_copy);
}
@ -772,7 +777,7 @@ async fn test_pipeline_performance_monitoring() {
// Analyze performance results
println!("📊 Performance Analysis:");
println!(" {'File':<12} {'Size':<8} {'Upload':<10} {'Processing':<12} {'Reported':<10} {'Status'}");
println!(" {:<12} {:<8} {:<10} {:<12} {:<10} {}", "File", "Size", "Upload", "Processing", "Reported", "Status");
println!(" {}", "-".repeat(70));
for (filename, size, upload_time, processing_time, reported_time, status) in &performance_results {
@ -782,7 +787,7 @@ async fn test_pipeline_performance_monitoring() {
let status_str = status.as_deref().unwrap_or("unknown");
println!(" {:<12} {:<8} {:?:<10} {:?:<12} {:<10} {}",
println!(" {:<12} {:<8} {:<10?} {:<12?} {:<10} {}",
filename, size, upload_time, processing_time, reported_str, status_str);
}

View File

@ -0,0 +1,238 @@
/*!
* Investigate why high document volumes return empty OCR content
*/
use reqwest::Client;
use serde_json::Value;
use std::time::{Duration, Instant};
use tokio::time::sleep;
use uuid::Uuid;
use futures;
use readur::models::{DocumentResponse, CreateUser, LoginRequest, LoginResponse};
const BASE_URL: &str = "http://localhost:8000";
struct Investigator {
client: Client,
token: String,
}
impl Investigator {
async fn new() -> Self {
let client = Client::new();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
let username = format!("investigator_{}", timestamp);
let email = format!("investigator_{}@test.com", timestamp);
// Register and login
let user_data = CreateUser {
username: username.clone(),
email: email.clone(),
password: "testpass123".to_string(),
role: Some(readur::models::UserRole::User),
};
client.post(&format!("{}/api/auth/register", BASE_URL))
.json(&user_data)
.send()
.await
.expect("Registration should work");
let login_data = LoginRequest {
username: username.clone(),
password: "testpass123".to_string(),
};
let login_response = client
.post(&format!("{}/api/auth/login", BASE_URL))
.json(&login_data)
.send()
.await
.expect("Login should work");
let login_result: LoginResponse = login_response.json().await.expect("Login should return JSON");
let token = login_result.token;
Self { client, token }
}
async fn upload_document(&self, content: &str, filename: &str) -> DocumentResponse {
let part = reqwest::multipart::Part::text(content.to_string())
.file_name(filename.to_string())
.mime_str("text/plain")
.expect("Valid mime type");
let form = reqwest::multipart::Form::new().part("file", part);
let response = self.client
.post(&format!("{}/api/documents", BASE_URL))
.header("Authorization", format!("Bearer {}", self.token))
.multipart(form)
.send()
.await
.expect("Upload should work");
response.json().await.expect("Valid JSON")
}
async fn get_document_details(&self, doc_id: &str) -> Value {
let response = self.client
.get(&format!("{}/api/documents/{}/ocr", BASE_URL, doc_id))
.header("Authorization", format!("Bearer {}", self.token))
.send()
.await
.expect("Should get document details");
response.json().await.expect("Valid JSON")
}
async fn get_queue_stats(&self) -> Value {
let response = self.client
.get(&format!("{}/api/queue/stats", BASE_URL))
.header("Authorization", format!("Bearer {}", self.token))
.send()
.await;
match response {
Ok(resp) => resp.json().await.unwrap_or_else(|_| serde_json::json!({"error": "Failed to parse"})),
Err(_) => serde_json::json!({"error": "Failed to get queue stats"})
}
}
}
#[tokio::test]
async fn investigate_empty_content_issue() {
println!("🔍 INVESTIGATING EMPTY CONTENT ISSUE");
println!("===================================");
let investigator = Investigator::new().await;
// Test with different document counts to find the threshold
let test_cases = vec![
("Low concurrency", 3),
("Medium concurrency", 10),
("High concurrency", 20),
];
for (test_name, doc_count) in test_cases {
println!("\n📊 TEST: {} ({} documents)", test_name, doc_count);
println!("{}=", "=".repeat(50));
// Upload documents
let mut documents = Vec::new();
for i in 1..=doc_count {
let content = format!("TEST-{}-CONTENT-{:02}", test_name.replace(" ", "_").to_uppercase(), i);
let filename = format!("test_{}_{:02}.txt", test_name.replace(" ", "_"), i);
documents.push((content, filename));
}
println!("📤 Uploading {} documents...", doc_count);
let upload_start = Instant::now();
let uploaded_docs = futures::future::join_all(
documents.iter().map(|(content, filename)| {
investigator.upload_document(content, filename)
}).collect::<Vec<_>>()
).await;
let upload_time = upload_start.elapsed();
println!("✅ Upload completed in {:?}", upload_time);
// Check queue stats immediately after upload
let queue_stats = investigator.get_queue_stats().await;
println!("📊 Queue stats after upload: {}", serde_json::to_string_pretty(&queue_stats).unwrap_or_default());
// Wait for processing with detailed monitoring
println!("🔄 Monitoring OCR processing...");
let mut completed_count = 0;
let process_start = Instant::now();
while completed_count < doc_count && process_start.elapsed() < Duration::from_secs(60) {
sleep(Duration::from_secs(2)).await;
let mut current_completed = 0;
let mut sample_results = Vec::new();
for (i, doc) in uploaded_docs.iter().enumerate().take(3) { // Sample first 3 docs
let details = investigator.get_document_details(&doc.id.to_string()).await;
let status = details["ocr_status"].as_str().unwrap_or("unknown");
let ocr_text = details["ocr_text"].as_str().unwrap_or("");
let expected = &documents[i].0;
if status == "completed" {
current_completed += 1;
}
sample_results.push((doc.id.to_string(), status.to_string(), expected.clone(), ocr_text.to_string()));
}
// Estimate total completed (this is rough but gives us an idea)
let estimated_total_completed = if current_completed > 0 {
(current_completed as f64 / 3.0 * doc_count as f64) as usize
} else {
0
};
if estimated_total_completed != completed_count {
completed_count = estimated_total_completed;
println!(" 📈 Progress: ~{}/{} completed", completed_count, doc_count);
// Show sample results
for (doc_id, status, expected, actual) in sample_results {
if status == "completed" {
let is_correct = actual == expected;
let result_icon = if is_correct { "" } else if actual.is_empty() { "❌📄" } else { "❌🔄" };
println!(" {} {}: expected='{}' actual='{}'", result_icon, &doc_id[..8], expected, actual);
}
}
}
if estimated_total_completed >= doc_count {
break;
}
}
let process_time = process_start.elapsed();
println!("⏱️ Processing time: {:?}", process_time);
// Final analysis
let mut success_count = 0;
let mut empty_count = 0;
let mut other_corruption = 0;
for (i, doc) in uploaded_docs.iter().enumerate() {
let details = investigator.get_document_details(&doc.id.to_string()).await;
let status = details["ocr_status"].as_str().unwrap_or("unknown");
let ocr_text = details["ocr_text"].as_str().unwrap_or("");
let expected = &documents[i].0;
if status == "completed" {
if ocr_text == expected {
success_count += 1;
} else if ocr_text.is_empty() {
empty_count += 1;
} else {
other_corruption += 1;
}
}
}
println!("\n📊 RESULTS for {} documents:", doc_count);
println!(" ✅ Successful: {}", success_count);
println!(" ❌ Empty content: {}", empty_count);
println!(" 🔄 Other corruption: {}", other_corruption);
println!(" 📈 Success rate: {:.1}%", (success_count as f64 / doc_count as f64) * 100.0);
// Get final queue stats
let final_queue_stats = investigator.get_queue_stats().await;
println!("📊 Final queue stats: {}", serde_json::to_string_pretty(&final_queue_stats).unwrap_or_default());
if empty_count > 0 {
println!("⚠️ EMPTY CONTENT THRESHOLD FOUND AT {} DOCUMENTS", doc_count);
}
}
}

View File

@ -35,9 +35,10 @@ struct OCRPipelineTestHarness {
impl OCRPipelineTestHarness {
async fn new() -> Result<Self> {
// Initialize database connection
// Initialize database connection with higher limits for stress testing
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(10)
.max_connections(50) // Increased to support high concurrency tests
.acquire_timeout(std::time::Duration::from_secs(10))
.connect(TEST_DB_URL)
.await?;

View File

@ -11,7 +11,7 @@
*/
use reqwest::Client;
use serde_json::{json, Value};
use serde_json::Value;
use std::time::{Duration, Instant};
use tokio::time::sleep;
use uuid::Uuid;
@ -38,7 +38,7 @@ impl OCRQueueTestClient {
}
/// Register and login a test user
async fn register_and_login(&mut self, role: UserRole) -> Result<String, Box<dyn std::error::Error>> {
async fn register_and_login(&mut self, role: UserRole) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
@ -100,7 +100,7 @@ impl OCRQueueTestClient {
}
/// Get OCR queue statistics
async fn get_queue_stats(&self) -> Result<Value, Box<dyn std::error::Error>> {
async fn get_queue_stats(&self) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
let token = self.token.as_ref().ok_or("Not authenticated")?;
let response = self.client
@ -118,7 +118,7 @@ impl OCRQueueTestClient {
}
/// Requeue failed OCR jobs
async fn requeue_failed_jobs(&self) -> Result<Value, Box<dyn std::error::Error>> {
async fn requeue_failed_jobs(&self) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
let token = self.token.as_ref().ok_or("Not authenticated")?;
let response = self.client
@ -136,7 +136,7 @@ impl OCRQueueTestClient {
}
/// Upload a document for OCR processing
async fn upload_document(&self, content: &str, filename: &str) -> Result<DocumentResponse, Box<dyn std::error::Error>> {
async fn upload_document(&self, content: &str, filename: &str) -> Result<DocumentResponse, Box<dyn std::error::Error + Send + Sync>> {
let token = self.token.as_ref().ok_or("Not authenticated")?;
let part = reqwest::multipart::Part::text(content.to_string())
@ -161,7 +161,7 @@ impl OCRQueueTestClient {
}
/// Upload multiple documents concurrently
async fn upload_multiple_documents(&self, count: usize, base_content: &str) -> Result<Vec<DocumentResponse>, Box<dyn std::error::Error>> {
async fn upload_multiple_documents(&self, count: usize, base_content: &str) -> Result<Vec<DocumentResponse>, Box<dyn std::error::Error + Send + Sync>> {
let mut handles = Vec::new();
for i in 0..count {
@ -180,7 +180,7 @@ impl OCRQueueTestClient {
for handle in handles {
match handle.await? {
Ok(doc) => documents.push(doc),
Err(e) => return Err(e),
Err(e) => return Err(format!("Upload failed: {}", e).into()),
}
}
@ -188,7 +188,7 @@ impl OCRQueueTestClient {
}
/// Wait for OCR processing to complete for multiple documents
async fn wait_for_multiple_ocr_completion(&self, document_ids: &[String]) -> Result<Vec<bool>, Box<dyn std::error::Error>> {
async fn wait_for_multiple_ocr_completion(&self, document_ids: &[String]) -> Result<Vec<bool>, Box<dyn std::error::Error + Send + Sync>> {
let start = Instant::now();
let mut completed_status = vec![false; document_ids.len()];
@ -224,7 +224,7 @@ impl OCRQueueTestClient {
}
/// Get all documents for the user
async fn get_documents(&self) -> Result<Vec<DocumentResponse>, Box<dyn std::error::Error>> {
async fn get_documents(&self) -> Result<Vec<DocumentResponse>, Box<dyn std::error::Error + Send + Sync>> {
let token = self.token.as_ref().ok_or("Not authenticated")?;
let response = self.client
@ -478,14 +478,14 @@ async fn test_queue_performance_monitoring() {
let sample_duration = sample_time.elapsed();
performance_samples.push((start_time.elapsed(), stats, sample_duration));
println!("📊 Sample at {:?}: response_time={:?}, pending={}, processing={}",
start_time.elapsed(),
sample_duration,
stats["pending"].as_i64().unwrap_or(0),
stats["processing"].as_i64().unwrap_or(0));
performance_samples.push((start_time.elapsed(), stats, sample_duration));
if start_time.elapsed() + sample_interval < monitoring_duration {
sleep(sample_interval).await;
}

View File

@ -13,7 +13,7 @@
*/
use reqwest::Client;
use serde_json::{json, Value};
use serde_json::Value;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
@ -126,7 +126,7 @@ impl LoadTestClient {
}
/// Setup a test user for load testing
async fn setup_user(&mut self, user_index: usize) -> Result<String, Box<dyn std::error::Error>> {
async fn setup_user(&mut self, user_index: usize) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
@ -188,7 +188,7 @@ impl LoadTestClient {
}
/// Perform a timed document upload
async fn timed_upload(&self, content: &str, filename: &str) -> Result<(DocumentResponse, Duration), Box<dyn std::error::Error>> {
async fn timed_upload(&self, content: &str, filename: &str) -> Result<(DocumentResponse, Duration), Box<dyn std::error::Error + Send + Sync>> {
let start = Instant::now();
let token = self.token.as_ref().ok_or("Not authenticated")?;
@ -216,7 +216,7 @@ impl LoadTestClient {
}
/// Perform a timed document list request
async fn timed_list_documents(&self) -> Result<(Vec<DocumentResponse>, Duration), Box<dyn std::error::Error>> {
async fn timed_list_documents(&self) -> Result<(Vec<DocumentResponse>, Duration), Box<dyn std::error::Error + Send + Sync>> {
let start = Instant::now();
let token = self.token.as_ref().ok_or("Not authenticated")?;
@ -237,7 +237,7 @@ impl LoadTestClient {
}
/// Perform a timed search request
async fn timed_search(&self, query: &str) -> Result<(Value, Duration), Box<dyn std::error::Error>> {
async fn timed_search(&self, query: &str) -> Result<(Value, Duration), Box<dyn std::error::Error + Send + Sync>> {
let start = Instant::now();
let token = self.token.as_ref().ok_or("Not authenticated")?;

View File

@ -14,7 +14,6 @@
use reqwest::Client;
use serde_json::{json, Value};
use std::time::Duration;
use uuid::Uuid;
use readur::models::{CreateUser, LoginRequest, LoginResponse, UserRole};
@ -732,26 +731,24 @@ async fn test_data_visibility_boundaries() {
println!("✅ Document visibility boundaries verified");
// Test search isolation (if available)
if let Ok((user1_search, _)) = client.client
let search_response = client.client
.get(&format!("{}/api/search", BASE_URL))
.header("Authorization", format!("Bearer {}", client.user1_token.as_ref().unwrap()))
.query(&[("q", "confidential")])
.send()
.await
.and_then(|r| async move {
let status = r.status();
let json: Result<Value, _> = r.json().await;
json.map(|j| (j, status))
})
.await
{
if let Some(results) = user1_search["documents"].as_array() {
let user1_search_sees_user2 = results.iter().any(|doc| {
doc["id"] == user2_doc_id
});
assert!(!user1_search_sees_user2, "User1 search should not return User2 documents");
println!("✅ Search isolation verified");
.await;
if let Ok(response) = search_response {
let status = response.status();
if let Ok(user1_search) = response.json::<Value>().await {
if let Some(results) = user1_search["documents"].as_array() {
let user1_search_sees_user2 = results.iter().any(|doc| {
doc["id"] == user2_doc_id
});
assert!(!user1_search_sees_user2, "User1 search should not return User2 documents");
println!("✅ Search isolation verified");
}
}
}
@ -820,7 +817,7 @@ async fn test_token_and_session_security() {
// Test 2: Token for one user accessing another user's resources
println!("🔍 Testing token cross-contamination...");
let user1_token = client.user1_token.as_ref().unwrap();
let _user1_token = client.user1_token.as_ref().unwrap();
let user2_token = client.user2_token.as_ref().unwrap();
// Upload documents with each user

225
tests/stress_test_25.rs Normal file
View File

@ -0,0 +1,225 @@
/*!
* Moderate Stress Test - 25 Documents for Complete Verification
*/
use reqwest::Client;
use serde_json::Value;
use std::time::{Duration, Instant};
use tokio::time::sleep;
use uuid::Uuid;
use futures;
use readur::models::{DocumentResponse, CreateUser, LoginRequest, LoginResponse};
const BASE_URL: &str = "http://localhost:8000";
const TIMEOUT: Duration = Duration::from_secs(120);
struct StressTester {
client: Client,
token: String,
}
impl StressTester {
async fn new() -> Self {
let client = Client::new();
// Check server health
client.get(&format!("{}/api/health", BASE_URL))
.timeout(Duration::from_secs(5))
.send()
.await
.expect("Server should be running");
// Create test user
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
let username = format!("stress_25_{}", timestamp);
let email = format!("stress_25_{}@test.com", timestamp);
// Register user
let user_data = CreateUser {
username: username.clone(),
email: email.clone(),
password: "testpass123".to_string(),
role: Some(readur::models::UserRole::User),
};
client.post(&format!("{}/api/auth/register", BASE_URL))
.json(&user_data)
.send()
.await
.expect("Registration should work");
// Login
let login_data = LoginRequest {
username: username.clone(),
password: "testpass123".to_string(),
};
let login_response = client
.post(&format!("{}/api/auth/login", BASE_URL))
.json(&login_data)
.send()
.await
.expect("Login should work");
let login_result: LoginResponse = login_response.json().await.expect("Login should return JSON");
let token = login_result.token;
println!("✅ Stress tester initialized");
Self { client, token }
}
async fn upload_document(&self, content: &str, filename: &str) -> DocumentResponse {
let part = reqwest::multipart::Part::text(content.to_string())
.file_name(filename.to_string())
.mime_str("text/plain")
.expect("Valid mime type");
let form = reqwest::multipart::Form::new().part("file", part);
let response = self.client
.post(&format!("{}/api/documents", BASE_URL))
.header("Authorization", format!("Bearer {}", self.token))
.multipart(form)
.send()
.await
.expect("Upload should work");
response.json().await.expect("Valid JSON")
}
async fn wait_for_ocr_completion(&self, document_ids: &[Uuid]) -> Vec<Value> {
let start = Instant::now();
while start.elapsed() < TIMEOUT {
let all_docs = self.get_all_documents().await;
let completed = all_docs.iter()
.filter(|doc| {
let doc_id_str = doc["id"].as_str().unwrap_or("");
let status = doc["ocr_status"].as_str().unwrap_or("");
document_ids.iter().any(|id| id.to_string() == doc_id_str) && status == "completed"
})
.count();
if completed == document_ids.len() {
return all_docs.into_iter()
.filter(|doc| {
let doc_id_str = doc["id"].as_str().unwrap_or("");
document_ids.iter().any(|id| id.to_string() == doc_id_str)
})
.collect();
}
sleep(Duration::from_millis(500)).await;
}
panic!("OCR processing did not complete within timeout");
}
async fn get_all_documents(&self) -> Vec<Value> {
let response = self.client
.get(&format!("{}/api/documents", BASE_URL))
.header("Authorization", format!("Bearer {}", self.token))
.send()
.await
.expect("Documents endpoint should work");
let data: Value = response.json().await.expect("Valid JSON");
match data {
Value::Object(obj) if obj.contains_key("documents") => {
obj["documents"].as_array().unwrap_or(&vec![]).clone()
}
Value::Array(arr) => arr,
_ => vec![]
}
}
}
#[tokio::test]
async fn stress_test_25_documents() {
println!("🚀 MODERATE STRESS TEST: 25 DOCUMENTS");
println!("======================================");
let tester = StressTester::new().await;
// Create 25 documents with unique content
let mut documents = Vec::new();
for i in 1..=25 {
let content = format!("STRESS-DOC-{:02}-SIGNATURE-{:02}-UNIQUE-CONTENT", i, i);
let filename = format!("stress_{:02}.txt", i);
documents.push((content, filename));
}
println!("📊 Testing {} documents concurrently", documents.len());
// Phase 1: Upload all documents concurrently
println!("\n🏁 UPLOADING...");
let upload_start = Instant::now();
let uploaded_docs = futures::future::join_all(
documents.iter().map(|(content, filename)| {
tester.upload_document(content, filename)
}).collect::<Vec<_>>()
).await;
let upload_duration = upload_start.elapsed();
println!("{} uploads completed in {:?}", uploaded_docs.len(), upload_duration);
// Phase 2: Wait for OCR completion
println!("\n🔬 PROCESSING OCR...");
let processing_start = Instant::now();
let document_ids: Vec<Uuid> = uploaded_docs.iter().map(|doc| doc.id).collect();
let final_docs = tester.wait_for_ocr_completion(&document_ids).await;
let processing_duration = processing_start.elapsed();
println!("✅ OCR processing completed in {:?}", processing_duration);
// Phase 3: Corruption Analysis
println!("\n📊 VERIFYING RESULTS...");
let mut successful = 0;
let mut corrupted = 0;
let mut corruption_details = Vec::new();
for (i, doc) in final_docs.iter().enumerate() {
let expected_content = &documents[i].0;
let actual_text = doc["ocr_text"].as_str().unwrap_or("");
let doc_id = doc["id"].as_str().unwrap_or("");
if actual_text == expected_content {
successful += 1;
} else {
corrupted += 1;
corruption_details.push((doc_id.to_string(), expected_content.clone(), actual_text.to_string()));
}
}
// Final Results
println!("\n🏆 STRESS TEST RESULTS");
println!("======================");
println!("📊 Total Documents: {}", documents.len());
println!("✅ Successful: {}", successful);
println!("❌ Corrupted: {}", corrupted);
println!("📈 Success Rate: {:.1}%", (successful as f64 / documents.len() as f64) * 100.0);
println!("⏱️ Upload Time: {:?}", upload_duration);
println!("⏱️ OCR Time: {:?}", processing_duration);
println!("⏱️ Total Time: {:?}", upload_duration + processing_duration);
if corrupted == 0 {
println!("\n🎉 STRESS TEST PASSED!");
println!("🎯 ALL {} DOCUMENTS PROCESSED WITHOUT CORRUPTION!", documents.len());
println!("🚀 HIGH CONCURRENCY OCR CORRUPTION ISSUE IS FULLY RESOLVED!");
} else {
println!("\n🚨 STRESS TEST FAILED!");
println!("❌ CORRUPTION DETECTED IN {} DOCUMENTS:", corrupted);
for (doc_id, expected, actual) in &corruption_details {
println!(" 📄 {}: expected '{}' got '{}'", doc_id, expected, actual);
}
panic!("CORRUPTION DETECTED in {} out of {} documents", corrupted, documents.len());
}
}

View File

@ -0,0 +1,405 @@
/*!
* Throttled High Concurrency OCR Test
*
* This test verifies that our new throttling mechanism properly handles
* high concurrency scenarios (50+ documents) without database connection
* pool exhaustion or corrupting OCR results.
*/
use anyhow::Result;
use sqlx::{PgPool, Row};
use std::sync::Arc;
use tokio::time::{Duration, Instant};
use tracing::{info, warn, error};
use uuid::Uuid;
use readur::{
config::Config,
db::Database,
models::{Document, Settings},
file_service::FileService,
enhanced_ocr::EnhancedOcrService,
ocr_queue::OcrQueueService,
db_guardrails_simple::DocumentTransactionManager,
request_throttler::RequestThrottler,
};
const TEST_DB_URL: &str = "postgresql://readur_user:readur_password@localhost:5432/readur";
struct ThrottledTestHarness {
db: Database,
pool: PgPool,
file_service: FileService,
queue_service: Arc<OcrQueueService>,
transaction_manager: DocumentTransactionManager,
}
impl ThrottledTestHarness {
async fn new() -> Result<Self> {
// Initialize database with proper connection limits
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(30) // Higher limit for stress testing
.acquire_timeout(std::time::Duration::from_secs(15))
.connect(TEST_DB_URL)
.await?;
let db = Database::new(TEST_DB_URL).await?;
// Initialize services
let file_service = FileService::new("./test_uploads".to_string());
// Create throttled queue service - this is the key improvement
let queue_service = Arc::new(OcrQueueService::new(
db.clone(),
pool.clone(),
15 // Limit to 15 concurrent OCR jobs to prevent DB pool exhaustion
));
let transaction_manager = DocumentTransactionManager::new(pool.clone());
// Ensure test upload directory exists
std::fs::create_dir_all("./test_uploads").unwrap_or_default();
Ok(Self {
db,
pool,
file_service,
queue_service,
transaction_manager,
})
}
async fn create_test_user(&self) -> Result<Uuid> {
let user_id = Uuid::new_v4();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
sqlx::query(
r#"
INSERT INTO users (id, username, email, password_hash, role)
VALUES ($1, $2, $3, $4, 'user')
"#
)
.bind(user_id)
.bind(format!("throttle_test_user_{}", timestamp))
.bind(format!("throttle_test_{}@example.com", timestamp))
.bind("dummy_hash")
.execute(&self.pool)
.await?;
info!("✅ Created test user: {}", user_id);
Ok(user_id)
}
async fn create_test_documents(&self, user_id: Uuid, count: usize) -> Result<Vec<(Uuid, String)>> {
let mut documents = Vec::new();
info!("📝 Creating {} test documents", count);
for i in 1..=count {
let content = format!("THROTTLE-TEST-DOC-{:03}-UNIQUE-CONTENT-{}", i, Uuid::new_v4());
let filename = format!("throttle_test_{:03}.txt", i);
let doc_id = Uuid::new_v4();
let file_path = format!("./test_uploads/{}.txt", doc_id);
// Write content to file
tokio::fs::write(&file_path, &content).await?;
// Create document record
sqlx::query(
r#"
INSERT INTO documents (
id, filename, original_filename, file_path, file_size,
mime_type, content, user_id, ocr_status, created_at, updated_at
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'pending', NOW(), NOW())
"#
)
.bind(doc_id)
.bind(&filename)
.bind(&filename)
.bind(&file_path)
.bind(content.len() as i64)
.bind("text/plain")
.bind(&content)
.bind(user_id)
.execute(&self.pool)
.await?;
// Enqueue for OCR processing with random priority
let priority = 10 - (i % 5) as i32; // Priorities from 5-10
self.queue_service.enqueue_document(doc_id, priority, content.len() as i64).await?;
documents.push((doc_id, content));
if i % 10 == 0 {
info!(" ✅ Created {} documents so far", i);
}
}
info!("✅ All {} documents created and enqueued", count);
Ok(documents)
}
async fn start_throttled_workers(&self, num_workers: usize) -> Result<()> {
info!("🏭 Starting {} throttled OCR workers", num_workers);
let mut handles = Vec::new();
for worker_num in 1..=num_workers {
let queue_service = self.queue_service.clone();
let handle = tokio::spawn(async move {
let worker_id = format!("throttled-worker-{}", worker_num);
info!("Worker {} starting", worker_id);
// Each worker runs for a limited time to avoid infinite loops
let start_time = Instant::now();
let max_runtime = Duration::from_secs(300); // 5 minutes max
// Run a simplified worker loop instead of calling start_worker
// start_worker() consumes the Arc<Self>, so we can't call it multiple times
loop {
if start_time.elapsed() > max_runtime {
break;
}
// Process a single job if available
match queue_service.dequeue().await {
Ok(Some(item)) => {
info!("Worker {} processing job {}", worker_id, item.id);
// Process item using the built-in throttling
let ocr_service = readur::enhanced_ocr::EnhancedOcrService::new("/tmp".to_string());
if let Err(e) = queue_service.process_item(item, &ocr_service).await {
error!("Worker {} processing error: {}", worker_id, e);
}
}
Ok(None) => {
// No jobs available, wait a bit
tokio::time::sleep(Duration::from_millis(100)).await;
}
Err(e) => {
error!("Worker {} dequeue error: {}", worker_id, e);
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
info!("Worker {} completed", worker_id);
});
handles.push(handle);
}
// Don't wait for all workers to complete - they run in background
Ok(())
}
async fn wait_for_completion(&self, expected_docs: usize, timeout_minutes: u64) -> Result<()> {
let start_time = Instant::now();
let timeout = Duration::from_secs(timeout_minutes * 60);
info!("⏳ Waiting for {} documents to complete (timeout: {} minutes)", expected_docs, timeout_minutes);
loop {
if start_time.elapsed() > timeout {
warn!("⏰ Timeout reached waiting for OCR completion");
break;
}
// Check completion status
let completed_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM documents WHERE ocr_status = 'completed'"
)
.fetch_one(&self.pool)
.await?;
let failed_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM documents WHERE ocr_status = 'failed'"
)
.fetch_one(&self.pool)
.await?;
let processing_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM documents WHERE ocr_status = 'processing'"
)
.fetch_one(&self.pool)
.await?;
let pending_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM documents WHERE ocr_status = 'pending'"
)
.fetch_one(&self.pool)
.await?;
info!("📊 Status: {} completed, {} failed, {} processing, {} pending",
completed_count, failed_count, processing_count, pending_count);
if completed_count + failed_count >= expected_docs as i64 {
info!("✅ All documents have been processed!");
break;
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
Ok(())
}
async fn verify_results(&self, expected_documents: &[(Uuid, String)]) -> Result<ThrottleTestResults> {
info!("🔍 Verifying OCR results for {} documents", expected_documents.len());
let mut results = ThrottleTestResults {
total_documents: expected_documents.len(),
completed: 0,
failed: 0,
corrupted: 0,
empty_content: 0,
correct_content: 0,
};
for (doc_id, expected_content) in expected_documents {
let row = sqlx::query(
r#"
SELECT ocr_status, ocr_text, ocr_error, filename
FROM documents
WHERE id = $1
"#
)
.bind(doc_id)
.fetch_one(&self.pool)
.await?;
let status: Option<String> = row.get("ocr_status");
let ocr_text: Option<String> = row.get("ocr_text");
let ocr_error: Option<String> = row.get("ocr_error");
let filename: String = row.get("filename");
match status.as_deref() {
Some("completed") => {
results.completed += 1;
match ocr_text.as_deref() {
Some(text) if text.is_empty() => {
warn!("❌ Document {} ({}) has empty OCR content", doc_id, filename);
results.empty_content += 1;
}
Some(text) if text == expected_content => {
results.correct_content += 1;
}
Some(text) => {
warn!("❌ Document {} ({}) has corrupted content:", doc_id, filename);
warn!(" Expected: {}", expected_content);
warn!(" Got: {}", text);
results.corrupted += 1;
}
None => {
warn!("❌ Document {} ({}) has NULL OCR content", doc_id, filename);
results.empty_content += 1;
}
}
}
Some("failed") => {
results.failed += 1;
info!("⚠️ Document {} ({}) failed: {}", doc_id, filename,
ocr_error.as_deref().unwrap_or("Unknown error"));
}
other => {
warn!("❓ Document {} ({}) has unexpected status: {:?}", doc_id, filename, other);
}
}
}
Ok(results)
}
async fn cleanup(&self) -> Result<()> {
// Clean up test files
let _ = tokio::fs::remove_dir_all("./test_uploads").await;
Ok(())
}
}
#[derive(Debug)]
struct ThrottleTestResults {
total_documents: usize,
completed: usize,
failed: usize,
corrupted: usize,
empty_content: usize,
correct_content: usize,
}
impl ThrottleTestResults {
fn success_rate(&self) -> f64 {
if self.total_documents == 0 { return 0.0; }
(self.correct_content as f64 / self.total_documents as f64) * 100.0
}
fn completion_rate(&self) -> f64 {
if self.total_documents == 0 { return 0.0; }
((self.completed + self.failed) as f64 / self.total_documents as f64) * 100.0
}
}
#[tokio::test]
async fn test_throttled_high_concurrency_50_documents() {
println!("🚀 THROTTLED HIGH CONCURRENCY TEST - 50 DOCUMENTS");
println!("================================================");
let harness = ThrottledTestHarness::new().await
.expect("Failed to initialize throttled test harness");
// Create test user
let user_id = harness.create_test_user().await
.expect("Failed to create test user");
// Create 50 test documents
let document_count = 50;
let test_documents = harness.create_test_documents(user_id, document_count).await
.expect("Failed to create test documents");
// Start multiple throttled workers
harness.start_throttled_workers(5).await
.expect("Failed to start throttled workers");
// Wait for completion with generous timeout
harness.wait_for_completion(document_count, 10).await
.expect("Failed to wait for completion");
// Verify results
let results = harness.verify_results(&test_documents).await
.expect("Failed to verify results");
// Cleanup
harness.cleanup().await.expect("Failed to cleanup");
// Print detailed results
println!("\n🏆 THROTTLED TEST RESULTS:");
println!("========================");
println!("📊 Total Documents: {}", results.total_documents);
println!("✅ Completed: {}", results.completed);
println!("❌ Failed: {}", results.failed);
println!("🔧 Correct Content: {}", results.correct_content);
println!("🚫 Empty Content: {}", results.empty_content);
println!("💥 Corrupted Content: {}", results.corrupted);
println!("📈 Success Rate: {:.1}%", results.success_rate());
println!("📊 Completion Rate: {:.1}%", results.completion_rate());
// Assertions
assert!(results.completion_rate() >= 90.0,
"Completion rate too low: {:.1}% (expected >= 90%)", results.completion_rate());
assert!(results.empty_content == 0,
"Found {} documents with empty content (should be 0 with throttling)", results.empty_content);
assert!(results.corrupted == 0,
"Found {} documents with corrupted content (should be 0 with throttling)", results.corrupted);
assert!(results.success_rate() >= 80.0,
"Success rate too low: {:.1}% (expected >= 80%)", results.success_rate());
println!("🎉 Throttled high concurrency test PASSED!");
}