Skip to content

Commit 4b70934

Browse files
Merge pull request #392 from microsoft/psl-pk-cosavmfix
fix: Handle Cosmos DB replication lag during concurrent batch creation
2 parents 0bbd905 + 27e2f53 commit 4b70934

4 files changed

Lines changed: 84 additions & 40 deletions

File tree

src/backend/common/database/cosmosdb.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import asyncio
12
from datetime import datetime, timezone
23
from typing import Dict, List, Optional
34
from uuid import UUID, uuid4
45

56
from azure.cosmos.aio import CosmosClient
67
from azure.cosmos.aio._database import DatabaseProxy
78
from azure.cosmos.exceptions import (
8-
CosmosResourceExistsError
9+
CosmosResourceExistsError,
10+
CosmosResourceNotFoundError,
911
)
1012

1113
from common.database.database_base import DatabaseBase
@@ -85,9 +87,26 @@ async def create_batch(self, user_id: str, batch_id: UUID) -> BatchRecord:
8587
await self.batch_container.create_item(body=batch.dict())
8688
return batch
8789
except CosmosResourceExistsError:
88-
self.logger.info(f"Batch with ID {batch_id} already exists")
89-
batchexists = await self.get_batch(user_id, str(batch_id))
90-
return batchexists
90+
self.logger.info("Batch already exists, reading existing record", batch_id=str(batch_id))
91+
# Retry read with backoff to handle replication lag after 409 conflict
92+
for attempt in range(3):
93+
try:
94+
batchexists = await self.batch_container.read_item(
95+
item=str(batch_id), partition_key=str(batch_id)
96+
)
97+
if batchexists.get("user_id") != user_id:
98+
self.logger.error("Batch belongs to a different user", batch_id=str(batch_id))
99+
raise CosmosResourceNotFoundError(message="Batch not found")
100+
self.logger.info("Returning existing batch record", batch_id=str(batch_id))
101+
return BatchRecord.fromdb(batchexists)
102+
except CosmosResourceNotFoundError:
103+
if attempt < 2:
104+
self.logger.info("Batch read returned 404 after conflict, retrying", batch_id=str(batch_id), attempt=attempt + 1)
105+
await asyncio.sleep(0.5 * (attempt + 1))
106+
else:
107+
raise RuntimeError(
108+
f"Batch {batch_id} already exists but could not be read after retries"
109+
)
91110

92111
except Exception as e:
93112
self.logger.error("Failed to create batch", error=str(e))
@@ -158,7 +177,7 @@ async def get_batch(self, user_id: str, batch_id: str) -> Optional[Dict]:
158177
]
159178
batch = None
160179
async for item in self.batch_container.query_items(
161-
query=query, parameters=params
180+
query=query, parameters=params, partition_key=batch_id
162181
):
163182
batch = item
164183

@@ -173,7 +192,7 @@ async def get_file(self, file_id: str) -> Optional[Dict]:
173192
params = [{"name": "@file_id", "value": file_id}]
174193
file_entry = None
175194
async for item in self.file_container.query_items(
176-
query=query, parameters=params
195+
query=query, parameters=params, partition_key=file_id
177196
):
178197
file_entry = item
179198
return file_entry
@@ -209,7 +228,7 @@ async def get_batch_from_id(self, batch_id: str) -> Dict:
209228

210229
batch = None # Store the batch
211230
async for item in self.batch_container.query_items(
212-
query=query, parameters=params
231+
query=query, parameters=params, partition_key=batch_id
213232
):
214233
batch = item # Assign the batch to the variable
215234

