diff --git a/src/backend/common/database/cosmosdb.py b/src/backend/common/database/cosmosdb.py index 46c020d0..8c56286a 100644 --- a/src/backend/common/database/cosmosdb.py +++ b/src/backend/common/database/cosmosdb.py @@ -94,19 +94,21 @@ async def create_batch(self, user_id: str, batch_id: UUID) -> BatchRecord: batchexists = await self.batch_container.read_item( item=str(batch_id), partition_key=str(batch_id) ) - if batchexists.get("user_id") != user_id: - self.logger.error("Batch belongs to a different user", batch_id=str(batch_id)) - raise CosmosResourceNotFoundError(message="Batch not found") - self.logger.info("Returning existing batch record", batch_id=str(batch_id)) - return BatchRecord.fromdb(batchexists) except CosmosResourceNotFoundError: if attempt < 2: self.logger.info("Batch read returned 404 after conflict, retrying", batch_id=str(batch_id), attempt=attempt + 1) await asyncio.sleep(0.5 * (attempt + 1)) - else: - raise RuntimeError( - f"Batch {batch_id} already exists but could not be read after retries" - ) + continue + raise RuntimeError( + f"Batch {batch_id} already exists but could not be read after retries" + ) + + if batchexists.get("user_id") != user_id: + self.logger.error("Batch belongs to a different user", batch_id=str(batch_id)) + raise PermissionError("Batch not found") + + self.logger.info("Returning existing batch record", batch_id=str(batch_id)) + return BatchRecord.fromdb(batchexists) except Exception as e: self.logger.error("Failed to create batch", error=str(e)) diff --git a/src/backend/common/services/batch_service.py b/src/backend/common/services/batch_service.py index 27a874e0..efce27fa 100644 --- a/src/backend/common/services/batch_service.py +++ b/src/backend/common/services/batch_service.py @@ -288,7 +288,7 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi self.logger.info("File uploaded to blob storage", filename=file.filename, batch_id=batch_id) # Create file entry - file_record_obj = await self.database.add_file(batch_id, file_id, file.filename, blob_path) + file_record_obj = await self.database.add_file(UUID(batch_id), UUID(file_id), file.filename, blob_path) file_record_dict = getattr(file_record_obj, "dict", None) file_record = file_record_dict() if callable(file_record_dict) else file_record_obj diff --git a/src/tests/backend/common/database/cosmosdb_test.py b/src/tests/backend/common/database/cosmosdb_test.py index d08fb015..faaaaa09 100644 --- a/src/tests/backend/common/database/cosmosdb_test.py +++ b/src/tests/backend/common/database/cosmosdb_test.py @@ -11,7 +11,7 @@ from uuid import uuid4 # noqa: E402 from azure.cosmos.aio import CosmosClient # noqa: E402 -from azure.cosmos.exceptions import CosmosResourceExistsError # noqa: E402 +from azure.cosmos.exceptions import CosmosResourceExistsError, CosmosResourceNotFoundError # noqa: E402 from common.database.cosmosdb import ( # noqa: E402 CosmosDBClient, @@ -188,6 +188,84 @@ async def test_create_batch_exists(cosmos_db_client, mocker): ) +@pytest.mark.asyncio +async def test_create_batch_conflict_retry_on_404(cosmos_db_client, mocker): + """Test that read_item is retried when it returns 404 after a 409 conflict.""" + user_id = "user_1" + batch_id = uuid4() + + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) + + existing_batch = { + "id": str(batch_id), + "batch_id": str(batch_id), + "user_id": user_id, + "file_count": 0, + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + "status": ProcessStatus.READY_TO_PROCESS, + } + + # First call raises 404, second call succeeds + mock_batch_container.read_item = AsyncMock( + side_effect=[CosmosResourceNotFoundError(message="Not found"), existing_batch] + ) + + batch = await cosmos_db_client.create_batch(user_id, batch_id) + + assert batch.batch_id == batch_id + assert batch.user_id == user_id + assert mock_batch_container.read_item.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_batch_conflict_cross_user(cosmos_db_client, mocker): + """Test that a PermissionError is raised when the batch belongs to a different user.""" + user_id = "user_1" + batch_id = uuid4() + + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) + + existing_batch = { + "id": str(batch_id), + "batch_id": str(batch_id), + "user_id": "different_user", + "file_count": 0, + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + "status": ProcessStatus.READY_TO_PROCESS, + } + mock_batch_container.read_item = AsyncMock(return_value=existing_batch) + + with pytest.raises(PermissionError, match="Batch not found"): + await cosmos_db_client.create_batch(user_id, batch_id) + + +@pytest.mark.asyncio +async def test_create_batch_conflict_exhausted_retries(cosmos_db_client, mocker): + """Test that RuntimeError is raised when read_item returns 404 after all retries.""" + user_id = "user_1" + batch_id = uuid4() + + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) + + # All 3 attempts raise 404 + mock_batch_container.read_item = AsyncMock( + side_effect=CosmosResourceNotFoundError(message="Not found") + ) + + with pytest.raises(RuntimeError, match="already exists but could not be read after retries"): + await cosmos_db_client.create_batch(user_id, batch_id) + + assert mock_batch_container.read_item.call_count == 3 + + @pytest.mark.asyncio async def test_create_batch_exception(cosmos_db_client, mocker): user_id = "user_1" @@ -487,6 +565,8 @@ async def mock_query_items(query, parameters, **kwargs): assert file["status"] == ProcessStatus.READY_TO_PROCESS mock_file_container.query_items.assert_called_once() + call_kwargs = mock_file_container.query_items.call_args + assert call_kwargs.kwargs.get("partition_key") == file_id @pytest.mark.asyncio @@ -612,6 +692,8 @@ async def mock_query_items(query, parameters, **kwargs): assert batch["status"] == ProcessStatus.READY_TO_PROCESS mock_batch_container.query_items.assert_called_once() + call_kwargs = mock_batch_container.query_items.call_args + assert call_kwargs.kwargs.get("partition_key") == batch_id @pytest.mark.asyncio