Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/backend/common/database/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,21 @@ async def create_batch(self, user_id: str, batch_id: UUID) -> BatchRecord:
batchexists = await self.batch_container.read_item(
item=str(batch_id), partition_key=str(batch_id)
)
if batchexists.get("user_id") != user_id:
self.logger.error("Batch belongs to a different user", batch_id=str(batch_id))
raise CosmosResourceNotFoundError(message="Batch not found")
self.logger.info("Returning existing batch record", batch_id=str(batch_id))
return BatchRecord.fromdb(batchexists)
except CosmosResourceNotFoundError:
if attempt < 2:
self.logger.info("Batch read returned 404 after conflict, retrying", batch_id=str(batch_id), attempt=attempt + 1)
await asyncio.sleep(0.5 * (attempt + 1))
else:
raise RuntimeError(
f"Batch {batch_id} already exists but could not be read after retries"
)
continue
raise RuntimeError(
f"Batch {batch_id} already exists but could not be read after retries"
)

if batchexists.get("user_id") != user_id:
self.logger.error("Batch belongs to a different user", batch_id=str(batch_id))
raise PermissionError("Batch not found")

self.logger.info("Returning existing batch record", batch_id=str(batch_id))
return BatchRecord.fromdb(batchexists)

except Exception as e:
self.logger.error("Failed to create batch", error=str(e))
Expand Down
2 changes: 1 addition & 1 deletion src/backend/common/services/batch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
self.logger.info("File uploaded to blob storage", filename=file.filename, batch_id=batch_id)

# Create file entry
file_record_obj = await self.database.add_file(batch_id, file_id, file.filename, blob_path)
file_record_obj = await self.database.add_file(UUID(batch_id), UUID(file_id), file.filename, blob_path)
file_record_dict = getattr(file_record_obj, "dict", None)
file_record = file_record_dict() if callable(file_record_dict) else file_record_obj

Expand Down
84 changes: 83 additions & 1 deletion src/tests/backend/common/database/cosmosdb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from uuid import uuid4 # noqa: E402

from azure.cosmos.aio import CosmosClient # noqa: E402
from azure.cosmos.exceptions import CosmosResourceExistsError # noqa: E402
from azure.cosmos.exceptions import CosmosResourceExistsError, CosmosResourceNotFoundError # noqa: E402

from common.database.cosmosdb import ( # noqa: E402
CosmosDBClient,
Expand Down Expand Up @@ -188,6 +188,84 @@ async def test_create_batch_exists(cosmos_db_client, mocker):
)


@pytest.mark.asyncio
async def test_create_batch_conflict_retry_on_404(cosmos_db_client, mocker):
"""Test that read_item is retried when it returns 404 after a 409 conflict."""
user_id = "user_1"
batch_id = uuid4()

mock_batch_container = mock.MagicMock()
mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)

existing_batch = {
"id": str(batch_id),
"batch_id": str(batch_id),
"user_id": user_id,
"file_count": 0,
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"status": ProcessStatus.READY_TO_PROCESS,
}

# First call raises 404, second call succeeds
mock_batch_container.read_item = AsyncMock(
side_effect=[CosmosResourceNotFoundError(message="Not found"), existing_batch]
)

batch = await cosmos_db_client.create_batch(user_id, batch_id)

assert batch.batch_id == batch_id
assert batch.user_id == user_id
assert mock_batch_container.read_item.call_count == 2


@pytest.mark.asyncio
async def test_create_batch_conflict_cross_user(cosmos_db_client, mocker):
"""Test that a PermissionError is raised when the batch belongs to a different user."""
user_id = "user_1"
batch_id = uuid4()

mock_batch_container = mock.MagicMock()
mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)

existing_batch = {
"id": str(batch_id),
"batch_id": str(batch_id),
"user_id": "different_user",
"file_count": 0,
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"status": ProcessStatus.READY_TO_PROCESS,
}
mock_batch_container.read_item = AsyncMock(return_value=existing_batch)

with pytest.raises(PermissionError, match="Batch not found"):
await cosmos_db_client.create_batch(user_id, batch_id)


@pytest.mark.asyncio
async def test_create_batch_conflict_exhausted_retries(cosmos_db_client, mocker):
"""Test that RuntimeError is raised when read_item returns 404 after all retries."""
user_id = "user_1"
batch_id = uuid4()

mock_batch_container = mock.MagicMock()
mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)

# All 3 attempts raise 404
mock_batch_container.read_item = AsyncMock(
side_effect=CosmosResourceNotFoundError(message="Not found")
)

with pytest.raises(RuntimeError, match="already exists but could not be read after retries"):
await cosmos_db_client.create_batch(user_id, batch_id)

assert mock_batch_container.read_item.call_count == 3


@pytest.mark.asyncio
async def test_create_batch_exception(cosmos_db_client, mocker):
user_id = "user_1"
Expand Down Expand Up @@ -487,6 +565,8 @@ async def mock_query_items(query, parameters, **kwargs):
assert file["status"] == ProcessStatus.READY_TO_PROCESS

mock_file_container.query_items.assert_called_once()
call_kwargs = mock_file_container.query_items.call_args
assert call_kwargs.kwargs.get("partition_key") == file_id


@pytest.mark.asyncio
Expand Down Expand Up @@ -612,6 +692,8 @@ async def mock_query_items(query, parameters, **kwargs):
assert batch["status"] == ProcessStatus.READY_TO_PROCESS

mock_batch_container.query_items.assert_called_once()
call_kwargs = mock_batch_container.query_items.call_args
assert call_kwargs.kwargs.get("partition_key") == batch_id


@pytest.mark.asyncio
Expand Down
Loading