fi(backend): migrate python code to rust lol
This commit is contained in:
parent
13ea0b3a5c
commit
aeb98acea8
|
|
@ -1,471 +0,0 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import asyncpg
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from ..database import get_db
|
|
||||||
from ..auth import get_current_user
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/labels", tags=["labels"])
|
|
||||||
|
|
||||||
|
|
||||||
class LabelCreate(BaseModel):
|
|
||||||
name: str = Field(..., max_length=50)
|
|
||||||
description: Optional[str] = None
|
|
||||||
color: str = Field(default="#0969da", regex="^#[0-9a-fA-F]{6}$")
|
|
||||||
background_color: Optional[str] = Field(None, regex="^#[0-9a-fA-F]{6}$")
|
|
||||||
icon: Optional[str] = Field(None, max_length=50)
|
|
||||||
|
|
||||||
|
|
||||||
class LabelUpdate(BaseModel):
|
|
||||||
name: Optional[str] = Field(None, max_length=50)
|
|
||||||
description: Optional[str] = None
|
|
||||||
color: Optional[str] = Field(None, regex="^#[0-9a-fA-F]{6}$")
|
|
||||||
background_color: Optional[str] = Field(None, regex="^#[0-9a-fA-F]{6}$")
|
|
||||||
icon: Optional[str] = Field(None, max_length=50)
|
|
||||||
|
|
||||||
|
|
||||||
class Label(BaseModel):
|
|
||||||
id: UUID
|
|
||||||
user_id: UUID
|
|
||||||
name: str
|
|
||||||
description: Optional[str]
|
|
||||||
color: str
|
|
||||||
background_color: Optional[str]
|
|
||||||
icon: Optional[str]
|
|
||||||
is_system: bool
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
document_count: Optional[int] = 0
|
|
||||||
source_count: Optional[int] = 0
|
|
||||||
|
|
||||||
|
|
||||||
class LabelAssignment(BaseModel):
|
|
||||||
label_ids: List[UUID]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[Label])
|
|
||||||
async def get_labels(
|
|
||||||
include_counts: bool = Query(False, description="Include usage counts"),
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Get all labels available to the current user."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
if include_counts:
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
l.id, l.user_id, l.name, l.description, l.color,
|
|
||||||
l.background_color, l.icon, l.is_system, l.created_at, l.updated_at,
|
|
||||||
COUNT(DISTINCT dl.document_id) as document_count,
|
|
||||||
COUNT(DISTINCT sl.source_id) as source_count
|
|
||||||
FROM labels l
|
|
||||||
LEFT JOIN document_labels dl ON l.id = dl.label_id
|
|
||||||
LEFT JOIN source_labels sl ON l.id = sl.label_id
|
|
||||||
WHERE l.user_id = $1 OR l.is_system = TRUE
|
|
||||||
GROUP BY l.id
|
|
||||||
ORDER BY l.name
|
|
||||||
"""
|
|
||||||
else:
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
id, user_id, name, description, color,
|
|
||||||
background_color, icon, is_system, created_at, updated_at,
|
|
||||||
0 as document_count, 0 as source_count
|
|
||||||
FROM labels
|
|
||||||
WHERE user_id = $1 OR is_system = TRUE
|
|
||||||
ORDER BY name
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows = await db.fetch(query, user_id)
|
|
||||||
return [Label(**row) for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=Label)
|
|
||||||
async def create_label(
|
|
||||||
label: LabelCreate,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Create a new label."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
row = await db.fetchrow(
|
|
||||||
"""
|
|
||||||
INSERT INTO labels (user_id, name, description, color, background_color, icon)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
|
||||||
RETURNING *
|
|
||||||
""",
|
|
||||||
user_id, label.name, label.description, label.color,
|
|
||||||
label.background_color, label.icon
|
|
||||||
)
|
|
||||||
return Label(**row, document_count=0, source_count=0)
|
|
||||||
except asyncpg.UniqueViolationError:
|
|
||||||
raise HTTPException(status_code=400, detail="Label with this name already exists")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{label_id}", response_model=Label)
|
|
||||||
async def get_label(
|
|
||||||
label_id: UUID,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Get a specific label."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
row = await db.fetchrow(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
l.*,
|
|
||||||
COUNT(DISTINCT dl.document_id) as document_count,
|
|
||||||
COUNT(DISTINCT sl.source_id) as source_count
|
|
||||||
FROM labels l
|
|
||||||
LEFT JOIN document_labels dl ON l.id = dl.label_id
|
|
||||||
LEFT JOIN source_labels sl ON l.id = sl.label_id
|
|
||||||
WHERE l.id = $1 AND (l.user_id = $2 OR l.is_system = TRUE)
|
|
||||||
GROUP BY l.id
|
|
||||||
""",
|
|
||||||
label_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not row:
|
|
||||||
raise HTTPException(status_code=404, detail="Label not found")
|
|
||||||
|
|
||||||
return Label(**row)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{label_id}", response_model=Label)
|
|
||||||
async def update_label(
|
|
||||||
label_id: UUID,
|
|
||||||
label: LabelUpdate,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Update a label."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
# Check if label exists and user has permission
|
|
||||||
existing = await db.fetchrow(
|
|
||||||
"SELECT * FROM labels WHERE id = $1 AND user_id = $2 AND is_system = FALSE",
|
|
||||||
label_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not existing:
|
|
||||||
raise HTTPException(status_code=404, detail="Label not found or cannot be modified")
|
|
||||||
|
|
||||||
# Build update query dynamically
|
|
||||||
updates = []
|
|
||||||
values = [label_id]
|
|
||||||
param_index = 2
|
|
||||||
|
|
||||||
if label.name is not None:
|
|
||||||
updates.append(f"name = ${param_index}")
|
|
||||||
values.append(label.name)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if label.description is not None:
|
|
||||||
updates.append(f"description = ${param_index}")
|
|
||||||
values.append(label.description)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if label.color is not None:
|
|
||||||
updates.append(f"color = ${param_index}")
|
|
||||||
values.append(label.color)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if label.background_color is not None:
|
|
||||||
updates.append(f"background_color = ${param_index}")
|
|
||||||
values.append(label.background_color)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if label.icon is not None:
|
|
||||||
updates.append(f"icon = ${param_index}")
|
|
||||||
values.append(label.icon)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if not updates:
|
|
||||||
return Label(**existing, document_count=0, source_count=0)
|
|
||||||
|
|
||||||
updates.append("updated_at = CURRENT_TIMESTAMP")
|
|
||||||
|
|
||||||
try:
|
|
||||||
row = await db.fetchrow(
|
|
||||||
f"""
|
|
||||||
UPDATE labels
|
|
||||||
SET {', '.join(updates)}
|
|
||||||
WHERE id = $1
|
|
||||||
RETURNING *
|
|
||||||
""",
|
|
||||||
*values
|
|
||||||
)
|
|
||||||
return Label(**row, document_count=0, source_count=0)
|
|
||||||
except asyncpg.UniqueViolationError:
|
|
||||||
raise HTTPException(status_code=400, detail="Label with this name already exists")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{label_id}")
|
|
||||||
async def delete_label(
|
|
||||||
label_id: UUID,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Delete a label."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
"DELETE FROM labels WHERE id = $1 AND user_id = $2 AND is_system = FALSE",
|
|
||||||
label_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if result == "DELETE 0":
|
|
||||||
raise HTTPException(status_code=404, detail="Label not found or cannot be deleted")
|
|
||||||
|
|
||||||
return {"message": "Label deleted successfully"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/documents/{document_id}", response_model=List[Label])
|
|
||||||
async def get_document_labels(
|
|
||||||
document_id: UUID,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Get all labels for a document."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
# Verify document ownership
|
|
||||||
doc = await db.fetchrow(
|
|
||||||
"SELECT id FROM documents WHERE id = $1 AND user_id = $2",
|
|
||||||
document_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not doc:
|
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
|
||||||
|
|
||||||
rows = await db.fetch(
|
|
||||||
"""
|
|
||||||
SELECT l.*, 0 as document_count, 0 as source_count
|
|
||||||
FROM labels l
|
|
||||||
INNER JOIN document_labels dl ON l.id = dl.label_id
|
|
||||||
WHERE dl.document_id = $1
|
|
||||||
ORDER BY l.name
|
|
||||||
""",
|
|
||||||
document_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return [Label(**row) for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/documents/{document_id}", response_model=List[Label])
|
|
||||||
async def update_document_labels(
|
|
||||||
document_id: UUID,
|
|
||||||
assignment: LabelAssignment,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Replace all labels for a document."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
# Verify document ownership
|
|
||||||
doc = await db.fetchrow(
|
|
||||||
"SELECT id FROM documents WHERE id = $1 AND user_id = $2",
|
|
||||||
document_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not doc:
|
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
|
||||||
|
|
||||||
# Verify all labels exist and are accessible
|
|
||||||
label_check = await db.fetch(
|
|
||||||
"""
|
|
||||||
SELECT id FROM labels
|
|
||||||
WHERE id = ANY($1::uuid[]) AND (user_id = $2 OR is_system = TRUE)
|
|
||||||
""",
|
|
||||||
assignment.label_ids, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(label_check) != len(assignment.label_ids):
|
|
||||||
raise HTTPException(status_code=400, detail="One or more labels not found")
|
|
||||||
|
|
||||||
async with db.transaction():
|
|
||||||
# Remove existing labels
|
|
||||||
await db.execute(
|
|
||||||
"DELETE FROM document_labels WHERE document_id = $1",
|
|
||||||
document_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add new labels
|
|
||||||
if assignment.label_ids:
|
|
||||||
await db.executemany(
|
|
||||||
"""
|
|
||||||
INSERT INTO document_labels (document_id, label_id, assigned_by)
|
|
||||||
VALUES ($1, $2, $3)
|
|
||||||
""",
|
|
||||||
[(document_id, label_id, user_id) for label_id in assignment.label_ids]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return updated labels
|
|
||||||
return await get_document_labels(document_id, current_user, db)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/documents/{document_id}/labels/{label_id}")
|
|
||||||
async def add_document_label(
|
|
||||||
document_id: UUID,
|
|
||||||
label_id: UUID,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Add a single label to a document."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
# Verify document ownership
|
|
||||||
doc = await db.fetchrow(
|
|
||||||
"SELECT id FROM documents WHERE id = $1 AND user_id = $2",
|
|
||||||
document_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not doc:
|
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
|
||||||
|
|
||||||
# Verify label exists and is accessible
|
|
||||||
label = await db.fetchrow(
|
|
||||||
"SELECT id FROM labels WHERE id = $1 AND (user_id = $2 OR is_system = TRUE)",
|
|
||||||
label_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not label:
|
|
||||||
raise HTTPException(status_code=404, detail="Label not found")
|
|
||||||
|
|
||||||
try:
|
|
||||||
await db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO document_labels (document_id, label_id, assigned_by)
|
|
||||||
VALUES ($1, $2, $3)
|
|
||||||
""",
|
|
||||||
document_id, label_id, user_id
|
|
||||||
)
|
|
||||||
return {"message": "Label added successfully"}
|
|
||||||
except asyncpg.UniqueViolationError:
|
|
||||||
return {"message": "Label already assigned"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/documents/{document_id}/labels/{label_id}")
|
|
||||||
async def remove_document_label(
|
|
||||||
document_id: UUID,
|
|
||||||
label_id: UUID,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Remove a single label from a document."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
# Verify document ownership
|
|
||||||
doc = await db.fetchrow(
|
|
||||||
"SELECT id FROM documents WHERE id = $1 AND user_id = $2",
|
|
||||||
document_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not doc:
|
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
"DELETE FROM document_labels WHERE document_id = $1 AND label_id = $2",
|
|
||||||
document_id, label_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if result == "DELETE 0":
|
|
||||||
raise HTTPException(status_code=404, detail="Label not found on document")
|
|
||||||
|
|
||||||
return {"message": "Label removed successfully"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bulk/documents", response_model=dict)
|
|
||||||
async def bulk_update_document_labels(
|
|
||||||
document_ids: List[UUID],
|
|
||||||
assignment: LabelAssignment,
|
|
||||||
mode: str = Query("replace", regex="^(replace|add|remove)$"),
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
db: asyncpg.Connection = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Bulk update labels for multiple documents."""
|
|
||||||
user_id = current_user["id"]
|
|
||||||
|
|
||||||
# Verify document ownership
|
|
||||||
docs = await db.fetch(
|
|
||||||
"SELECT id FROM documents WHERE id = ANY($1::uuid[]) AND user_id = $2",
|
|
||||||
document_ids, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(docs) != len(document_ids):
|
|
||||||
raise HTTPException(status_code=400, detail="One or more documents not found")
|
|
||||||
|
|
||||||
# Verify labels
|
|
||||||
if assignment.label_ids:
|
|
||||||
label_check = await db.fetch(
|
|
||||||
"""
|
|
||||||
SELECT id FROM labels
|
|
||||||
WHERE id = ANY($1::uuid[]) AND (user_id = $2 OR is_system = TRUE)
|
|
||||||
""",
|
|
||||||
assignment.label_ids, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(label_check) != len(assignment.label_ids):
|
|
||||||
raise HTTPException(status_code=400, detail="One or more labels not found")
|
|
||||||
|
|
||||||
async with db.transaction():
|
|
||||||
if mode == "replace":
|
|
||||||
# Remove all existing labels
|
|
||||||
await db.execute(
|
|
||||||
"DELETE FROM document_labels WHERE document_id = ANY($1::uuid[])",
|
|
||||||
document_ids
|
|
||||||
)
|
|
||||||
# Add new labels
|
|
||||||
if assignment.label_ids:
|
|
||||||
values = [
|
|
||||||
(doc_id, label_id, user_id)
|
|
||||||
for doc_id in document_ids
|
|
||||||
for label_id in assignment.label_ids
|
|
||||||
]
|
|
||||||
await db.executemany(
|
|
||||||
"""
|
|
||||||
INSERT INTO document_labels (document_id, label_id, assigned_by)
|
|
||||||
VALUES ($1, $2, $3)
|
|
||||||
ON CONFLICT DO NOTHING
|
|
||||||
""",
|
|
||||||
values
|
|
||||||
)
|
|
||||||
|
|
||||||
elif mode == "add":
|
|
||||||
if assignment.label_ids:
|
|
||||||
values = [
|
|
||||||
(doc_id, label_id, user_id)
|
|
||||||
for doc_id in document_ids
|
|
||||||
for label_id in assignment.label_ids
|
|
||||||
]
|
|
||||||
await db.executemany(
|
|
||||||
"""
|
|
||||||
INSERT INTO document_labels (document_id, label_id, assigned_by)
|
|
||||||
VALUES ($1, $2, $3)
|
|
||||||
ON CONFLICT DO NOTHING
|
|
||||||
""",
|
|
||||||
values
|
|
||||||
)
|
|
||||||
|
|
||||||
elif mode == "remove":
|
|
||||||
if assignment.label_ids:
|
|
||||||
await db.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM document_labels
|
|
||||||
WHERE document_id = ANY($1::uuid[])
|
|
||||||
AND label_id = ANY($2::uuid[])
|
|
||||||
""",
|
|
||||||
document_ids, assignment.label_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message": f"Labels {mode}d successfully",
|
|
||||||
"documents_updated": len(document_ids)
|
|
||||||
}
|
|
||||||
|
|
@ -67,7 +67,9 @@ pub struct LabelQuery {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, ToSchema)]
|
#[derive(Debug, Deserialize, ToSchema)]
|
||||||
pub struct BulkUpdateMode {
|
pub struct BulkUpdateRequest {
|
||||||
|
pub document_ids: Vec<Uuid>,
|
||||||
|
pub label_ids: Vec<Uuid>,
|
||||||
#[serde(default = "default_bulk_mode")]
|
#[serde(default = "default_bulk_mode")]
|
||||||
pub mode: String, // "replace", "add", or "remove"
|
pub mode: String, // "replace", "add", or "remove"
|
||||||
}
|
}
|
||||||
|
|
@ -688,26 +690,147 @@ pub async fn remove_document_label(
|
||||||
path = "/api/labels/bulk/documents",
|
path = "/api/labels/bulk/documents",
|
||||||
tag = "labels",
|
tag = "labels",
|
||||||
security(("bearer_auth" = [])),
|
security(("bearer_auth" = [])),
|
||||||
request_body = LabelAssignment,
|
request_body = BulkUpdateRequest,
|
||||||
params(
|
|
||||||
("mode" = String, Query, description = "Operation mode: replace, add, or remove")
|
|
||||||
),
|
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Bulk operation completed"),
|
(status = 200, description = "Bulk operation completed"),
|
||||||
(status = 400, description = "Invalid input"),
|
(status = 400, description = "Invalid input"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn bulk_update_document_labels(
|
pub async fn bulk_update_document_labels(
|
||||||
Query(mode_query): Query<BulkUpdateMode>,
|
State(state): State<Arc<AppState>>,
|
||||||
State(_state): State<Arc<AppState>>,
|
auth_user: AuthUser,
|
||||||
_auth_user: AuthUser,
|
Json(payload): Json<BulkUpdateRequest>,
|
||||||
Json(_payload): Json<BulkUpdateMode>, // This should actually be a combined payload
|
|
||||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||||
// Note: This is a simplified implementation. In a real scenario, you'd want
|
let user_id = auth_user.user.id;
|
||||||
// a more complex payload structure that includes both document_ids and label_ids
|
|
||||||
// For now, returning a placeholder response
|
// Validate mode
|
||||||
|
if !["replace", "add", "remove"].contains(&payload.mode.as_str()) {
|
||||||
|
return Err(StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify document ownership
|
||||||
|
let doc_count = sqlx::query(
|
||||||
|
"SELECT COUNT(*) as count FROM documents WHERE id = ANY($1) AND user_id = $2"
|
||||||
|
)
|
||||||
|
.bind(&payload.document_ids)
|
||||||
|
.bind(user_id)
|
||||||
|
.fetch_one(state.db.get_pool())
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to verify document ownership: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let count: i64 = doc_count.try_get("count").unwrap_or(0);
|
||||||
|
if count as usize != payload.document_ids.len() {
|
||||||
|
return Err(StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify labels exist and are accessible (if any labels provided)
|
||||||
|
if !payload.label_ids.is_empty() {
|
||||||
|
let label_count = sqlx::query(
|
||||||
|
"SELECT COUNT(*) as count FROM labels WHERE id = ANY($1) AND (user_id = $2 OR is_system = TRUE)"
|
||||||
|
)
|
||||||
|
.bind(&payload.label_ids)
|
||||||
|
.bind(user_id)
|
||||||
|
.fetch_one(state.db.get_pool())
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to verify labels: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let count: i64 = label_count.try_get("count").unwrap_or(0);
|
||||||
|
if count as usize != payload.label_ids.len() {
|
||||||
|
return Err(StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin transaction
|
||||||
|
let mut tx = state.db.get_pool().begin().await.map_err(|e| {
|
||||||
|
tracing::error!("Failed to begin transaction: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
|
match payload.mode.as_str() {
|
||||||
|
"replace" => {
|
||||||
|
// Remove all existing labels for these documents
|
||||||
|
sqlx::query(
|
||||||
|
"DELETE FROM document_labels WHERE document_id = ANY($1)"
|
||||||
|
)
|
||||||
|
.bind(&payload.document_ids)
|
||||||
|
.execute(&mut *tx)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to remove existing labels: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Add new labels
|
||||||
|
if !payload.label_ids.is_empty() {
|
||||||
|
for document_id in &payload.document_ids {
|
||||||
|
for label_id in &payload.label_ids {
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO document_labels (document_id, label_id, assigned_by) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING"
|
||||||
|
)
|
||||||
|
.bind(document_id)
|
||||||
|
.bind(label_id)
|
||||||
|
.bind(user_id)
|
||||||
|
.execute(&mut *tx)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to add label: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"add" => {
|
||||||
|
if !payload.label_ids.is_empty() {
|
||||||
|
for document_id in &payload.document_ids {
|
||||||
|
for label_id in &payload.label_ids {
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO document_labels (document_id, label_id, assigned_by) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING"
|
||||||
|
)
|
||||||
|
.bind(document_id)
|
||||||
|
.bind(label_id)
|
||||||
|
.bind(user_id)
|
||||||
|
.execute(&mut *tx)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to add label: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"remove" => {
|
||||||
|
if !payload.label_ids.is_empty() {
|
||||||
|
sqlx::query(
|
||||||
|
"DELETE FROM document_labels WHERE document_id = ANY($1) AND label_id = ANY($2)"
|
||||||
|
)
|
||||||
|
.bind(&payload.document_ids)
|
||||||
|
.bind(&payload.label_ids)
|
||||||
|
.execute(&mut *tx)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Failed to remove labels: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => return Err(StatusCode::BAD_REQUEST)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.commit().await.map_err(|e| {
|
||||||
|
tracing::error!("Failed to commit transaction: {}", e);
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(Json(serde_json::json!({
|
Ok(Json(serde_json::json!({
|
||||||
"message": "Bulk update functionality placeholder",
|
"message": format!("Labels {}d successfully", payload.mode),
|
||||||
"mode": mode_query.mode
|
"documents_updated": payload.document_ids.len()
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
Loading…
Reference in New Issue