@@ -335,11 +354,14 @@ async def add_file_log(
335354
raise
336355

337356
async def update_batch_entry(
338-
self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int
357+
self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int,
358+
existing_batch: Optional[Dict] = None
339359
):
340-
"""Update batch status."""
360+
"""Update batch status. If existing_batch is provided, skip the re-fetch."""
341361
try:
342-
batch = await self.get_batch(user_id, batch_id)
362+
batch = existing_batch
363+
if batch is None:
364+
batch = await self.get_batch(user_id, batch_id)
343365
if not batch:
344366
raise ValueError("Batch not found")
345367

src/backend/common/database/database_factory.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,40 @@
99

1010
class DatabaseFactory:
1111
_instance: Optional[DatabaseBase] = None
12+
_lock: Optional[asyncio.Lock] = None
1213
_logger = AppLogger("DatabaseFactory")
1314

15+
@staticmethod
16+
def _get_lock() -> asyncio.Lock:
17+
if DatabaseFactory._lock is None:
18+
DatabaseFactory._lock = asyncio.Lock()
19+
return DatabaseFactory._lock
20+
1421
@staticmethod
1522
async def get_database():
23+
if DatabaseFactory._instance is not None:
24+
return DatabaseFactory._instance
25+
26+
async with DatabaseFactory._get_lock():
27+
# Double-check after acquiring the lock
28+
if DatabaseFactory._instance is not None:
29+
return DatabaseFactory._instance
1630

17-
config = Config() # Create an instance of Config
31+
config = Config() # Create an instance of Config
1832

19-
cosmos_db_client = CosmosDBClient(
20-
endpoint=config.cosmosdb_endpoint,
21-
credential=config.get_azure_credentials(),
22-
database_name=config.cosmosdb_database,
23-
batch_container=config.cosmosdb_batch_container,
24-
file_container=config.cosmosdb_file_container,
25-
log_container=config.cosmosdb_log_container,
26-
)
33+
cosmos_db_client = CosmosDBClient(
34+
endpoint=config.cosmosdb_endpoint,
35+
credential=config.get_azure_credentials(),
36+
database_name=config.cosmosdb_database,
37+
batch_container=config.cosmosdb_batch_container,
38+
file_container=config.cosmosdb_file_container,
39+
log_container=config.cosmosdb_log_container,
40+
)
2741

28-
await cosmos_db_client.initialize_cosmos()
42+
await cosmos_db_client.initialize_cosmos()
2943

30-
return cosmos_db_client
44+
DatabaseFactory._instance = cosmos_db_client
45+
return cosmos_db_client
3146

3247

3348
# Local testing of config and code

src/backend/common/services/batch_service.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,9 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
288288
self.logger.info("File uploaded to blob storage", filename=file.filename, batch_id=batch_id)
289289

290290
# Create file entry
291-
await self.database.add_file(batch_id, file_id, file.filename, blob_path)
292-
file_record = await self.database.get_file(file_id)
291+
file_record_obj = await self.database.add_file(batch_id, file_id, file.filename, blob_path)
292+
file_record_dict = getattr(file_record_obj, "dict", None)
293+
file_record = file_record_dict() if callable(file_record_dict) else file_record_obj
293294

294295
await self.database.add_file_log(
295296
UUID(file_id),
@@ -308,6 +309,7 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
308309
user_id,
309310
ProcessStatus.READY_TO_PROCESS,
310311
batch["file_count"],
312+
existing_batch=batch,
311313
)
312314
# Return response
313315
return {"batch": batch, "file": file_record}
@@ -318,7 +320,8 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
318320
batch.file_count = len(files)
319321
batch.updated_at = datetime.utcnow().isoformat()
320322
await self.database.update_batch_entry(
321-
batch_id, user_id, ProcessStatus.READY_TO_PROCESS, batch.file_count
323+
batch_id, user_id, ProcessStatus.READY_TO_PROCESS, batch.file_count,
324+
existing_batch=batch.dict(),
322325
)
323326
# Return response
324327
return {"batch": batch, "file": file_record}

src/tests/backend/common/database/cosmosdb_test.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -158,21 +158,22 @@ async def test_create_batch_exists(cosmos_db_client, mocker):
158158
user_id = "user_1"
159159
batch_id = uuid4()
160160

161-
# Mock container creation and get_batch
161+
# Mock container creation and read_item
162162
mock_batch_container = mock.MagicMock()
163163
mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container)
164164
mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError)
165165

166-
# Mock the get_batch method
167-
mock_get_batch = AsyncMock(return_value=BatchRecord(
168-
batch_id=batch_id,
169-
user_id=user_id,
170-
file_count=0,
171-
created_at=datetime.now(timezone.utc),
172-
updated_at=datetime.now(timezone.utc),
173-
status=ProcessStatus.READY_TO_PROCESS
174-
))
175-
mocker.patch.object(cosmos_db_client, 'get_batch', mock_get_batch)
166+
# Mock read_item to return the existing batch record
167+
existing_batch = {
168+
"id": str(batch_id),
169+
"batch_id": str(batch_id),
170+
"user_id": user_id,
171+
"file_count": 0,
172+
"created_at": datetime.now(timezone.utc).isoformat(),
173+
"updated_at": datetime.now(timezone.utc).isoformat(),
174+
"status": ProcessStatus.READY_TO_PROCESS,
175+
}
176+
mock_batch_container.read_item = AsyncMock(return_value=existing_batch)
176177

177178
# Call the method
178179
batch = await cosmos_db_client.create_batch(user_id, batch_id)
@@ -182,7 +183,9 @@ async def test_create_batch_exists(cosmos_db_client, mocker):
182183
assert batch.user_id == user_id
183184
assert batch.status == ProcessStatus.READY_TO_PROCESS
184185

185-
mock_get_batch.assert_called_once_with(user_id, str(batch_id))
186+
mock_batch_container.read_item.assert_called_once_with(
187+
item=str(batch_id), partition_key=str(batch_id)
188+
)
186189

187190

188191
@pytest.mark.asyncio
@@ -404,7 +407,7 @@ async def test_get_batch(cosmos_db_client, mocker):
404407
}
405408

406409
# We define the async generator function that will yield the expected batch
407-
async def mock_query_items(query, parameters):
410+
async def mock_query_items(query, parameters, **kwargs):
408411
yield expected_batch
409412

410413
# Assign the async generator to query_items mock
@@ -422,6 +425,7 @@ async def mock_query_items(query, parameters):
422425
{"name": "@batch_id", "value": batch_id},
423426
{"name": "@user_id", "value": user_id},
424427
],
428+
partition_key=batch_id,
425429
)
426430

427431

@@ -468,8 +472,8 @@ async def test_get_file(cosmos_db_client, mocker):
468472
"blob_path": "/path/to/file"
469473
}
470474

471-
# We define the async generator function that will yield the expected batch
472-
async def mock_query_items(query, parameters):
475+
# We define the async generator function that will yield the expected file
476+
async def mock_query_items(query, parameters, **kwargs):
473477
yield expected_file
474478

475479
# Assign the async generator to query_items mock
@@ -594,7 +598,7 @@ async def test_get_batch_from_id(cosmos_db_client, mocker):
594598
}
595599

596600
# Define the async generator function that will yield the expected batch
597-
async def mock_query_items(query, parameters):
601+
async def mock_query_items(query, parameters, **kwargs):
598602
yield expected_batch
599603

600604
# Assign the async generator to query_items mock

0 commit comments

Comments
 (0)