Skip to content

Commit d7c6779

Browse files
Enhance CosmosDBClient and DatabaseFactory for improved error handling and concurrency
- Added asyncio support and a lock mechanism in DatabaseFactory to ensure thread safety. - Implemented retry logic with backoff for reading existing batches in CosmosDBClient. - Updated batch service to handle existing batch entries more efficiently.
1 parent 20ec1ce commit d7c6779

3 files changed

Lines changed: 52 additions & 24 deletions

File tree

src/backend/common/database/cosmosdb.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import asyncio
12
from datetime import datetime, timezone
23
from typing import Dict, List, Optional
34
from uuid import UUID, uuid4
45

56
from azure.cosmos.aio import CosmosClient
67
from azure.cosmos.aio._database import DatabaseProxy
78
from azure.cosmos.exceptions import (
8-
CosmosResourceExistsError
9+
CosmosResourceExistsError,
10+
CosmosResourceNotFoundError,
911
)
1012

1113
from common.database.database_base import DatabaseBase
@@ -85,9 +87,21 @@ async def create_batch(self, user_id: str, batch_id: UUID) -> BatchRecord:
8587
await self.batch_container.create_item(body=batch.dict())
8688
return batch
8789
except CosmosResourceExistsError:
88-
self.logger.info(f"Batch with ID {batch_id} already exists")
89-
batchexists = await self.get_batch(user_id, str(batch_id))
90-
return batchexists
90+
self.logger.info("Batch already exists, reading existing record", batch_id=str(batch_id))
91+
# Retry read with backoff to handle replication lag after 409 conflict
92+
for attempt in range(3):
93+
try:
94+
batchexists = await self.batch_container.read_item(
95+
item=str(batch_id), partition_key=str(batch_id)
96+
)
97+
if batchexists.get("user_id") != user_id:
98+
raise ValueError("Batch belongs to a different user")
99+
return batchexists
100+
except CosmosResourceNotFoundError:
101+
if attempt < 2:
102+
await asyncio.sleep(0.5 * (attempt + 1))
103+
else:
104+
return batch
91105

92106
except Exception as e:
93107
self.logger.error("Failed to create batch", error=str(e))
@@ -158,7 +172,7 @@ async def get_batch(self, user_id: str, batch_id: str) -> Optional[Dict]:
158172
]
159173
batch = None
160174
async for item in self.batch_container.query_items(
161-
query=query, parameters=params
175+
query=query, parameters=params, partition_key=batch_id
162176
):
163177
batch = item
164178

@@ -173,7 +187,7 @@ async def get_file(self, file_id: str) -> Optional[Dict]:
173187
params = [{"name": "@file_id", "value": file_id}]
174188
file_entry = None
175189
async for item in self.file_container.query_items(
176-
query=query, parameters=params
190+
query=query, parameters=params, partition_key=file_id
177191
):
178192
file_entry = item
179193
return file_entry
@@ -209,7 +223,7 @@ async def get_batch_from_id(self, batch_id: str) -> Dict:
209223

210224
batch = None # Store the batch
211225
async for item in self.batch_container.query_items(
212-
query=query, parameters=params
226+
query=query, parameters=params, partition_key=batch_id
213227
):
214228
batch = item # Assign the batch to the variable
215229

@@ -335,11 +349,14 @@ async def add_file_log(
335349
raise
336350

337351
async def update_batch_entry(
338-
self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int
352+
self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int,
353+
existing_batch: Optional[Dict] = None
339354
):
340-
"""Update batch status."""
355+
"""Update batch status. If existing_batch is provided, skip the re-fetch."""
341356
try:
342-
batch = await self.get_batch(user_id, batch_id)
357+
batch = existing_batch
358+
if batch is None:
359+
batch = await self.get_batch(user_id, batch_id)
343360
if not batch:
344361
raise ValueError("Batch not found")
345362

src/backend/common/database/database_factory.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,34 @@
99

1010
class DatabaseFactory:
1111
_instance: Optional[DatabaseBase] = None
12+
_lock: asyncio.Lock = asyncio.Lock()
1213
_logger = AppLogger("DatabaseFactory")
1314

1415
@staticmethod
1516
async def get_database():
17+
if DatabaseFactory._instance is not None:
18+
return DatabaseFactory._instance
1619

17-
config = Config() # Create an instance of Config
20+
async with DatabaseFactory._lock:
21+
# Double-check after acquiring the lock
22+
if DatabaseFactory._instance is not None:
23+
return DatabaseFactory._instance
1824

19-
cosmos_db_client = CosmosDBClient(
20-
endpoint=config.cosmosdb_endpoint,
21-
credential=config.get_azure_credentials(),
22-
database_name=config.cosmosdb_database,
23-
batch_container=config.cosmosdb_batch_container,
24-
file_container=config.cosmosdb_file_container,
25-
log_container=config.cosmosdb_log_container,
26-
)
25+
config = Config() # Create an instance of Config
2726

28-
await cosmos_db_client.initialize_cosmos()
27+
cosmos_db_client = CosmosDBClient(
28+
endpoint=config.cosmosdb_endpoint,
29+
credential=config.get_azure_credentials(),
30+
database_name=config.cosmosdb_database,
31+
batch_container=config.cosmosdb_batch_container,
32+
file_container=config.cosmosdb_file_container,
33+
log_container=config.cosmosdb_log_container,
34+
)
2935

30-
return cosmos_db_client
36+
await cosmos_db_client.initialize_cosmos()
37+
38+
DatabaseFactory._instance = cosmos_db_client
39+
return cosmos_db_client
3140

3241

3342
# Local testing of config and code

src/backend/common/services/batch_service.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
287287
)
288288

289289
# Create file entry
290-
await self.database.add_file(batch_id, file_id, file.filename, blob_path)
291-
file_record = await self.database.get_file(file_id)
290+
file_record_obj = await self.database.add_file(batch_id, file_id, file.filename, blob_path)
291+
file_record = file_record_obj.dict() if hasattr(file_record_obj, 'dict') else file_record_obj
292292

293293
await self.database.add_file_log(
294294
UUID(file_id),
@@ -307,6 +307,7 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
307307
user_id,
308308
ProcessStatus.READY_TO_PROCESS,
309309
batch["file_count"],
310+
existing_batch=batch,
310311
)
311312
# Return response
312313
return {"batch": batch, "file": file_record}
@@ -317,7 +318,8 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi
317318
batch.file_count = len(files)
318319
batch.updated_at = datetime.utcnow().isoformat()
319320
await self.database.update_batch_entry(
320-
batch_id, user_id, ProcessStatus.READY_TO_PROCESS, batch.file_count
321+
batch_id, user_id, ProcessStatus.READY_TO_PROCESS, batch.file_count,
322+
existing_batch=batch.dict(),
321323
)
322324
# Return response
323325
return {"batch": batch, "file": file_record}

0 commit comments

Comments
 (0)