feat(server): gracefully manage requeue requests for the same document

This commit is contained in:
perf3ct 2025-07-11 21:27:12 +00:00
parent be31c14814
commit b31e1a672d
6 changed files with 134 additions and 18 deletions

View File

@ -967,6 +967,86 @@ impl OcrQueueService {
/// Requeue failed items
pub async fn requeue_failed_items(&self) -> Result<i64> {
tracing::debug!("Attempting to requeue failed items");
// First check if there are any failed items to requeue
let failed_count: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM ocr_queue
WHERE status = 'failed'
AND attempts < max_attempts
"#
)
.fetch_one(&self.pool)
.await
.map_err(|e| {
tracing::error!("Failed to count failed items: {:?}", e);
e
})?;
tracing::debug!("Found {} failed items eligible for requeue", failed_count);
if failed_count == 0 {
return Ok(0);
}
// Check for potential constraint violations
let conflict_check: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM ocr_queue q1
WHERE q1.status = 'failed'
AND q1.attempts < q1.max_attempts
AND EXISTS (
SELECT 1 FROM ocr_queue q2
WHERE q2.document_id = q1.document_id
AND q2.id != q1.id
AND q2.status IN ('pending', 'processing')
)
"#
)
.fetch_one(&self.pool)
.await
.map_err(|e| {
tracing::error!("Failed to check for conflicts: {:?}", e);
e
})?;
if conflict_check > 0 {
tracing::warn!("Found {} documents with existing pending/processing entries", conflict_check);
// Update only items that won't violate the unique constraint
let result = sqlx::query(
r#"
UPDATE ocr_queue
SET status = 'pending',
attempts = 0,
error_message = NULL,
started_at = NULL,
worker_id = NULL
WHERE status = 'failed'
AND attempts < max_attempts
AND NOT EXISTS (
SELECT 1 FROM ocr_queue q2
WHERE q2.document_id = ocr_queue.document_id
AND q2.id != ocr_queue.id
AND q2.status IN ('pending', 'processing')
)
"#
)
.execute(&self.pool)
.await
.map_err(|e| {
tracing::error!("Database error in requeue_failed_items (with conflict check): {:?}", e);
e
})?;
let rows_affected = result.rows_affected() as i64;
tracing::debug!("Requeued {} failed items (skipped {} due to conflicts)", rows_affected, conflict_check);
return Ok(rows_affected);
}
// No conflicts, proceed with normal update
let result = sqlx::query(
r#"
UPDATE ocr_queue
@ -980,9 +1060,16 @@ impl OcrQueueService {
"#
)
.execute(&self.pool)
.await?;
.await
.map_err(|e| {
tracing::error!("Database error in requeue_failed_items: {:?}", e);
e
})?;
Ok(result.rows_affected() as i64)
let rows_affected = result.rows_affected() as i64;
tracing::debug!("Requeued {} failed items", rows_affected);
Ok(rows_affected)
}
/// Clean up old completed items

View File

