diff --git a/src/bin/migrate_to_s3.rs b/src/bin/migrate_to_s3.rs index 9d31b17..5b66769 100644 --- a/src/bin/migrate_to_s3.rs +++ b/src/bin/migrate_to_s3.rs @@ -8,12 +8,15 @@ //! 3. Upload files to S3 with proper structure //! 4. Update database records with S3 paths //! 5. Optionally delete local files after successful upload +//! 6. Support rollback on failure with transaction-like behavior use anyhow::Result; use clap::Parser; use std::path::Path; +use std::collections::{HashMap, VecDeque}; use uuid::Uuid; use tracing::{info, warn, error}; +use serde::{Serialize, Deserialize}; use readur::{ config::Config, @@ -21,6 +24,41 @@ use readur::{ services::{s3_service::S3Service, file_service::FileService}, }; +/// Migration state tracking for rollback functionality +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MigrationState { + started_at: chrono::DateTime, + completed_migrations: Vec, + failed_migrations: Vec, + total_files: usize, + processed_files: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MigrationRecord { + document_id: Uuid, + user_id: Uuid, + original_path: String, + s3_key: String, + migrated_at: chrono::DateTime, + associated_files: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AssociatedFile { + file_type: String, // "thumbnail" or "processed_image" + original_path: String, + s3_key: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct FailedMigration { + document_id: Uuid, + original_path: String, + error: String, + failed_at: chrono::DateTime, +} + #[derive(Parser)] #[command(name = "migrate_to_s3")] #[command(about = "Migrate existing local files to S3 storage")] @@ -40,6 +78,14 @@ struct Args { /// Only migrate files for specific user ID #[arg(short, long)] user_id: Option, + + /// Enable rollback on failure - will revert any successful migrations if overall process fails + #[arg(long)] + enable_rollback: bool, + + /// Resume from a specific document ID (for partial recovery) + #[arg(long)] + resume_from: Option, } #[tokio::main] @@ -132,41 +178,107 @@ async fn main() -> Result<()> { return Ok(()); } - // Perform migration + // Initialize migration state + let mut migration_state = MigrationState { + started_at: chrono::Utc::now(), + completed_migrations: Vec::new(), + failed_migrations: Vec::new(), + total_files: local_documents.len(), + processed_files: 0, + }; + + // Resume from specific document if requested + let start_index = if let Some(resume_from_str) = &args.resume_from { + let resume_doc_id = Uuid::parse_str(resume_from_str)?; + local_documents.iter().position(|doc| doc.id == resume_doc_id) + .unwrap_or(0) + } else { + 0 + }; + + info!("šŸ“Š Migration plan: {} files to process (starting from index {})", + local_documents.len() - start_index, start_index); + + // Perform migration with progress tracking let mut migrated_count = 0; let mut failed_count = 0; - for doc in local_documents { - info!("šŸ“¦ Migrating: {} ({})", doc.original_filename, doc.id); + for (index, doc) in local_documents.iter().enumerate().skip(start_index) { + info!("šŸ“¦ Migrating: {} ({}) [{}/{}]", + doc.original_filename, doc.id, index + 1, local_documents.len()); - match migrate_document(&db, &s3_service, &file_service, &doc, args.delete_local).await { - Ok(_) => { + match migrate_document_with_tracking(&db, &s3_service, &file_service, doc, args.delete_local).await { + Ok(migration_record) => { migrated_count += 1; - info!("āœ… Successfully migrated: {}", doc.original_filename); + migration_state.completed_migrations.push(migration_record); + migration_state.processed_files += 1; + info!("āœ… Successfully migrated: {} [{}/{}]", + doc.original_filename, migrated_count, local_documents.len()); + + // Save progress periodically (every 10 files) + if migrated_count % 10 == 0 { + save_migration_state(&migration_state).await?; + } } Err(e) => { failed_count += 1; + let failed_migration = FailedMigration { + document_id: doc.id, + original_path: doc.file_path.clone(), + error: e.to_string(), + failed_at: chrono::Utc::now(), + }; + migration_state.failed_migrations.push(failed_migration); + migration_state.processed_files += 1; error!("āŒ Failed to migrate {}: {}", doc.original_filename, e); + + // If rollback is enabled and we have failures, offer to rollback + if args.enable_rollback && failed_count > 0 { + error!("šŸ’„ Migration failure detected with rollback enabled!"); + error!("Do you want to rollback all {} successful migrations? (y/N)", migrated_count); + + // For automation, we'll automatically rollback on any failure + // In interactive mode, you could read from stdin here + warn!("šŸ”„ Automatically initiating rollback due to failure..."); + match rollback_migrations(&db, &s3_service, &migration_state).await { + Ok(rolled_back) => { + error!("šŸ”„ Successfully rolled back {} migrations", rolled_back); + return Err(anyhow::anyhow!("Migration failed and was rolled back")); + } + Err(rollback_err) => { + error!("šŸ’„ CRITICAL: Rollback failed: {}", rollback_err); + error!("šŸ’¾ Migration state saved for manual recovery"); + save_migration_state(&migration_state).await?; + return Err(anyhow::anyhow!( + "Migration failed and rollback also failed. Check migration state file for manual recovery." + )); + } + } + } } } } + // Save final migration state + save_migration_state(&migration_state).await?; + info!("šŸŽ‰ Migration completed!"); info!("āœ… Successfully migrated: {} files", migrated_count); if failed_count > 0 { warn!("āŒ Failed to migrate: {} files", failed_count); + warn!("šŸ’¾ Check migration_state.json for details on failures"); } Ok(()) } -async fn migrate_document( +async fn migrate_document_with_tracking( db: &Database, s3_service: &S3Service, file_service: &FileService, document: &readur::models::Document, delete_local: bool, -) -> Result<()> { +) -> Result { // Read local file let local_path = Path::new(&document.file_path); if !local_path.exists() { @@ -189,7 +301,7 @@ async fn migrate_document( db.update_document_file_path(document.id, &s3_path).await?; // Migrate associated files (thumbnails, processed images) - migrate_associated_files(s3_service, file_service, document, delete_local).await?; + let associated_files = migrate_associated_files_with_tracking(s3_service, file_service, document, delete_local).await?; // Delete local file if requested if delete_local { @@ -200,27 +312,46 @@ async fn migrate_document( } } - Ok(()) + // Create migration record for tracking + let migration_record = MigrationRecord { + document_id: document.id, + user_id: document.user_id, + original_path: document.file_path.clone(), + s3_key, + migrated_at: chrono::Utc::now(), + associated_files, + }; + + Ok(migration_record) } -async fn migrate_associated_files( +async fn migrate_associated_files_with_tracking( s3_service: &S3Service, file_service: &FileService, document: &readur::models::Document, delete_local: bool, -) -> Result<()> { +) -> Result> { + let mut associated_files = Vec::new(); // Migrate thumbnail let thumbnail_path = file_service.get_thumbnails_path().join(format!("{}_thumb.jpg", document.id)); if thumbnail_path.exists() { match tokio::fs::read(&thumbnail_path).await { Ok(thumbnail_data) => { - if let Err(e) = s3_service.store_thumbnail(document.user_id, document.id, &thumbnail_data).await { - warn!("Failed to migrate thumbnail for {}: {}", document.id, e); - } else { - info!("šŸ“ø Migrated thumbnail for: {}", document.original_filename); - if delete_local { - let _ = tokio::fs::remove_file(&thumbnail_path).await; + match s3_service.store_thumbnail(document.user_id, document.id, &thumbnail_data).await { + Ok(s3_key) => { + info!("šŸ“ø Migrated thumbnail for: {}", document.original_filename); + associated_files.push(AssociatedFile { + file_type: "thumbnail".to_string(), + original_path: thumbnail_path.to_string_lossy().to_string(), + s3_key, + }); + if delete_local { + let _ = tokio::fs::remove_file(&thumbnail_path).await; + } + } + Err(e) => { + warn!("Failed to migrate thumbnail for {}: {}", document.id, e); } } } @@ -233,12 +364,20 @@ async fn migrate_associated_files( if processed_path.exists() { match tokio::fs::read(&processed_path).await { Ok(processed_data) => { - if let Err(e) = s3_service.store_processed_image(document.user_id, document.id, &processed_data).await { - warn!("Failed to migrate processed image for {}: {}", document.id, e); - } else { - info!("šŸ–¼ļø Migrated processed image for: {}", document.original_filename); - if delete_local { - let _ = tokio::fs::remove_file(&processed_path).await; + match s3_service.store_processed_image(document.user_id, document.id, &processed_data).await { + Ok(s3_key) => { + info!("šŸ–¼ļø Migrated processed image for: {}", document.original_filename); + associated_files.push(AssociatedFile { + file_type: "processed_image".to_string(), + original_path: processed_path.to_string_lossy().to_string(), + s3_key, + }); + if delete_local { + let _ = tokio::fs::remove_file(&processed_path).await; + } + } + Err(e) => { + warn!("Failed to migrate processed image for {}: {}", document.id, e); } } } @@ -246,5 +385,94 @@ async fn migrate_associated_files( } } + Ok(associated_files) +} + +/// Save migration state to disk for recovery purposes +async fn save_migration_state(state: &MigrationState) -> Result<()> { + let state_json = serde_json::to_string_pretty(state)?; + tokio::fs::write("migration_state.json", state_json).await?; + info!("šŸ’¾ Migration state saved to migration_state.json"); Ok(()) +} + +/// Rollback migrations by restoring database paths and deleting S3 objects +async fn rollback_migrations( + db: &Database, + s3_service: &S3Service, + state: &MigrationState, +) -> Result { + info!("šŸ”„ Starting rollback of {} migrations...", state.completed_migrations.len()); + + let mut rolled_back = 0; + let mut rollback_errors = Vec::new(); + + // Process migrations in reverse order (most recent first) + for migration in state.completed_migrations.iter().rev() { + info!("šŸ”„ Rolling back migration for document {}", migration.document_id); + + // Restore original database path + match db.update_document_file_path(migration.document_id, &migration.original_path).await { + Ok(_) => { + info!("āœ… Restored database path for document {}", migration.document_id); + } + Err(e) => { + let error_msg = format!("Failed to restore DB path for {}: {}", migration.document_id, e); + error!("āŒ {}", error_msg); + rollback_errors.push(error_msg); + continue; // Skip S3 cleanup if DB restore failed + } + } + + // Delete S3 object (main document) + match s3_service.delete_file(&migration.s3_key).await { + Ok(_) => { + info!("šŸ—‘ļø Deleted S3 object: {}", migration.s3_key); + } + Err(e) => { + let error_msg = format!("Failed to delete S3 object {}: {}", migration.s3_key, e); + warn!("āš ļø {}", error_msg); + rollback_errors.push(error_msg); + // Continue with associated files even if main file deletion failed + } + } + + // Delete associated S3 objects (thumbnails, processed images) + for associated in &migration.associated_files { + match s3_service.delete_file(&associated.s3_key).await { + Ok(_) => { + info!("šŸ—‘ļø Deleted associated S3 object: {} ({})", associated.s3_key, associated.file_type); + } + Err(e) => { + let error_msg = format!("Failed to delete associated S3 object {}: {}", associated.s3_key, e); + warn!("āš ļø {}", error_msg); + rollback_errors.push(error_msg); + } + } + } + + rolled_back += 1; + } + + if !rollback_errors.is_empty() { + warn!("āš ļø Rollback completed with {} errors:", rollback_errors.len()); + for error in &rollback_errors { + warn!(" - {}", error); + } + + // Save error details for manual cleanup + let error_state = serde_json::json!({ + "rollback_completed_at": chrono::Utc::now(), + "rolled_back_count": rolled_back, + "rollback_errors": rollback_errors, + "original_migration_state": state + }); + + let error_json = serde_json::to_string_pretty(&error_state)?; + tokio::fs::write("rollback_errors.json", error_json).await?; + warn!("šŸ’¾ Rollback errors saved to rollback_errors.json for manual cleanup"); + } + + info!("āœ… Rollback completed: {} migrations processed", rolled_back); + Ok(rolled_back) } \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 720046b..6ac59e0 100644 --- a/src/config.rs +++ b/src/config.rs @@ -54,35 +54,7 @@ impl Config { // Database Configuration let database_url = match env::var("DATABASE_URL") { Ok(val) => { - // Mask sensitive parts of database URL for logging - let masked_url = if val.contains('@') { - let parts: Vec<&str> = val.split('@').collect(); - if parts.len() >= 2 { - let credentials_part = parts[0]; - let remaining_part = parts[1..].join("@"); - - // Extract just the username part before the password - if let Some(username_start) = credentials_part.rfind("://") { - let protocol = &credentials_part[..username_start + 3]; - let credentials = &credentials_part[username_start + 3..]; - if let Some(colon_pos) = credentials.find(':') { - let username = &credentials[..colon_pos]; - // Show first and last character of the username - let masked_username = format!("{}{}", &username[..1], &username[username.len() - 1..]); - format!("{}{}:***@{}", protocol, masked_username, remaining_part) - } else { - format!("{}***@{}", protocol, remaining_part) - } - } else { - "***masked***".to_string() - } - } else { - "***masked***".to_string() - } - } else { - val.clone() - }; - println!("āœ… DATABASE_URL: {} (loaded from env)", masked_url); + println!("āœ… DATABASE_URL: configured (loaded from env)"); val } Err(_) => { @@ -462,10 +434,8 @@ impl Config { if !bucket_name.is_empty() && !access_key_id.is_empty() && !secret_access_key.is_empty() { println!("āœ… S3_BUCKET_NAME: {} (loaded from env)", bucket_name); println!("āœ… S3_REGION: {} (loaded from env)", region); - println!("āœ… S3_ACCESS_KEY_ID: {}***{} (loaded from env)", - &access_key_id[..2.min(access_key_id.len())], - &access_key_id[access_key_id.len().saturating_sub(2)..]); - println!("āœ… S3_SECRET_ACCESS_KEY: ***hidden*** (loaded from env, {} chars)", secret_access_key.len()); + println!("āœ… S3_ACCESS_KEY_ID: configured (loaded from env)"); + println!("āœ… S3_SECRET_ACCESS_KEY: configured (loaded from env)"); if let Some(ref endpoint) = endpoint_url { println!("āœ… S3_ENDPOINT_URL: {} (loaded from env)", endpoint); } diff --git a/src/main.rs b/src/main.rs index 8623b14..cd991bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -121,15 +121,35 @@ async fn main() -> anyhow::Result<()> { println!("šŸ“ Upload directory: {}", config.upload_path); println!("šŸ‘ļø Watch directory: {}", config.watch_folder); - // Initialize file service using the new storage backend architecture + // Initialize file service using the new storage backend architecture with fallback info!("Initializing file service with storage backend..."); let storage_config = readur::storage::factory::storage_config_from_env(&config)?; - let file_service = readur::services::file_service::FileService::from_config(storage_config, config.upload_path.clone()).await - .map_err(|e| { - error!("Failed to initialize file service with configured storage backend: {}", e); - e - })?; - info!("āœ… File service initialized with {} storage backend", file_service.storage_type()); + let file_service = match readur::services::file_service::FileService::from_config(storage_config, config.upload_path.clone()).await { + Ok(service) => { + info!("āœ… File service initialized with {} storage backend", service.storage_type()); + service + } + Err(e) => { + error!("āŒ Failed to initialize configured storage backend: {}", e); + warn!("šŸ”„ Falling back to local storage..."); + + // Create fallback local storage configuration + let fallback_config = readur::storage::StorageConfig::Local { + upload_path: config.upload_path.clone(), + }; + + match readur::services::file_service::FileService::from_config(fallback_config, config.upload_path.clone()).await { + Ok(fallback_service) => { + warn!("āœ… Successfully initialized fallback local storage"); + fallback_service + } + Err(fallback_err) => { + error!("šŸ’„ CRITICAL: Even fallback local storage failed to initialize: {}", fallback_err); + return Err(fallback_err.into()); + } + } + } + }; // Initialize the storage backend (creates directories, validates access, etc.) if let Err(e) = file_service.initialize_storage().await { @@ -159,7 +179,6 @@ async fn main() -> anyhow::Result<()> { } Err(e) => { println!("āŒ CRITICAL: Failed to connect to database for web operations!"); - println!("Database URL: {}", db_info); // Use the already-masked URL println!("Error: {}", e); println!("\nšŸ”§ Please verify:"); println!(" - Database server is running"); diff --git a/src/ocr/enhanced.rs b/src/ocr/enhanced.rs index 3d355b4..4030d18 100644 --- a/src/ocr/enhanced.rs +++ b/src/ocr/enhanced.rs @@ -45,14 +45,6 @@ impl EnhancedOcrService { Self { temp_dir, file_service } } - /// Backward-compatible constructor for tests and legacy code - /// Creates a FileService with local storage using UPLOAD_PATH env var - #[deprecated(note = "Use new() with FileService parameter instead")] - pub fn new_legacy(temp_dir: String) -> Self { - let upload_path = std::env::var("UPLOAD_PATH").unwrap_or_else(|_| "./uploads".to_string()); - let file_service = FileService::new(upload_path); - Self::new(temp_dir, file_service) - } /// Extract text from image with high-quality OCR settings #[cfg(feature = "ocr")] diff --git a/src/ocr/queue.rs b/src/ocr/queue.rs index d08c935..5cb91e5 100644 --- a/src/ocr/queue.rs +++ b/src/ocr/queue.rs @@ -75,14 +75,6 @@ impl OcrQueueService { } } - /// Backward-compatible constructor for tests and legacy code - /// Creates a FileService with local storage using UPLOAD_PATH env var - #[deprecated(note = "Use new() with FileService parameter instead")] - pub fn new_legacy(db: Database, pool: PgPool, max_concurrent_jobs: usize) -> Self { - let upload_path = std::env::var("UPLOAD_PATH").unwrap_or_else(|_| "./uploads".to_string()); - let file_service = std::sync::Arc::new(crate::services::file_service::FileService::new(upload_path)); - Self::new(db, pool, max_concurrent_jobs, file_service) - } /// Add a document to the OCR queue pub async fn enqueue_document(&self, document_id: Uuid, priority: i32, file_size: i64) -> Result { diff --git a/src/routes/documents/crud.rs b/src/routes/documents/crud.rs index 78d50f1..984607d 100644 --- a/src/routes/documents/crud.rs +++ b/src/routes/documents/crud.rs @@ -202,10 +202,10 @@ pub async fn upload_document( } // Create ingestion service - let file_service = &state.file_service; + let file_service_clone = state.file_service.as_ref().clone(); let ingestion_service = DocumentIngestionService::new( state.db.clone(), - (**file_service).clone(), + file_service_clone, ); debug!("[UPLOAD_DEBUG] Calling ingestion service for file: {}", filename); diff --git a/src/routes/webdav/webdav_sync.rs b/src/routes/webdav/webdav_sync.rs index 29f77d1..5fcc510 100644 --- a/src/routes/webdav/webdav_sync.rs +++ b/src/routes/webdav/webdav_sync.rs @@ -315,8 +315,8 @@ async fn process_single_file( debug!("Downloaded file: {} ({} bytes)", file_info.name, file_data.len()); // Use the unified ingestion service for consistent deduplication - let file_service = &state.file_service; - let ingestion_service = DocumentIngestionService::new(state.db.clone(), (**file_service).clone()); + let file_service_clone = state.file_service.as_ref().clone(); + let ingestion_service = DocumentIngestionService::new(state.db.clone(), file_service_clone); let result = if let Some(source_id) = webdav_source_id { ingestion_service diff --git a/src/services/s3_service.rs b/src/services/s3_service.rs index 1d392f2..d932a85 100644 --- a/src/services/s3_service.rs +++ b/src/services/s3_service.rs @@ -6,6 +6,8 @@ use serde_json; use std::collections::HashMap; use std::time::Duration; use uuid::Uuid; +use futures::stream::StreamExt; +use tokio::io::{AsyncRead, AsyncReadExt}; #[cfg(feature = "s3")] use aws_sdk_s3::Client; @@ -15,10 +17,18 @@ use aws_credential_types::Credentials; use aws_types::region::Region as AwsRegion; #[cfg(feature = "s3")] use aws_sdk_s3::primitives::ByteStream; +#[cfg(feature = "s3")] +use aws_sdk_s3::types::{CompletedPart, CompletedMultipartUpload}; use crate::models::{FileIngestionInfo, S3SourceConfig}; use crate::storage::StorageBackend; +/// Threshold for using streaming multipart uploads (100MB) +const STREAMING_THRESHOLD: usize = 100 * 1024 * 1024; + +/// Multipart upload chunk size (16MB - AWS minimum is 5MB, we use 16MB for better performance) +const MULTIPART_CHUNK_SIZE: usize = 16 * 1024 * 1024; + #[derive(Debug, Clone)] pub struct S3Service { #[cfg(feature = "s3")] @@ -347,7 +357,15 @@ impl S3Service { #[cfg(feature = "s3")] { let key = self.generate_document_key(user_id, document_id, filename); - self.store_file(&key, data, None).await?; + + // Use streaming upload for large files + if data.len() > STREAMING_THRESHOLD { + info!("Using streaming multipart upload for large file: {} ({} bytes)", key, data.len()); + self.store_file_multipart(&key, data, None).await?; + } else { + self.store_file(&key, data, None).await?; + } + Ok(key) } } @@ -438,6 +456,137 @@ impl S3Service { } } + /// Store large files using multipart upload for better performance and memory usage + async fn store_file_multipart(&self, key: &str, data: &[u8], metadata: Option>) -> Result<()> { + #[cfg(not(feature = "s3"))] + { + return Err(anyhow!("S3 support not compiled in")); + } + + #[cfg(feature = "s3")] + { + info!("Starting multipart upload for file: {}/{} ({} bytes)", self.config.bucket_name, key, data.len()); + + let key_owned = key.to_string(); + let data_owned = data.to_vec(); + let metadata_owned = metadata.clone(); + let bucket_name = self.config.bucket_name.clone(); + let client = self.client.clone(); + + self.retry_operation(&format!("store_file_multipart: {}", key), || { + let key = key_owned.clone(); + let data = data_owned.clone(); + let metadata = metadata_owned.clone(); + let bucket_name = bucket_name.clone(); + let client = client.clone(); + let content_type = self.get_content_type_from_key(&key); + + async move { + // Step 1: Initiate multipart upload + let mut create_request = client + .create_multipart_upload() + .bucket(&bucket_name) + .key(&key); + + // Add metadata if provided + if let Some(meta) = metadata { + for (k, v) in meta { + create_request = create_request.metadata(k, v); + } + } + + // Set content type based on file extension + if let Some(ct) = content_type { + create_request = create_request.content_type(ct); + } + + let create_response = create_request.send().await + .map_err(|e| anyhow!("Failed to initiate multipart upload for {}: {}", key, e))?; + + let upload_id = create_response.upload_id() + .ok_or_else(|| anyhow!("Missing upload ID in multipart upload response"))?; + + info!("Initiated multipart upload for {}: {}", key, upload_id); + + // Step 2: Upload parts in chunks + let mut completed_parts = Vec::new(); + let total_chunks = (data.len() + MULTIPART_CHUNK_SIZE - 1) / MULTIPART_CHUNK_SIZE; + + for (chunk_index, chunk) in data.chunks(MULTIPART_CHUNK_SIZE).enumerate() { + let part_number = (chunk_index + 1) as i32; + + debug!("Uploading part {} of {} for {} ({} bytes)", + part_number, total_chunks, key, chunk.len()); + + let upload_part_response = client + .upload_part() + .bucket(&bucket_name) + .key(&key) + .upload_id(upload_id) + .part_number(part_number) + .body(ByteStream::from(chunk.to_vec())) + .send() + .await + .map_err(|e| anyhow!("Failed to upload part {} for {}: {}", part_number, key, e))?; + + let etag = upload_part_response.e_tag() + .ok_or_else(|| anyhow!("Missing ETag in upload part response"))?; + + completed_parts.push( + CompletedPart::builder() + .part_number(part_number) + .e_tag(etag) + .build() + ); + + debug!("Successfully uploaded part {} for {}", part_number, key); + } + + // Step 3: Complete multipart upload + let completed_multipart_upload = CompletedMultipartUpload::builder() + .set_parts(Some(completed_parts)) + .build(); + + client + .complete_multipart_upload() + .bucket(&bucket_name) + .key(&key) + .upload_id(upload_id) + .multipart_upload(completed_multipart_upload) + .send() + .await + .map_err(|e| { + // If completion fails, try to abort the multipart upload + let abort_client = client.clone(); + let abort_bucket = bucket_name.clone(); + let abort_key = key.clone(); + let abort_upload_id = upload_id.to_string(); + + tokio::spawn(async move { + if let Err(abort_err) = abort_client + .abort_multipart_upload() + .bucket(abort_bucket) + .key(abort_key) + .upload_id(abort_upload_id) + .send() + .await + { + error!("Failed to abort multipart upload: {}", abort_err); + } + }); + + anyhow!("Failed to complete multipart upload for {}: {}", key, e) + })?; + + info!("Successfully completed multipart upload for {}", key); + Ok(()) + } + }).await?; + + Ok(()) + } + } + /// Retrieve a file from S3 pub async fn retrieve_file(&self, key: &str) -> Result> { #[cfg(not(feature = "s3"))] @@ -666,7 +815,15 @@ impl StorageBackend for S3Service { async fn store_document(&self, user_id: Uuid, document_id: Uuid, filename: &str, data: &[u8]) -> Result { // Generate S3 key let key = self.generate_document_key(user_id, document_id, filename); - self.store_file(&key, data, None).await?; + + // Use streaming upload for large files + if data.len() > STREAMING_THRESHOLD { + info!("Using streaming multipart upload for large file: {} ({} bytes)", key, data.len()); + self.store_file_multipart(&key, data, None).await?; + } else { + self.store_file(&key, data, None).await?; + } + Ok(format!("s3://{}", key)) } diff --git a/src/storage/local.rs b/src/storage/local.rs index 45bdb96..7f48c2e 100644 --- a/src/storage/local.rs +++ b/src/storage/local.rs @@ -3,21 +3,30 @@ use anyhow::Result; use async_trait::async_trait; use std::path::{Path, PathBuf}; +use std::collections::HashMap; +use std::sync::Arc; use tokio::fs; -use tracing::{info, error}; +use tokio::sync::RwLock; +use tracing::{info, error, warn, debug}; use uuid::Uuid; use super::StorageBackend; +use crate::utils::security::{validate_filename, validate_and_sanitize_path, validate_path_within_base}; /// Local filesystem storage backend pub struct LocalStorageBackend { upload_path: String, + /// Cache for resolved file paths to reduce filesystem calls + path_cache: Arc>>>, } impl LocalStorageBackend { /// Create a new local storage backend pub fn new(upload_path: String) -> Self { - Self { upload_path } + Self { + upload_path, + path_cache: Arc::new(RwLock::new(HashMap::new())), + } } /// Get the base upload path @@ -50,33 +59,131 @@ impl LocalStorageBackend { Path::new(&self.upload_path).join("backups") } - /// Resolve file path, handling both old and new directory structures + /// Resolve file path, handling both old and new directory structures with caching pub async fn resolve_file_path(&self, file_path: &str) -> Result { - // If the file exists at the given path, use it - if Path::new(file_path).exists() { - return Ok(file_path.to_string()); + // Check cache first + { + let cache = self.path_cache.read().await; + if let Some(cached_result) = cache.get(file_path) { + return match cached_result { + Some(resolved_path) => { + debug!("Cache hit for file path: {} -> {}", file_path, resolved_path); + Ok(resolved_path.clone()) + } + None => { + debug!("Cache hit for non-existent file: {}", file_path); + Err(anyhow::anyhow!("File not found: {} (cached)", file_path)) + } + }; + } } + + // Generate candidate paths in order of likelihood + let candidates = self.generate_path_candidates(file_path); - // Try to find the file in the new structured directory + // Check candidates efficiently + let mut found_path = None; + for candidate in &candidates { + match tokio::fs::metadata(candidate).await { + Ok(metadata) => { + if metadata.is_file() { + found_path = Some(candidate.clone()); + debug!("Found file at: {}", candidate); + break; + } + } + Err(_) => { + // File doesn't exist at this path, continue to next candidate + debug!("File not found at: {}", candidate); + } + } + } + + // Cache the result + { + let mut cache = self.path_cache.write().await; + cache.insert(file_path.to_string(), found_path.clone()); + + // Prevent cache from growing too large + if cache.len() > 10000 { + // Clear oldest 20% of entries (simple cache eviction) + let to_remove: Vec = cache.keys().take(2000).cloned().collect(); + for key in to_remove { + cache.remove(&key); + } + debug!("Evicted cache entries to prevent memory growth"); + } + } + + match found_path { + Some(path) => { + if path != file_path { + info!("Resolved file path: {} -> {}", file_path, path); + } + Ok(path) + } + None => { + debug!("File not found in any candidate location: {}", file_path); + Err(anyhow::anyhow!( + "File not found: {} (checked {} locations)", + file_path, + candidates.len() + )) + } + } + } + + /// Generate candidate paths for file resolution + fn generate_path_candidates(&self, file_path: &str) -> Vec { + let mut candidates = Vec::new(); + + // 1. Original path (most likely for new files) + candidates.push(file_path.to_string()); + + // 2. For legacy compatibility - try structured directory if file_path.starts_with("./uploads/") && !file_path.contains("/documents/") { - let new_path = file_path.replace("./uploads/", "./uploads/documents/"); - if Path::new(&new_path).exists() { - info!("Found file in new structured directory: {} -> {}", file_path, new_path); - return Ok(new_path); - } + candidates.push(file_path.replace("./uploads/", "./uploads/documents/")); } - // Try without the ./ prefix + // 3. Try without ./ prefix in structured directory if file_path.starts_with("uploads/") && !file_path.contains("/documents/") { - let new_path = file_path.replace("uploads/", "uploads/documents/"); - if Path::new(&new_path).exists() { - info!("Found file in new structured directory: {} -> {}", file_path, new_path); - return Ok(new_path); - } + candidates.push(file_path.replace("uploads/", "uploads/documents/")); } - // File not found in any expected location - Err(anyhow::anyhow!("File not found: {} (checked original path and structured directory)", file_path)) + // 4. Try relative to our configured upload path + if !file_path.starts_with(&self.upload_path) { + let relative_path = Path::new(&self.upload_path).join(file_path); + candidates.push(relative_path.to_string_lossy().to_string()); + + // Also try in documents subdirectory + let documents_path = Path::new(&self.upload_path).join("documents").join(file_path); + candidates.push(documents_path.to_string_lossy().to_string()); + } + + // 5. Try absolute path if it looks like a filename only + if !file_path.contains('/') && !file_path.contains('\\') { + // Try in documents directory + let abs_documents_path = Path::new(&self.upload_path) + .join("documents") + .join(file_path); + candidates.push(abs_documents_path.to_string_lossy().to_string()); + } + + candidates + } + + /// Clear the path resolution cache (useful for testing or after file operations) + pub async fn clear_path_cache(&self) { + let mut cache = self.path_cache.write().await; + cache.clear(); + debug!("Cleared path resolution cache"); + } + + /// Invalidate cache entry for a specific path + pub async fn invalidate_cache_entry(&self, file_path: &str) { + let mut cache = self.path_cache.write().await; + cache.remove(file_path); + debug!("Invalidated cache entry for: {}", file_path); } /// Save a file with generated UUID filename (legacy method) @@ -112,7 +219,10 @@ impl LocalStorageBackend { #[async_trait] impl StorageBackend for LocalStorageBackend { async fn store_document(&self, _user_id: Uuid, document_id: Uuid, filename: &str, data: &[u8]) -> Result { - let extension = Path::new(filename) + // Validate and sanitize the filename + let sanitized_filename = validate_filename(filename)?; + + let extension = Path::new(&sanitized_filename) .extension() .and_then(|ext| ext.to_str()) .unwrap_or(""); @@ -126,13 +236,28 @@ impl StorageBackend for LocalStorageBackend { let documents_dir = self.get_documents_path(); let file_path = documents_dir.join(&document_filename); + // Validate that the final path is within our base directory + validate_path_within_base( + &file_path.to_string_lossy(), + &self.upload_path + )?; + // Ensure the documents directory exists fs::create_dir_all(&documents_dir).await?; + // Validate data size (prevent extremely large files from causing issues) + if data.len() > 1_000_000_000 { // 1GB limit + return Err(anyhow::anyhow!("File too large for storage (max 1GB)")); + } + fs::write(&file_path, data).await?; + // Invalidate any cached negative results for this path + let path_str = file_path.to_string_lossy().to_string(); + self.invalidate_cache_entry(&path_str).await; + info!("Stored document locally: {}", file_path.display()); - Ok(file_path.to_string_lossy().to_string()) + Ok(path_str) } async fn store_thumbnail(&self, _user_id: Uuid, document_id: Uuid, data: &[u8]) -> Result { @@ -142,10 +267,25 @@ impl StorageBackend for LocalStorageBackend { let thumbnail_filename = format!("{}_thumb.jpg", document_id); let thumbnail_path = thumbnails_dir.join(&thumbnail_filename); + // Validate that the final path is within our base directory + validate_path_within_base( + &thumbnail_path.to_string_lossy(), + &self.upload_path + )?; + + // Validate data size for thumbnails (should be much smaller) + if data.len() > 10_000_000 { // 10MB limit for thumbnails + return Err(anyhow::anyhow!("Thumbnail too large (max 10MB)")); + } + fs::write(&thumbnail_path, data).await?; + // Invalidate any cached negative results for this path + let path_str = thumbnail_path.to_string_lossy().to_string(); + self.invalidate_cache_entry(&path_str).await; + info!("Stored thumbnail locally: {}", thumbnail_path.display()); - Ok(thumbnail_path.to_string_lossy().to_string()) + Ok(path_str) } async fn store_processed_image(&self, _user_id: Uuid, document_id: Uuid, data: &[u8]) -> Result { @@ -155,15 +295,44 @@ impl StorageBackend for LocalStorageBackend { let processed_filename = format!("{}_processed.png", document_id); let processed_path = processed_dir.join(&processed_filename); + // Validate that the final path is within our base directory + validate_path_within_base( + &processed_path.to_string_lossy(), + &self.upload_path + )?; + + // Validate data size for processed images + if data.len() > 50_000_000 { // 50MB limit for processed images + return Err(anyhow::anyhow!("Processed image too large (max 50MB)")); + } + fs::write(&processed_path, data).await?; + // Invalidate any cached negative results for this path + let path_str = processed_path.to_string_lossy().to_string(); + self.invalidate_cache_entry(&path_str).await; + info!("Stored processed image locally: {}", processed_path.display()); - Ok(processed_path.to_string_lossy().to_string()) + Ok(path_str) } async fn retrieve_file(&self, path: &str) -> Result> { - let resolved_path = self.resolve_file_path(path).await?; + // Validate and sanitize the input path + let sanitized_path = validate_and_sanitize_path(path)?; + + let resolved_path = self.resolve_file_path(&sanitized_path).await?; + + // Validate that the resolved path is within our base directory + validate_path_within_base(&resolved_path, &self.upload_path)?; + let data = fs::read(&resolved_path).await?; + + // Additional safety check on file size when reading + if data.len() > 1_000_000_000 { // 1GB limit + warn!("Attempted to read extremely large file: {} ({} bytes)", resolved_path, data.len()); + return Err(anyhow::anyhow!("File too large to read safely")); + } + Ok(data) } @@ -172,16 +341,25 @@ impl StorageBackend for LocalStorageBackend { let mut serious_errors = Vec::new(); // Helper function to safely delete a file - async fn safe_delete(path: &Path, serious_errors: &mut Vec) -> Option { + let storage_backend = self; + async fn safe_delete(path: &Path, serious_errors: &mut Vec, backend: &LocalStorageBackend) -> Option { match fs::remove_file(path).await { Ok(_) => { info!("Deleted file: {}", path.display()); - Some(path.to_string_lossy().to_string()) + let path_str = path.to_string_lossy().to_string(); + + // Invalidate cache entry for the deleted file + backend.invalidate_cache_entry(&path_str).await; + + Some(path_str) } Err(e) => { match e.kind() { std::io::ErrorKind::NotFound => { info!("File already deleted: {}", path.display()); + // Still invalidate cache in case it was cached as existing + let path_str = path.to_string_lossy().to_string(); + backend.invalidate_cache_entry(&path_str).await; None } _ => { @@ -223,7 +401,7 @@ impl StorageBackend for LocalStorageBackend { let mut main_file_deleted = false; for candidate_path in &main_file_candidates { if candidate_path.exists() { - if let Some(deleted_path) = safe_delete(candidate_path, &mut serious_errors).await { + if let Some(deleted_path) = safe_delete(candidate_path, &mut serious_errors, storage_backend).await { deleted_files.push(deleted_path); main_file_deleted = true; break; // Only delete the first match we find @@ -238,14 +416,14 @@ impl StorageBackend for LocalStorageBackend { // Delete thumbnail if it exists let thumbnail_filename = format!("{}_thumb.jpg", document_id); let thumbnail_path = self.get_thumbnails_path().join(&thumbnail_filename); - if let Some(deleted_path) = safe_delete(&thumbnail_path, &mut serious_errors).await { + if let Some(deleted_path) = safe_delete(&thumbnail_path, &mut serious_errors, storage_backend).await { deleted_files.push(deleted_path); } // Delete processed image if it exists let processed_image_filename = format!("{}_processed.png", document_id); let processed_image_path = self.get_processed_images_path().join(&processed_image_filename); - if let Some(deleted_path) = safe_delete(&processed_image_path, &mut serious_errors).await { + if let Some(deleted_path) = safe_delete(&processed_image_path, &mut serious_errors, storage_backend).await { deleted_files.push(deleted_path); } @@ -265,8 +443,20 @@ impl StorageBackend for LocalStorageBackend { } async fn file_exists(&self, path: &str) -> Result { - match self.resolve_file_path(path).await { - Ok(_) => Ok(true), + // Validate and sanitize the input path + let sanitized_path = match validate_and_sanitize_path(path) { + Ok(p) => p, + Err(_) => return Ok(false), // Invalid paths don't exist + }; + + match self.resolve_file_path(&sanitized_path).await { + Ok(resolved_path) => { + // Additional validation that the resolved path is within base directory + match validate_path_within_base(&resolved_path, &self.upload_path) { + Ok(_) => Ok(true), + Err(_) => Ok(false), // Paths outside base directory don't "exist" for us + } + } Err(_) => Ok(false), } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 477cc78..89de501 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1 +1,2 @@ -pub mod debug; \ No newline at end of file +pub mod debug; +pub mod security; \ No newline at end of file diff --git a/src/utils/security.rs b/src/utils/security.rs new file mode 100644 index 0000000..7f693d2 --- /dev/null +++ b/src/utils/security.rs @@ -0,0 +1,231 @@ +//! Security utilities for input validation and sanitization + +use anyhow::Result; +use std::path::{Path, PathBuf, Component}; +use tracing::warn; + +/// Validate and sanitize file paths to prevent path traversal attacks +pub fn validate_and_sanitize_path(input_path: &str) -> Result { + // Check for null bytes (not allowed in file paths) + if input_path.contains('\0') { + return Err(anyhow::anyhow!("Path contains null bytes")); + } + + // Check for excessively long paths + if input_path.len() > 4096 { + return Err(anyhow::anyhow!("Path too long (max 4096 characters)")); + } + + // Convert to Path for normalization + let path = Path::new(input_path); + + // Check for path traversal attempts + for component in path.components() { + match component { + Component::ParentDir => { + warn!("Path traversal attempt detected: {}", input_path); + return Err(anyhow::anyhow!("Path traversal not allowed")); + } + Component::Normal(name) => { + let name_str = name.to_string_lossy(); + + // Check for dangerous file names + if is_dangerous_filename(&name_str) { + return Err(anyhow::anyhow!("Potentially dangerous filename: {}", name_str)); + } + + // Check for control characters (except newline and tab which might be in file content) + for ch in name_str.chars() { + if ch.is_control() && ch != '\n' && ch != '\t' { + return Err(anyhow::anyhow!("Filename contains control characters")); + } + } + } + _ => {} // Allow root, current dir, and prefix components + } + } + + // Normalize the path to remove redundant components + let normalized = normalize_path(path); + Ok(normalized.to_string_lossy().to_string()) +} + +/// Validate filename for document storage +pub fn validate_filename(filename: &str) -> Result { + // Basic length check + if filename.is_empty() { + return Err(anyhow::anyhow!("Filename cannot be empty")); + } + + if filename.len() > 255 { + return Err(anyhow::anyhow!("Filename too long (max 255 characters)")); + } + + // Check for null bytes + if filename.contains('\0') { + return Err(anyhow::anyhow!("Filename contains null bytes")); + } + + // Check for path separators (filenames should not contain them) + if filename.contains('/') || filename.contains('\\') { + return Err(anyhow::anyhow!("Filename cannot contain path separators")); + } + + // Check for control characters + for ch in filename.chars() { + if ch.is_control() && ch != '\n' && ch != '\t' { + return Err(anyhow::anyhow!("Filename contains control characters")); + } + } + + // Check for dangerous patterns + if is_dangerous_filename(filename) { + return Err(anyhow::anyhow!("Potentially dangerous filename: {}", filename)); + } + + // Sanitize the filename by replacing problematic characters + let sanitized = sanitize_filename(filename); + Ok(sanitized) +} + +/// Check if a filename is potentially dangerous +fn is_dangerous_filename(filename: &str) -> bool { + let filename_lower = filename.to_lowercase(); + + // Windows reserved names + let reserved_names = [ + "con", "prn", "aux", "nul", + "com1", "com2", "com3", "com4", "com5", "com6", "com7", "com8", "com9", + "lpt1", "lpt2", "lpt3", "lpt4", "lpt5", "lpt6", "lpt7", "lpt8", "lpt9", + ]; + + // Check if filename (without extension) matches reserved names + let name_without_ext = filename_lower.split('.').next().unwrap_or(""); + if reserved_names.contains(&name_without_ext) { + return true; + } + + // Check for suspicious patterns + if filename_lower.starts_with('.') && filename_lower.len() > 1 { + // Allow common hidden files but reject suspicious ones + let allowed_hidden = [".env", ".gitignore", ".htaccess"]; + if !allowed_hidden.iter().any(|&allowed| filename_lower.starts_with(allowed)) { + // Be more permissive with document files that might have dots + if !filename_lower.contains(&['.', 'd', 'o', 'c']) && + !filename_lower.contains(&['.', 'p', 'd', 'f']) && + !filename_lower.contains(&['.', 't', 'x', 't']) { + return true; + } + } + } + + false +} + +/// Sanitize filename by replacing problematic characters +fn sanitize_filename(filename: &str) -> String { + let mut sanitized = String::new(); + + for ch in filename.chars() { + match ch { + // Replace problematic characters with underscores + '<' | '>' | ':' | '"' | '|' | '?' | '*' => sanitized.push('_'), + // Allow most other characters + _ if !ch.is_control() || ch == '\n' || ch == '\t' => sanitized.push(ch), + // Skip control characters + _ => {} + } + } + + // Trim whitespace from ends + sanitized.trim().to_string() +} + +/// Normalize a path by resolving . and .. components without filesystem access +fn normalize_path(path: &Path) -> PathBuf { + let mut normalized = PathBuf::new(); + + for component in path.components() { + match component { + Component::Normal(_) | Component::RootDir | Component::Prefix(_) => { + normalized.push(component); + } + Component::CurDir => { + // Skip current directory references + } + Component::ParentDir => { + // This should have been caught earlier, but handle it safely + if normalized.parent().is_some() { + normalized.pop(); + } + // If we can't go up, just ignore the .. component + } + } + } + + normalized +} + +/// Validate that a path is within the allowed base directory +pub fn validate_path_within_base(path: &str, base_dir: &str) -> Result<()> { + let path_buf = PathBuf::from(path); + let base_buf = PathBuf::from(base_dir); + + // Canonicalize if possible, but don't fail if paths don't exist yet + let canonical_path = path_buf.canonicalize().unwrap_or_else(|_| { + // If canonicalization fails, do our best with normalization + normalize_path(&path_buf) + }); + + let canonical_base = base_buf.canonicalize().unwrap_or_else(|_| { + normalize_path(&base_buf) + }); + + if !canonical_path.starts_with(&canonical_base) { + return Err(anyhow::anyhow!( + "Path '{}' is not within allowed base directory '{}'", + path, base_dir + )); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_filename() { + // Valid filenames + assert!(validate_filename("document.pdf").is_ok()); + assert!(validate_filename("my-file_2023.docx").is_ok()); + assert!(validate_filename("report (final).txt").is_ok()); + + // Invalid filenames + assert!(validate_filename("").is_err()); + assert!(validate_filename("../etc/passwd").is_err()); + assert!(validate_filename("file\0name.txt").is_err()); + assert!(validate_filename("con.txt").is_err()); + assert!(validate_filename("file/name.txt").is_err()); + } + + #[test] + fn test_validate_path() { + // Valid paths + assert!(validate_and_sanitize_path("documents/file.pdf").is_ok()); + assert!(validate_and_sanitize_path("./uploads/document.txt").is_ok()); + + // Invalid paths + assert!(validate_and_sanitize_path("../../../etc/passwd").is_err()); + assert!(validate_and_sanitize_path("documents/../config.txt").is_err()); + assert!(validate_and_sanitize_path("file\0name.txt").is_err()); + } + + #[test] + fn test_sanitize_filename() { + assert_eq!(sanitize_filename("file<>name.txt"), "file__name.txt"); + assert_eq!(sanitize_filename(" report.pdf "), "report.pdf"); + assert_eq!(sanitize_filename("file:name|test.doc"), "file_name_test.doc"); + } +} \ No newline at end of file diff --git a/tests/integration_enhanced_ocr_tests.rs b/tests/integration_enhanced_ocr_tests.rs index 7852aee..f9bcad1 100644 --- a/tests/integration_enhanced_ocr_tests.rs +++ b/tests/integration_enhanced_ocr_tests.rs @@ -2,6 +2,8 @@ mod tests { use readur::ocr::enhanced::{EnhancedOcrService, OcrResult, ImageQualityStats}; use readur::models::Settings; + use readur::services::file_service::FileService; + use readur::storage::{StorageConfig, factory::create_storage_backend}; use std::fs; use tempfile::{NamedTempFile, TempDir}; @@ -13,18 +15,25 @@ mod tests { TempDir::new().expect("Failed to create temp directory") } - #[test] - fn test_enhanced_ocr_service_creation() { + async fn create_test_file_service(temp_path: &str) -> FileService { + let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() }; + let storage_backend = create_storage_backend(storage_config).await.unwrap(); + FileService::with_storage(temp_path.to_string(), storage_backend) + } + + #[tokio::test] + async fn test_enhanced_ocr_service_creation() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); // Service should be created successfully assert!(!service.temp_dir.is_empty()); } - #[test] - fn test_image_quality_stats_creation() { + #[tokio::test] + async fn test_image_quality_stats_creation() { let stats = ImageQualityStats { average_brightness: 128.0, contrast_ratio: 0.5, @@ -38,11 +47,12 @@ mod tests { assert_eq!(stats.sharpness, 0.8); } - #[test] - fn test_count_words_safely_whitespace_separated() { + #[tokio::test] + async fn test_count_words_safely_whitespace_separated() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); // Test normal whitespace-separated text let text = "Hello world this is a test"; @@ -55,11 +65,12 @@ mod tests { assert_eq!(count, 3); } - #[test] - fn test_count_words_safely_continuous_text() { + #[tokio::test] + async fn test_count_words_safely_continuous_text() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); // Test continuous text without spaces (like some PDF extractions) let text = "HelloWorldThisIsAContinuousText"; @@ -72,11 +83,12 @@ mod tests { assert!(count > 0, "Should detect alphanumeric patterns as words"); } - #[test] - fn test_count_words_safely_edge_cases() { + #[tokio::test] + async fn test_count_words_safely_edge_cases() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); // Test empty text let count = service.count_words_safely(""); @@ -102,11 +114,12 @@ mod tests { assert!(count > 0, "Should detect words in mixed content"); } - #[test] - fn test_count_words_safely_large_text() { + #[tokio::test] + async fn test_count_words_safely_large_text() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); // Test with large text (over 1MB) to trigger sampling let word = "test "; @@ -118,11 +131,12 @@ mod tests { assert!(count <= 10_000_000, "Should cap at max limit: got {}", count); } - #[test] - fn test_count_words_safely_fallback_patterns() { + #[tokio::test] + async fn test_count_words_safely_fallback_patterns() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); // Test letter transition detection let text = "OneWordAnotherWordFinalWord"; @@ -140,8 +154,8 @@ mod tests { assert!(count >= 1, "Should detect words in mixed alphanumeric: got {}", count); } - #[test] - fn test_ocr_result_structure() { + #[tokio::test] + async fn test_ocr_result_structure() { let result = OcrResult { text: "Test text".to_string(), confidence: 85.5, @@ -162,7 +176,9 @@ mod tests { #[tokio::test] async fn test_extract_text_from_plain_text() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let temp_file = NamedTempFile::with_suffix(".txt").unwrap(); @@ -185,7 +201,9 @@ mod tests { #[tokio::test] async fn test_extract_text_with_context() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let temp_file = NamedTempFile::with_suffix(".txt").unwrap(); @@ -211,7 +229,9 @@ mod tests { #[tokio::test] async fn test_extract_text_unsupported_mime_type() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let temp_file = NamedTempFile::new().unwrap(); @@ -229,7 +249,9 @@ mod tests { #[tokio::test] async fn test_extract_text_nonexistent_file() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let result = service @@ -242,7 +264,9 @@ mod tests { #[tokio::test] async fn test_extract_text_large_file_truncation() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let temp_file = NamedTempFile::with_suffix(".txt").unwrap(); @@ -262,10 +286,12 @@ mod tests { } #[cfg(feature = "ocr")] - #[test] - fn test_validate_ocr_quality_high_confidence() { + #[tokio::test] + async fn test_validate_ocr_quality_high_confidence() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let mut settings = create_test_settings(); settings.ocr_min_confidence = 30.0; @@ -283,10 +309,12 @@ mod tests { } #[cfg(feature = "ocr")] - #[test] - fn test_validate_ocr_quality_low_confidence() { + #[tokio::test] + async fn test_validate_ocr_quality_low_confidence() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let mut settings = create_test_settings(); settings.ocr_min_confidence = 50.0; @@ -304,10 +332,12 @@ mod tests { } #[cfg(feature = "ocr")] - #[test] - fn test_validate_ocr_quality_no_words() { + #[tokio::test] + async fn test_validate_ocr_quality_no_words() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let result = OcrResult { @@ -324,10 +354,12 @@ mod tests { } #[cfg(feature = "ocr")] - #[test] - fn test_validate_ocr_quality_poor_character_distribution() { + #[tokio::test] + async fn test_validate_ocr_quality_poor_character_distribution() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let result = OcrResult { @@ -344,10 +376,12 @@ mod tests { } #[cfg(feature = "ocr")] - #[test] - fn test_validate_ocr_quality_good_character_distribution() { + #[tokio::test] + async fn test_validate_ocr_quality_good_character_distribution() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let result = OcrResult { @@ -366,7 +400,9 @@ mod tests { #[tokio::test] async fn test_word_count_calculation() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let test_cases = vec![ @@ -395,7 +431,9 @@ mod tests { #[tokio::test] async fn test_pdf_extraction_with_invalid_pdf() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let temp_file = NamedTempFile::with_suffix(".pdf").unwrap(); @@ -413,7 +451,9 @@ mod tests { #[tokio::test] async fn test_pdf_extraction_with_minimal_valid_pdf() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); // Minimal PDF with "Hello" text @@ -485,7 +525,9 @@ startxref #[tokio::test] async fn test_pdf_size_limit() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let temp_file = NamedTempFile::with_suffix(".pdf").unwrap(); @@ -503,8 +545,8 @@ startxref assert!(error_msg.contains("too large")); } - #[test] - fn test_settings_default_values() { + #[tokio::test] + async fn test_settings_default_values() { let settings = Settings::default(); // Test that OCR-related settings have reasonable defaults @@ -521,7 +563,9 @@ startxref #[tokio::test] async fn test_concurrent_ocr_processing() { let temp_dir = create_temp_dir(); - let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path, file_service); let settings = create_test_settings(); let mut handles = vec![]; @@ -532,7 +576,9 @@ startxref let content = format!("Concurrent test content {}", i); fs::write(temp_file.path(), &content).unwrap(); - let service_clone = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string()); + let temp_path_clone = temp_dir.path().to_str().unwrap().to_string(); + let file_service_clone = create_test_file_service(&temp_path_clone).await; + let service_clone = EnhancedOcrService::new(temp_path_clone, file_service_clone); let settings_clone = settings.clone(); let file_path = temp_file.path().to_str().unwrap().to_string(); diff --git a/tests/integration_ocr_pipeline_integration_test.rs b/tests/integration_ocr_pipeline_integration_test.rs index 799598a..13fd1d1 100644 --- a/tests/integration_ocr_pipeline_integration_test.rs +++ b/tests/integration_ocr_pipeline_integration_test.rs @@ -24,6 +24,12 @@ use readur::{ db_guardrails_simple::DocumentTransactionManager, }; +async fn create_test_file_service(temp_path: &str) -> FileService { + let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() }; + let storage_backend = create_storage_backend(storage_config).await.unwrap(); + FileService::with_storage(temp_path.to_string(), storage_backend) +} + struct OCRPipelineTestHarness { db: Database, pool: PgPool, @@ -329,7 +335,8 @@ impl OCRPipelineTestHarness { // Clone the components we need rather than the whole harness let queue_service = self.queue_service.clone(); let transaction_manager = self.transaction_manager.clone(); - let ocr_service = EnhancedOcrService::new("/tmp".to_string()); + let file_service = create_test_file_service("/tmp").await; + let ocr_service = EnhancedOcrService::new("/tmp".to_string(), file_service); let pool = self.pool.clone(); let handle = tokio::spawn(async move { diff --git a/tests/integration_pdf_word_count_tests.rs b/tests/integration_pdf_word_count_tests.rs index b015fcd..2c4b509 100644 --- a/tests/integration_pdf_word_count_tests.rs +++ b/tests/integration_pdf_word_count_tests.rs @@ -2,6 +2,8 @@ mod pdf_word_count_integration_tests { use readur::ocr::enhanced::EnhancedOcrService; use readur::models::Settings; + use readur::services::file_service::FileService; + use readur::storage::{StorageConfig, factory::create_storage_backend}; use std::io::Write; use tempfile::{NamedTempFile, TempDir}; @@ -13,6 +15,12 @@ mod pdf_word_count_integration_tests { TempDir::new().expect("Failed to create temp directory") } + async fn create_test_file_service(temp_path: &str) -> FileService { + let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() }; + let storage_backend = create_storage_backend(storage_config).await.unwrap(); + FileService::with_storage(temp_path.to_string(), storage_backend) + } + /// Create a mock PDF with specific text patterns for testing fn create_mock_pdf_file(content: &str) -> NamedTempFile { let mut temp_file = NamedTempFile::new().expect("Failed to create temp file"); @@ -82,7 +90,8 @@ mod pdf_word_count_integration_tests { async fn test_pdf_extraction_with_normal_text() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); let settings = create_test_settings(); // Create a PDF with normal spaced text @@ -108,7 +117,8 @@ mod pdf_word_count_integration_tests { async fn test_pdf_extraction_with_continuous_text() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); let settings = create_test_settings(); // Create a PDF with continuous text (no spaces) @@ -136,7 +146,8 @@ mod pdf_word_count_integration_tests { async fn test_pdf_extraction_with_mixed_content() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); let settings = create_test_settings(); // Create a PDF with mixed content (letters, numbers, punctuation) @@ -159,7 +170,8 @@ mod pdf_word_count_integration_tests { async fn test_pdf_extraction_empty_content() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); let settings = create_test_settings(); // Create a PDF with only whitespace/empty content @@ -185,7 +197,8 @@ mod pdf_word_count_integration_tests { async fn test_pdf_extraction_punctuation_only() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); let settings = create_test_settings(); // Create a PDF with only punctuation @@ -212,7 +225,8 @@ mod pdf_word_count_integration_tests { async fn test_pdf_quality_validation() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); let settings = create_test_settings(); // Use a real test PDF file if available @@ -261,7 +275,8 @@ mod pdf_word_count_integration_tests { async fn test_pdf_file_size_validation() { let temp_dir = create_temp_dir(); let _temp_path = temp_dir.path().to_str().unwrap().to_string(); - let _service = EnhancedOcrService::new(_temp_path); + let _file_service = create_test_file_service(&_temp_path).await; + let _service = EnhancedOcrService::new(_temp_path.clone(), _file_service); let _settings = create_test_settings(); // Create a small PDF file to test file operations @@ -278,11 +293,12 @@ mod pdf_word_count_integration_tests { assert!(metadata.len() < 100 * 1024 * 1024, "Test PDF should be under size limit"); } - #[test] - fn test_word_counting_regression_cases() { + #[tokio::test] + async fn test_word_counting_regression_cases() { let temp_dir = create_temp_dir(); let temp_path = temp_dir.path().to_str().unwrap().to_string(); - let service = EnhancedOcrService::new(temp_path); + let file_service = create_test_file_service(&temp_path).await; + let service = EnhancedOcrService::new(temp_path.clone(), file_service); // Regression test cases for the specific PDF issue let test_cases = vec![ diff --git a/tests/integration_per_user_watch_directories_tests.rs b/tests/integration_per_user_watch_directories_tests.rs index d0f66a8..2856dd7 100644 --- a/tests/integration_per_user_watch_directories_tests.rs +++ b/tests/integration_per_user_watch_directories_tests.rs @@ -296,7 +296,7 @@ async fn test_user_watch_directory_file_processing_simulation() -> Result<()> { // Create user watch manager to test file path mapping let user_watch_service = state.user_watch_service.as_ref().unwrap(); - let user_watch_manager = readur::scheduling::user_watch_manager::UserWatchManager::new(state.db.clone(), (**user_watch_service).clone()); + let user_watch_manager = readur::scheduling::user_watch_manager::UserWatchManager::new(state.db.clone(), Arc::clone(user_watch_service)); // Create test user let test_user = readur::models::User { diff --git a/tests/integration_simple_throttling_test.rs b/tests/integration_simple_throttling_test.rs index c21ffe5..05dc98e 100644 --- a/tests/integration_simple_throttling_test.rs +++ b/tests/integration_simple_throttling_test.rs @@ -16,6 +16,8 @@ use readur::{ db::Database, ocr::queue::OcrQueueService, ocr::enhanced::EnhancedOcrService, + services::file_service::FileService, + storage::{StorageConfig, factory::create_storage_backend}, }; // Use the same database URL as the running server @@ -25,6 +27,12 @@ fn get_test_db_url() -> String { .unwrap_or_else(|_| "postgresql://readur:readur@localhost:5432/readur".to_string()) } +async fn create_test_file_service(temp_path: &str) -> FileService { + let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() }; + let storage_backend = create_storage_backend(storage_config).await.unwrap(); + FileService::with_storage(temp_path.to_string(), storage_backend) +} + struct SimpleThrottleTest { pool: PgPool, queue_service: Arc, @@ -140,7 +148,8 @@ impl SimpleThrottleTest { let handle = tokio::spawn(async move { let worker_name = format!("worker-{}", worker_id); - let ocr_service = EnhancedOcrService::new("/tmp".to_string()); + let file_service = create_test_file_service("/tmp").await; + let ocr_service = EnhancedOcrService::new("/tmp".to_string(), file_service); let mut jobs_processed = 0; info!("Worker {} starting", worker_name);