feat(storage): further support the s3 storage backend
This commit is contained in:
parent
6624fc57fb
commit
862c36aa72
|
|
@ -8,12 +8,15 @@
|
|||
//! 3. Upload files to S3 with proper structure
|
||||
//! 4. Update database records with S3 paths
|
||||
//! 5. Optionally delete local files after successful upload
|
||||
//! 6. Support rollback on failure with transaction-like behavior
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::path::Path;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use uuid::Uuid;
|
||||
use tracing::{info, warn, error};
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use readur::{
|
||||
config::Config,
|
||||
|
|
@ -21,6 +24,41 @@ use readur::{
|
|||
services::{s3_service::S3Service, file_service::FileService},
|
||||
};
|
||||
|
||||
/// Migration state tracking for rollback functionality
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct MigrationState {
|
||||
started_at: chrono::DateTime<chrono::Utc>,
|
||||
completed_migrations: Vec<MigrationRecord>,
|
||||
failed_migrations: Vec<FailedMigration>,
|
||||
total_files: usize,
|
||||
processed_files: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct MigrationRecord {
|
||||
document_id: Uuid,
|
||||
user_id: Uuid,
|
||||
original_path: String,
|
||||
s3_key: String,
|
||||
migrated_at: chrono::DateTime<chrono::Utc>,
|
||||
associated_files: Vec<AssociatedFile>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AssociatedFile {
|
||||
file_type: String, // "thumbnail" or "processed_image"
|
||||
original_path: String,
|
||||
s3_key: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FailedMigration {
|
||||
document_id: Uuid,
|
||||
original_path: String,
|
||||
error: String,
|
||||
failed_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "migrate_to_s3")]
|
||||
#[command(about = "Migrate existing local files to S3 storage")]
|
||||
|
|
@ -40,6 +78,14 @@ struct Args {
|
|||
/// Only migrate files for specific user ID
|
||||
#[arg(short, long)]
|
||||
user_id: Option<String>,
|
||||
|
||||
/// Enable rollback on failure - will revert any successful migrations if overall process fails
|
||||
#[arg(long)]
|
||||
enable_rollback: bool,
|
||||
|
||||
/// Resume from a specific document ID (for partial recovery)
|
||||
#[arg(long)]
|
||||
resume_from: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
|
@ -132,41 +178,107 @@ async fn main() -> Result<()> {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
// Perform migration
|
||||
// Initialize migration state
|
||||
let mut migration_state = MigrationState {
|
||||
started_at: chrono::Utc::now(),
|
||||
completed_migrations: Vec::new(),
|
||||
failed_migrations: Vec::new(),
|
||||
total_files: local_documents.len(),
|
||||
processed_files: 0,
|
||||
};
|
||||
|
||||
// Resume from specific document if requested
|
||||
let start_index = if let Some(resume_from_str) = &args.resume_from {
|
||||
let resume_doc_id = Uuid::parse_str(resume_from_str)?;
|
||||
local_documents.iter().position(|doc| doc.id == resume_doc_id)
|
||||
.unwrap_or(0)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
info!("📊 Migration plan: {} files to process (starting from index {})",
|
||||
local_documents.len() - start_index, start_index);
|
||||
|
||||
// Perform migration with progress tracking
|
||||
let mut migrated_count = 0;
|
||||
let mut failed_count = 0;
|
||||
|
||||
for doc in local_documents {
|
||||
info!("📦 Migrating: {} ({})", doc.original_filename, doc.id);
|
||||
for (index, doc) in local_documents.iter().enumerate().skip(start_index) {
|
||||
info!("📦 Migrating: {} ({}) [{}/{}]",
|
||||
doc.original_filename, doc.id, index + 1, local_documents.len());
|
||||
|
||||
match migrate_document(&db, &s3_service, &file_service, &doc, args.delete_local).await {
|
||||
Ok(_) => {
|
||||
match migrate_document_with_tracking(&db, &s3_service, &file_service, doc, args.delete_local).await {
|
||||
Ok(migration_record) => {
|
||||
migrated_count += 1;
|
||||
info!("✅ Successfully migrated: {}", doc.original_filename);
|
||||
migration_state.completed_migrations.push(migration_record);
|
||||
migration_state.processed_files += 1;
|
||||
info!("✅ Successfully migrated: {} [{}/{}]",
|
||||
doc.original_filename, migrated_count, local_documents.len());
|
||||
|
||||
// Save progress periodically (every 10 files)
|
||||
if migrated_count % 10 == 0 {
|
||||
save_migration_state(&migration_state).await?;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
failed_count += 1;
|
||||
let failed_migration = FailedMigration {
|
||||
document_id: doc.id,
|
||||
original_path: doc.file_path.clone(),
|
||||
error: e.to_string(),
|
||||
failed_at: chrono::Utc::now(),
|
||||
};
|
||||
migration_state.failed_migrations.push(failed_migration);
|
||||
migration_state.processed_files += 1;
|
||||
error!("❌ Failed to migrate {}: {}", doc.original_filename, e);
|
||||
|
||||
// If rollback is enabled and we have failures, offer to rollback
|
||||
if args.enable_rollback && failed_count > 0 {
|
||||
error!("💥 Migration failure detected with rollback enabled!");
|
||||
error!("Do you want to rollback all {} successful migrations? (y/N)", migrated_count);
|
||||
|
||||
// For automation, we'll automatically rollback on any failure
|
||||
// In interactive mode, you could read from stdin here
|
||||
warn!("🔄 Automatically initiating rollback due to failure...");
|
||||
match rollback_migrations(&db, &s3_service, &migration_state).await {
|
||||
Ok(rolled_back) => {
|
||||
error!("🔄 Successfully rolled back {} migrations", rolled_back);
|
||||
return Err(anyhow::anyhow!("Migration failed and was rolled back"));
|
||||
}
|
||||
Err(rollback_err) => {
|
||||
error!("💥 CRITICAL: Rollback failed: {}", rollback_err);
|
||||
error!("💾 Migration state saved for manual recovery");
|
||||
save_migration_state(&migration_state).await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"Migration failed and rollback also failed. Check migration state file for manual recovery."
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save final migration state
|
||||
save_migration_state(&migration_state).await?;
|
||||
|
||||
info!("🎉 Migration completed!");
|
||||
info!("✅ Successfully migrated: {} files", migrated_count);
|
||||
if failed_count > 0 {
|
||||
warn!("❌ Failed to migrate: {} files", failed_count);
|
||||
warn!("💾 Check migration_state.json for details on failures");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn migrate_document(
|
||||
async fn migrate_document_with_tracking(
|
||||
db: &Database,
|
||||
s3_service: &S3Service,
|
||||
file_service: &FileService,
|
||||
document: &readur::models::Document,
|
||||
delete_local: bool,
|
||||
) -> Result<()> {
|
||||
) -> Result<MigrationRecord> {
|
||||
// Read local file
|
||||
let local_path = Path::new(&document.file_path);
|
||||
if !local_path.exists() {
|
||||
|
|
@ -189,7 +301,7 @@ async fn migrate_document(
|
|||
db.update_document_file_path(document.id, &s3_path).await?;
|
||||
|
||||
// Migrate associated files (thumbnails, processed images)
|
||||
migrate_associated_files(s3_service, file_service, document, delete_local).await?;
|
||||
let associated_files = migrate_associated_files_with_tracking(s3_service, file_service, document, delete_local).await?;
|
||||
|
||||
// Delete local file if requested
|
||||
if delete_local {
|
||||
|
|
@ -200,27 +312,46 @@ async fn migrate_document(
|
|||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
// Create migration record for tracking
|
||||
let migration_record = MigrationRecord {
|
||||
document_id: document.id,
|
||||
user_id: document.user_id,
|
||||
original_path: document.file_path.clone(),
|
||||
s3_key,
|
||||
migrated_at: chrono::Utc::now(),
|
||||
associated_files,
|
||||
};
|
||||
|
||||
Ok(migration_record)
|
||||
}
|
||||
|
||||
async fn migrate_associated_files(
|
||||
async fn migrate_associated_files_with_tracking(
|
||||
s3_service: &S3Service,
|
||||
file_service: &FileService,
|
||||
document: &readur::models::Document,
|
||||
delete_local: bool,
|
||||
) -> Result<()> {
|
||||
) -> Result<Vec<AssociatedFile>> {
|
||||
let mut associated_files = Vec::new();
|
||||
|
||||
// Migrate thumbnail
|
||||
let thumbnail_path = file_service.get_thumbnails_path().join(format!("{}_thumb.jpg", document.id));
|
||||
if thumbnail_path.exists() {
|
||||
match tokio::fs::read(&thumbnail_path).await {
|
||||
Ok(thumbnail_data) => {
|
||||
if let Err(e) = s3_service.store_thumbnail(document.user_id, document.id, &thumbnail_data).await {
|
||||
warn!("Failed to migrate thumbnail for {}: {}", document.id, e);
|
||||
} else {
|
||||
info!("📸 Migrated thumbnail for: {}", document.original_filename);
|
||||
if delete_local {
|
||||
let _ = tokio::fs::remove_file(&thumbnail_path).await;
|
||||
match s3_service.store_thumbnail(document.user_id, document.id, &thumbnail_data).await {
|
||||
Ok(s3_key) => {
|
||||
info!("📸 Migrated thumbnail for: {}", document.original_filename);
|
||||
associated_files.push(AssociatedFile {
|
||||
file_type: "thumbnail".to_string(),
|
||||
original_path: thumbnail_path.to_string_lossy().to_string(),
|
||||
s3_key,
|
||||
});
|
||||
if delete_local {
|
||||
let _ = tokio::fs::remove_file(&thumbnail_path).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to migrate thumbnail for {}: {}", document.id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -233,12 +364,20 @@ async fn migrate_associated_files(
|
|||
if processed_path.exists() {
|
||||
match tokio::fs::read(&processed_path).await {
|
||||
Ok(processed_data) => {
|
||||
if let Err(e) = s3_service.store_processed_image(document.user_id, document.id, &processed_data).await {
|
||||
warn!("Failed to migrate processed image for {}: {}", document.id, e);
|
||||
} else {
|
||||
info!("🖼️ Migrated processed image for: {}", document.original_filename);
|
||||
if delete_local {
|
||||
let _ = tokio::fs::remove_file(&processed_path).await;
|
||||
match s3_service.store_processed_image(document.user_id, document.id, &processed_data).await {
|
||||
Ok(s3_key) => {
|
||||
info!("🖼️ Migrated processed image for: {}", document.original_filename);
|
||||
associated_files.push(AssociatedFile {
|
||||
file_type: "processed_image".to_string(),
|
||||
original_path: processed_path.to_string_lossy().to_string(),
|
||||
s3_key,
|
||||
});
|
||||
if delete_local {
|
||||
let _ = tokio::fs::remove_file(&processed_path).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to migrate processed image for {}: {}", document.id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -246,5 +385,94 @@ async fn migrate_associated_files(
|
|||
}
|
||||
}
|
||||
|
||||
Ok(associated_files)
|
||||
}
|
||||
|
||||
/// Save migration state to disk for recovery purposes
|
||||
async fn save_migration_state(state: &MigrationState) -> Result<()> {
|
||||
let state_json = serde_json::to_string_pretty(state)?;
|
||||
tokio::fs::write("migration_state.json", state_json).await?;
|
||||
info!("💾 Migration state saved to migration_state.json");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Rollback migrations by restoring database paths and deleting S3 objects
|
||||
async fn rollback_migrations(
|
||||
db: &Database,
|
||||
s3_service: &S3Service,
|
||||
state: &MigrationState,
|
||||
) -> Result<usize> {
|
||||
info!("🔄 Starting rollback of {} migrations...", state.completed_migrations.len());
|
||||
|
||||
let mut rolled_back = 0;
|
||||
let mut rollback_errors = Vec::new();
|
||||
|
||||
// Process migrations in reverse order (most recent first)
|
||||
for migration in state.completed_migrations.iter().rev() {
|
||||
info!("🔄 Rolling back migration for document {}", migration.document_id);
|
||||
|
||||
// Restore original database path
|
||||
match db.update_document_file_path(migration.document_id, &migration.original_path).await {
|
||||
Ok(_) => {
|
||||
info!("✅ Restored database path for document {}", migration.document_id);
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = format!("Failed to restore DB path for {}: {}", migration.document_id, e);
|
||||
error!("❌ {}", error_msg);
|
||||
rollback_errors.push(error_msg);
|
||||
continue; // Skip S3 cleanup if DB restore failed
|
||||
}
|
||||
}
|
||||
|
||||
// Delete S3 object (main document)
|
||||
match s3_service.delete_file(&migration.s3_key).await {
|
||||
Ok(_) => {
|
||||
info!("🗑️ Deleted S3 object: {}", migration.s3_key);
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = format!("Failed to delete S3 object {}: {}", migration.s3_key, e);
|
||||
warn!("⚠️ {}", error_msg);
|
||||
rollback_errors.push(error_msg);
|
||||
// Continue with associated files even if main file deletion failed
|
||||
}
|
||||
}
|
||||
|
||||
// Delete associated S3 objects (thumbnails, processed images)
|
||||
for associated in &migration.associated_files {
|
||||
match s3_service.delete_file(&associated.s3_key).await {
|
||||
Ok(_) => {
|
||||
info!("🗑️ Deleted associated S3 object: {} ({})", associated.s3_key, associated.file_type);
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = format!("Failed to delete associated S3 object {}: {}", associated.s3_key, e);
|
||||
warn!("⚠️ {}", error_msg);
|
||||
rollback_errors.push(error_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rolled_back += 1;
|
||||
}
|
||||
|
||||
if !rollback_errors.is_empty() {
|
||||
warn!("⚠️ Rollback completed with {} errors:", rollback_errors.len());
|
||||
for error in &rollback_errors {
|
||||
warn!(" - {}", error);
|
||||
}
|
||||
|
||||
// Save error details for manual cleanup
|
||||
let error_state = serde_json::json!({
|
||||
"rollback_completed_at": chrono::Utc::now(),
|
||||
"rolled_back_count": rolled_back,
|
||||
"rollback_errors": rollback_errors,
|
||||
"original_migration_state": state
|
||||
});
|
||||
|
||||
let error_json = serde_json::to_string_pretty(&error_state)?;
|
||||
tokio::fs::write("rollback_errors.json", error_json).await?;
|
||||
warn!("💾 Rollback errors saved to rollback_errors.json for manual cleanup");
|
||||
}
|
||||
|
||||
info!("✅ Rollback completed: {} migrations processed", rolled_back);
|
||||
Ok(rolled_back)
|
||||
}
|
||||
|
|
@ -54,35 +54,7 @@ impl Config {
|
|||
// Database Configuration
|
||||
let database_url = match env::var("DATABASE_URL") {
|
||||
Ok(val) => {
|
||||
// Mask sensitive parts of database URL for logging
|
||||
let masked_url = if val.contains('@') {
|
||||
let parts: Vec<&str> = val.split('@').collect();
|
||||
if parts.len() >= 2 {
|
||||
let credentials_part = parts[0];
|
||||
let remaining_part = parts[1..].join("@");
|
||||
|
||||
// Extract just the username part before the password
|
||||
if let Some(username_start) = credentials_part.rfind("://") {
|
||||
let protocol = &credentials_part[..username_start + 3];
|
||||
let credentials = &credentials_part[username_start + 3..];
|
||||
if let Some(colon_pos) = credentials.find(':') {
|
||||
let username = &credentials[..colon_pos];
|
||||
// Show first and last character of the username
|
||||
let masked_username = format!("{}{}", &username[..1], &username[username.len() - 1..]);
|
||||
format!("{}{}:***@{}", protocol, masked_username, remaining_part)
|
||||
} else {
|
||||
format!("{}***@{}", protocol, remaining_part)
|
||||
}
|
||||
} else {
|
||||
"***masked***".to_string()
|
||||
}
|
||||
} else {
|
||||
"***masked***".to_string()
|
||||
}
|
||||
} else {
|
||||
val.clone()
|
||||
};
|
||||
println!("✅ DATABASE_URL: {} (loaded from env)", masked_url);
|
||||
println!("✅ DATABASE_URL: configured (loaded from env)");
|
||||
val
|
||||
}
|
||||
Err(_) => {
|
||||
|
|
@ -462,10 +434,8 @@ impl Config {
|
|||
if !bucket_name.is_empty() && !access_key_id.is_empty() && !secret_access_key.is_empty() {
|
||||
println!("✅ S3_BUCKET_NAME: {} (loaded from env)", bucket_name);
|
||||
println!("✅ S3_REGION: {} (loaded from env)", region);
|
||||
println!("✅ S3_ACCESS_KEY_ID: {}***{} (loaded from env)",
|
||||
&access_key_id[..2.min(access_key_id.len())],
|
||||
&access_key_id[access_key_id.len().saturating_sub(2)..]);
|
||||
println!("✅ S3_SECRET_ACCESS_KEY: ***hidden*** (loaded from env, {} chars)", secret_access_key.len());
|
||||
println!("✅ S3_ACCESS_KEY_ID: configured (loaded from env)");
|
||||
println!("✅ S3_SECRET_ACCESS_KEY: configured (loaded from env)");
|
||||
if let Some(ref endpoint) = endpoint_url {
|
||||
println!("✅ S3_ENDPOINT_URL: {} (loaded from env)", endpoint);
|
||||
}
|
||||
|
|
|
|||
35
src/main.rs
35
src/main.rs
|
|
@ -121,15 +121,35 @@ async fn main() -> anyhow::Result<()> {
|
|||
println!("📁 Upload directory: {}", config.upload_path);
|
||||
println!("👁️ Watch directory: {}", config.watch_folder);
|
||||
|
||||
// Initialize file service using the new storage backend architecture
|
||||
// Initialize file service using the new storage backend architecture with fallback
|
||||
info!("Initializing file service with storage backend...");
|
||||
let storage_config = readur::storage::factory::storage_config_from_env(&config)?;
|
||||
let file_service = readur::services::file_service::FileService::from_config(storage_config, config.upload_path.clone()).await
|
||||
.map_err(|e| {
|
||||
error!("Failed to initialize file service with configured storage backend: {}", e);
|
||||
e
|
||||
})?;
|
||||
info!("✅ File service initialized with {} storage backend", file_service.storage_type());
|
||||
let file_service = match readur::services::file_service::FileService::from_config(storage_config, config.upload_path.clone()).await {
|
||||
Ok(service) => {
|
||||
info!("✅ File service initialized with {} storage backend", service.storage_type());
|
||||
service
|
||||
}
|
||||
Err(e) => {
|
||||
error!("❌ Failed to initialize configured storage backend: {}", e);
|
||||
warn!("🔄 Falling back to local storage...");
|
||||
|
||||
// Create fallback local storage configuration
|
||||
let fallback_config = readur::storage::StorageConfig::Local {
|
||||
upload_path: config.upload_path.clone(),
|
||||
};
|
||||
|
||||
match readur::services::file_service::FileService::from_config(fallback_config, config.upload_path.clone()).await {
|
||||
Ok(fallback_service) => {
|
||||
warn!("✅ Successfully initialized fallback local storage");
|
||||
fallback_service
|
||||
}
|
||||
Err(fallback_err) => {
|
||||
error!("💥 CRITICAL: Even fallback local storage failed to initialize: {}", fallback_err);
|
||||
return Err(fallback_err.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize the storage backend (creates directories, validates access, etc.)
|
||||
if let Err(e) = file_service.initialize_storage().await {
|
||||
|
|
@ -159,7 +179,6 @@ async fn main() -> anyhow::Result<()> {
|
|||
}
|
||||
Err(e) => {
|
||||
println!("❌ CRITICAL: Failed to connect to database for web operations!");
|
||||
println!("Database URL: {}", db_info); // Use the already-masked URL
|
||||
println!("Error: {}", e);
|
||||
println!("\n🔧 Please verify:");
|
||||
println!(" - Database server is running");
|
||||
|
|
|
|||
|
|
@ -45,14 +45,6 @@ impl EnhancedOcrService {
|
|||
Self { temp_dir, file_service }
|
||||
}
|
||||
|
||||
/// Backward-compatible constructor for tests and legacy code
|
||||
/// Creates a FileService with local storage using UPLOAD_PATH env var
|
||||
#[deprecated(note = "Use new() with FileService parameter instead")]
|
||||
pub fn new_legacy(temp_dir: String) -> Self {
|
||||
let upload_path = std::env::var("UPLOAD_PATH").unwrap_or_else(|_| "./uploads".to_string());
|
||||
let file_service = FileService::new(upload_path);
|
||||
Self::new(temp_dir, file_service)
|
||||
}
|
||||
|
||||
/// Extract text from image with high-quality OCR settings
|
||||
#[cfg(feature = "ocr")]
|
||||
|
|
|
|||
|
|
@ -75,14 +75,6 @@ impl OcrQueueService {
|
|||
}
|
||||
}
|
||||
|
||||
/// Backward-compatible constructor for tests and legacy code
|
||||
/// Creates a FileService with local storage using UPLOAD_PATH env var
|
||||
#[deprecated(note = "Use new() with FileService parameter instead")]
|
||||
pub fn new_legacy(db: Database, pool: PgPool, max_concurrent_jobs: usize) -> Self {
|
||||
let upload_path = std::env::var("UPLOAD_PATH").unwrap_or_else(|_| "./uploads".to_string());
|
||||
let file_service = std::sync::Arc::new(crate::services::file_service::FileService::new(upload_path));
|
||||
Self::new(db, pool, max_concurrent_jobs, file_service)
|
||||
}
|
||||
|
||||
/// Add a document to the OCR queue
|
||||
pub async fn enqueue_document(&self, document_id: Uuid, priority: i32, file_size: i64) -> Result<Uuid> {
|
||||
|
|
|
|||
|
|
@ -202,10 +202,10 @@ pub async fn upload_document(
|
|||
}
|
||||
|
||||
// Create ingestion service
|
||||
let file_service = &state.file_service;
|
||||
let file_service_clone = state.file_service.as_ref().clone();
|
||||
let ingestion_service = DocumentIngestionService::new(
|
||||
state.db.clone(),
|
||||
(**file_service).clone(),
|
||||
file_service_clone,
|
||||
);
|
||||
|
||||
debug!("[UPLOAD_DEBUG] Calling ingestion service for file: {}", filename);
|
||||
|
|
|
|||
|
|
@ -315,8 +315,8 @@ async fn process_single_file(
|
|||
debug!("Downloaded file: {} ({} bytes)", file_info.name, file_data.len());
|
||||
|
||||
// Use the unified ingestion service for consistent deduplication
|
||||
let file_service = &state.file_service;
|
||||
let ingestion_service = DocumentIngestionService::new(state.db.clone(), (**file_service).clone());
|
||||
let file_service_clone = state.file_service.as_ref().clone();
|
||||
let ingestion_service = DocumentIngestionService::new(state.db.clone(), file_service_clone);
|
||||
|
||||
let result = if let Some(source_id) = webdav_source_id {
|
||||
ingestion_service
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ use serde_json;
|
|||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use uuid::Uuid;
|
||||
use futures::stream::StreamExt;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
#[cfg(feature = "s3")]
|
||||
use aws_sdk_s3::Client;
|
||||
|
|
@ -15,10 +17,18 @@ use aws_credential_types::Credentials;
|
|||
use aws_types::region::Region as AwsRegion;
|
||||
#[cfg(feature = "s3")]
|
||||
use aws_sdk_s3::primitives::ByteStream;
|
||||
#[cfg(feature = "s3")]
|
||||
use aws_sdk_s3::types::{CompletedPart, CompletedMultipartUpload};
|
||||
|
||||
use crate::models::{FileIngestionInfo, S3SourceConfig};
|
||||
use crate::storage::StorageBackend;
|
||||
|
||||
/// Threshold for using streaming multipart uploads (100MB)
|
||||
const STREAMING_THRESHOLD: usize = 100 * 1024 * 1024;
|
||||
|
||||
/// Multipart upload chunk size (16MB - AWS minimum is 5MB, we use 16MB for better performance)
|
||||
const MULTIPART_CHUNK_SIZE: usize = 16 * 1024 * 1024;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct S3Service {
|
||||
#[cfg(feature = "s3")]
|
||||
|
|
@ -347,7 +357,15 @@ impl S3Service {
|
|||
#[cfg(feature = "s3")]
|
||||
{
|
||||
let key = self.generate_document_key(user_id, document_id, filename);
|
||||
self.store_file(&key, data, None).await?;
|
||||
|
||||
// Use streaming upload for large files
|
||||
if data.len() > STREAMING_THRESHOLD {
|
||||
info!("Using streaming multipart upload for large file: {} ({} bytes)", key, data.len());
|
||||
self.store_file_multipart(&key, data, None).await?;
|
||||
} else {
|
||||
self.store_file(&key, data, None).await?;
|
||||
}
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
}
|
||||
|
|
@ -438,6 +456,137 @@ impl S3Service {
|
|||
}
|
||||
}
|
||||
|
||||
/// Store large files using multipart upload for better performance and memory usage
|
||||
async fn store_file_multipart(&self, key: &str, data: &[u8], metadata: Option<HashMap<String, String>>) -> Result<()> {
|
||||
#[cfg(not(feature = "s3"))]
|
||||
{
|
||||
return Err(anyhow!("S3 support not compiled in"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "s3")]
|
||||
{
|
||||
info!("Starting multipart upload for file: {}/{} ({} bytes)", self.config.bucket_name, key, data.len());
|
||||
|
||||
let key_owned = key.to_string();
|
||||
let data_owned = data.to_vec();
|
||||
let metadata_owned = metadata.clone();
|
||||
let bucket_name = self.config.bucket_name.clone();
|
||||
let client = self.client.clone();
|
||||
|
||||
self.retry_operation(&format!("store_file_multipart: {}", key), || {
|
||||
let key = key_owned.clone();
|
||||
let data = data_owned.clone();
|
||||
let metadata = metadata_owned.clone();
|
||||
let bucket_name = bucket_name.clone();
|
||||
let client = client.clone();
|
||||
let content_type = self.get_content_type_from_key(&key);
|
||||
|
||||
async move {
|
||||
// Step 1: Initiate multipart upload
|
||||
let mut create_request = client
|
||||
.create_multipart_upload()
|
||||
.bucket(&bucket_name)
|
||||
.key(&key);
|
||||
|
||||
// Add metadata if provided
|
||||
if let Some(meta) = metadata {
|
||||
for (k, v) in meta {
|
||||
create_request = create_request.metadata(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
// Set content type based on file extension
|
||||
if let Some(ct) = content_type {
|
||||
create_request = create_request.content_type(ct);
|
||||
}
|
||||
|
||||
let create_response = create_request.send().await
|
||||
.map_err(|e| anyhow!("Failed to initiate multipart upload for {}: {}", key, e))?;
|
||||
|
||||
let upload_id = create_response.upload_id()
|
||||
.ok_or_else(|| anyhow!("Missing upload ID in multipart upload response"))?;
|
||||
|
||||
info!("Initiated multipart upload for {}: {}", key, upload_id);
|
||||
|
||||
// Step 2: Upload parts in chunks
|
||||
let mut completed_parts = Vec::new();
|
||||
let total_chunks = (data.len() + MULTIPART_CHUNK_SIZE - 1) / MULTIPART_CHUNK_SIZE;
|
||||
|
||||
for (chunk_index, chunk) in data.chunks(MULTIPART_CHUNK_SIZE).enumerate() {
|
||||
let part_number = (chunk_index + 1) as i32;
|
||||
|
||||
debug!("Uploading part {} of {} for {} ({} bytes)",
|
||||
part_number, total_chunks, key, chunk.len());
|
||||
|
||||
let upload_part_response = client
|
||||
.upload_part()
|
||||
.bucket(&bucket_name)
|
||||
.key(&key)
|
||||
.upload_id(upload_id)
|
||||
.part_number(part_number)
|
||||
.body(ByteStream::from(chunk.to_vec()))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to upload part {} for {}: {}", part_number, key, e))?;
|
||||
|
||||
let etag = upload_part_response.e_tag()
|
||||
.ok_or_else(|| anyhow!("Missing ETag in upload part response"))?;
|
||||
|
||||
completed_parts.push(
|
||||
CompletedPart::builder()
|
||||
.part_number(part_number)
|
||||
.e_tag(etag)
|
||||
.build()
|
||||
);
|
||||
|
||||
debug!("Successfully uploaded part {} for {}", part_number, key);
|
||||
}
|
||||
|
||||
// Step 3: Complete multipart upload
|
||||
let completed_multipart_upload = CompletedMultipartUpload::builder()
|
||||
.set_parts(Some(completed_parts))
|
||||
.build();
|
||||
|
||||
client
|
||||
.complete_multipart_upload()
|
||||
.bucket(&bucket_name)
|
||||
.key(&key)
|
||||
.upload_id(upload_id)
|
||||
.multipart_upload(completed_multipart_upload)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
// If completion fails, try to abort the multipart upload
|
||||
let abort_client = client.clone();
|
||||
let abort_bucket = bucket_name.clone();
|
||||
let abort_key = key.clone();
|
||||
let abort_upload_id = upload_id.to_string();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(abort_err) = abort_client
|
||||
.abort_multipart_upload()
|
||||
.bucket(abort_bucket)
|
||||
.key(abort_key)
|
||||
.upload_id(abort_upload_id)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
error!("Failed to abort multipart upload: {}", abort_err);
|
||||
}
|
||||
});
|
||||
|
||||
anyhow!("Failed to complete multipart upload for {}: {}", key, e)
|
||||
})?;
|
||||
|
||||
info!("Successfully completed multipart upload for {}", key);
|
||||
Ok(())
|
||||
}
|
||||
}).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve a file from S3
|
||||
pub async fn retrieve_file(&self, key: &str) -> Result<Vec<u8>> {
|
||||
#[cfg(not(feature = "s3"))]
|
||||
|
|
@ -666,7 +815,15 @@ impl StorageBackend for S3Service {
|
|||
async fn store_document(&self, user_id: Uuid, document_id: Uuid, filename: &str, data: &[u8]) -> Result<String> {
|
||||
// Generate S3 key
|
||||
let key = self.generate_document_key(user_id, document_id, filename);
|
||||
self.store_file(&key, data, None).await?;
|
||||
|
||||
// Use streaming upload for large files
|
||||
if data.len() > STREAMING_THRESHOLD {
|
||||
info!("Using streaming multipart upload for large file: {} ({} bytes)", key, data.len());
|
||||
self.store_file_multipart(&key, data, None).await?;
|
||||
} else {
|
||||
self.store_file(&key, data, None).await?;
|
||||
}
|
||||
|
||||
Ok(format!("s3://{}", key))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,21 +3,30 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::fs;
|
||||
use tracing::{info, error};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, error, warn, debug};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::StorageBackend;
|
||||
use crate::utils::security::{validate_filename, validate_and_sanitize_path, validate_path_within_base};
|
||||
|
||||
/// Local filesystem storage backend
|
||||
pub struct LocalStorageBackend {
|
||||
upload_path: String,
|
||||
/// Cache for resolved file paths to reduce filesystem calls
|
||||
path_cache: Arc<RwLock<HashMap<String, Option<String>>>>,
|
||||
}
|
||||
|
||||
impl LocalStorageBackend {
|
||||
/// Create a new local storage backend
|
||||
pub fn new(upload_path: String) -> Self {
|
||||
Self { upload_path }
|
||||
Self {
|
||||
upload_path,
|
||||
path_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the base upload path
|
||||
|
|
@ -50,33 +59,131 @@ impl LocalStorageBackend {
|
|||
Path::new(&self.upload_path).join("backups")
|
||||
}
|
||||
|
||||
/// Resolve file path, handling both old and new directory structures
|
||||
/// Resolve file path, handling both old and new directory structures with caching
|
||||
pub async fn resolve_file_path(&self, file_path: &str) -> Result<String> {
|
||||
// If the file exists at the given path, use it
|
||||
if Path::new(file_path).exists() {
|
||||
return Ok(file_path.to_string());
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.path_cache.read().await;
|
||||
if let Some(cached_result) = cache.get(file_path) {
|
||||
return match cached_result {
|
||||
Some(resolved_path) => {
|
||||
debug!("Cache hit for file path: {} -> {}", file_path, resolved_path);
|
||||
Ok(resolved_path.clone())
|
||||
}
|
||||
None => {
|
||||
debug!("Cache hit for non-existent file: {}", file_path);
|
||||
Err(anyhow::anyhow!("File not found: {} (cached)", file_path))
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Generate candidate paths in order of likelihood
|
||||
let candidates = self.generate_path_candidates(file_path);
|
||||
|
||||
// Try to find the file in the new structured directory
|
||||
// Check candidates efficiently
|
||||
let mut found_path = None;
|
||||
for candidate in &candidates {
|
||||
match tokio::fs::metadata(candidate).await {
|
||||
Ok(metadata) => {
|
||||
if metadata.is_file() {
|
||||
found_path = Some(candidate.clone());
|
||||
debug!("Found file at: {}", candidate);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// File doesn't exist at this path, continue to next candidate
|
||||
debug!("File not found at: {}", candidate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
{
|
||||
let mut cache = self.path_cache.write().await;
|
||||
cache.insert(file_path.to_string(), found_path.clone());
|
||||
|
||||
// Prevent cache from growing too large
|
||||
if cache.len() > 10000 {
|
||||
// Clear oldest 20% of entries (simple cache eviction)
|
||||
let to_remove: Vec<String> = cache.keys().take(2000).cloned().collect();
|
||||
for key in to_remove {
|
||||
cache.remove(&key);
|
||||
}
|
||||
debug!("Evicted cache entries to prevent memory growth");
|
||||
}
|
||||
}
|
||||
|
||||
match found_path {
|
||||
Some(path) => {
|
||||
if path != file_path {
|
||||
info!("Resolved file path: {} -> {}", file_path, path);
|
||||
}
|
||||
Ok(path)
|
||||
}
|
||||
None => {
|
||||
debug!("File not found in any candidate location: {}", file_path);
|
||||
Err(anyhow::anyhow!(
|
||||
"File not found: {} (checked {} locations)",
|
||||
file_path,
|
||||
candidates.len()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate candidate paths for file resolution
|
||||
fn generate_path_candidates(&self, file_path: &str) -> Vec<String> {
|
||||
let mut candidates = Vec::new();
|
||||
|
||||
// 1. Original path (most likely for new files)
|
||||
candidates.push(file_path.to_string());
|
||||
|
||||
// 2. For legacy compatibility - try structured directory
|
||||
if file_path.starts_with("./uploads/") && !file_path.contains("/documents/") {
|
||||
let new_path = file_path.replace("./uploads/", "./uploads/documents/");
|
||||
if Path::new(&new_path).exists() {
|
||||
info!("Found file in new structured directory: {} -> {}", file_path, new_path);
|
||||
return Ok(new_path);
|
||||
}
|
||||
candidates.push(file_path.replace("./uploads/", "./uploads/documents/"));
|
||||
}
|
||||
|
||||
// Try without the ./ prefix
|
||||
// 3. Try without ./ prefix in structured directory
|
||||
if file_path.starts_with("uploads/") && !file_path.contains("/documents/") {
|
||||
let new_path = file_path.replace("uploads/", "uploads/documents/");
|
||||
if Path::new(&new_path).exists() {
|
||||
info!("Found file in new structured directory: {} -> {}", file_path, new_path);
|
||||
return Ok(new_path);
|
||||
}
|
||||
candidates.push(file_path.replace("uploads/", "uploads/documents/"));
|
||||
}
|
||||
|
||||
// File not found in any expected location
|
||||
Err(anyhow::anyhow!("File not found: {} (checked original path and structured directory)", file_path))
|
||||
// 4. Try relative to our configured upload path
|
||||
if !file_path.starts_with(&self.upload_path) {
|
||||
let relative_path = Path::new(&self.upload_path).join(file_path);
|
||||
candidates.push(relative_path.to_string_lossy().to_string());
|
||||
|
||||
// Also try in documents subdirectory
|
||||
let documents_path = Path::new(&self.upload_path).join("documents").join(file_path);
|
||||
candidates.push(documents_path.to_string_lossy().to_string());
|
||||
}
|
||||
|
||||
// 5. Try absolute path if it looks like a filename only
|
||||
if !file_path.contains('/') && !file_path.contains('\\') {
|
||||
// Try in documents directory
|
||||
let abs_documents_path = Path::new(&self.upload_path)
|
||||
.join("documents")
|
||||
.join(file_path);
|
||||
candidates.push(abs_documents_path.to_string_lossy().to_string());
|
||||
}
|
||||
|
||||
candidates
|
||||
}
|
||||
|
||||
/// Clear the path resolution cache (useful for testing or after file operations)
|
||||
pub async fn clear_path_cache(&self) {
|
||||
let mut cache = self.path_cache.write().await;
|
||||
cache.clear();
|
||||
debug!("Cleared path resolution cache");
|
||||
}
|
||||
|
||||
/// Invalidate cache entry for a specific path
|
||||
pub async fn invalidate_cache_entry(&self, file_path: &str) {
|
||||
let mut cache = self.path_cache.write().await;
|
||||
cache.remove(file_path);
|
||||
debug!("Invalidated cache entry for: {}", file_path);
|
||||
}
|
||||
|
||||
/// Save a file with generated UUID filename (legacy method)
|
||||
|
|
@ -112,7 +219,10 @@ impl LocalStorageBackend {
|
|||
#[async_trait]
|
||||
impl StorageBackend for LocalStorageBackend {
|
||||
async fn store_document(&self, _user_id: Uuid, document_id: Uuid, filename: &str, data: &[u8]) -> Result<String> {
|
||||
let extension = Path::new(filename)
|
||||
// Validate and sanitize the filename
|
||||
let sanitized_filename = validate_filename(filename)?;
|
||||
|
||||
let extension = Path::new(&sanitized_filename)
|
||||
.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.unwrap_or("");
|
||||
|
|
@ -126,13 +236,28 @@ impl StorageBackend for LocalStorageBackend {
|
|||
let documents_dir = self.get_documents_path();
|
||||
let file_path = documents_dir.join(&document_filename);
|
||||
|
||||
// Validate that the final path is within our base directory
|
||||
validate_path_within_base(
|
||||
&file_path.to_string_lossy(),
|
||||
&self.upload_path
|
||||
)?;
|
||||
|
||||
// Ensure the documents directory exists
|
||||
fs::create_dir_all(&documents_dir).await?;
|
||||
|
||||
// Validate data size (prevent extremely large files from causing issues)
|
||||
if data.len() > 1_000_000_000 { // 1GB limit
|
||||
return Err(anyhow::anyhow!("File too large for storage (max 1GB)"));
|
||||
}
|
||||
|
||||
fs::write(&file_path, data).await?;
|
||||
|
||||
// Invalidate any cached negative results for this path
|
||||
let path_str = file_path.to_string_lossy().to_string();
|
||||
self.invalidate_cache_entry(&path_str).await;
|
||||
|
||||
info!("Stored document locally: {}", file_path.display());
|
||||
Ok(file_path.to_string_lossy().to_string())
|
||||
Ok(path_str)
|
||||
}
|
||||
|
||||
async fn store_thumbnail(&self, _user_id: Uuid, document_id: Uuid, data: &[u8]) -> Result<String> {
|
||||
|
|
@ -142,10 +267,25 @@ impl StorageBackend for LocalStorageBackend {
|
|||
let thumbnail_filename = format!("{}_thumb.jpg", document_id);
|
||||
let thumbnail_path = thumbnails_dir.join(&thumbnail_filename);
|
||||
|
||||
// Validate that the final path is within our base directory
|
||||
validate_path_within_base(
|
||||
&thumbnail_path.to_string_lossy(),
|
||||
&self.upload_path
|
||||
)?;
|
||||
|
||||
// Validate data size for thumbnails (should be much smaller)
|
||||
if data.len() > 10_000_000 { // 10MB limit for thumbnails
|
||||
return Err(anyhow::anyhow!("Thumbnail too large (max 10MB)"));
|
||||
}
|
||||
|
||||
fs::write(&thumbnail_path, data).await?;
|
||||
|
||||
// Invalidate any cached negative results for this path
|
||||
let path_str = thumbnail_path.to_string_lossy().to_string();
|
||||
self.invalidate_cache_entry(&path_str).await;
|
||||
|
||||
info!("Stored thumbnail locally: {}", thumbnail_path.display());
|
||||
Ok(thumbnail_path.to_string_lossy().to_string())
|
||||
Ok(path_str)
|
||||
}
|
||||
|
||||
async fn store_processed_image(&self, _user_id: Uuid, document_id: Uuid, data: &[u8]) -> Result<String> {
|
||||
|
|
@ -155,15 +295,44 @@ impl StorageBackend for LocalStorageBackend {
|
|||
let processed_filename = format!("{}_processed.png", document_id);
|
||||
let processed_path = processed_dir.join(&processed_filename);
|
||||
|
||||
// Validate that the final path is within our base directory
|
||||
validate_path_within_base(
|
||||
&processed_path.to_string_lossy(),
|
||||
&self.upload_path
|
||||
)?;
|
||||
|
||||
// Validate data size for processed images
|
||||
if data.len() > 50_000_000 { // 50MB limit for processed images
|
||||
return Err(anyhow::anyhow!("Processed image too large (max 50MB)"));
|
||||
}
|
||||
|
||||
fs::write(&processed_path, data).await?;
|
||||
|
||||
// Invalidate any cached negative results for this path
|
||||
let path_str = processed_path.to_string_lossy().to_string();
|
||||
self.invalidate_cache_entry(&path_str).await;
|
||||
|
||||
info!("Stored processed image locally: {}", processed_path.display());
|
||||
Ok(processed_path.to_string_lossy().to_string())
|
||||
Ok(path_str)
|
||||
}
|
||||
|
||||
async fn retrieve_file(&self, path: &str) -> Result<Vec<u8>> {
|
||||
let resolved_path = self.resolve_file_path(path).await?;
|
||||
// Validate and sanitize the input path
|
||||
let sanitized_path = validate_and_sanitize_path(path)?;
|
||||
|
||||
let resolved_path = self.resolve_file_path(&sanitized_path).await?;
|
||||
|
||||
// Validate that the resolved path is within our base directory
|
||||
validate_path_within_base(&resolved_path, &self.upload_path)?;
|
||||
|
||||
let data = fs::read(&resolved_path).await?;
|
||||
|
||||
// Additional safety check on file size when reading
|
||||
if data.len() > 1_000_000_000 { // 1GB limit
|
||||
warn!("Attempted to read extremely large file: {} ({} bytes)", resolved_path, data.len());
|
||||
return Err(anyhow::anyhow!("File too large to read safely"));
|
||||
}
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
|
|
@ -172,16 +341,25 @@ impl StorageBackend for LocalStorageBackend {
|
|||
let mut serious_errors = Vec::new();
|
||||
|
||||
// Helper function to safely delete a file
|
||||
async fn safe_delete(path: &Path, serious_errors: &mut Vec<String>) -> Option<String> {
|
||||
let storage_backend = self;
|
||||
async fn safe_delete(path: &Path, serious_errors: &mut Vec<String>, backend: &LocalStorageBackend) -> Option<String> {
|
||||
match fs::remove_file(path).await {
|
||||
Ok(_) => {
|
||||
info!("Deleted file: {}", path.display());
|
||||
Some(path.to_string_lossy().to_string())
|
||||
let path_str = path.to_string_lossy().to_string();
|
||||
|
||||
// Invalidate cache entry for the deleted file
|
||||
backend.invalidate_cache_entry(&path_str).await;
|
||||
|
||||
Some(path_str)
|
||||
}
|
||||
Err(e) => {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::NotFound => {
|
||||
info!("File already deleted: {}", path.display());
|
||||
// Still invalidate cache in case it was cached as existing
|
||||
let path_str = path.to_string_lossy().to_string();
|
||||
backend.invalidate_cache_entry(&path_str).await;
|
||||
None
|
||||
}
|
||||
_ => {
|
||||
|
|
@ -223,7 +401,7 @@ impl StorageBackend for LocalStorageBackend {
|
|||
let mut main_file_deleted = false;
|
||||
for candidate_path in &main_file_candidates {
|
||||
if candidate_path.exists() {
|
||||
if let Some(deleted_path) = safe_delete(candidate_path, &mut serious_errors).await {
|
||||
if let Some(deleted_path) = safe_delete(candidate_path, &mut serious_errors, storage_backend).await {
|
||||
deleted_files.push(deleted_path);
|
||||
main_file_deleted = true;
|
||||
break; // Only delete the first match we find
|
||||
|
|
@ -238,14 +416,14 @@ impl StorageBackend for LocalStorageBackend {
|
|||
// Delete thumbnail if it exists
|
||||
let thumbnail_filename = format!("{}_thumb.jpg", document_id);
|
||||
let thumbnail_path = self.get_thumbnails_path().join(&thumbnail_filename);
|
||||
if let Some(deleted_path) = safe_delete(&thumbnail_path, &mut serious_errors).await {
|
||||
if let Some(deleted_path) = safe_delete(&thumbnail_path, &mut serious_errors, storage_backend).await {
|
||||
deleted_files.push(deleted_path);
|
||||
}
|
||||
|
||||
// Delete processed image if it exists
|
||||
let processed_image_filename = format!("{}_processed.png", document_id);
|
||||
let processed_image_path = self.get_processed_images_path().join(&processed_image_filename);
|
||||
if let Some(deleted_path) = safe_delete(&processed_image_path, &mut serious_errors).await {
|
||||
if let Some(deleted_path) = safe_delete(&processed_image_path, &mut serious_errors, storage_backend).await {
|
||||
deleted_files.push(deleted_path);
|
||||
}
|
||||
|
||||
|
|
@ -265,8 +443,20 @@ impl StorageBackend for LocalStorageBackend {
|
|||
}
|
||||
|
||||
async fn file_exists(&self, path: &str) -> Result<bool> {
|
||||
match self.resolve_file_path(path).await {
|
||||
Ok(_) => Ok(true),
|
||||
// Validate and sanitize the input path
|
||||
let sanitized_path = match validate_and_sanitize_path(path) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return Ok(false), // Invalid paths don't exist
|
||||
};
|
||||
|
||||
match self.resolve_file_path(&sanitized_path).await {
|
||||
Ok(resolved_path) => {
|
||||
// Additional validation that the resolved path is within base directory
|
||||
match validate_path_within_base(&resolved_path, &self.upload_path) {
|
||||
Ok(_) => Ok(true),
|
||||
Err(_) => Ok(false), // Paths outside base directory don't "exist" for us
|
||||
}
|
||||
}
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
pub mod debug;
|
||||
pub mod debug;
|
||||
pub mod security;
|
||||
|
|
@ -0,0 +1,231 @@
|
|||
//! Security utilities for input validation and sanitization
|
||||
|
||||
use anyhow::Result;
|
||||
use std::path::{Path, PathBuf, Component};
|
||||
use tracing::warn;
|
||||
|
||||
/// Validate and sanitize file paths to prevent path traversal attacks
|
||||
pub fn validate_and_sanitize_path(input_path: &str) -> Result<String> {
|
||||
// Check for null bytes (not allowed in file paths)
|
||||
if input_path.contains('\0') {
|
||||
return Err(anyhow::anyhow!("Path contains null bytes"));
|
||||
}
|
||||
|
||||
// Check for excessively long paths
|
||||
if input_path.len() > 4096 {
|
||||
return Err(anyhow::anyhow!("Path too long (max 4096 characters)"));
|
||||
}
|
||||
|
||||
// Convert to Path for normalization
|
||||
let path = Path::new(input_path);
|
||||
|
||||
// Check for path traversal attempts
|
||||
for component in path.components() {
|
||||
match component {
|
||||
Component::ParentDir => {
|
||||
warn!("Path traversal attempt detected: {}", input_path);
|
||||
return Err(anyhow::anyhow!("Path traversal not allowed"));
|
||||
}
|
||||
Component::Normal(name) => {
|
||||
let name_str = name.to_string_lossy();
|
||||
|
||||
// Check for dangerous file names
|
||||
if is_dangerous_filename(&name_str) {
|
||||
return Err(anyhow::anyhow!("Potentially dangerous filename: {}", name_str));
|
||||
}
|
||||
|
||||
// Check for control characters (except newline and tab which might be in file content)
|
||||
for ch in name_str.chars() {
|
||||
if ch.is_control() && ch != '\n' && ch != '\t' {
|
||||
return Err(anyhow::anyhow!("Filename contains control characters"));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {} // Allow root, current dir, and prefix components
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize the path to remove redundant components
|
||||
let normalized = normalize_path(path);
|
||||
Ok(normalized.to_string_lossy().to_string())
|
||||
}
|
||||
|
||||
/// Validate filename for document storage
|
||||
pub fn validate_filename(filename: &str) -> Result<String> {
|
||||
// Basic length check
|
||||
if filename.is_empty() {
|
||||
return Err(anyhow::anyhow!("Filename cannot be empty"));
|
||||
}
|
||||
|
||||
if filename.len() > 255 {
|
||||
return Err(anyhow::anyhow!("Filename too long (max 255 characters)"));
|
||||
}
|
||||
|
||||
// Check for null bytes
|
||||
if filename.contains('\0') {
|
||||
return Err(anyhow::anyhow!("Filename contains null bytes"));
|
||||
}
|
||||
|
||||
// Check for path separators (filenames should not contain them)
|
||||
if filename.contains('/') || filename.contains('\\') {
|
||||
return Err(anyhow::anyhow!("Filename cannot contain path separators"));
|
||||
}
|
||||
|
||||
// Check for control characters
|
||||
for ch in filename.chars() {
|
||||
if ch.is_control() && ch != '\n' && ch != '\t' {
|
||||
return Err(anyhow::anyhow!("Filename contains control characters"));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for dangerous patterns
|
||||
if is_dangerous_filename(filename) {
|
||||
return Err(anyhow::anyhow!("Potentially dangerous filename: {}", filename));
|
||||
}
|
||||
|
||||
// Sanitize the filename by replacing problematic characters
|
||||
let sanitized = sanitize_filename(filename);
|
||||
Ok(sanitized)
|
||||
}
|
||||
|
||||
/// Check if a filename is potentially dangerous
|
||||
fn is_dangerous_filename(filename: &str) -> bool {
|
||||
let filename_lower = filename.to_lowercase();
|
||||
|
||||
// Windows reserved names
|
||||
let reserved_names = [
|
||||
"con", "prn", "aux", "nul",
|
||||
"com1", "com2", "com3", "com4", "com5", "com6", "com7", "com8", "com9",
|
||||
"lpt1", "lpt2", "lpt3", "lpt4", "lpt5", "lpt6", "lpt7", "lpt8", "lpt9",
|
||||
];
|
||||
|
||||
// Check if filename (without extension) matches reserved names
|
||||
let name_without_ext = filename_lower.split('.').next().unwrap_or("");
|
||||
if reserved_names.contains(&name_without_ext) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if filename_lower.starts_with('.') && filename_lower.len() > 1 {
|
||||
// Allow common hidden files but reject suspicious ones
|
||||
let allowed_hidden = [".env", ".gitignore", ".htaccess"];
|
||||
if !allowed_hidden.iter().any(|&allowed| filename_lower.starts_with(allowed)) {
|
||||
// Be more permissive with document files that might have dots
|
||||
if !filename_lower.contains(&['.', 'd', 'o', 'c']) &&
|
||||
!filename_lower.contains(&['.', 'p', 'd', 'f']) &&
|
||||
!filename_lower.contains(&['.', 't', 'x', 't']) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Sanitize filename by replacing problematic characters
|
||||
fn sanitize_filename(filename: &str) -> String {
|
||||
let mut sanitized = String::new();
|
||||
|
||||
for ch in filename.chars() {
|
||||
match ch {
|
||||
// Replace problematic characters with underscores
|
||||
'<' | '>' | ':' | '"' | '|' | '?' | '*' => sanitized.push('_'),
|
||||
// Allow most other characters
|
||||
_ if !ch.is_control() || ch == '\n' || ch == '\t' => sanitized.push(ch),
|
||||
// Skip control characters
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Trim whitespace from ends
|
||||
sanitized.trim().to_string()
|
||||
}
|
||||
|
||||
/// Normalize a path by resolving . and .. components without filesystem access
|
||||
fn normalize_path(path: &Path) -> PathBuf {
|
||||
let mut normalized = PathBuf::new();
|
||||
|
||||
for component in path.components() {
|
||||
match component {
|
||||
Component::Normal(_) | Component::RootDir | Component::Prefix(_) => {
|
||||
normalized.push(component);
|
||||
}
|
||||
Component::CurDir => {
|
||||
// Skip current directory references
|
||||
}
|
||||
Component::ParentDir => {
|
||||
// This should have been caught earlier, but handle it safely
|
||||
if normalized.parent().is_some() {
|
||||
normalized.pop();
|
||||
}
|
||||
// If we can't go up, just ignore the .. component
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
normalized
|
||||
}
|
||||
|
||||
/// Validate that a path is within the allowed base directory
|
||||
pub fn validate_path_within_base(path: &str, base_dir: &str) -> Result<()> {
|
||||
let path_buf = PathBuf::from(path);
|
||||
let base_buf = PathBuf::from(base_dir);
|
||||
|
||||
// Canonicalize if possible, but don't fail if paths don't exist yet
|
||||
let canonical_path = path_buf.canonicalize().unwrap_or_else(|_| {
|
||||
// If canonicalization fails, do our best with normalization
|
||||
normalize_path(&path_buf)
|
||||
});
|
||||
|
||||
let canonical_base = base_buf.canonicalize().unwrap_or_else(|_| {
|
||||
normalize_path(&base_buf)
|
||||
});
|
||||
|
||||
if !canonical_path.starts_with(&canonical_base) {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Path '{}' is not within allowed base directory '{}'",
|
||||
path, base_dir
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_filename() {
|
||||
// Valid filenames
|
||||
assert!(validate_filename("document.pdf").is_ok());
|
||||
assert!(validate_filename("my-file_2023.docx").is_ok());
|
||||
assert!(validate_filename("report (final).txt").is_ok());
|
||||
|
||||
// Invalid filenames
|
||||
assert!(validate_filename("").is_err());
|
||||
assert!(validate_filename("../etc/passwd").is_err());
|
||||
assert!(validate_filename("file\0name.txt").is_err());
|
||||
assert!(validate_filename("con.txt").is_err());
|
||||
assert!(validate_filename("file/name.txt").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path() {
|
||||
// Valid paths
|
||||
assert!(validate_and_sanitize_path("documents/file.pdf").is_ok());
|
||||
assert!(validate_and_sanitize_path("./uploads/document.txt").is_ok());
|
||||
|
||||
// Invalid paths
|
||||
assert!(validate_and_sanitize_path("../../../etc/passwd").is_err());
|
||||
assert!(validate_and_sanitize_path("documents/../config.txt").is_err());
|
||||
assert!(validate_and_sanitize_path("file\0name.txt").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_filename() {
|
||||
assert_eq!(sanitize_filename("file<>name.txt"), "file__name.txt");
|
||||
assert_eq!(sanitize_filename(" report.pdf "), "report.pdf");
|
||||
assert_eq!(sanitize_filename("file:name|test.doc"), "file_name_test.doc");
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,8 @@
|
|||
mod tests {
|
||||
use readur::ocr::enhanced::{EnhancedOcrService, OcrResult, ImageQualityStats};
|
||||
use readur::models::Settings;
|
||||
use readur::services::file_service::FileService;
|
||||
use readur::storage::{StorageConfig, factory::create_storage_backend};
|
||||
use std::fs;
|
||||
use tempfile::{NamedTempFile, TempDir};
|
||||
|
||||
|
|
@ -13,18 +15,25 @@ mod tests {
|
|||
TempDir::new().expect("Failed to create temp directory")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enhanced_ocr_service_creation() {
|
||||
async fn create_test_file_service(temp_path: &str) -> FileService {
|
||||
let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() };
|
||||
let storage_backend = create_storage_backend(storage_config).await.unwrap();
|
||||
FileService::with_storage(temp_path.to_string(), storage_backend)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_enhanced_ocr_service_creation() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
|
||||
// Service should be created successfully
|
||||
assert!(!service.temp_dir.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_quality_stats_creation() {
|
||||
#[tokio::test]
|
||||
async fn test_image_quality_stats_creation() {
|
||||
let stats = ImageQualityStats {
|
||||
average_brightness: 128.0,
|
||||
contrast_ratio: 0.5,
|
||||
|
|
@ -38,11 +47,12 @@ mod tests {
|
|||
assert_eq!(stats.sharpness, 0.8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_words_safely_whitespace_separated() {
|
||||
#[tokio::test]
|
||||
async fn test_count_words_safely_whitespace_separated() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
|
||||
// Test normal whitespace-separated text
|
||||
let text = "Hello world this is a test";
|
||||
|
|
@ -55,11 +65,12 @@ mod tests {
|
|||
assert_eq!(count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_words_safely_continuous_text() {
|
||||
#[tokio::test]
|
||||
async fn test_count_words_safely_continuous_text() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
|
||||
// Test continuous text without spaces (like some PDF extractions)
|
||||
let text = "HelloWorldThisIsAContinuousText";
|
||||
|
|
@ -72,11 +83,12 @@ mod tests {
|
|||
assert!(count > 0, "Should detect alphanumeric patterns as words");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_words_safely_edge_cases() {
|
||||
#[tokio::test]
|
||||
async fn test_count_words_safely_edge_cases() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
|
||||
// Test empty text
|
||||
let count = service.count_words_safely("");
|
||||
|
|
@ -102,11 +114,12 @@ mod tests {
|
|||
assert!(count > 0, "Should detect words in mixed content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_words_safely_large_text() {
|
||||
#[tokio::test]
|
||||
async fn test_count_words_safely_large_text() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
|
||||
// Test with large text (over 1MB) to trigger sampling
|
||||
let word = "test ";
|
||||
|
|
@ -118,11 +131,12 @@ mod tests {
|
|||
assert!(count <= 10_000_000, "Should cap at max limit: got {}", count);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_words_safely_fallback_patterns() {
|
||||
#[tokio::test]
|
||||
async fn test_count_words_safely_fallback_patterns() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
|
||||
// Test letter transition detection
|
||||
let text = "OneWordAnotherWordFinalWord";
|
||||
|
|
@ -140,8 +154,8 @@ mod tests {
|
|||
assert!(count >= 1, "Should detect words in mixed alphanumeric: got {}", count);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_result_structure() {
|
||||
#[tokio::test]
|
||||
async fn test_ocr_result_structure() {
|
||||
let result = OcrResult {
|
||||
text: "Test text".to_string(),
|
||||
confidence: 85.5,
|
||||
|
|
@ -162,7 +176,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_extract_text_from_plain_text() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let temp_file = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
|
|
@ -185,7 +201,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_extract_text_with_context() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let temp_file = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
|
|
@ -211,7 +229,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_extract_text_unsupported_mime_type() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
|
|
@ -229,7 +249,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_extract_text_nonexistent_file() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let result = service
|
||||
|
|
@ -242,7 +264,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_extract_text_large_file_truncation() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let temp_file = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
|
|
@ -262,10 +286,12 @@ mod tests {
|
|||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_validate_ocr_quality_high_confidence() {
|
||||
#[tokio::test]
|
||||
async fn test_validate_ocr_quality_high_confidence() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let mut settings = create_test_settings();
|
||||
settings.ocr_min_confidence = 30.0;
|
||||
|
||||
|
|
@ -283,10 +309,12 @@ mod tests {
|
|||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_validate_ocr_quality_low_confidence() {
|
||||
#[tokio::test]
|
||||
async fn test_validate_ocr_quality_low_confidence() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let mut settings = create_test_settings();
|
||||
settings.ocr_min_confidence = 50.0;
|
||||
|
||||
|
|
@ -304,10 +332,12 @@ mod tests {
|
|||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_validate_ocr_quality_no_words() {
|
||||
#[tokio::test]
|
||||
async fn test_validate_ocr_quality_no_words() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let result = OcrResult {
|
||||
|
|
@ -324,10 +354,12 @@ mod tests {
|
|||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_validate_ocr_quality_poor_character_distribution() {
|
||||
#[tokio::test]
|
||||
async fn test_validate_ocr_quality_poor_character_distribution() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let result = OcrResult {
|
||||
|
|
@ -344,10 +376,12 @@ mod tests {
|
|||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_validate_ocr_quality_good_character_distribution() {
|
||||
#[tokio::test]
|
||||
async fn test_validate_ocr_quality_good_character_distribution() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let result = OcrResult {
|
||||
|
|
@ -366,7 +400,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_word_count_calculation() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let test_cases = vec![
|
||||
|
|
@ -395,7 +431,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_pdf_extraction_with_invalid_pdf() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let temp_file = NamedTempFile::with_suffix(".pdf").unwrap();
|
||||
|
|
@ -413,7 +451,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_pdf_extraction_with_minimal_valid_pdf() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
// Minimal PDF with "Hello" text
|
||||
|
|
@ -485,7 +525,9 @@ startxref
|
|||
#[tokio::test]
|
||||
async fn test_pdf_size_limit() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let temp_file = NamedTempFile::with_suffix(".pdf").unwrap();
|
||||
|
|
@ -503,8 +545,8 @@ startxref
|
|||
assert!(error_msg.contains("too large"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_settings_default_values() {
|
||||
#[tokio::test]
|
||||
async fn test_settings_default_values() {
|
||||
let settings = Settings::default();
|
||||
|
||||
// Test that OCR-related settings have reasonable defaults
|
||||
|
|
@ -521,7 +563,9 @@ startxref
|
|||
#[tokio::test]
|
||||
async fn test_concurrent_ocr_processing() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path, file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
let mut handles = vec![];
|
||||
|
|
@ -532,7 +576,9 @@ startxref
|
|||
let content = format!("Concurrent test content {}", i);
|
||||
fs::write(temp_file.path(), &content).unwrap();
|
||||
|
||||
let service_clone = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
|
||||
let temp_path_clone = temp_dir.path().to_str().unwrap().to_string();
|
||||
let file_service_clone = create_test_file_service(&temp_path_clone).await;
|
||||
let service_clone = EnhancedOcrService::new(temp_path_clone, file_service_clone);
|
||||
let settings_clone = settings.clone();
|
||||
let file_path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,12 @@ use readur::{
|
|||
db_guardrails_simple::DocumentTransactionManager,
|
||||
};
|
||||
|
||||
async fn create_test_file_service(temp_path: &str) -> FileService {
|
||||
let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() };
|
||||
let storage_backend = create_storage_backend(storage_config).await.unwrap();
|
||||
FileService::with_storage(temp_path.to_string(), storage_backend)
|
||||
}
|
||||
|
||||
struct OCRPipelineTestHarness {
|
||||
db: Database,
|
||||
pool: PgPool,
|
||||
|
|
@ -329,7 +335,8 @@ impl OCRPipelineTestHarness {
|
|||
// Clone the components we need rather than the whole harness
|
||||
let queue_service = self.queue_service.clone();
|
||||
let transaction_manager = self.transaction_manager.clone();
|
||||
let ocr_service = EnhancedOcrService::new("/tmp".to_string());
|
||||
let file_service = create_test_file_service("/tmp").await;
|
||||
let ocr_service = EnhancedOcrService::new("/tmp".to_string(), file_service);
|
||||
let pool = self.pool.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
mod pdf_word_count_integration_tests {
|
||||
use readur::ocr::enhanced::EnhancedOcrService;
|
||||
use readur::models::Settings;
|
||||
use readur::services::file_service::FileService;
|
||||
use readur::storage::{StorageConfig, factory::create_storage_backend};
|
||||
use std::io::Write;
|
||||
use tempfile::{NamedTempFile, TempDir};
|
||||
|
||||
|
|
@ -13,6 +15,12 @@ mod pdf_word_count_integration_tests {
|
|||
TempDir::new().expect("Failed to create temp directory")
|
||||
}
|
||||
|
||||
async fn create_test_file_service(temp_path: &str) -> FileService {
|
||||
let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() };
|
||||
let storage_backend = create_storage_backend(storage_config).await.unwrap();
|
||||
FileService::with_storage(temp_path.to_string(), storage_backend)
|
||||
}
|
||||
|
||||
/// Create a mock PDF with specific text patterns for testing
|
||||
fn create_mock_pdf_file(content: &str) -> NamedTempFile {
|
||||
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
|
||||
|
|
@ -82,7 +90,8 @@ mod pdf_word_count_integration_tests {
|
|||
async fn test_pdf_extraction_with_normal_text() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
// Create a PDF with normal spaced text
|
||||
|
|
@ -108,7 +117,8 @@ mod pdf_word_count_integration_tests {
|
|||
async fn test_pdf_extraction_with_continuous_text() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
// Create a PDF with continuous text (no spaces)
|
||||
|
|
@ -136,7 +146,8 @@ mod pdf_word_count_integration_tests {
|
|||
async fn test_pdf_extraction_with_mixed_content() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
// Create a PDF with mixed content (letters, numbers, punctuation)
|
||||
|
|
@ -159,7 +170,8 @@ mod pdf_word_count_integration_tests {
|
|||
async fn test_pdf_extraction_empty_content() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
// Create a PDF with only whitespace/empty content
|
||||
|
|
@ -185,7 +197,8 @@ mod pdf_word_count_integration_tests {
|
|||
async fn test_pdf_extraction_punctuation_only() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
// Create a PDF with only punctuation
|
||||
|
|
@ -212,7 +225,8 @@ mod pdf_word_count_integration_tests {
|
|||
async fn test_pdf_quality_validation() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
let settings = create_test_settings();
|
||||
|
||||
// Use a real test PDF file if available
|
||||
|
|
@ -261,7 +275,8 @@ mod pdf_word_count_integration_tests {
|
|||
async fn test_pdf_file_size_validation() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let _temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let _service = EnhancedOcrService::new(_temp_path);
|
||||
let _file_service = create_test_file_service(&_temp_path).await;
|
||||
let _service = EnhancedOcrService::new(_temp_path.clone(), _file_service);
|
||||
let _settings = create_test_settings();
|
||||
|
||||
// Create a small PDF file to test file operations
|
||||
|
|
@ -278,11 +293,12 @@ mod pdf_word_count_integration_tests {
|
|||
assert!(metadata.len() < 100 * 1024 * 1024, "Test PDF should be under size limit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_word_counting_regression_cases() {
|
||||
#[tokio::test]
|
||||
async fn test_word_counting_regression_cases() {
|
||||
let temp_dir = create_temp_dir();
|
||||
let temp_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let service = EnhancedOcrService::new(temp_path);
|
||||
let file_service = create_test_file_service(&temp_path).await;
|
||||
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
|
||||
|
||||
// Regression test cases for the specific PDF issue
|
||||
let test_cases = vec![
|
||||
|
|
|
|||
|
|
@ -296,7 +296,7 @@ async fn test_user_watch_directory_file_processing_simulation() -> Result<()> {
|
|||
|
||||
// Create user watch manager to test file path mapping
|
||||
let user_watch_service = state.user_watch_service.as_ref().unwrap();
|
||||
let user_watch_manager = readur::scheduling::user_watch_manager::UserWatchManager::new(state.db.clone(), (**user_watch_service).clone());
|
||||
let user_watch_manager = readur::scheduling::user_watch_manager::UserWatchManager::new(state.db.clone(), Arc::clone(user_watch_service));
|
||||
|
||||
// Create test user
|
||||
let test_user = readur::models::User {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ use readur::{
|
|||
db::Database,
|
||||
ocr::queue::OcrQueueService,
|
||||
ocr::enhanced::EnhancedOcrService,
|
||||
services::file_service::FileService,
|
||||
storage::{StorageConfig, factory::create_storage_backend},
|
||||
};
|
||||
|
||||
// Use the same database URL as the running server
|
||||
|
|
@ -25,6 +27,12 @@ fn get_test_db_url() -> String {
|
|||
.unwrap_or_else(|_| "postgresql://readur:readur@localhost:5432/readur".to_string())
|
||||
}
|
||||
|
||||
async fn create_test_file_service(temp_path: &str) -> FileService {
|
||||
let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() };
|
||||
let storage_backend = create_storage_backend(storage_config).await.unwrap();
|
||||
FileService::with_storage(temp_path.to_string(), storage_backend)
|
||||
}
|
||||
|
||||
struct SimpleThrottleTest {
|
||||
pool: PgPool,
|
||||
queue_service: Arc<OcrQueueService>,
|
||||
|
|
@ -140,7 +148,8 @@ impl SimpleThrottleTest {
|
|||
|
||||
let handle = tokio::spawn(async move {
|
||||
let worker_name = format!("worker-{}", worker_id);
|
||||
let ocr_service = EnhancedOcrService::new("/tmp".to_string());
|
||||
let file_service = create_test_file_service("/tmp").await;
|
||||
let ocr_service = EnhancedOcrService::new("/tmp".to_string(), file_service);
|
||||
let mut jobs_processed = 0;
|
||||
|
||||
info!("Worker {} starting", worker_name);
|
||||
|
|
|
|||
Loading…
Reference in New Issue