Skip to content

Commit 75e0523

Browse files
Merge pull request #399 from microsoft/psl-pk-cosavmfix
fix: fix copilot comments - enhance batch handling in CosmosDBClient and add tests for conflict scenarios
2 parents 1d4a220 + fc8de18 commit 75e0523

3 files changed

Lines changed: 95 additions & 11 deletions

File tree

src/backend/common/database/cosmosdb.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,21 @@ async def create_batch(self, user_id: str, batch_id: UUID) -> BatchRecord:
9494
batchexists = await self.batch_container.read_item(
9595
item=str(batch_id), partition_key=str(batch_id)
9696
)
97-
if batchexists.get("user_id") != user_id:
98-
self.logger.error("Batch belongs to a different user", batch_id=str(batch_id))
99-
raise CosmosResourceNotFoundError(message="Batch not found")
100-
self.logger.info("Returning existing batch record", batch_id=str(batch_id))
101-
return BatchRecord.fromdb(batchexists)
10297
except CosmosResourceNotFoundError:
10398
if attempt < 2:
10499
self.logger.info("Batch read returned 404 after conflict, retrying", batch_id=str(batch_id), attempt=attempt + 1)
105100
await asyncio.sleep(0.5 * (attempt + 1))
106-
else:
107-
raise RuntimeError(
108-
f"Batch {batch_id} already exists but could not be read after retries"
109-
)
101+
continue
102+
raise RuntimeError(
103+
f"Batch {batch_id} already exists but could not be read after retries"
104+
)
105+
106+
if batchexists.get("user_id") != user_id:
107+
self.logger.error("Batch belongs to a different user", batch_id=str(batch_id))
108+
raise PermissionError("Batch not found")
109+
110+
self.logger.info("Returning existing batch record", batch_id=str(batch_id))
111+
return BatchRecord.fromdb(batchexists)
110112

111113
except Exception as e:
112114
self.logger.error("Failed to create batch", error=str(e))

src/backend/common/services/batch_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
288288
self.logger.info("File uploaded to blob storage", filename=file.filename, batch_id=batch_id)
289289

290290
# Create file entry
291-
file_record_obj = await self.database.add_file(batch_id, file_id, file.filename, blob_path)
291+
file_record_obj = await self.database.add_file(UUID(batch_id), UUID(file_id), file.filename, blob_path)
292292
file_record_dict = getattr(file_record_obj, "dict", None)
293293
file_record = file_record_dict() if callable(file_record_dict) else file_record_obj
294294

src/tests/backend/common/database/cosmosdb_test.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from uuid import uuid4 # noqa: E402
1212

1313
from azure.cosmos.aio import CosmosClient # noqa: E402
14-
from azure.cosmos.exceptions import CosmosResourceExistsError # noqa: E402
14+
from azure.cosmos.exceptions import CosmosResourceExistsError, CosmosResourceNotFoundError # noqa: E402
1515

1616
from common.database.cosmosdb import ( # noqa: E402
1717
CosmosDBClient,
@@ -188,6 +188,84 @@ async def test_create_batch_exists(cosmos_db_client, mocker):
188188
)
189189

190190

191+
@pytest.mark.asyncio
192+
async def test_create_batch_conflict_retry_on_404(cosmos_db_client, mocker):
193+
"""Test that read_item is retried when it returns 404 after a 409 conflict."""
194+
user_id = "user_1"
195+
batch_id = uuid4()
196+
197+
mock_batch_container = mock.MagicMock()
198+
mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
199+
mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)
200+
201+
existing_batch = {
202+
"id": str(batch_id),
203+
"batch_id": str(batch_id),
204+
"user_id": user_id,
205+
"file_count": 0,
206+
"created_at": datetime.now(timezone.utc).isoformat(),
207+
"updated_at": datetime.now(timezone.utc).isoformat(),
208+
"status": ProcessStatus.READY_TO_PROCESS,
209+
}
210+
211+
# First call raises 404, second call succeeds
212+
mock_batch_container.read_item = AsyncMock(
213+
side_effect=[CosmosResourceNotFoundError(message="Not found"), existing_batch]
214+
)
215+
216+
batch = await cosmos_db_client.create_batch(user_id, batch_id)
217+
218+
assert batch.batch_id == batch_id
219+
assert batch.user_id == user_id
220+
assert mock_batch_container.read_item.call_count == 2
221+
222+
223+
@pytest.mark.asyncio
224+
async def test_create_batch_conflict_cross_user(cosmos_db_client, mocker):
225+
"""Test that a PermissionError is raised when the batch belongs to a different user."""
226+
user_id = "user_1"
227+
batch_id = uuid4()
228+
229+
mock_batch_container = mock.MagicMock()
230+
mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
231+
mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)
232+
233+
existing_batch = {
234+
"id": str(batch_id),
235+
"batch_id": str(batch_id),
236+
"user_id": "different_user",
237+
"file_count": 0,
238+
"created_at": datetime.now(timezone.utc).isoformat(),
239+
"updated_at": datetime.now(timezone.utc).isoformat(),
240+
"status": ProcessStatus.READY_TO_PROCESS,
241+
}
242+
mock_batch_container.read_item = AsyncMock(return_value=existing_batch)
243+
244+
with pytest.raises(PermissionError, match="Batch not found"):
245+
await cosmos_db_client.create_batch(user_id, batch_id)
246+
247+
248+
@pytest.mark.asyncio
249+
async def test_create_batch_conflict_exhausted_retries(cosmos_db_client, mocker):
250+
"""Test that RuntimeError is raised when read_item returns 404 after all retries."""
251+
user_id = "user_1"
252+
batch_id = uuid4()
253+
254+
mock_batch_container = mock.MagicMock()
255+
mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
256+
mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)
257+
258+
# All 3 attempts raise 404
259+
mock_batch_container.read_item = AsyncMock(
260+
side_effect=CosmosResourceNotFoundError(message="Not found")
261+
)
262+
263+
with pytest.raises(RuntimeError, match="already exists but could not be read after retries"):
264+
await cosmos_db_client.create_batch(user_id, batch_id)
265+
266+
assert mock_batch_container.read_item.call_count == 3
267+
268+
191269
@pytest.mark.asyncio
192270
async def test_create_batch_exception(cosmos_db_client, mocker):
193271
user_id = "user_1"
@@ -487,6 +565,8 @@ async def mock_query_items(query, parameters, **kwargs):
487565
assert file["status"] == ProcessStatus.READY_TO_PROCESS
488566

489567
mock_file_container.query_items.assert_called_once()
568+
call_kwargs = mock_file_container.query_items.call_args
569+
assert call_kwargs.kwargs.get("partition_key") == file_id
490570

491571

492572
@pytest.mark.asyncio
@@ -612,6 +692,8 @@ async def mock_query_items(query, parameters, **kwargs):
612692
assert batch["status"] == ProcessStatus.READY_TO_PROCESS
613693

614694
mock_batch_container.query_items.assert_called_once()
695+
call_kwargs = mock_batch_container.query_items.call_args
696+
assert call_kwargs.kwargs.get("partition_key") == batch_id
615697

616698

617699
@pytest.mark.asyncio

0 commit comments

Comments
 (0)