580 lines
17 KiB
Rust
580 lines
17 KiB
Rust
/*!
|
|
* Database Guardrails for Concurrent Processing Safety
|
|
*
|
|
* This module provides database transaction patterns and validation
|
|
* mechanisms to prevent race conditions and data corruption.
|
|
*/
|
|
|
|
use sqlx::{PgPool, Postgres, Transaction};
|
|
use uuid::Uuid;
|
|
use anyhow::Result;
|
|
use tracing::{warn, error, info};
|
|
|
|
/// Transaction-safe document operations with validation
|
|
#[derive(Clone)]
|
|
pub struct DocumentTransactionManager {
|
|
pool: PgPool,
|
|
}
|
|
|
|
impl DocumentTransactionManager {
|
|
pub fn new(pool: PgPool) -> Self {
|
|
Self { pool }
|
|
}
|
|
|
|
/// Update OCR results with full transaction safety and validation
|
|
/// Sanitize text for PostgreSQL storage
|
|
/// Removes null bytes and ensures valid UTF-8 encoding
|
|
fn sanitize_text_for_db(text: &str) -> String {
|
|
// Remove null bytes which PostgreSQL cannot store in TEXT fields
|
|
let cleaned: String = text
|
|
.chars()
|
|
.filter(|&c| c != '\0')
|
|
.collect();
|
|
|
|
// Additional safety: ensure the string is valid UTF-8
|
|
// (should already be, but this is defensive)
|
|
String::from_utf8_lossy(cleaned.as_bytes()).to_string()
|
|
}
|
|
|
|
pub async fn update_ocr_with_validation(
|
|
&self,
|
|
document_id: Uuid,
|
|
expected_filename: &str,
|
|
ocr_text: &str,
|
|
confidence: f64,
|
|
word_count: i32,
|
|
processing_time_ms: i64,
|
|
) -> Result<bool> {
|
|
let mut tx = self.pool.begin().await?;
|
|
|
|
// 1. Lock the document row for update
|
|
let document = sqlx::query(
|
|
r#"
|
|
SELECT id, filename, ocr_status, file_size, created_at
|
|
FROM documents
|
|
WHERE id = $1
|
|
FOR UPDATE
|
|
"#
|
|
)
|
|
.bind(document_id)
|
|
.fetch_optional(&mut *tx)
|
|
.await?;
|
|
|
|
let document = match document {
|
|
Some(doc) => doc,
|
|
None => {
|
|
tx.rollback().await?;
|
|
warn!("Document {} not found during OCR update", document_id);
|
|
return Ok(false);
|
|
}
|
|
};
|
|
|
|
// 2. Validate document hasn't been modified unexpectedly
|
|
let filename: String = document.get("filename");
|
|
if filename != expected_filename {
|
|
tx.rollback().await?;
|
|
error!(
|
|
"Document {} filename mismatch: expected '{}', got '{}'",
|
|
document_id, expected_filename, filename
|
|
);
|
|
return Ok(false);
|
|
}
|
|
|
|
// 3. Check if OCR is already completed (prevent double processing)
|
|
let ocr_status: Option<String> = document.get("ocr_status");
|
|
if ocr_status.as_deref() == Some("completed") {
|
|
tx.rollback().await?;
|
|
warn!("Document {} OCR already completed, skipping update", document_id);
|
|
return Ok(false);
|
|
}
|
|
|
|
// 4. Validate OCR data quality
|
|
if ocr_text.is_empty() && confidence > 50.0 {
|
|
tx.rollback().await?;
|
|
warn!("Document {} has high confidence ({}) but empty OCR text", document_id, confidence);
|
|
return Ok(false);
|
|
}
|
|
|
|
// 5. Sanitize text before database insertion
|
|
let sanitized_text = Self::sanitize_text_for_db(ocr_text);
|
|
|
|
// Log if sanitization was needed
|
|
if sanitized_text.len() != ocr_text.len() {
|
|
warn!(
|
|
"Text sanitization was required for document {}: original {} chars, sanitized {} chars",
|
|
document_id, ocr_text.len(), sanitized_text.len()
|
|
);
|
|
}
|
|
|
|
// 6. Perform the update with additional safety checks
|
|
let updated_rows = sqlx::query!(
|
|
r#"
|
|
UPDATE documents
|
|
SET ocr_text = $2,
|
|
ocr_status = 'completed',
|
|
ocr_completed_at = NOW(),
|
|
ocr_confidence = $3,
|
|
ocr_word_count = $4,
|
|
ocr_processing_time_ms = $5,
|
|
updated_at = NOW()
|
|
WHERE id = $1
|
|
AND ocr_status != 'completed' -- Extra safety check
|
|
"#,
|
|
document_id,
|
|
sanitized_text.as_str(),
|
|
confidence,
|
|
word_count,
|
|
processing_time_ms
|
|
)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
if updated_rows.rows_affected() != 1 {
|
|
tx.rollback().await?;
|
|
error!("Document {} OCR update affected {} rows (expected 1)", document_id, updated_rows.rows_affected());
|
|
return Ok(false);
|
|
}
|
|
|
|
// 7. Remove from OCR queue atomically
|
|
let queue_removed = sqlx::query!(
|
|
r#"
|
|
DELETE FROM ocr_queue
|
|
WHERE document_id = $1
|
|
AND status = 'processing'
|
|
"#,
|
|
document_id
|
|
)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
if queue_removed.rows_affected() == 0 {
|
|
warn!("Document {} not found in OCR queue during completion", document_id);
|
|
}
|
|
|
|
// 8. Commit transaction
|
|
tx.commit().await?;
|
|
|
|
info!(
|
|
"Document {} OCR updated successfully: {} chars, {:.1}% confidence, {} words",
|
|
document_id, sanitized_text.len(), confidence, word_count
|
|
);
|
|
|
|
Ok(true)
|
|
}
|
|
|
|
/// Safely claim a document from OCR queue with proper locking
|
|
pub async fn claim_ocr_job(&self, worker_id: &str) -> Result<Option<OcrJob>> {
|
|
let mut tx = self.pool.begin().await?;
|
|
|
|
// 1. Find and lock next available job
|
|
let job = sqlx::query_as::<_, OcrJob>(
|
|
r#"
|
|
UPDATE ocr_queue
|
|
SET status = 'processing',
|
|
started_at = NOW(),
|
|
worker_id = $1,
|
|
attempts = attempts + 1
|
|
WHERE id = (
|
|
SELECT id
|
|
FROM ocr_queue
|
|
WHERE status = 'pending'
|
|
AND attempts < max_attempts
|
|
ORDER BY priority DESC, created_at ASC
|
|
FOR UPDATE SKIP LOCKED
|
|
LIMIT 1
|
|
)
|
|
RETURNING
|
|
id,
|
|
document_id,
|
|
priority,
|
|
status,
|
|
attempts,
|
|
max_attempts,
|
|
worker_id,
|
|
created_at,
|
|
started_at,
|
|
completed_at,
|
|
error_message
|
|
"#
|
|
)
|
|
.bind(worker_id)
|
|
.fetch_optional(&mut *tx)
|
|
.await?;
|
|
|
|
if let Some(job) = job {
|
|
// 2. Validate document still exists and is processable
|
|
let document_exists = sqlx::query!(
|
|
r#"
|
|
SELECT filename, file_path, ocr_status
|
|
FROM documents
|
|
WHERE id = $1
|
|
AND ocr_status IN ('pending', 'processing')
|
|
"#,
|
|
job.document_id
|
|
)
|
|
.fetch_optional(&mut *tx)
|
|
.await?;
|
|
|
|
if document_exists.is_none() {
|
|
// Document was deleted or already processed
|
|
sqlx::query!(
|
|
"DELETE FROM ocr_queue WHERE id = $1",
|
|
job.id
|
|
)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
tx.commit().await?;
|
|
return Ok(None);
|
|
}
|
|
|
|
tx.commit().await?;
|
|
Ok(Some(job))
|
|
} else {
|
|
tx.rollback().await?;
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
/// Safely handle OCR job failure with retry logic
|
|
pub async fn handle_ocr_failure(
|
|
&self,
|
|
job_id: Uuid,
|
|
document_id: Uuid,
|
|
error_message: &str,
|
|
) -> Result<bool> {
|
|
let mut tx = self.pool.begin().await?;
|
|
|
|
// 1. Check if job should be retried or marked as failed
|
|
let job = sqlx::query!(
|
|
r#"
|
|
SELECT attempts, max_attempts
|
|
FROM ocr_queue
|
|
WHERE id = $1
|
|
FOR UPDATE
|
|
"#,
|
|
job_id
|
|
)
|
|
.fetch_optional(&mut *tx)
|
|
.await?;
|
|
|
|
let should_retry = if let Some(job) = job {
|
|
job.attempts < job.max_attempts
|
|
} else {
|
|
false
|
|
};
|
|
|
|
if should_retry {
|
|
// 2. Reset job for retry
|
|
sqlx::query!(
|
|
r#"
|
|
UPDATE ocr_queue
|
|
SET status = 'pending',
|
|
worker_id = NULL,
|
|
started_at = NULL,
|
|
error_message = $2
|
|
WHERE id = $1
|
|
"#,
|
|
job_id,
|
|
error_message
|
|
)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
info!("OCR job {} scheduled for retry", job_id);
|
|
} else {
|
|
// 3. Mark document as failed and remove from queue
|
|
sqlx::query!(
|
|
r#"
|
|
UPDATE documents
|
|
SET ocr_status = 'failed',
|
|
ocr_error = $2,
|
|
updated_at = NOW()
|
|
WHERE id = $1
|
|
"#,
|
|
document_id,
|
|
error_message
|
|
)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
sqlx::query!(
|
|
"DELETE FROM ocr_queue WHERE id = $1",
|
|
job_id
|
|
)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
error!("OCR job {} failed permanently: {}", job_id, error_message);
|
|
}
|
|
|
|
tx.commit().await?;
|
|
Ok(should_retry)
|
|
}
|
|
|
|
/// Validate database consistency and fix orphaned records
|
|
pub async fn validate_consistency(&self) -> Result<ConsistencyReport> {
|
|
let mut report = ConsistencyReport::default();
|
|
|
|
// 1. Find documents with OCR status mismatch
|
|
let orphaned_queue_items = sqlx::query!(
|
|
r#"
|
|
SELECT q.id, q.document_id, d.ocr_status
|
|
FROM ocr_queue q
|
|
LEFT JOIN documents d ON q.document_id = d.id
|
|
WHERE d.id IS NULL
|
|
OR d.ocr_status = 'completed'
|
|
"#
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
report.orphaned_queue_items = orphaned_queue_items.len();
|
|
|
|
// 2. Find documents stuck in processing
|
|
let stuck_processing = sqlx::query!(
|
|
r#"
|
|
SELECT COUNT(*) as count
|
|
FROM documents
|
|
WHERE ocr_status = 'processing'
|
|
AND updated_at < NOW() - INTERVAL '30 minutes'
|
|
"#
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
report.stuck_processing_docs = stuck_processing.count.unwrap_or(0) as usize;
|
|
|
|
// 3. Find queue items without corresponding documents
|
|
let queue_without_docs = sqlx::query!(
|
|
r#"
|
|
SELECT COUNT(*) as count
|
|
FROM ocr_queue q
|
|
LEFT JOIN documents d ON q.document_id = d.id
|
|
WHERE d.id IS NULL
|
|
"#
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
report.queue_without_docs = queue_without_docs.count.unwrap_or(0) as usize;
|
|
|
|
Ok(report)
|
|
}
|
|
|
|
/// Clean up orphaned and inconsistent records
|
|
pub async fn cleanup_orphaned_records(&self) -> Result<CleanupReport> {
|
|
let mut report = CleanupReport::default();
|
|
|
|
// 1. Remove queue items for completed documents
|
|
let removed_completed = sqlx::query!(
|
|
r#"
|
|
DELETE FROM ocr_queue
|
|
WHERE document_id IN (
|
|
SELECT d.id FROM documents d
|
|
WHERE d.ocr_status = 'completed'
|
|
)
|
|
"#
|
|
)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
report.removed_completed_queue_items = removed_completed.rows_affected() as usize;
|
|
|
|
// 2. Remove queue items for non-existent documents
|
|
let removed_orphaned = sqlx::query!(
|
|
r#"
|
|
DELETE FROM ocr_queue
|
|
WHERE document_id NOT IN (
|
|
SELECT id FROM documents
|
|
)
|
|
"#
|
|
)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
report.removed_orphaned_queue_items = removed_orphaned.rows_affected() as usize;
|
|
|
|
// 3. Reset stuck processing documents
|
|
let reset_stuck = sqlx::query!(
|
|
r#"
|
|
UPDATE documents
|
|
SET ocr_status = 'pending'
|
|
WHERE ocr_status = 'processing'
|
|
AND updated_at < NOW() - INTERVAL '30 minutes'
|
|
"#
|
|
)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
report.reset_stuck_documents = reset_stuck.rows_affected() as usize;
|
|
|
|
Ok(report)
|
|
}
|
|
}
|
|
|
|
/// OCR job structure with all necessary fields
|
|
#[derive(Debug, Clone)]
|
|
pub struct OcrJob {
|
|
pub id: Uuid,
|
|
pub document_id: Uuid,
|
|
pub priority: i32,
|
|
pub status: String,
|
|
pub attempts: i32,
|
|
pub max_attempts: i32,
|
|
pub worker_id: Option<String>,
|
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
|
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
pub error_message: Option<String>,
|
|
}
|
|
|
|
/// Database consistency validation report
|
|
#[derive(Debug, Default)]
|
|
pub struct ConsistencyReport {
|
|
pub orphaned_queue_items: usize,
|
|
pub stuck_processing_docs: usize,
|
|
pub queue_without_docs: usize,
|
|
pub is_consistent: bool,
|
|
}
|
|
|
|
impl ConsistencyReport {
|
|
pub fn is_consistent(&self) -> bool {
|
|
self.orphaned_queue_items == 0
|
|
&& self.stuck_processing_docs == 0
|
|
&& self.queue_without_docs == 0
|
|
}
|
|
}
|
|
|
|
/// Database cleanup operation report
|
|
#[derive(Debug, Default)]
|
|
pub struct CleanupReport {
|
|
pub removed_completed_queue_items: usize,
|
|
pub removed_orphaned_queue_items: usize,
|
|
pub reset_stuck_documents: usize,
|
|
}
|
|
|
|
/// Database connection health checker
|
|
pub struct DatabaseHealthChecker {
|
|
pool: PgPool,
|
|
}
|
|
|
|
impl DatabaseHealthChecker {
|
|
pub fn new(pool: PgPool) -> Self {
|
|
Self { pool }
|
|
}
|
|
|
|
/// Check database connection pool health
|
|
pub async fn check_pool_health(&self) -> Result<PoolHealthReport> {
|
|
let start = std::time::Instant::now();
|
|
|
|
// Test basic connectivity
|
|
let test_query = sqlx::query!("SELECT 1 as test")
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
let response_time = start.elapsed();
|
|
|
|
// Get pool statistics if available
|
|
let pool_size = self.pool.size();
|
|
let idle_connections = self.pool.num_idle();
|
|
|
|
Ok(PoolHealthReport {
|
|
is_healthy: test_query.test == Some(1),
|
|
response_time_ms: response_time.as_millis() as u64,
|
|
pool_size,
|
|
idle_connections,
|
|
utilization_percent: if pool_size > 0 {
|
|
((pool_size - idle_connections) as f64 / pool_size as f64 * 100.0) as u8
|
|
} else {
|
|
0
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct PoolHealthReport {
|
|
pub is_healthy: bool,
|
|
pub response_time_ms: u64,
|
|
pub pool_size: u32,
|
|
pub idle_connections: u32,
|
|
pub utilization_percent: u8,
|
|
}
|
|
|
|
/// Distributed locking for critical sections
|
|
pub struct DistributedLock {
|
|
pool: PgPool,
|
|
}
|
|
|
|
impl DistributedLock {
|
|
pub fn new(pool: PgPool) -> Self {
|
|
Self { pool }
|
|
}
|
|
|
|
/// Acquire a named lock with timeout
|
|
pub async fn acquire_lock(&self, lock_name: &str, timeout_secs: i32) -> Result<bool> {
|
|
let lock_id = self.hash_lock_name(lock_name);
|
|
|
|
let result = sqlx::query!(
|
|
"SELECT pg_try_advisory_lock($1, $2) as acquired",
|
|
lock_id,
|
|
timeout_secs
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
Ok(result.acquired.unwrap_or(false))
|
|
}
|
|
|
|
/// Release a named lock
|
|
pub async fn release_lock(&self, lock_name: &str) -> Result<bool> {
|
|
let lock_id = self.hash_lock_name(lock_name);
|
|
|
|
let result = sqlx::query!(
|
|
"SELECT pg_advisory_unlock($1, 0) as released",
|
|
lock_id
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
Ok(result.released.unwrap_or(false))
|
|
}
|
|
|
|
fn hash_lock_name(&self, name: &str) -> i64 {
|
|
use std::collections::hash_map::DefaultHasher;
|
|
use std::hash::{Hash, Hasher};
|
|
|
|
let mut hasher = DefaultHasher::new();
|
|
name.hash(&mut hasher);
|
|
hasher.finish() as i64
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_sanitize_text_for_db() {
|
|
// Test removing null bytes
|
|
let text_with_nulls = "Hello\0World\0!";
|
|
let sanitized = DocumentTransactionManager::sanitize_text_for_db(text_with_nulls);
|
|
assert_eq!(sanitized, "HelloWorld!");
|
|
|
|
// Test preserving normal text
|
|
let normal_text = "This is a normal PDF text with special chars: €£¥";
|
|
let sanitized = DocumentTransactionManager::sanitize_text_for_db(normal_text);
|
|
assert_eq!(sanitized, normal_text);
|
|
|
|
// Test handling empty string
|
|
let empty = "";
|
|
let sanitized = DocumentTransactionManager::sanitize_text_for_db(empty);
|
|
assert_eq!(sanitized, "");
|
|
|
|
// Test handling text with multiple null bytes
|
|
let many_nulls = "\0\0Start\0Middle\0\0End\0\0";
|
|
let sanitized = DocumentTransactionManager::sanitize_text_for_db(many_nulls);
|
|
assert_eq!(sanitized, "StartMiddleEnd");
|
|
}
|
|
} |