@ -64,7 +64,7 @@ pub struct DocumentDebugInfo {
pub permissions: Option<String>,
}
#[derive(Serialize, ToSchema)]
#[derive(Serialize, Deserialize, ToSchema)]
pub struct DocumentPaginationInfo {
pub total: i64,
pub limit: i64,
@ -72,7 +72,7 @@ pub struct DocumentPaginationInfo {
pub has_more: bool,
}
#[derive(Serialize, ToSchema)]
#[derive(Serialize, Deserialize, ToSchema)]
pub struct PaginatedDocumentsResponse {
pub documents: Vec<crate::models::DocumentResponse>,
pub pagination: DocumentPaginationInfo,

View File

@ -6,7 +6,7 @@ use axum::{
Router,
};
use sqlx::Row;
use std::sync::Arc;
use std::{sync::Arc, error::Error};
use crate::{auth::AuthUser, ocr::queue::OcrQueueService, AppState, models::UserRole};
@ -85,10 +85,28 @@ async fn requeue_failed(
require_admin(&auth_user)?;
let queue_service = OcrQueueService::new(state.db.clone(), state.db.get_pool().clone(), 1);
let count = queue_service
.requeue_failed_items()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let count = match queue_service.requeue_failed_items().await {
Ok(count) => count,
Err(e) => {
let error_msg = format!("Failed to requeue failed items: {:?}", e);
tracing::error!("{}", error_msg);
// Print to stderr so we can see it in test output
eprintln!("REQUEUE ERROR: {}", error_msg);
// Try to get the source chain
eprintln!("Error chain:");
let mut source = e.source();
let mut depth = 1;
while let Some(err) = source {
eprintln!(" {}: {}", depth, err);
source = err.source();
depth += 1;
}
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
Ok(Json(serde_json::json!({
"requeued_count": count,

View File

@ -20,6 +20,7 @@ use tokio::time::sleep;
use uuid::Uuid;
use readur::models::{CreateUser, LoginRequest, LoginResponse, UserRole, DocumentResponse};
use readur::routes::documents::types::PaginatedDocumentsResponse;
use readur::routes::documents::types::DocumentUploadResponse;
fn get_base_url() -> String {
@ -216,7 +217,8 @@ impl FileProcessingTestClient {
.await?;
if response.status().is_success() {
let documents: Vec<DocumentResponse> = response.json().await?;
let paginated_response: PaginatedDocumentsResponse = response.json().await?;
let documents = paginated_response.documents;
if let Some(doc) = documents.iter().find(|d| d.id.to_string() == document_id) {
println!("📄 DEBUG: Found document with OCR status: {:?}", doc.ocr_status);
@ -590,8 +592,9 @@ async fn test_image_processing_pipeline() {
.await
.expect("Failed to get documents");
let documents: Vec<DocumentResponse> = response.json().await
let paginated_response: PaginatedDocumentsResponse = response.json().await
.expect("Failed to parse response");
let documents = paginated_response.documents;
documents.into_iter()
.find(|d| d.id.to_string() == document_id)
@ -707,7 +710,8 @@ async fn test_processing_error_recovery() {
.await;
if let Ok(resp) = response {
if let Ok(docs) = resp.json::<Vec<DocumentResponse>>().await {
if let Ok(paginated_response) = resp.json::<PaginatedDocumentsResponse>().await {
let docs = paginated_response.documents;
if let Some(doc) = docs.iter().find(|d| d.id.to_string() == document.id.to_string()) {
match doc.ocr_status.as_deref() {
Some("completed") => {
@ -998,8 +1002,9 @@ async fn test_concurrent_file_processing() {
.expect("Should get documents");
if response.status().is_success() {
let documents: Vec<DocumentResponse> = response.json().await
let paginated_response: PaginatedDocumentsResponse = response.json().await
.expect("Should parse response");
let documents = paginated_response.documents;
if let Some(doc) = documents.iter().find(|d| d.id.to_string() == document_id) {
match doc.ocr_status.as_deref() {

View File

@ -17,7 +17,7 @@ use tokio::time::sleep;
use uuid::Uuid;
use readur::models::{CreateUser, LoginRequest, LoginResponse, UserRole, DocumentResponse};
use readur::routes::documents::types::DocumentUploadResponse;
use readur::routes::documents::types::{DocumentUploadResponse, PaginatedDocumentsResponse};
fn get_base_url() -> String {
std::env::var("API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string())
@ -227,7 +227,8 @@ impl OCRQueueTestClient {
.await?;
if response.status().is_success() {
let documents: Vec<DocumentResponse> = response.json().await?;
let paginated_response: PaginatedDocumentsResponse = response.json().await?;
let documents = paginated_response.documents;
for (i, doc_id) in document_ids.iter().enumerate() {
if !completed_status[i] {
@ -275,7 +276,8 @@ impl OCRQueueTestClient {
return Err(format!("Get documents failed: {}", response.text().await?).into());
}
let documents: Vec<DocumentResponse> = response.json().await?;
let paginated_response: PaginatedDocumentsResponse = response.json().await?;
let documents = paginated_response.documents;
Ok(documents)
}
}

View File

@ -22,7 +22,7 @@ use uuid::Uuid;
use chrono;
use readur::models::{CreateUser, LoginRequest, LoginResponse, UserRole};
use readur::routes::documents::types::DocumentUploadResponse;
use readur::routes::documents::types::{DocumentUploadResponse, PaginatedDocumentsResponse};
fn get_base_url() -> String {
std::env::var("API_URL").unwrap_or_else(|_| "http://localhost:8000".to_string())
@ -239,7 +239,11 @@ impl LoadTestClient {
return Err(format!("List documents failed: {}", response.text().await?).into());
}
let documents_array: Vec<serde_json::Value> = response.json().await?;
let paginated_response: PaginatedDocumentsResponse = response.json().await?;
let documents_array: Vec<serde_json::Value> = paginated_response.documents
.into_iter()
.map(|doc| serde_json::to_value(doc).unwrap())
.collect();
Ok((documents_array, elapsed))
}