|
11 | 11 | from uuid import uuid4 # noqa: E402 |
12 | 12 |
|
13 | 13 | from azure.cosmos.aio import CosmosClient # noqa: E402 |
14 | | -from azure.cosmos.exceptions import CosmosResourceExistsError # noqa: E402 |
| 14 | +from azure.cosmos.exceptions import CosmosResourceExistsError, CosmosResourceNotFoundError # noqa: E402 |
15 | 15 |
|
16 | 16 | from common.database.cosmosdb import ( # noqa: E402 |
17 | 17 | CosmosDBClient, |
@@ -188,6 +188,84 @@ async def test_create_batch_exists(cosmos_db_client, mocker): |
188 | 188 | ) |
189 | 189 |
|
190 | 190 |
|
| 191 | +@pytest.mark.asyncio |
| 192 | +async def test_create_batch_conflict_retry_on_404(cosmos_db_client, mocker): |
| 193 | + """Test that read_item is retried when it returns 404 after a 409 conflict.""" |
| 194 | + user_id = "user_1" |
| 195 | + batch_id = uuid4() |
| 196 | + |
| 197 | + mock_batch_container = mock.MagicMock() |
| 198 | + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) |
| 199 | + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) |
| 200 | + |
| 201 | + existing_batch = { |
| 202 | + "id": str(batch_id), |
| 203 | + "batch_id": str(batch_id), |
| 204 | + "user_id": user_id, |
| 205 | + "file_count": 0, |
| 206 | + "created_at": datetime.now(timezone.utc).isoformat(), |
| 207 | + "updated_at": datetime.now(timezone.utc).isoformat(), |
| 208 | + "status": ProcessStatus.READY_TO_PROCESS, |
| 209 | + } |
| 210 | + |
| 211 | + # First call raises 404, second call succeeds |
| 212 | + mock_batch_container.read_item = AsyncMock( |
| 213 | + side_effect=[CosmosResourceNotFoundError(message="Not found"), existing_batch] |
| 214 | + ) |
| 215 | + |
| 216 | + batch = await cosmos_db_client.create_batch(user_id, batch_id) |
| 217 | + |
| 218 | + assert batch.batch_id == batch_id |
| 219 | + assert batch.user_id == user_id |
| 220 | + assert mock_batch_container.read_item.call_count == 2 |
| 221 | + |
| 222 | + |
| 223 | +@pytest.mark.asyncio |
| 224 | +async def test_create_batch_conflict_cross_user(cosmos_db_client, mocker): |
| 225 | + """Test that a PermissionError is raised when the batch belongs to a different user.""" |
| 226 | + user_id = "user_1" |
| 227 | + batch_id = uuid4() |
| 228 | + |
| 229 | + mock_batch_container = mock.MagicMock() |
| 230 | + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) |
| 231 | + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) |
| 232 | + |
| 233 | + existing_batch = { |
| 234 | + "id": str(batch_id), |
| 235 | + "batch_id": str(batch_id), |
| 236 | + "user_id": "different_user", |
| 237 | + "file_count": 0, |
| 238 | + "created_at": datetime.now(timezone.utc).isoformat(), |
| 239 | + "updated_at": datetime.now(timezone.utc).isoformat(), |
| 240 | + "status": ProcessStatus.READY_TO_PROCESS, |
| 241 | + } |
| 242 | + mock_batch_container.read_item = AsyncMock(return_value=existing_batch) |
| 243 | + |
| 244 | + with pytest.raises(PermissionError, match="Batch not found"): |
| 245 | + await cosmos_db_client.create_batch(user_id, batch_id) |
| 246 | + |
| 247 | + |
| 248 | +@pytest.mark.asyncio |
| 249 | +async def test_create_batch_conflict_exhausted_retries(cosmos_db_client, mocker): |
| 250 | + """Test that RuntimeError is raised when read_item returns 404 after all retries.""" |
| 251 | + user_id = "user_1" |
| 252 | + batch_id = uuid4() |
| 253 | + |
| 254 | + mock_batch_container = mock.MagicMock() |
| 255 | + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) |
| 256 | + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) |
| 257 | + |
| 258 | + # All 3 attempts raise 404 |
| 259 | + mock_batch_container.read_item = AsyncMock( |
| 260 | + side_effect=CosmosResourceNotFoundError(message="Not found") |
| 261 | + ) |
| 262 | + |
| 263 | + with pytest.raises(RuntimeError, match="already exists but could not be read after retries"): |
| 264 | + await cosmos_db_client.create_batch(user_id, batch_id) |
| 265 | + |
| 266 | + assert mock_batch_container.read_item.call_count == 3 |
| 267 | + |
| 268 | + |
191 | 269 | @pytest.mark.asyncio |
192 | 270 | async def test_create_batch_exception(cosmos_db_client, mocker): |
193 | 271 | user_id = "user_1" |
@@ -487,6 +565,8 @@ async def mock_query_items(query, parameters, **kwargs): |
487 | 565 | assert file["status"] == ProcessStatus.READY_TO_PROCESS |
488 | 566 |
|
489 | 567 | mock_file_container.query_items.assert_called_once() |
| 568 | + call_kwargs = mock_file_container.query_items.call_args |
| 569 | + assert call_kwargs.kwargs.get("partition_key") == file_id |
490 | 570 |
|
491 | 571 |
|
492 | 572 | @pytest.mark.asyncio |
@@ -612,6 +692,8 @@ async def mock_query_items(query, parameters, **kwargs): |
612 | 692 | assert batch["status"] == ProcessStatus.READY_TO_PROCESS |
613 | 693 |
|
614 | 694 | mock_batch_container.query_items.assert_called_once() |
| 695 | + call_kwargs = mock_batch_container.query_items.call_args |
| 696 | + assert call_kwargs.kwargs.get("partition_key") == batch_id |
615 | 697 |
|
616 | 698 |
|
617 | 699 | @pytest.mark.asyncio |
|
0 commit comments