use anyhow::Result; use chrono::Utc; use sqlx::{PgPool, Row}; use uuid::Uuid; use crate::models::{CreateUser, Document, SearchRequest, SearchMode, SearchSnippet, HighlightRange, EnhancedDocumentResponse, User}; #[derive(Clone)] pub struct Database { pool: PgPool, } impl Database { pub async fn new(database_url: &str) -> Result { let pool = PgPool::connect(database_url).await?; Ok(Self { pool }) } pub async fn migrate(&self) -> Result<()> { // Create extensions sqlx::query(r#"CREATE EXTENSION IF NOT EXISTS "uuid-ossp""#) .execute(&self.pool) .await?; sqlx::query(r#"CREATE EXTENSION IF NOT EXISTS "pg_trgm""#) .execute(&self.pool) .await?; // Create users table sqlx::query( r#" CREATE TABLE IF NOT EXISTS users ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), username VARCHAR(255) UNIQUE NOT NULL, email VARCHAR(255) UNIQUE NOT NULL, password_hash VARCHAR(255) NOT NULL, created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() ) "#, ) .execute(&self.pool) .await?; // Create documents table sqlx::query( r#" CREATE TABLE IF NOT EXISTS documents ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), filename VARCHAR(255) NOT NULL, original_filename VARCHAR(255) NOT NULL, file_path VARCHAR(500) NOT NULL, file_size BIGINT NOT NULL, mime_type VARCHAR(100) NOT NULL, content TEXT, ocr_text TEXT, tags TEXT[] DEFAULT '{}', created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW(), user_id UUID REFERENCES users(id) ON DELETE CASCADE ) "#, ) .execute(&self.pool) .await?; // Create indexes sqlx::query(r#"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"#) .execute(&self.pool) .await?; sqlx::query(r#"CREATE INDEX IF NOT EXISTS idx_documents_filename ON documents(filename)"#) .execute(&self.pool) .await?; sqlx::query(r#"CREATE INDEX IF NOT EXISTS idx_documents_mime_type ON documents(mime_type)"#) .execute(&self.pool) .await?; sqlx::query(r#"CREATE INDEX IF NOT EXISTS idx_documents_tags ON documents USING GIN(tags)"#) .execute(&self.pool) .await?; sqlx::query(r#"CREATE INDEX IF NOT EXISTS idx_documents_content_search ON documents USING GIN(to_tsvector('english', COALESCE(content, '') || ' ' || COALESCE(ocr_text, '')))"#) .execute(&self.pool) .await?; // Create settings table sqlx::query( r#" CREATE TABLE IF NOT EXISTS settings ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), user_id UUID REFERENCES users(id) ON DELETE CASCADE UNIQUE, ocr_language VARCHAR(10) DEFAULT 'eng', created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() ) "#, ) .execute(&self.pool) .await?; Ok(()) } pub async fn create_user(&self, user: CreateUser) -> Result { let password_hash = bcrypt::hash(&user.password, 12)?; let now = Utc::now(); let row = sqlx::query( r#" INSERT INTO users (username, email, password_hash, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, username, email, password_hash, created_at, updated_at "# ) .bind(&user.username) .bind(&user.email) .bind(&password_hash) .bind(now) .bind(now) .fetch_one(&self.pool) .await?; Ok(User { id: row.get("id"), username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), }) } pub async fn get_user_by_username(&self, username: &str) -> Result> { let row = sqlx::query( "SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE username = $1" ) .bind(username) .fetch_optional(&self.pool) .await?; match row { Some(row) => Ok(Some(User { id: row.get("id"), username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), })), None => Ok(None), } } pub async fn get_user_by_id(&self, id: Uuid) -> Result> { let row = sqlx::query( "SELECT id, username, email, password_hash, created_at, updated_at FROM users WHERE id = $1" ) .bind(id) .fetch_optional(&self.pool) .await?; match row { Some(row) => Ok(Some(User { id: row.get("id"), username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), })), None => Ok(None), } } pub async fn create_document(&self, document: Document) -> Result { let row = sqlx::query( r#" INSERT INTO documents (id, filename, original_filename, file_path, file_size, mime_type, content, ocr_text, tags, created_at, updated_at, user_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id, filename, original_filename, file_path, file_size, mime_type, content, ocr_text, tags, created_at, updated_at, user_id "# ) .bind(document.id) .bind(&document.filename) .bind(&document.original_filename) .bind(&document.file_path) .bind(document.file_size) .bind(&document.mime_type) .bind(&document.content) .bind(&document.ocr_text) .bind(&document.tags) .bind(document.created_at) .bind(document.updated_at) .bind(document.user_id) .fetch_one(&self.pool) .await?; Ok(Document { id: row.get("id"), filename: row.get("filename"), original_filename: row.get("original_filename"), file_path: row.get("file_path"), file_size: row.get("file_size"), mime_type: row.get("mime_type"), content: row.get("content"), ocr_text: row.get("ocr_text"), tags: row.get("tags"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), user_id: row.get("user_id"), }) } pub async fn get_documents_by_user(&self, user_id: Uuid, limit: i64, offset: i64) -> Result> { let rows = sqlx::query( r#" SELECT id, filename, original_filename, file_path, file_size, mime_type, content, ocr_text, tags, created_at, updated_at, user_id FROM documents WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3 "# ) .bind(user_id) .bind(limit) .bind(offset) .fetch_all(&self.pool) .await?; let documents = rows .into_iter() .map(|row| Document { id: row.get("id"), filename: row.get("filename"), original_filename: row.get("original_filename"), file_path: row.get("file_path"), file_size: row.get("file_size"), mime_type: row.get("mime_type"), content: row.get("content"), ocr_text: row.get("ocr_text"), tags: row.get("tags"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), user_id: row.get("user_id"), }) .collect(); Ok(documents) } pub async fn search_documents(&self, user_id: Uuid, search: SearchRequest) -> Result<(Vec, i64)> { let mut query_builder = sqlx::QueryBuilder::new( r#" SELECT id, filename, original_filename, file_path, file_size, mime_type, content, ocr_text, tags, created_at, updated_at, user_id, ts_rank(to_tsvector('english', COALESCE(content, '') || ' ' || COALESCE(ocr_text, '')), plainto_tsquery('english', "# ); query_builder.push_bind(&search.query); query_builder.push(")) as rank FROM documents WHERE user_id = "); query_builder.push_bind(user_id); query_builder.push(" AND to_tsvector('english', COALESCE(content, '') || ' ' || COALESCE(ocr_text, '')) @@ plainto_tsquery('english', "); query_builder.push_bind(&search.query); query_builder.push(")"); if let Some(tags) = &search.tags { if !tags.is_empty() { query_builder.push(" AND tags && "); query_builder.push_bind(tags); } } if let Some(mime_types) = &search.mime_types { if !mime_types.is_empty() { query_builder.push(" AND mime_type = ANY("); query_builder.push_bind(mime_types); query_builder.push(")"); } } query_builder.push(" ORDER BY rank DESC, created_at DESC"); if let Some(limit) = search.limit { query_builder.push(" LIMIT "); query_builder.push_bind(limit); } if let Some(offset) = search.offset { query_builder.push(" OFFSET "); query_builder.push_bind(offset); } let rows = query_builder.build().fetch_all(&self.pool).await?; let documents = rows .into_iter() .map(|row| Document { id: row.get("id"), filename: row.get("filename"), original_filename: row.get("original_filename"), file_path: row.get("file_path"), file_size: row.get("file_size"), mime_type: row.get("mime_type"), content: row.get("content"), ocr_text: row.get("ocr_text"), tags: row.get("tags"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), user_id: row.get("user_id"), }) .collect(); let total_row = sqlx::query( r#" SELECT COUNT(*) as total FROM documents WHERE user_id = $1 AND to_tsvector('english', COALESCE(content, '') || ' ' || COALESCE(ocr_text, '')) @@ plainto_tsquery('english', $2) "# ) .bind(user_id) .bind(&search.query) .fetch_one(&self.pool) .await?; let total: i64 = total_row.get("total"); Ok((documents, total)) } pub async fn enhanced_search_documents(&self, user_id: Uuid, search: SearchRequest) -> Result<(Vec, i64, u64)> { let start_time = std::time::Instant::now(); // Build search query based on search mode let search_mode = search.search_mode.as_ref().unwrap_or(&SearchMode::Simple); let query_function = match search_mode { SearchMode::Simple => "plainto_tsquery", SearchMode::Phrase => "phraseto_tsquery", SearchMode::Fuzzy => "plainto_tsquery", // Could be enhanced with similarity SearchMode::Boolean => "to_tsquery", }; let mut query_builder = sqlx::QueryBuilder::new(&format!( r#" SELECT id, filename, original_filename, file_path, file_size, mime_type, content, ocr_text, tags, created_at, updated_at, user_id, ts_rank(to_tsvector('english', COALESCE(content, '') || ' ' || COALESCE(ocr_text, '')), {}('english', "#, query_function )); query_builder.push_bind(&search.query); query_builder.push(&format!(")) as rank FROM documents WHERE user_id = ")); query_builder.push_bind(user_id); query_builder.push(&format!(" AND to_tsvector('english', COALESCE(content, '') || ' ' || COALESCE(ocr_text, '')) @@ {}('english', ", query_function)); query_builder.push_bind(&search.query); query_builder.push(")"); if let Some(tags) = &search.tags { if !tags.is_empty() { query_builder.push(" AND tags && "); query_builder.push_bind(tags); } } if let Some(mime_types) = &search.mime_types { if !mime_types.is_empty() { query_builder.push(" AND mime_type = ANY("); query_builder.push_bind(mime_types); query_builder.push(")"); } } query_builder.push(" ORDER BY rank DESC, created_at DESC"); if let Some(limit) = search.limit { query_builder.push(" LIMIT "); query_builder.push_bind(limit); } if let Some(offset) = search.offset { query_builder.push(" OFFSET "); query_builder.push_bind(offset); } let rows = query_builder.build().fetch_all(&self.pool).await?; let include_snippets = search.include_snippets.unwrap_or(true); let snippet_length = search.snippet_length.unwrap_or(200); let mut documents = Vec::new(); for row in rows { let doc_id: Uuid = row.get("id"); let content: Option = row.get("content"); let ocr_text: Option = row.get("ocr_text"); let rank: f32 = row.get("rank"); let snippets = if include_snippets { self.generate_snippets(&search.query, content.as_deref(), ocr_text.as_deref(), snippet_length) } else { Vec::new() }; documents.push(EnhancedDocumentResponse { id: doc_id, filename: row.get("filename"), original_filename: row.get("original_filename"), file_size: row.get("file_size"), mime_type: row.get("mime_type"), tags: row.get("tags"), created_at: row.get("created_at"), has_ocr_text: ocr_text.is_some(), search_rank: Some(rank), snippets, }); } let total_row = sqlx::query(&format!( r#" SELECT COUNT(*) as total FROM documents WHERE user_id = $1 AND to_tsvector('english', COALESCE(content, '') || ' ' || COALESCE(ocr_text, '')) @@ {}('english', $2) "#, query_function )) .bind(user_id) .bind(&search.query) .fetch_one(&self.pool) .await?; let total: i64 = total_row.get("total"); let query_time = start_time.elapsed().as_millis() as u64; Ok((documents, total, query_time)) } fn generate_snippets(&self, query: &str, content: Option<&str>, ocr_text: Option<&str>, snippet_length: i32) -> Vec { let mut snippets = Vec::new(); // Combine content and OCR text let full_text = match (content, ocr_text) { (Some(c), Some(o)) => format!("{} {}", c, o), (Some(c), None) => c.to_string(), (None, Some(o)) => o.to_string(), (None, None) => return snippets, }; // Simple keyword matching for snippets (could be enhanced with better search algorithms) let _query_terms: Vec<&str> = query.split_whitespace().collect(); let text_lower = full_text.to_lowercase(); let query_lower = query.to_lowercase(); // Find matches for (i, _) in text_lower.match_indices(&query_lower) { let snippet_start = if i >= snippet_length as usize / 2 { i - snippet_length as usize / 2 } else { 0 }; let snippet_end = std::cmp::min( snippet_start + snippet_length as usize, full_text.len() ); if snippet_start < full_text.len() { let snippet_text = &full_text[snippet_start..snippet_end]; // Find highlight ranges within this snippet let mut highlight_ranges = Vec::new(); let snippet_lower = snippet_text.to_lowercase(); for (match_start, _) in snippet_lower.match_indices(&query_lower) { highlight_ranges.push(HighlightRange { start: match_start as i32, end: (match_start + query.len()) as i32, }); } snippets.push(SearchSnippet { text: snippet_text.to_string(), start_offset: snippet_start as i32, end_offset: snippet_end as i32, highlight_ranges, }); // Limit to a few snippets per document if snippets.len() >= 3 { break; } } } snippets } pub async fn update_document_ocr(&self, id: Uuid, ocr_text: &str) -> Result<()> { sqlx::query("UPDATE documents SET ocr_text = $1, updated_at = NOW() WHERE id = $2") .bind(ocr_text) .bind(id) .execute(&self.pool) .await?; Ok(()) } pub async fn get_all_users(&self) -> Result> { let rows = sqlx::query( "SELECT id, username, email, password_hash, created_at, updated_at FROM users ORDER BY created_at DESC" ) .fetch_all(&self.pool) .await?; let users = rows .into_iter() .map(|row| User { id: row.get("id"), username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), }) .collect(); Ok(users) } pub async fn update_user(&self, id: Uuid, username: Option, email: Option, password: Option) -> Result { let user = self.get_user_by_id(id).await?.ok_or_else(|| anyhow::anyhow!("User not found"))?; let username = username.unwrap_or(user.username); let email = email.unwrap_or(user.email); let password_hash = if let Some(pwd) = password { bcrypt::hash(&pwd, 12)? } else { user.password_hash }; let row = sqlx::query( r#" UPDATE users SET username = $1, email = $2, password_hash = $3, updated_at = NOW() WHERE id = $4 RETURNING id, username, email, password_hash, created_at, updated_at "# ) .bind(&username) .bind(&email) .bind(&password_hash) .bind(id) .fetch_one(&self.pool) .await?; Ok(User { id: row.get("id"), username: row.get("username"), email: row.get("email"), password_hash: row.get("password_hash"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), }) } pub async fn delete_user(&self, id: Uuid) -> Result<()> { sqlx::query("DELETE FROM users WHERE id = $1") .bind(id) .execute(&self.pool) .await?; Ok(()) } pub async fn get_user_settings(&self, user_id: Uuid) -> Result> { let row = sqlx::query( "SELECT id, user_id, ocr_language, created_at, updated_at FROM settings WHERE user_id = $1" ) .bind(user_id) .fetch_optional(&self.pool) .await?; match row { Some(row) => Ok(Some(crate::models::Settings { id: row.get("id"), user_id: row.get("user_id"), ocr_language: row.get("ocr_language"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), })), None => Ok(None), } } pub async fn create_or_update_settings(&self, user_id: Uuid, ocr_language: &str) -> Result { let row = sqlx::query( r#" INSERT INTO settings (user_id, ocr_language) VALUES ($1, $2) ON CONFLICT (user_id) DO UPDATE SET ocr_language = $2, updated_at = NOW() RETURNING id, user_id, ocr_language, created_at, updated_at "# ) .bind(user_id) .bind(ocr_language) .fetch_one(&self.pool) .await?; Ok(crate::models::Settings { id: row.get("id"), user_id: row.get("user_id"), ocr_language: row.get("ocr_language"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), }) } }