feat(storage): further support the s3 storage backend

This commit is contained in:
perf3ct 2025-08-01 17:57:09 +00:00
parent 6624fc57fb
commit 862c36aa72
No known key found for this signature in database
GPG Key ID: 569C4EEC436F5232
16 changed files with 1041 additions and 183 deletions

View File

@ -8,12 +8,15 @@
//! 3. Upload files to S3 with proper structure
//! 4. Update database records with S3 paths
//! 5. Optionally delete local files after successful upload
//! 6. Support rollback on failure with transaction-like behavior
use anyhow::Result;
use clap::Parser;
use std::path::Path;
use std::collections::{HashMap, VecDeque};
use uuid::Uuid;
use tracing::{info, warn, error};
use serde::{Serialize, Deserialize};
use readur::{
config::Config,
@ -21,6 +24,41 @@ use readur::{
services::{s3_service::S3Service, file_service::FileService},
};
/// Migration state tracking for rollback functionality
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MigrationState {
started_at: chrono::DateTime<chrono::Utc>,
completed_migrations: Vec<MigrationRecord>,
failed_migrations: Vec<FailedMigration>,
total_files: usize,
processed_files: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MigrationRecord {
document_id: Uuid,
user_id: Uuid,
original_path: String,
s3_key: String,
migrated_at: chrono::DateTime<chrono::Utc>,
associated_files: Vec<AssociatedFile>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AssociatedFile {
file_type: String, // "thumbnail" or "processed_image"
original_path: String,
s3_key: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct FailedMigration {
document_id: Uuid,
original_path: String,
error: String,
failed_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Parser)]
#[command(name = "migrate_to_s3")]
#[command(about = "Migrate existing local files to S3 storage")]
@ -40,6 +78,14 @@ struct Args {
/// Only migrate files for specific user ID
#[arg(short, long)]
user_id: Option<String>,
/// Enable rollback on failure - will revert any successful migrations if overall process fails
#[arg(long)]
enable_rollback: bool,
/// Resume from a specific document ID (for partial recovery)
#[arg(long)]
resume_from: Option<String>,
}
#[tokio::main]
@ -132,41 +178,107 @@ async fn main() -> Result<()> {
return Ok(());
}
// Perform migration
// Initialize migration state
let mut migration_state = MigrationState {
started_at: chrono::Utc::now(),
completed_migrations: Vec::new(),
failed_migrations: Vec::new(),
total_files: local_documents.len(),
processed_files: 0,
};
// Resume from specific document if requested
let start_index = if let Some(resume_from_str) = &args.resume_from {
let resume_doc_id = Uuid::parse_str(resume_from_str)?;
local_documents.iter().position(|doc| doc.id == resume_doc_id)
.unwrap_or(0)
} else {
0
};
info!("📊 Migration plan: {} files to process (starting from index {})",
local_documents.len() - start_index, start_index);
// Perform migration with progress tracking
let mut migrated_count = 0;
let mut failed_count = 0;
for doc in local_documents {
info!("📦 Migrating: {} ({})", doc.original_filename, doc.id);
for (index, doc) in local_documents.iter().enumerate().skip(start_index) {
info!("📦 Migrating: {} ({}) [{}/{}]",
doc.original_filename, doc.id, index + 1, local_documents.len());
match migrate_document(&db, &s3_service, &file_service, &doc, args.delete_local).await {
Ok(_) => {
match migrate_document_with_tracking(&db, &s3_service, &file_service, doc, args.delete_local).await {
Ok(migration_record) => {
migrated_count += 1;
info!("✅ Successfully migrated: {}", doc.original_filename);
migration_state.completed_migrations.push(migration_record);
migration_state.processed_files += 1;
info!("✅ Successfully migrated: {} [{}/{}]",
doc.original_filename, migrated_count, local_documents.len());
// Save progress periodically (every 10 files)
if migrated_count % 10 == 0 {
save_migration_state(&migration_state).await?;
}
}
Err(e) => {
failed_count += 1;
let failed_migration = FailedMigration {
document_id: doc.id,
original_path: doc.file_path.clone(),
error: e.to_string(),
failed_at: chrono::Utc::now(),
};
migration_state.failed_migrations.push(failed_migration);
migration_state.processed_files += 1;
error!("❌ Failed to migrate {}: {}", doc.original_filename, e);
// If rollback is enabled and we have failures, offer to rollback
if args.enable_rollback && failed_count > 0 {
error!("💥 Migration failure detected with rollback enabled!");
error!("Do you want to rollback all {} successful migrations? (y/N)", migrated_count);
// For automation, we'll automatically rollback on any failure
// In interactive mode, you could read from stdin here
warn!("🔄 Automatically initiating rollback due to failure...");
match rollback_migrations(&db, &s3_service, &migration_state).await {
Ok(rolled_back) => {
error!("🔄 Successfully rolled back {} migrations", rolled_back);
return Err(anyhow::anyhow!("Migration failed and was rolled back"));
}
Err(rollback_err) => {
error!("💥 CRITICAL: Rollback failed: {}", rollback_err);
error!("💾 Migration state saved for manual recovery");
save_migration_state(&migration_state).await?;
return Err(anyhow::anyhow!(
"Migration failed and rollback also failed. Check migration state file for manual recovery."
));
}
}
}
}
}
}
// Save final migration state
save_migration_state(&migration_state).await?;
info!("🎉 Migration completed!");
info!("✅ Successfully migrated: {} files", migrated_count);
if failed_count > 0 {
warn!("❌ Failed to migrate: {} files", failed_count);
warn!("💾 Check migration_state.json for details on failures");
}
Ok(())
}
async fn migrate_document(
async fn migrate_document_with_tracking(
db: &Database,
s3_service: &S3Service,
file_service: &FileService,
document: &readur::models::Document,
delete_local: bool,
) -> Result<()> {
) -> Result<MigrationRecord> {
// Read local file
let local_path = Path::new(&document.file_path);
if !local_path.exists() {
@ -189,7 +301,7 @@ async fn migrate_document(
db.update_document_file_path(document.id, &s3_path).await?;
// Migrate associated files (thumbnails, processed images)
migrate_associated_files(s3_service, file_service, document, delete_local).await?;
let associated_files = migrate_associated_files_with_tracking(s3_service, file_service, document, delete_local).await?;
// Delete local file if requested
if delete_local {
@ -200,27 +312,46 @@ async fn migrate_document(
}
}
Ok(())
// Create migration record for tracking
let migration_record = MigrationRecord {
document_id: document.id,
user_id: document.user_id,
original_path: document.file_path.clone(),
s3_key,
migrated_at: chrono::Utc::now(),
associated_files,
};
Ok(migration_record)
}
async fn migrate_associated_files(
async fn migrate_associated_files_with_tracking(
s3_service: &S3Service,
file_service: &FileService,
document: &readur::models::Document,
delete_local: bool,
) -> Result<()> {
) -> Result<Vec<AssociatedFile>> {
let mut associated_files = Vec::new();
// Migrate thumbnail
let thumbnail_path = file_service.get_thumbnails_path().join(format!("{}_thumb.jpg", document.id));
if thumbnail_path.exists() {
match tokio::fs::read(&thumbnail_path).await {
Ok(thumbnail_data) => {
if let Err(e) = s3_service.store_thumbnail(document.user_id, document.id, &thumbnail_data).await {
warn!("Failed to migrate thumbnail for {}: {}", document.id, e);
} else {
info!("📸 Migrated thumbnail for: {}", document.original_filename);
if delete_local {
let _ = tokio::fs::remove_file(&thumbnail_path).await;
match s3_service.store_thumbnail(document.user_id, document.id, &thumbnail_data).await {
Ok(s3_key) => {
info!("📸 Migrated thumbnail for: {}", document.original_filename);
associated_files.push(AssociatedFile {
file_type: "thumbnail".to_string(),
original_path: thumbnail_path.to_string_lossy().to_string(),
s3_key,
});
if delete_local {
let _ = tokio::fs::remove_file(&thumbnail_path).await;
}
}
Err(e) => {
warn!("Failed to migrate thumbnail for {}: {}", document.id, e);
}
}
}
@ -233,12 +364,20 @@ async fn migrate_associated_files(
if processed_path.exists() {
match tokio::fs::read(&processed_path).await {
Ok(processed_data) => {
if let Err(e) = s3_service.store_processed_image(document.user_id, document.id, &processed_data).await {
warn!("Failed to migrate processed image for {}: {}", document.id, e);
} else {
info!("🖼️ Migrated processed image for: {}", document.original_filename);
if delete_local {
let _ = tokio::fs::remove_file(&processed_path).await;
match s3_service.store_processed_image(document.user_id, document.id, &processed_data).await {
Ok(s3_key) => {
info!("🖼️ Migrated processed image for: {}", document.original_filename);
associated_files.push(AssociatedFile {
file_type: "processed_image".to_string(),
original_path: processed_path.to_string_lossy().to_string(),
s3_key,
});
if delete_local {
let _ = tokio::fs::remove_file(&processed_path).await;
}
}
Err(e) => {
warn!("Failed to migrate processed image for {}: {}", document.id, e);
}
}
}
@ -246,5 +385,94 @@ async fn migrate_associated_files(
}
}
Ok(associated_files)
}
/// Save migration state to disk for recovery purposes
async fn save_migration_state(state: &MigrationState) -> Result<()> {
let state_json = serde_json::to_string_pretty(state)?;
tokio::fs::write("migration_state.json", state_json).await?;
info!("💾 Migration state saved to migration_state.json");
Ok(())
}
/// Rollback migrations by restoring database paths and deleting S3 objects
async fn rollback_migrations(
db: &Database,
s3_service: &S3Service,
state: &MigrationState,
) -> Result<usize> {
info!("🔄 Starting rollback of {} migrations...", state.completed_migrations.len());
let mut rolled_back = 0;
let mut rollback_errors = Vec::new();
// Process migrations in reverse order (most recent first)
for migration in state.completed_migrations.iter().rev() {
info!("🔄 Rolling back migration for document {}", migration.document_id);
// Restore original database path
match db.update_document_file_path(migration.document_id, &migration.original_path).await {
Ok(_) => {
info!("✅ Restored database path for document {}", migration.document_id);
}
Err(e) => {
let error_msg = format!("Failed to restore DB path for {}: {}", migration.document_id, e);
error!("❌ {}", error_msg);
rollback_errors.push(error_msg);
continue; // Skip S3 cleanup if DB restore failed
}
}
// Delete S3 object (main document)
match s3_service.delete_file(&migration.s3_key).await {
Ok(_) => {
info!("🗑️ Deleted S3 object: {}", migration.s3_key);
}
Err(e) => {
let error_msg = format!("Failed to delete S3 object {}: {}", migration.s3_key, e);
warn!("⚠️ {}", error_msg);
rollback_errors.push(error_msg);
// Continue with associated files even if main file deletion failed
}
}
// Delete associated S3 objects (thumbnails, processed images)
for associated in &migration.associated_files {
match s3_service.delete_file(&associated.s3_key).await {
Ok(_) => {
info!("🗑️ Deleted associated S3 object: {} ({})", associated.s3_key, associated.file_type);
}
Err(e) => {
let error_msg = format!("Failed to delete associated S3 object {}: {}", associated.s3_key, e);
warn!("⚠️ {}", error_msg);
rollback_errors.push(error_msg);
}
}
}
rolled_back += 1;
}
if !rollback_errors.is_empty() {
warn!("⚠️ Rollback completed with {} errors:", rollback_errors.len());
for error in &rollback_errors {
warn!(" - {}", error);
}
// Save error details for manual cleanup
let error_state = serde_json::json!({
"rollback_completed_at": chrono::Utc::now(),
"rolled_back_count": rolled_back,
"rollback_errors": rollback_errors,
"original_migration_state": state
});
let error_json = serde_json::to_string_pretty(&error_state)?;
tokio::fs::write("rollback_errors.json", error_json).await?;
warn!("💾 Rollback errors saved to rollback_errors.json for manual cleanup");
}
info!("✅ Rollback completed: {} migrations processed", rolled_back);
Ok(rolled_back)
}

View File

@ -54,35 +54,7 @@ impl Config {
// Database Configuration
let database_url = match env::var("DATABASE_URL") {
Ok(val) => {
// Mask sensitive parts of database URL for logging
let masked_url = if val.contains('@') {
let parts: Vec<&str> = val.split('@').collect();
if parts.len() >= 2 {
let credentials_part = parts[0];
let remaining_part = parts[1..].join("@");
// Extract just the username part before the password
if let Some(username_start) = credentials_part.rfind("://") {
let protocol = &credentials_part[..username_start + 3];
let credentials = &credentials_part[username_start + 3..];
if let Some(colon_pos) = credentials.find(':') {
let username = &credentials[..colon_pos];
// Show first and last character of the username
let masked_username = format!("{}{}", &username[..1], &username[username.len() - 1..]);
format!("{}{}:***@{}", protocol, masked_username, remaining_part)
} else {
format!("{}***@{}", protocol, remaining_part)
}
} else {
"***masked***".to_string()
}
} else {
"***masked***".to_string()
}
} else {
val.clone()
};
println!("✅ DATABASE_URL: {} (loaded from env)", masked_url);
println!("✅ DATABASE_URL: configured (loaded from env)");
val
}
Err(_) => {
@ -462,10 +434,8 @@ impl Config {
if !bucket_name.is_empty() && !access_key_id.is_empty() && !secret_access_key.is_empty() {
println!("✅ S3_BUCKET_NAME: {} (loaded from env)", bucket_name);
println!("✅ S3_REGION: {} (loaded from env)", region);
println!("✅ S3_ACCESS_KEY_ID: {}***{} (loaded from env)",
&access_key_id[..2.min(access_key_id.len())],
&access_key_id[access_key_id.len().saturating_sub(2)..]);
println!("✅ S3_SECRET_ACCESS_KEY: ***hidden*** (loaded from env, {} chars)", secret_access_key.len());
println!("✅ S3_ACCESS_KEY_ID: configured (loaded from env)");
println!("✅ S3_SECRET_ACCESS_KEY: configured (loaded from env)");
if let Some(ref endpoint) = endpoint_url {
println!("✅ S3_ENDPOINT_URL: {} (loaded from env)", endpoint);
}

View File

@ -121,15 +121,35 @@ async fn main() -> anyhow::Result<()> {
println!("📁 Upload directory: {}", config.upload_path);
println!("👁️ Watch directory: {}", config.watch_folder);
// Initialize file service using the new storage backend architecture
// Initialize file service using the new storage backend architecture with fallback
info!("Initializing file service with storage backend...");
let storage_config = readur::storage::factory::storage_config_from_env(&config)?;
let file_service = readur::services::file_service::FileService::from_config(storage_config, config.upload_path.clone()).await
.map_err(|e| {
error!("Failed to initialize file service with configured storage backend: {}", e);
e
})?;
info!("✅ File service initialized with {} storage backend", file_service.storage_type());
let file_service = match readur::services::file_service::FileService::from_config(storage_config, config.upload_path.clone()).await {
Ok(service) => {
info!("✅ File service initialized with {} storage backend", service.storage_type());
service
}
Err(e) => {
error!("❌ Failed to initialize configured storage backend: {}", e);
warn!("🔄 Falling back to local storage...");
// Create fallback local storage configuration
let fallback_config = readur::storage::StorageConfig::Local {
upload_path: config.upload_path.clone(),
};
match readur::services::file_service::FileService::from_config(fallback_config, config.upload_path.clone()).await {
Ok(fallback_service) => {
warn!("✅ Successfully initialized fallback local storage");
fallback_service
}
Err(fallback_err) => {
error!("💥 CRITICAL: Even fallback local storage failed to initialize: {}", fallback_err);
return Err(fallback_err.into());
}
}
}
};
// Initialize the storage backend (creates directories, validates access, etc.)
if let Err(e) = file_service.initialize_storage().await {
@ -159,7 +179,6 @@ async fn main() -> anyhow::Result<()> {
}
Err(e) => {
println!("❌ CRITICAL: Failed to connect to database for web operations!");
println!("Database URL: {}", db_info); // Use the already-masked URL
println!("Error: {}", e);
println!("\n🔧 Please verify:");
println!(" - Database server is running");

View File

@ -45,14 +45,6 @@ impl EnhancedOcrService {
Self { temp_dir, file_service }
}
/// Backward-compatible constructor for tests and legacy code
/// Creates a FileService with local storage using UPLOAD_PATH env var
#[deprecated(note = "Use new() with FileService parameter instead")]
pub fn new_legacy(temp_dir: String) -> Self {
let upload_path = std::env::var("UPLOAD_PATH").unwrap_or_else(|_| "./uploads".to_string());
let file_service = FileService::new(upload_path);
Self::new(temp_dir, file_service)
}
/// Extract text from image with high-quality OCR settings
#[cfg(feature = "ocr")]

View File

@ -75,14 +75,6 @@ impl OcrQueueService {
}
}
/// Backward-compatible constructor for tests and legacy code
/// Creates a FileService with local storage using UPLOAD_PATH env var
#[deprecated(note = "Use new() with FileService parameter instead")]
pub fn new_legacy(db: Database, pool: PgPool, max_concurrent_jobs: usize) -> Self {
let upload_path = std::env::var("UPLOAD_PATH").unwrap_or_else(|_| "./uploads".to_string());
let file_service = std::sync::Arc::new(crate::services::file_service::FileService::new(upload_path));
Self::new(db, pool, max_concurrent_jobs, file_service)
}
/// Add a document to the OCR queue
pub async fn enqueue_document(&self, document_id: Uuid, priority: i32, file_size: i64) -> Result<Uuid> {

View File

@ -202,10 +202,10 @@ pub async fn upload_document(
}
// Create ingestion service
let file_service = &state.file_service;
let file_service_clone = state.file_service.as_ref().clone();
let ingestion_service = DocumentIngestionService::new(
state.db.clone(),
(**file_service).clone(),
file_service_clone,
);
debug!("[UPLOAD_DEBUG] Calling ingestion service for file: {}", filename);

View File

@ -315,8 +315,8 @@ async fn process_single_file(
debug!("Downloaded file: {} ({} bytes)", file_info.name, file_data.len());
// Use the unified ingestion service for consistent deduplication
let file_service = &state.file_service;
let ingestion_service = DocumentIngestionService::new(state.db.clone(), (**file_service).clone());
let file_service_clone = state.file_service.as_ref().clone();
let ingestion_service = DocumentIngestionService::new(state.db.clone(), file_service_clone);
let result = if let Some(source_id) = webdav_source_id {
ingestion_service

View File

@ -6,6 +6,8 @@ use serde_json;
use std::collections::HashMap;
use std::time::Duration;
use uuid::Uuid;
use futures::stream::StreamExt;
use tokio::io::{AsyncRead, AsyncReadExt};
#[cfg(feature = "s3")]
use aws_sdk_s3::Client;
@ -15,10 +17,18 @@ use aws_credential_types::Credentials;
use aws_types::region::Region as AwsRegion;
#[cfg(feature = "s3")]
use aws_sdk_s3::primitives::ByteStream;
#[cfg(feature = "s3")]
use aws_sdk_s3::types::{CompletedPart, CompletedMultipartUpload};
use crate::models::{FileIngestionInfo, S3SourceConfig};
use crate::storage::StorageBackend;
/// Threshold for using streaming multipart uploads (100MB)
const STREAMING_THRESHOLD: usize = 100 * 1024 * 1024;
/// Multipart upload chunk size (16MB - AWS minimum is 5MB, we use 16MB for better performance)
const MULTIPART_CHUNK_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct S3Service {
#[cfg(feature = "s3")]
@ -347,7 +357,15 @@ impl S3Service {
#[cfg(feature = "s3")]
{
let key = self.generate_document_key(user_id, document_id, filename);
self.store_file(&key, data, None).await?;
// Use streaming upload for large files
if data.len() > STREAMING_THRESHOLD {
info!("Using streaming multipart upload for large file: {} ({} bytes)", key, data.len());
self.store_file_multipart(&key, data, None).await?;
} else {
self.store_file(&key, data, None).await?;
}
Ok(key)
}
}
@ -438,6 +456,137 @@ impl S3Service {
}
}
/// Store large files using multipart upload for better performance and memory usage
async fn store_file_multipart(&self, key: &str, data: &[u8], metadata: Option<HashMap<String, String>>) -> Result<()> {
#[cfg(not(feature = "s3"))]
{
return Err(anyhow!("S3 support not compiled in"));
}
#[cfg(feature = "s3")]
{
info!("Starting multipart upload for file: {}/{} ({} bytes)", self.config.bucket_name, key, data.len());
let key_owned = key.to_string();
let data_owned = data.to_vec();
let metadata_owned = metadata.clone();
let bucket_name = self.config.bucket_name.clone();
let client = self.client.clone();
self.retry_operation(&format!("store_file_multipart: {}", key), || {
let key = key_owned.clone();
let data = data_owned.clone();
let metadata = metadata_owned.clone();
let bucket_name = bucket_name.clone();
let client = client.clone();
let content_type = self.get_content_type_from_key(&key);
async move {
// Step 1: Initiate multipart upload
let mut create_request = client
.create_multipart_upload()
.bucket(&bucket_name)
.key(&key);
// Add metadata if provided
if let Some(meta) = metadata {
for (k, v) in meta {
create_request = create_request.metadata(k, v);
}
}
// Set content type based on file extension
if let Some(ct) = content_type {
create_request = create_request.content_type(ct);
}
let create_response = create_request.send().await
.map_err(|e| anyhow!("Failed to initiate multipart upload for {}: {}", key, e))?;
let upload_id = create_response.upload_id()
.ok_or_else(|| anyhow!("Missing upload ID in multipart upload response"))?;
info!("Initiated multipart upload for {}: {}", key, upload_id);
// Step 2: Upload parts in chunks
let mut completed_parts = Vec::new();
let total_chunks = (data.len() + MULTIPART_CHUNK_SIZE - 1) / MULTIPART_CHUNK_SIZE;
for (chunk_index, chunk) in data.chunks(MULTIPART_CHUNK_SIZE).enumerate() {
let part_number = (chunk_index + 1) as i32;
debug!("Uploading part {} of {} for {} ({} bytes)",
part_number, total_chunks, key, chunk.len());
let upload_part_response = client
.upload_part()
.bucket(&bucket_name)
.key(&key)
.upload_id(upload_id)
.part_number(part_number)
.body(ByteStream::from(chunk.to_vec()))
.send()
.await
.map_err(|e| anyhow!("Failed to upload part {} for {}: {}", part_number, key, e))?;
let etag = upload_part_response.e_tag()
.ok_or_else(|| anyhow!("Missing ETag in upload part response"))?;
completed_parts.push(
CompletedPart::builder()
.part_number(part_number)
.e_tag(etag)
.build()
);
debug!("Successfully uploaded part {} for {}", part_number, key);
}
// Step 3: Complete multipart upload
let completed_multipart_upload = CompletedMultipartUpload::builder()
.set_parts(Some(completed_parts))
.build();
client
.complete_multipart_upload()
.bucket(&bucket_name)
.key(&key)
.upload_id(upload_id)
.multipart_upload(completed_multipart_upload)
.send()
.await
.map_err(|e| {
// If completion fails, try to abort the multipart upload
let abort_client = client.clone();
let abort_bucket = bucket_name.clone();
let abort_key = key.clone();
let abort_upload_id = upload_id.to_string();
tokio::spawn(async move {
if let Err(abort_err) = abort_client
.abort_multipart_upload()
.bucket(abort_bucket)
.key(abort_key)
.upload_id(abort_upload_id)
.send()
.await
{
error!("Failed to abort multipart upload: {}", abort_err);
}
});
anyhow!("Failed to complete multipart upload for {}: {}", key, e)
})?;
info!("Successfully completed multipart upload for {}", key);
Ok(())
}
}).await?;
Ok(())
}
}
/// Retrieve a file from S3
pub async fn retrieve_file(&self, key: &str) -> Result<Vec<u8>> {
#[cfg(not(feature = "s3"))]
@ -666,7 +815,15 @@ impl StorageBackend for S3Service {
async fn store_document(&self, user_id: Uuid, document_id: Uuid, filename: &str, data: &[u8]) -> Result<String> {
// Generate S3 key
let key = self.generate_document_key(user_id, document_id, filename);
self.store_file(&key, data, None).await?;
// Use streaming upload for large files
if data.len() > STREAMING_THRESHOLD {
info!("Using streaming multipart upload for large file: {} ({} bytes)", key, data.len());
self.store_file_multipart(&key, data, None).await?;
} else {
self.store_file(&key, data, None).await?;
}
Ok(format!("s3://{}", key))
}

View File

@ -3,21 +3,30 @@
use anyhow::Result;
use async_trait::async_trait;
use std::path::{Path, PathBuf};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::fs;
use tracing::{info, error};
use tokio::sync::RwLock;
use tracing::{info, error, warn, debug};
use uuid::Uuid;
use super::StorageBackend;
use crate::utils::security::{validate_filename, validate_and_sanitize_path, validate_path_within_base};
/// Local filesystem storage backend
pub struct LocalStorageBackend {
upload_path: String,
/// Cache for resolved file paths to reduce filesystem calls
path_cache: Arc<RwLock<HashMap<String, Option<String>>>>,
}
impl LocalStorageBackend {
/// Create a new local storage backend
pub fn new(upload_path: String) -> Self {
Self { upload_path }
Self {
upload_path,
path_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Get the base upload path
@ -50,33 +59,131 @@ impl LocalStorageBackend {
Path::new(&self.upload_path).join("backups")
}
/// Resolve file path, handling both old and new directory structures
/// Resolve file path, handling both old and new directory structures with caching
pub async fn resolve_file_path(&self, file_path: &str) -> Result<String> {
// If the file exists at the given path, use it
if Path::new(file_path).exists() {
return Ok(file_path.to_string());
// Check cache first
{
let cache = self.path_cache.read().await;
if let Some(cached_result) = cache.get(file_path) {
return match cached_result {
Some(resolved_path) => {
debug!("Cache hit for file path: {} -> {}", file_path, resolved_path);
Ok(resolved_path.clone())
}
None => {
debug!("Cache hit for non-existent file: {}", file_path);
Err(anyhow::anyhow!("File not found: {} (cached)", file_path))
}
};
}
}
// Generate candidate paths in order of likelihood
let candidates = self.generate_path_candidates(file_path);
// Try to find the file in the new structured directory
// Check candidates efficiently
let mut found_path = None;
for candidate in &candidates {
match tokio::fs::metadata(candidate).await {
Ok(metadata) => {
if metadata.is_file() {
found_path = Some(candidate.clone());
debug!("Found file at: {}", candidate);
break;
}
}
Err(_) => {
// File doesn't exist at this path, continue to next candidate
debug!("File not found at: {}", candidate);
}
}
}
// Cache the result
{
let mut cache = self.path_cache.write().await;
cache.insert(file_path.to_string(), found_path.clone());
// Prevent cache from growing too large
if cache.len() > 10000 {
// Clear oldest 20% of entries (simple cache eviction)
let to_remove: Vec<String> = cache.keys().take(2000).cloned().collect();
for key in to_remove {
cache.remove(&key);
}
debug!("Evicted cache entries to prevent memory growth");
}
}
match found_path {
Some(path) => {
if path != file_path {
info!("Resolved file path: {} -> {}", file_path, path);
}
Ok(path)
}
None => {
debug!("File not found in any candidate location: {}", file_path);
Err(anyhow::anyhow!(
"File not found: {} (checked {} locations)",
file_path,
candidates.len()
))
}
}
}
/// Generate candidate paths for file resolution
fn generate_path_candidates(&self, file_path: &str) -> Vec<String> {
let mut candidates = Vec::new();
// 1. Original path (most likely for new files)
candidates.push(file_path.to_string());
// 2. For legacy compatibility - try structured directory
if file_path.starts_with("./uploads/") && !file_path.contains("/documents/") {
let new_path = file_path.replace("./uploads/", "./uploads/documents/");
if Path::new(&new_path).exists() {
info!("Found file in new structured directory: {} -> {}", file_path, new_path);
return Ok(new_path);
}
candidates.push(file_path.replace("./uploads/", "./uploads/documents/"));
}
// Try without the ./ prefix
// 3. Try without ./ prefix in structured directory
if file_path.starts_with("uploads/") && !file_path.contains("/documents/") {
let new_path = file_path.replace("uploads/", "uploads/documents/");
if Path::new(&new_path).exists() {
info!("Found file in new structured directory: {} -> {}", file_path, new_path);
return Ok(new_path);
}
candidates.push(file_path.replace("uploads/", "uploads/documents/"));
}
// File not found in any expected location
Err(anyhow::anyhow!("File not found: {} (checked original path and structured directory)", file_path))
// 4. Try relative to our configured upload path
if !file_path.starts_with(&self.upload_path) {
let relative_path = Path::new(&self.upload_path).join(file_path);
candidates.push(relative_path.to_string_lossy().to_string());
// Also try in documents subdirectory
let documents_path = Path::new(&self.upload_path).join("documents").join(file_path);
candidates.push(documents_path.to_string_lossy().to_string());
}
// 5. Try absolute path if it looks like a filename only
if !file_path.contains('/') && !file_path.contains('\\') {
// Try in documents directory
let abs_documents_path = Path::new(&self.upload_path)
.join("documents")
.join(file_path);
candidates.push(abs_documents_path.to_string_lossy().to_string());
}
candidates
}
/// Clear the path resolution cache (useful for testing or after file operations)
pub async fn clear_path_cache(&self) {
let mut cache = self.path_cache.write().await;
cache.clear();
debug!("Cleared path resolution cache");
}
/// Invalidate cache entry for a specific path
pub async fn invalidate_cache_entry(&self, file_path: &str) {
let mut cache = self.path_cache.write().await;
cache.remove(file_path);
debug!("Invalidated cache entry for: {}", file_path);
}
/// Save a file with generated UUID filename (legacy method)
@ -112,7 +219,10 @@ impl LocalStorageBackend {
#[async_trait]
impl StorageBackend for LocalStorageBackend {
async fn store_document(&self, _user_id: Uuid, document_id: Uuid, filename: &str, data: &[u8]) -> Result<String> {
let extension = Path::new(filename)
// Validate and sanitize the filename
let sanitized_filename = validate_filename(filename)?;
let extension = Path::new(&sanitized_filename)
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("");
@ -126,13 +236,28 @@ impl StorageBackend for LocalStorageBackend {
let documents_dir = self.get_documents_path();
let file_path = documents_dir.join(&document_filename);
// Validate that the final path is within our base directory
validate_path_within_base(
&file_path.to_string_lossy(),
&self.upload_path
)?;
// Ensure the documents directory exists
fs::create_dir_all(&documents_dir).await?;
// Validate data size (prevent extremely large files from causing issues)
if data.len() > 1_000_000_000 { // 1GB limit
return Err(anyhow::anyhow!("File too large for storage (max 1GB)"));
}
fs::write(&file_path, data).await?;
// Invalidate any cached negative results for this path
let path_str = file_path.to_string_lossy().to_string();
self.invalidate_cache_entry(&path_str).await;
info!("Stored document locally: {}", file_path.display());
Ok(file_path.to_string_lossy().to_string())
Ok(path_str)
}
async fn store_thumbnail(&self, _user_id: Uuid, document_id: Uuid, data: &[u8]) -> Result<String> {
@ -142,10 +267,25 @@ impl StorageBackend for LocalStorageBackend {
let thumbnail_filename = format!("{}_thumb.jpg", document_id);
let thumbnail_path = thumbnails_dir.join(&thumbnail_filename);
// Validate that the final path is within our base directory
validate_path_within_base(
&thumbnail_path.to_string_lossy(),
&self.upload_path
)?;
// Validate data size for thumbnails (should be much smaller)
if data.len() > 10_000_000 { // 10MB limit for thumbnails
return Err(anyhow::anyhow!("Thumbnail too large (max 10MB)"));
}
fs::write(&thumbnail_path, data).await?;
// Invalidate any cached negative results for this path
let path_str = thumbnail_path.to_string_lossy().to_string();
self.invalidate_cache_entry(&path_str).await;
info!("Stored thumbnail locally: {}", thumbnail_path.display());
Ok(thumbnail_path.to_string_lossy().to_string())
Ok(path_str)
}
async fn store_processed_image(&self, _user_id: Uuid, document_id: Uuid, data: &[u8]) -> Result<String> {
@ -155,15 +295,44 @@ impl StorageBackend for LocalStorageBackend {
let processed_filename = format!("{}_processed.png", document_id);
let processed_path = processed_dir.join(&processed_filename);
// Validate that the final path is within our base directory
validate_path_within_base(
&processed_path.to_string_lossy(),
&self.upload_path
)?;
// Validate data size for processed images
if data.len() > 50_000_000 { // 50MB limit for processed images
return Err(anyhow::anyhow!("Processed image too large (max 50MB)"));
}
fs::write(&processed_path, data).await?;
// Invalidate any cached negative results for this path
let path_str = processed_path.to_string_lossy().to_string();
self.invalidate_cache_entry(&path_str).await;
info!("Stored processed image locally: {}", processed_path.display());
Ok(processed_path.to_string_lossy().to_string())
Ok(path_str)
}
async fn retrieve_file(&self, path: &str) -> Result<Vec<u8>> {
let resolved_path = self.resolve_file_path(path).await?;
// Validate and sanitize the input path
let sanitized_path = validate_and_sanitize_path(path)?;
let resolved_path = self.resolve_file_path(&sanitized_path).await?;
// Validate that the resolved path is within our base directory
validate_path_within_base(&resolved_path, &self.upload_path)?;
let data = fs::read(&resolved_path).await?;
// Additional safety check on file size when reading
if data.len() > 1_000_000_000 { // 1GB limit
warn!("Attempted to read extremely large file: {} ({} bytes)", resolved_path, data.len());
return Err(anyhow::anyhow!("File too large to read safely"));
}
Ok(data)
}
@ -172,16 +341,25 @@ impl StorageBackend for LocalStorageBackend {
let mut serious_errors = Vec::new();
// Helper function to safely delete a file
async fn safe_delete(path: &Path, serious_errors: &mut Vec<String>) -> Option<String> {
let storage_backend = self;
async fn safe_delete(path: &Path, serious_errors: &mut Vec<String>, backend: &LocalStorageBackend) -> Option<String> {
match fs::remove_file(path).await {
Ok(_) => {
info!("Deleted file: {}", path.display());
Some(path.to_string_lossy().to_string())
let path_str = path.to_string_lossy().to_string();
// Invalidate cache entry for the deleted file
backend.invalidate_cache_entry(&path_str).await;
Some(path_str)
}
Err(e) => {
match e.kind() {
std::io::ErrorKind::NotFound => {
info!("File already deleted: {}", path.display());
// Still invalidate cache in case it was cached as existing
let path_str = path.to_string_lossy().to_string();
backend.invalidate_cache_entry(&path_str).await;
None
}
_ => {
@ -223,7 +401,7 @@ impl StorageBackend for LocalStorageBackend {
let mut main_file_deleted = false;
for candidate_path in &main_file_candidates {
if candidate_path.exists() {
if let Some(deleted_path) = safe_delete(candidate_path, &mut serious_errors).await {
if let Some(deleted_path) = safe_delete(candidate_path, &mut serious_errors, storage_backend).await {
deleted_files.push(deleted_path);
main_file_deleted = true;
break; // Only delete the first match we find
@ -238,14 +416,14 @@ impl StorageBackend for LocalStorageBackend {
// Delete thumbnail if it exists
let thumbnail_filename = format!("{}_thumb.jpg", document_id);
let thumbnail_path = self.get_thumbnails_path().join(&thumbnail_filename);
if let Some(deleted_path) = safe_delete(&thumbnail_path, &mut serious_errors).await {
if let Some(deleted_path) = safe_delete(&thumbnail_path, &mut serious_errors, storage_backend).await {
deleted_files.push(deleted_path);
}
// Delete processed image if it exists
let processed_image_filename = format!("{}_processed.png", document_id);
let processed_image_path = self.get_processed_images_path().join(&processed_image_filename);
if let Some(deleted_path) = safe_delete(&processed_image_path, &mut serious_errors).await {
if let Some(deleted_path) = safe_delete(&processed_image_path, &mut serious_errors, storage_backend).await {
deleted_files.push(deleted_path);
}
@ -265,8 +443,20 @@ impl StorageBackend for LocalStorageBackend {
}
async fn file_exists(&self, path: &str) -> Result<bool> {
match self.resolve_file_path(path).await {
Ok(_) => Ok(true),
// Validate and sanitize the input path
let sanitized_path = match validate_and_sanitize_path(path) {
Ok(p) => p,
Err(_) => return Ok(false), // Invalid paths don't exist
};
match self.resolve_file_path(&sanitized_path).await {
Ok(resolved_path) => {
// Additional validation that the resolved path is within base directory
match validate_path_within_base(&resolved_path, &self.upload_path) {
Ok(_) => Ok(true),
Err(_) => Ok(false), // Paths outside base directory don't "exist" for us
}
}
Err(_) => Ok(false),
}
}

View File

@ -1 +1,2 @@
pub mod debug;
pub mod debug;
pub mod security;

231
src/utils/security.rs Normal file
View File

@ -0,0 +1,231 @@
//! Security utilities for input validation and sanitization
use anyhow::Result;
use std::path::{Path, PathBuf, Component};
use tracing::warn;
/// Validate and sanitize file paths to prevent path traversal attacks
pub fn validate_and_sanitize_path(input_path: &str) -> Result<String> {
// Check for null bytes (not allowed in file paths)
if input_path.contains('\0') {
return Err(anyhow::anyhow!("Path contains null bytes"));
}
// Check for excessively long paths
if input_path.len() > 4096 {
return Err(anyhow::anyhow!("Path too long (max 4096 characters)"));
}
// Convert to Path for normalization
let path = Path::new(input_path);
// Check for path traversal attempts
for component in path.components() {
match component {
Component::ParentDir => {
warn!("Path traversal attempt detected: {}", input_path);
return Err(anyhow::anyhow!("Path traversal not allowed"));
}
Component::Normal(name) => {
let name_str = name.to_string_lossy();
// Check for dangerous file names
if is_dangerous_filename(&name_str) {
return Err(anyhow::anyhow!("Potentially dangerous filename: {}", name_str));
}
// Check for control characters (except newline and tab which might be in file content)
for ch in name_str.chars() {
if ch.is_control() && ch != '\n' && ch != '\t' {
return Err(anyhow::anyhow!("Filename contains control characters"));
}
}
}
_ => {} // Allow root, current dir, and prefix components
}
}
// Normalize the path to remove redundant components
let normalized = normalize_path(path);
Ok(normalized.to_string_lossy().to_string())
}
/// Validate filename for document storage
pub fn validate_filename(filename: &str) -> Result<String> {
// Basic length check
if filename.is_empty() {
return Err(anyhow::anyhow!("Filename cannot be empty"));
}
if filename.len() > 255 {
return Err(anyhow::anyhow!("Filename too long (max 255 characters)"));
}
// Check for null bytes
if filename.contains('\0') {
return Err(anyhow::anyhow!("Filename contains null bytes"));
}
// Check for path separators (filenames should not contain them)
if filename.contains('/') || filename.contains('\\') {
return Err(anyhow::anyhow!("Filename cannot contain path separators"));
}
// Check for control characters
for ch in filename.chars() {
if ch.is_control() && ch != '\n' && ch != '\t' {
return Err(anyhow::anyhow!("Filename contains control characters"));
}
}
// Check for dangerous patterns
if is_dangerous_filename(filename) {
return Err(anyhow::anyhow!("Potentially dangerous filename: {}", filename));
}
// Sanitize the filename by replacing problematic characters
let sanitized = sanitize_filename(filename);
Ok(sanitized)
}
/// Check if a filename is potentially dangerous
fn is_dangerous_filename(filename: &str) -> bool {
let filename_lower = filename.to_lowercase();
// Windows reserved names
let reserved_names = [
"con", "prn", "aux", "nul",
"com1", "com2", "com3", "com4", "com5", "com6", "com7", "com8", "com9",
"lpt1", "lpt2", "lpt3", "lpt4", "lpt5", "lpt6", "lpt7", "lpt8", "lpt9",
];
// Check if filename (without extension) matches reserved names
let name_without_ext = filename_lower.split('.').next().unwrap_or("");
if reserved_names.contains(&name_without_ext) {
return true;
}
// Check for suspicious patterns
if filename_lower.starts_with('.') && filename_lower.len() > 1 {
// Allow common hidden files but reject suspicious ones
let allowed_hidden = [".env", ".gitignore", ".htaccess"];
if !allowed_hidden.iter().any(|&allowed| filename_lower.starts_with(allowed)) {
// Be more permissive with document files that might have dots
if !filename_lower.contains(&['.', 'd', 'o', 'c']) &&
!filename_lower.contains(&['.', 'p', 'd', 'f']) &&
!filename_lower.contains(&['.', 't', 'x', 't']) {
return true;
}
}
}
false
}
/// Sanitize filename by replacing problematic characters
fn sanitize_filename(filename: &str) -> String {
let mut sanitized = String::new();
for ch in filename.chars() {
match ch {
// Replace problematic characters with underscores
'<' | '>' | ':' | '"' | '|' | '?' | '*' => sanitized.push('_'),
// Allow most other characters
_ if !ch.is_control() || ch == '\n' || ch == '\t' => sanitized.push(ch),
// Skip control characters
_ => {}
}
}
// Trim whitespace from ends
sanitized.trim().to_string()
}
/// Normalize a path by resolving . and .. components without filesystem access
fn normalize_path(path: &Path) -> PathBuf {
let mut normalized = PathBuf::new();
for component in path.components() {
match component {
Component::Normal(_) | Component::RootDir | Component::Prefix(_) => {
normalized.push(component);
}
Component::CurDir => {
// Skip current directory references
}
Component::ParentDir => {
// This should have been caught earlier, but handle it safely
if normalized.parent().is_some() {
normalized.pop();
}
// If we can't go up, just ignore the .. component
}
}
}
normalized
}
/// Validate that a path is within the allowed base directory
pub fn validate_path_within_base(path: &str, base_dir: &str) -> Result<()> {
let path_buf = PathBuf::from(path);
let base_buf = PathBuf::from(base_dir);
// Canonicalize if possible, but don't fail if paths don't exist yet
let canonical_path = path_buf.canonicalize().unwrap_or_else(|_| {
// If canonicalization fails, do our best with normalization
normalize_path(&path_buf)
});
let canonical_base = base_buf.canonicalize().unwrap_or_else(|_| {
normalize_path(&base_buf)
});
if !canonical_path.starts_with(&canonical_base) {
return Err(anyhow::anyhow!(
"Path '{}' is not within allowed base directory '{}'",
path, base_dir
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_filename() {
// Valid filenames
assert!(validate_filename("document.pdf").is_ok());
assert!(validate_filename("my-file_2023.docx").is_ok());
assert!(validate_filename("report (final).txt").is_ok());
// Invalid filenames
assert!(validate_filename("").is_err());
assert!(validate_filename("../etc/passwd").is_err());
assert!(validate_filename("file\0name.txt").is_err());
assert!(validate_filename("con.txt").is_err());
assert!(validate_filename("file/name.txt").is_err());
}
#[test]
fn test_validate_path() {
// Valid paths
assert!(validate_and_sanitize_path("documents/file.pdf").is_ok());
assert!(validate_and_sanitize_path("./uploads/document.txt").is_ok());
// Invalid paths
assert!(validate_and_sanitize_path("../../../etc/passwd").is_err());
assert!(validate_and_sanitize_path("documents/../config.txt").is_err());
assert!(validate_and_sanitize_path("file\0name.txt").is_err());
}
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("file<>name.txt"), "file__name.txt");
assert_eq!(sanitize_filename(" report.pdf "), "report.pdf");
assert_eq!(sanitize_filename("file:name|test.doc"), "file_name_test.doc");
}
}

View File

@ -2,6 +2,8 @@
mod tests {
use readur::ocr::enhanced::{EnhancedOcrService, OcrResult, ImageQualityStats};
use readur::models::Settings;
use readur::services::file_service::FileService;
use readur::storage::{StorageConfig, factory::create_storage_backend};
use std::fs;
use tempfile::{NamedTempFile, TempDir};
@ -13,18 +15,25 @@ mod tests {
TempDir::new().expect("Failed to create temp directory")
}
#[test]
fn test_enhanced_ocr_service_creation() {
async fn create_test_file_service(temp_path: &str) -> FileService {
let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() };
let storage_backend = create_storage_backend(storage_config).await.unwrap();
FileService::with_storage(temp_path.to_string(), storage_backend)
}
#[tokio::test]
async fn test_enhanced_ocr_service_creation() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
// Service should be created successfully
assert!(!service.temp_dir.is_empty());
}
#[test]
fn test_image_quality_stats_creation() {
#[tokio::test]
async fn test_image_quality_stats_creation() {
let stats = ImageQualityStats {
average_brightness: 128.0,
contrast_ratio: 0.5,
@ -38,11 +47,12 @@ mod tests {
assert_eq!(stats.sharpness, 0.8);
}
#[test]
fn test_count_words_safely_whitespace_separated() {
#[tokio::test]
async fn test_count_words_safely_whitespace_separated() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
// Test normal whitespace-separated text
let text = "Hello world this is a test";
@ -55,11 +65,12 @@ mod tests {
assert_eq!(count, 3);
}
#[test]
fn test_count_words_safely_continuous_text() {
#[tokio::test]
async fn test_count_words_safely_continuous_text() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
// Test continuous text without spaces (like some PDF extractions)
let text = "HelloWorldThisIsAContinuousText";
@ -72,11 +83,12 @@ mod tests {
assert!(count > 0, "Should detect alphanumeric patterns as words");
}
#[test]
fn test_count_words_safely_edge_cases() {
#[tokio::test]
async fn test_count_words_safely_edge_cases() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
// Test empty text
let count = service.count_words_safely("");
@ -102,11 +114,12 @@ mod tests {
assert!(count > 0, "Should detect words in mixed content");
}
#[test]
fn test_count_words_safely_large_text() {
#[tokio::test]
async fn test_count_words_safely_large_text() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
// Test with large text (over 1MB) to trigger sampling
let word = "test ";
@ -118,11 +131,12 @@ mod tests {
assert!(count <= 10_000_000, "Should cap at max limit: got {}", count);
}
#[test]
fn test_count_words_safely_fallback_patterns() {
#[tokio::test]
async fn test_count_words_safely_fallback_patterns() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
// Test letter transition detection
let text = "OneWordAnotherWordFinalWord";
@ -140,8 +154,8 @@ mod tests {
assert!(count >= 1, "Should detect words in mixed alphanumeric: got {}", count);
}
#[test]
fn test_ocr_result_structure() {
#[tokio::test]
async fn test_ocr_result_structure() {
let result = OcrResult {
text: "Test text".to_string(),
confidence: 85.5,
@ -162,7 +176,9 @@ mod tests {
#[tokio::test]
async fn test_extract_text_from_plain_text() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let temp_file = NamedTempFile::with_suffix(".txt").unwrap();
@ -185,7 +201,9 @@ mod tests {
#[tokio::test]
async fn test_extract_text_with_context() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let temp_file = NamedTempFile::with_suffix(".txt").unwrap();
@ -211,7 +229,9 @@ mod tests {
#[tokio::test]
async fn test_extract_text_unsupported_mime_type() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let temp_file = NamedTempFile::new().unwrap();
@ -229,7 +249,9 @@ mod tests {
#[tokio::test]
async fn test_extract_text_nonexistent_file() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let result = service
@ -242,7 +264,9 @@ mod tests {
#[tokio::test]
async fn test_extract_text_large_file_truncation() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let temp_file = NamedTempFile::with_suffix(".txt").unwrap();
@ -262,10 +286,12 @@ mod tests {
}
#[cfg(feature = "ocr")]
#[test]
fn test_validate_ocr_quality_high_confidence() {
#[tokio::test]
async fn test_validate_ocr_quality_high_confidence() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let mut settings = create_test_settings();
settings.ocr_min_confidence = 30.0;
@ -283,10 +309,12 @@ mod tests {
}
#[cfg(feature = "ocr")]
#[test]
fn test_validate_ocr_quality_low_confidence() {
#[tokio::test]
async fn test_validate_ocr_quality_low_confidence() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let mut settings = create_test_settings();
settings.ocr_min_confidence = 50.0;
@ -304,10 +332,12 @@ mod tests {
}
#[cfg(feature = "ocr")]
#[test]
fn test_validate_ocr_quality_no_words() {
#[tokio::test]
async fn test_validate_ocr_quality_no_words() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let result = OcrResult {
@ -324,10 +354,12 @@ mod tests {
}
#[cfg(feature = "ocr")]
#[test]
fn test_validate_ocr_quality_poor_character_distribution() {
#[tokio::test]
async fn test_validate_ocr_quality_poor_character_distribution() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let result = OcrResult {
@ -344,10 +376,12 @@ mod tests {
}
#[cfg(feature = "ocr")]
#[test]
fn test_validate_ocr_quality_good_character_distribution() {
#[tokio::test]
async fn test_validate_ocr_quality_good_character_distribution() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let result = OcrResult {
@ -366,7 +400,9 @@ mod tests {
#[tokio::test]
async fn test_word_count_calculation() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let test_cases = vec![
@ -395,7 +431,9 @@ mod tests {
#[tokio::test]
async fn test_pdf_extraction_with_invalid_pdf() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let temp_file = NamedTempFile::with_suffix(".pdf").unwrap();
@ -413,7 +451,9 @@ mod tests {
#[tokio::test]
async fn test_pdf_extraction_with_minimal_valid_pdf() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
// Minimal PDF with "Hello" text
@ -485,7 +525,9 @@ startxref
#[tokio::test]
async fn test_pdf_size_limit() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let temp_file = NamedTempFile::with_suffix(".pdf").unwrap();
@ -503,8 +545,8 @@ startxref
assert!(error_msg.contains("too large"));
}
#[test]
fn test_settings_default_values() {
#[tokio::test]
async fn test_settings_default_values() {
let settings = Settings::default();
// Test that OCR-related settings have reasonable defaults
@ -521,7 +563,9 @@ startxref
#[tokio::test]
async fn test_concurrent_ocr_processing() {
let temp_dir = create_temp_dir();
let service = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path, file_service);
let settings = create_test_settings();
let mut handles = vec![];
@ -532,7 +576,9 @@ startxref
let content = format!("Concurrent test content {}", i);
fs::write(temp_file.path(), &content).unwrap();
let service_clone = EnhancedOcrService::new(temp_dir.path().to_str().unwrap().to_string());
let temp_path_clone = temp_dir.path().to_str().unwrap().to_string();
let file_service_clone = create_test_file_service(&temp_path_clone).await;
let service_clone = EnhancedOcrService::new(temp_path_clone, file_service_clone);
let settings_clone = settings.clone();
let file_path = temp_file.path().to_str().unwrap().to_string();

View File

@ -24,6 +24,12 @@ use readur::{
db_guardrails_simple::DocumentTransactionManager,
};
async fn create_test_file_service(temp_path: &str) -> FileService {
let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() };
let storage_backend = create_storage_backend(storage_config).await.unwrap();
FileService::with_storage(temp_path.to_string(), storage_backend)
}
struct OCRPipelineTestHarness {
db: Database,
pool: PgPool,
@ -329,7 +335,8 @@ impl OCRPipelineTestHarness {
// Clone the components we need rather than the whole harness
let queue_service = self.queue_service.clone();
let transaction_manager = self.transaction_manager.clone();
let ocr_service = EnhancedOcrService::new("/tmp".to_string());
let file_service = create_test_file_service("/tmp").await;
let ocr_service = EnhancedOcrService::new("/tmp".to_string(), file_service);
let pool = self.pool.clone();
let handle = tokio::spawn(async move {

View File

@ -2,6 +2,8 @@
mod pdf_word_count_integration_tests {
use readur::ocr::enhanced::EnhancedOcrService;
use readur::models::Settings;
use readur::services::file_service::FileService;
use readur::storage::{StorageConfig, factory::create_storage_backend};
use std::io::Write;
use tempfile::{NamedTempFile, TempDir};
@ -13,6 +15,12 @@ mod pdf_word_count_integration_tests {
TempDir::new().expect("Failed to create temp directory")
}
async fn create_test_file_service(temp_path: &str) -> FileService {
let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() };
let storage_backend = create_storage_backend(storage_config).await.unwrap();
FileService::with_storage(temp_path.to_string(), storage_backend)
}
/// Create a mock PDF with specific text patterns for testing
fn create_mock_pdf_file(content: &str) -> NamedTempFile {
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
@ -82,7 +90,8 @@ mod pdf_word_count_integration_tests {
async fn test_pdf_extraction_with_normal_text() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
let settings = create_test_settings();
// Create a PDF with normal spaced text
@ -108,7 +117,8 @@ mod pdf_word_count_integration_tests {
async fn test_pdf_extraction_with_continuous_text() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
let settings = create_test_settings();
// Create a PDF with continuous text (no spaces)
@ -136,7 +146,8 @@ mod pdf_word_count_integration_tests {
async fn test_pdf_extraction_with_mixed_content() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
let settings = create_test_settings();
// Create a PDF with mixed content (letters, numbers, punctuation)
@ -159,7 +170,8 @@ mod pdf_word_count_integration_tests {
async fn test_pdf_extraction_empty_content() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
let settings = create_test_settings();
// Create a PDF with only whitespace/empty content
@ -185,7 +197,8 @@ mod pdf_word_count_integration_tests {
async fn test_pdf_extraction_punctuation_only() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
let settings = create_test_settings();
// Create a PDF with only punctuation
@ -212,7 +225,8 @@ mod pdf_word_count_integration_tests {
async fn test_pdf_quality_validation() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
let settings = create_test_settings();
// Use a real test PDF file if available
@ -261,7 +275,8 @@ mod pdf_word_count_integration_tests {
async fn test_pdf_file_size_validation() {
let temp_dir = create_temp_dir();
let _temp_path = temp_dir.path().to_str().unwrap().to_string();
let _service = EnhancedOcrService::new(_temp_path);
let _file_service = create_test_file_service(&_temp_path).await;
let _service = EnhancedOcrService::new(_temp_path.clone(), _file_service);
let _settings = create_test_settings();
// Create a small PDF file to test file operations
@ -278,11 +293,12 @@ mod pdf_word_count_integration_tests {
assert!(metadata.len() < 100 * 1024 * 1024, "Test PDF should be under size limit");
}
#[test]
fn test_word_counting_regression_cases() {
#[tokio::test]
async fn test_word_counting_regression_cases() {
let temp_dir = create_temp_dir();
let temp_path = temp_dir.path().to_str().unwrap().to_string();
let service = EnhancedOcrService::new(temp_path);
let file_service = create_test_file_service(&temp_path).await;
let service = EnhancedOcrService::new(temp_path.clone(), file_service);
// Regression test cases for the specific PDF issue
let test_cases = vec![

View File

@ -296,7 +296,7 @@ async fn test_user_watch_directory_file_processing_simulation() -> Result<()> {
// Create user watch manager to test file path mapping
let user_watch_service = state.user_watch_service.as_ref().unwrap();
let user_watch_manager = readur::scheduling::user_watch_manager::UserWatchManager::new(state.db.clone(), (**user_watch_service).clone());
let user_watch_manager = readur::scheduling::user_watch_manager::UserWatchManager::new(state.db.clone(), Arc::clone(user_watch_service));
// Create test user
let test_user = readur::models::User {

View File

@ -16,6 +16,8 @@ use readur::{
db::Database,
ocr::queue::OcrQueueService,
ocr::enhanced::EnhancedOcrService,
services::file_service::FileService,
storage::{StorageConfig, factory::create_storage_backend},
};
// Use the same database URL as the running server
@ -25,6 +27,12 @@ fn get_test_db_url() -> String {
.unwrap_or_else(|_| "postgresql://readur:readur@localhost:5432/readur".to_string())
}
async fn create_test_file_service(temp_path: &str) -> FileService {
let storage_config = StorageConfig::Local { upload_path: temp_path.to_string() };
let storage_backend = create_storage_backend(storage_config).await.unwrap();
FileService::with_storage(temp_path.to_string(), storage_backend)
}
struct SimpleThrottleTest {
pool: PgPool,
queue_service: Arc<OcrQueueService>,
@ -140,7 +148,8 @@ impl SimpleThrottleTest {
let handle = tokio::spawn(async move {
let worker_name = format!("worker-{}", worker_id);
let ocr_service = EnhancedOcrService::new("/tmp".to_string());
let file_service = create_test_file_service("/tmp").await;
let ocr_service = EnhancedOcrService::new("/tmp".to_string(), file_service);
let mut jobs_processed = 0;
info!("Worker {} starting", worker_name);