1+ import asyncio
12from datetime import datetime , timezone
23from typing import Dict , List , Optional
34from uuid import UUID , uuid4
45
56from azure .cosmos .aio import CosmosClient
67from azure .cosmos .aio ._database import DatabaseProxy
78from azure .cosmos .exceptions import (
8- CosmosResourceExistsError
9+ CosmosResourceExistsError ,
10+ CosmosResourceNotFoundError ,
911)
1012
1113from common .database .database_base import DatabaseBase
@@ -85,9 +87,26 @@ async def create_batch(self, user_id: str, batch_id: UUID) -> BatchRecord:
8587 await self .batch_container .create_item (body = batch .dict ())
8688 return batch
8789 except CosmosResourceExistsError :
88- self .logger .info (f"Batch with ID { batch_id } already exists" )
89- batchexists = await self .get_batch (user_id , str (batch_id ))
90- return batchexists
90+ self .logger .info ("Batch already exists, reading existing record" , batch_id = str (batch_id ))
91+ # Retry read with backoff to handle replication lag after 409 conflict
92+ for attempt in range (3 ):
93+ try :
94+ batchexists = await self .batch_container .read_item (
95+ item = str (batch_id ), partition_key = str (batch_id )
96+ )
97+ if batchexists .get ("user_id" ) != user_id :
98+ self .logger .error ("Batch belongs to a different user" , batch_id = str (batch_id ))
99+ raise CosmosResourceNotFoundError (message = "Batch not found" )
100+ self .logger .info ("Returning existing batch record" , batch_id = str (batch_id ))
101+ return BatchRecord .fromdb (batchexists )
102+ except CosmosResourceNotFoundError :
103+ if attempt < 2 :
104+ self .logger .info ("Batch read returned 404 after conflict, retrying" , batch_id = str (batch_id ), attempt = attempt + 1 )
105+ await asyncio .sleep (0.5 * (attempt + 1 ))
106+ else :
107+ raise RuntimeError (
108+ f"Batch { batch_id } already exists but could not be read after retries"
109+ )
91110
92111 except Exception as e :
93112 self .logger .error ("Failed to create batch" , error = str (e ))
@@ -158,7 +177,7 @@ async def get_batch(self, user_id: str, batch_id: str) -> Optional[Dict]:
158177 ]
159178 batch = None
160179 async for item in self .batch_container .query_items (
161- query = query , parameters = params
180+ query = query , parameters = params , partition_key = batch_id
162181 ):
163182 batch = item
164183
@@ -173,7 +192,7 @@ async def get_file(self, file_id: str) -> Optional[Dict]:
173192 params = [{"name" : "@file_id" , "value" : file_id }]
174193 file_entry = None
175194 async for item in self .file_container .query_items (
176- query = query , parameters = params
195+ query = query , parameters = params , partition_key = file_id
177196 ):
178197 file_entry = item
179198 return file_entry
@@ -209,7 +228,7 @@ async def get_batch_from_id(self, batch_id: str) -> Dict:
209228
210229 batch = None # Store the batch
211230 async for item in self .batch_container .query_items (
212- query = query , parameters = params
231+ query = query , parameters = params , partition_key = batch_id
213232 ):
214233 batch = item # Assign the batch to the variable
215234
@@ -335,11 +354,14 @@ async def add_file_log(
335354 raise
336355
337356 async def update_batch_entry (
338- self , batch_id : str , user_id : str , status : ProcessStatus , file_count : int
357+ self , batch_id : str , user_id : str , status : ProcessStatus , file_count : int ,
358+ existing_batch : Optional [Dict ] = None
339359 ):
340- """Update batch status."""
360+ """Update batch status. If existing_batch is provided, skip the re-fetch. """
341361 try :
342- batch = await self .get_batch (user_id , batch_id )
362+ batch = existing_batch
363+ if batch is None :
364+ batch = await self .get_batch (user_id , batch_id )
343365 if not batch :
344366 raise ValueError ("Batch not found" )
345367
0 commit comments