diff --git a/.github/workflows/azd-template-validation.yml b/.github/workflows/azd-template-validation.yml new file mode 100644 index 00000000..397607fe --- /dev/null +++ b/.github/workflows/azd-template-validation.yml @@ -0,0 +1,46 @@ +name: AZD Template Validation + +on: + schedule: + - cron: '30 1 * * 4' # Every Thursday at 7:00 AM IST (1:30 AM UTC) + workflow_dispatch: + +permissions: + contents: read + id-token: write + pull-requests: write + +jobs: + template_validation: + runs-on: ubuntu-latest + name: azd template validation + environment: production + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set timestamp + run: echo "HHMM=$(date -u +'%H%M')" >> $GITHUB_ENV + + - name: Validate Azure Template + id: validation + uses: microsoft/template-validation-action@v0.4.3 + with: + validateAzd: ${{ vars.TEMPLATE_VALIDATE_AZD }} + validateTests: ${{ vars.TEMPLATE_VALIDATE_TESTS }} + useDevContainer: ${{ vars.TEMPLATE_USE_DEV_CONTAINER }} + + env: + AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + AZURE_ENV_NAME: azd-${{ secrets.AZURE_ENV_NAME }}-${{ env.HHMM }} + AZURE_LOCATION: ${{ secrets.AZURE_LOCATION }} + AZURE_ENV_AI_SERVICE_LOCATION: ${{ secrets.AZURE_AI_DEPLOYMENT_LOCATION || secrets.AZURE_LOCATION }} + AZURE_ENV_MODEL_CAPACITY: 1 + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + AZURE_DEV_COLLECT_TELEMETRY: ${{ vars.AZURE_DEV_COLLECT_TELEMETRY }} + + - name: Print result + shell: bash + run: cat "${{ steps.validation.outputs.resultFile }}" \ No newline at end of file diff --git a/.github/workflows/azure-dev.yml b/.github/workflows/azure-dev.yml index 7e5a6f57..2b887434 100644 --- a/.github/workflows/azure-dev.yml +++ b/.github/workflows/azure-dev.yml @@ -1,37 +1,65 @@ -name: Azure Template Validation -on: - workflow_dispatch: - -permissions: - contents: read - id-token: write - pull-requests: write -jobs: - template_validation_job: - runs-on: ubuntu-latest +name: Azure Dev Deploy + +on: + workflow_dispatch: + +permissions: + contents: read + id-token: write + +jobs: + deploy: + runs-on: ubuntu-latest environment: production - name: Template validation - steps: - # Step 1: Checkout the code from your repository - - name: Checkout code - uses: actions/checkout@v6 - # Step 2: Validate the Azure template using microsoft/template-validation-action - - name: Validate Azure Template - uses: microsoft/template-validation-action@v0.4.3 - with: - validateAzd: true - useDevContainer: false - id: validation - env: - AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} - AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} - AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} - AZURE_ENV_NAME: ${{ secrets.AZURE_ENV_NAME }} - AZURE_LOCATION: ${{ secrets.AZURE_LOCATION }} - AZURE_AI_DEPLOYMENT_LOCATION : ${{ secrets.AZURE_AI_DEPLOYMENT_LOCATION }} - AZURE_ENV_MODEL_CAPACITY : 1 - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - AZURE_DEV_COLLECT_TELEMETRY: ${{ vars.AZURE_DEV_COLLECT_TELEMETRY }} - # Step 3: Print the result of the validation - - name: Print result - run: cat ${{ steps.validation.outputs.resultFile }} + env: + AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + AZURE_ENV_NAME: ${{ secrets.AZURE_ENV_NAME }} + AZURE_LOCATION: ${{ secrets.AZURE_LOCATION }} + AZURE_AI_DEPLOYMENT_LOCATION: ${{ secrets.AZURE_AI_DEPLOYMENT_LOCATION || secrets.AZURE_LOCATION }} + AZURE_ENV_MODEL_CAPACITY: 1 + AZURE_DEV_COLLECT_TELEMETRY: ${{ vars.AZURE_DEV_COLLECT_TELEMETRY }} + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set timestamp and env name + run: | + HHMM=$(date -u +'%H%M') + echo "AZURE_ENV_NAME=azd-${{ vars.AZURE_ENV_NAME }}-${HHMM}" >> $GITHUB_ENV + + - name: Install azd + uses: Azure/setup-azd@v2 + + - name: Login to Azure + uses: azure/login@v2 + with: + client-id: ${{ secrets.AZURE_CLIENT_ID }} + tenant-id: ${{ secrets.AZURE_TENANT_ID }} + subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + + - name: Login to AZD + shell: bash + run: | + azd auth login \ + --client-id "$AZURE_CLIENT_ID" \ + --federated-credential-provider "github" \ + --tenant-id "$AZURE_TENANT_ID" + + - name: Provision and deploy + shell: bash + run: | + set -e + + if ! azd env select "$AZURE_ENV_NAME"; then + azd env new "$AZURE_ENV_NAME" --subscription "$AZURE_SUBSCRIPTION_ID" --location "$AZURE_LOCATION" --no-prompt + fi + + azd config set defaults.subscription "$AZURE_SUBSCRIPTION_ID" + azd env set AZURE_SUBSCRIPTION_ID "$AZURE_SUBSCRIPTION_ID" + azd env set AZURE_LOCATION "$AZURE_LOCATION" + azd env set AZURE_ENV_AI_SERVICE_LOCATION "${AZURE_AI_DEPLOYMENT_LOCATION:-$AZURE_LOCATION}" + azd env set AZURE_ENV_MODEL_CAPACITY "$AZURE_ENV_MODEL_CAPACITY" + + azd up --no-prompt diff --git a/infra/main.json b/infra/main.json index f98bb0c5..09e37656 100644 --- a/infra/main.json +++ b/infra/main.json @@ -5,8 +5,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "3093757051086668797" + "version": "0.42.1.51946", + "templateHash": "7222423000870488333" }, "name": "Modernize Your Code Solution Accelerator", "description": "CSA CTO Gold Standard Solution Accelerator for Modernize Your Code. \r\n" @@ -5052,8 +5052,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "8663094775498995429" + "version": "0.42.1.51946", + "templateHash": "3406526791248457038" } }, "definitions": { @@ -12895,11 +12895,11 @@ }, "dependsOn": [ "applicationInsights", - "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').storageBlob)]", "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').ods)]", - "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').agentSvc)]", "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').monitor)]", "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').oms)]", + "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').storageBlob)]", + "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').agentSvc)]", "dataCollectionEndpoint", "logAnalyticsWorkspace", "virtualNetwork" @@ -25611,8 +25611,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "17285204072656433491" + "version": "0.42.1.51946", + "templateHash": "16969185198334420434" }, "name": "AI Services and Project Module", "description": "This module creates an AI Services resource and an AI Foundry project within it. It supports private networking, OpenAI deployments, and role assignments." @@ -26952,8 +26952,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "5121330425393020264" + "version": "0.42.1.51946", + "templateHash": "4140498216793917924" } }, "definitions": { @@ -27910,7 +27910,10 @@ "raiPolicyName": "[tryGet(coalesce(parameters('deployments'), createArray())[copyIndex()], 'raiPolicyName')]", "versionUpgradeOption": "[tryGet(coalesce(parameters('deployments'), createArray())[copyIndex()], 'versionUpgradeOption')]" }, - "sku": "[coalesce(tryGet(coalesce(parameters('deployments'), createArray())[copyIndex()], 'sku'), createObject('name', parameters('sku'), 'capacity', tryGet(parameters('sku'), 'capacity'), 'tier', tryGet(parameters('sku'), 'tier'), 'size', tryGet(parameters('sku'), 'size'), 'family', tryGet(parameters('sku'), 'family')))]" + "sku": "[coalesce(tryGet(coalesce(parameters('deployments'), createArray())[copyIndex()], 'sku'), createObject('name', parameters('sku'), 'capacity', tryGet(parameters('sku'), 'capacity'), 'tier', tryGet(parameters('sku'), 'tier'), 'size', tryGet(parameters('sku'), 'size'), 'family', tryGet(parameters('sku'), 'family')))]", + "dependsOn": [ + "aiProject" + ] }, "cognitiveService_lock": { "condition": "[and(not(empty(coalesce(parameters('lock'), createObject()))), not(equals(tryGet(parameters('lock'), 'kind'), 'None')))]", @@ -28664,8 +28667,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "10989408486030617267" + "version": "0.42.1.51946", + "templateHash": "2422737205646151487" } }, "definitions": { @@ -28818,8 +28821,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "7933643033523871028" + "version": "0.42.1.51946", + "templateHash": "11911242767938607365" } }, "definitions": { @@ -29036,8 +29039,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "5121330425393020264" + "version": "0.42.1.51946", + "templateHash": "4140498216793917924" } }, "definitions": { @@ -29994,7 +29997,10 @@ "raiPolicyName": "[tryGet(coalesce(parameters('deployments'), createArray())[copyIndex()], 'raiPolicyName')]", "versionUpgradeOption": "[tryGet(coalesce(parameters('deployments'), createArray())[copyIndex()], 'versionUpgradeOption')]" }, - "sku": "[coalesce(tryGet(coalesce(parameters('deployments'), createArray())[copyIndex()], 'sku'), createObject('name', parameters('sku'), 'capacity', tryGet(parameters('sku'), 'capacity'), 'tier', tryGet(parameters('sku'), 'tier'), 'size', tryGet(parameters('sku'), 'size'), 'family', tryGet(parameters('sku'), 'family')))]" + "sku": "[coalesce(tryGet(coalesce(parameters('deployments'), createArray())[copyIndex()], 'sku'), createObject('name', parameters('sku'), 'capacity', tryGet(parameters('sku'), 'capacity'), 'tier', tryGet(parameters('sku'), 'tier'), 'size', tryGet(parameters('sku'), 'size'), 'family', tryGet(parameters('sku'), 'family')))]", + "dependsOn": [ + "aiProject" + ] }, "cognitiveService_lock": { "condition": "[and(not(empty(coalesce(parameters('lock'), createObject()))), not(equals(tryGet(parameters('lock'), 'kind'), 'None')))]", @@ -30748,8 +30754,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "10989408486030617267" + "version": "0.42.1.51946", + "templateHash": "2422737205646151487" } }, "definitions": { @@ -30902,8 +30908,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "7933643033523871028" + "version": "0.42.1.51946", + "templateHash": "11911242767938607365" } }, "definitions": { @@ -31917,8 +31923,8 @@ "dependsOn": [ "aiServices", "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').openAI)]", - "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').cognitiveServices)]", "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').aiServices)]", + "[format('avmPrivateDnsZones[{0}]', variables('dnsZoneIndex').cognitiveServices)]", "virtualNetwork" ] }, @@ -31974,8 +31980,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "3881415837167031634" + "version": "0.42.1.51946", + "templateHash": "522477461329004641" } }, "definitions": { @@ -40219,8 +40225,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "15962472891869337617" + "version": "0.42.1.51946", + "templateHash": "15355322017409205910" } }, "definitions": { @@ -44082,8 +44088,8 @@ "metadata": { "_generator": { "name": "bicep", - "version": "0.41.2.15936", - "templateHash": "17636459140972536078" + "version": "0.42.1.51946", + "templateHash": "4242598725709304634" } }, "definitions": { diff --git a/infra/modules/ai-foundry/dependencies.bicep b/infra/modules/ai-foundry/dependencies.bicep index f0d119cb..0d28aa1a 100644 --- a/infra/modules/ai-foundry/dependencies.bicep +++ b/infra/modules/ai-foundry/dependencies.bicep @@ -208,6 +208,9 @@ resource cognitiveService_deployments 'Microsoft.CognitiveServices/accounts/depl size: sku.?size family: sku.?family } + dependsOn: [ + aiProject + ] } ] diff --git a/src/backend/common/database/cosmosdb.py b/src/backend/common/database/cosmosdb.py index 7ff043f4..8c56286a 100644 --- a/src/backend/common/database/cosmosdb.py +++ b/src/backend/common/database/cosmosdb.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timezone from typing import Dict, List, Optional from uuid import UUID, uuid4 @@ -5,7 +6,8 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.aio._database import DatabaseProxy from azure.cosmos.exceptions import ( - CosmosResourceExistsError + CosmosResourceExistsError, + CosmosResourceNotFoundError, ) from common.database.database_base import DatabaseBase @@ -85,9 +87,28 @@ async def create_batch(self, user_id: str, batch_id: UUID) -> BatchRecord: await self.batch_container.create_item(body=batch.dict()) return batch except CosmosResourceExistsError: - self.logger.info(f"Batch with ID {batch_id} already exists") - batchexists = await self.get_batch(user_id, str(batch_id)) - return batchexists + self.logger.info("Batch already exists, reading existing record", batch_id=str(batch_id)) + # Retry read with backoff to handle replication lag after 409 conflict + for attempt in range(3): + try: + batchexists = await self.batch_container.read_item( + item=str(batch_id), partition_key=str(batch_id) + ) + except CosmosResourceNotFoundError: + if attempt < 2: + self.logger.info("Batch read returned 404 after conflict, retrying", batch_id=str(batch_id), attempt=attempt + 1) + await asyncio.sleep(0.5 * (attempt + 1)) + continue + raise RuntimeError( + f"Batch {batch_id} already exists but could not be read after retries" + ) + + if batchexists.get("user_id") != user_id: + self.logger.error("Batch belongs to a different user", batch_id=str(batch_id)) + raise PermissionError("Batch not found") + + self.logger.info("Returning existing batch record", batch_id=str(batch_id)) + return BatchRecord.fromdb(batchexists) except Exception as e: self.logger.error("Failed to create batch", error=str(e)) @@ -158,7 +179,7 @@ async def get_batch(self, user_id: str, batch_id: str) -> Optional[Dict]: ] batch = None async for item in self.batch_container.query_items( - query=query, parameters=params + query=query, parameters=params, partition_key=batch_id ): batch = item @@ -173,7 +194,7 @@ async def get_file(self, file_id: str) -> Optional[Dict]: params = [{"name": "@file_id", "value": file_id}] file_entry = None async for item in self.file_container.query_items( - query=query, parameters=params + query=query, parameters=params, partition_key=file_id ): file_entry = item return file_entry @@ -209,7 +230,7 @@ async def get_batch_from_id(self, batch_id: str) -> Dict: batch = None # Store the batch async for item in self.batch_container.query_items( - query=query, parameters=params + query=query, parameters=params, partition_key=batch_id ): batch = item # Assign the batch to the variable @@ -335,11 +356,14 @@ async def add_file_log( raise async def update_batch_entry( - self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int + self, batch_id: str, user_id: str, status: ProcessStatus, file_count: int, + existing_batch: Optional[Dict] = None ): - """Update batch status.""" + """Update batch status. If existing_batch is provided, skip the re-fetch.""" try: - batch = await self.get_batch(user_id, batch_id) + batch = existing_batch + if batch is None: + batch = await self.get_batch(user_id, batch_id) if not batch: raise ValueError("Batch not found") diff --git a/src/backend/common/database/database_factory.py b/src/backend/common/database/database_factory.py index c2f7de9d..bbc84941 100644 --- a/src/backend/common/database/database_factory.py +++ b/src/backend/common/database/database_factory.py @@ -9,25 +9,40 @@ class DatabaseFactory: _instance: Optional[DatabaseBase] = None + _lock: Optional[asyncio.Lock] = None _logger = AppLogger("DatabaseFactory") + @staticmethod + def _get_lock() -> asyncio.Lock: + if DatabaseFactory._lock is None: + DatabaseFactory._lock = asyncio.Lock() + return DatabaseFactory._lock + @staticmethod async def get_database(): + if DatabaseFactory._instance is not None: + return DatabaseFactory._instance + + async with DatabaseFactory._get_lock(): + # Double-check after acquiring the lock + if DatabaseFactory._instance is not None: + return DatabaseFactory._instance - config = Config() # Create an instance of Config + config = Config() # Create an instance of Config - cosmos_db_client = CosmosDBClient( - endpoint=config.cosmosdb_endpoint, - credential=config.get_azure_credentials(), - database_name=config.cosmosdb_database, - batch_container=config.cosmosdb_batch_container, - file_container=config.cosmosdb_file_container, - log_container=config.cosmosdb_log_container, - ) + cosmos_db_client = CosmosDBClient( + endpoint=config.cosmosdb_endpoint, + credential=config.get_azure_credentials(), + database_name=config.cosmosdb_database, + batch_container=config.cosmosdb_batch_container, + file_container=config.cosmosdb_file_container, + log_container=config.cosmosdb_log_container, + ) - await cosmos_db_client.initialize_cosmos() + await cosmos_db_client.initialize_cosmos() - return cosmos_db_client + DatabaseFactory._instance = cosmos_db_client + return cosmos_db_client # Local testing of config and code diff --git a/src/backend/common/services/batch_service.py b/src/backend/common/services/batch_service.py index 828c16d5..efce27fa 100644 --- a/src/backend/common/services/batch_service.py +++ b/src/backend/common/services/batch_service.py @@ -288,8 +288,9 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi self.logger.info("File uploaded to blob storage", filename=file.filename, batch_id=batch_id) # Create file entry - await self.database.add_file(batch_id, file_id, file.filename, blob_path) - file_record = await self.database.get_file(file_id) + file_record_obj = await self.database.add_file(UUID(batch_id), UUID(file_id), file.filename, blob_path) + file_record_dict = getattr(file_record_obj, "dict", None) + file_record = file_record_dict() if callable(file_record_dict) else file_record_obj await self.database.add_file_log( UUID(file_id), @@ -308,6 +309,7 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi user_id, ProcessStatus.READY_TO_PROCESS, batch["file_count"], + existing_batch=batch, ) # Return response return {"batch": batch, "file": file_record} @@ -318,7 +320,8 @@ async def upload_file_to_batch(self, batch_id: str, user_id: str, file: UploadFi batch.file_count = len(files) batch.updated_at = datetime.utcnow().isoformat() await self.database.update_batch_entry( - batch_id, user_id, ProcessStatus.READY_TO_PROCESS, batch.file_count + batch_id, user_id, ProcessStatus.READY_TO_PROCESS, batch.file_count, + existing_batch=batch.dict(), ) # Return response return {"batch": batch, "file": file_record} diff --git a/src/tests/backend/common/database/cosmosdb_test.py b/src/tests/backend/common/database/cosmosdb_test.py index 3405d85e..faaaaa09 100644 --- a/src/tests/backend/common/database/cosmosdb_test.py +++ b/src/tests/backend/common/database/cosmosdb_test.py @@ -11,7 +11,7 @@ from uuid import uuid4 # noqa: E402 from azure.cosmos.aio import CosmosClient # noqa: E402 -from azure.cosmos.exceptions import CosmosResourceExistsError # noqa: E402 +from azure.cosmos.exceptions import CosmosResourceExistsError, CosmosResourceNotFoundError # noqa: E402 from common.database.cosmosdb import ( # noqa: E402 CosmosDBClient, @@ -158,21 +158,22 @@ async def test_create_batch_exists(cosmos_db_client, mocker): user_id = "user_1" batch_id = uuid4() - # Mock container creation and get_batch + # Mock container creation and read_item mock_batch_container = mock.MagicMock() mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) - # Mock the get_batch method - mock_get_batch = AsyncMock(return_value=BatchRecord( - batch_id=batch_id, - user_id=user_id, - file_count=0, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - status=ProcessStatus.READY_TO_PROCESS - )) - mocker.patch.object(cosmos_db_client, 'get_batch', mock_get_batch) + # Mock read_item to return the existing batch record + existing_batch = { + "id": str(batch_id), + "batch_id": str(batch_id), + "user_id": user_id, + "file_count": 0, + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + "status": ProcessStatus.READY_TO_PROCESS, + } + mock_batch_container.read_item = AsyncMock(return_value=existing_batch) # Call the method batch = await cosmos_db_client.create_batch(user_id, batch_id) @@ -182,7 +183,87 @@ async def test_create_batch_exists(cosmos_db_client, mocker): assert batch.user_id == user_id assert batch.status == ProcessStatus.READY_TO_PROCESS - mock_get_batch.assert_called_once_with(user_id, str(batch_id)) + mock_batch_container.read_item.assert_called_once_with( + item=str(batch_id), partition_key=str(batch_id) + ) + + +@pytest.mark.asyncio +async def test_create_batch_conflict_retry_on_404(cosmos_db_client, mocker): + """Test that read_item is retried when it returns 404 after a 409 conflict.""" + user_id = "user_1" + batch_id = uuid4() + + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) + + existing_batch = { + "id": str(batch_id), + "batch_id": str(batch_id), + "user_id": user_id, + "file_count": 0, + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + "status": ProcessStatus.READY_TO_PROCESS, + } + + # First call raises 404, second call succeeds + mock_batch_container.read_item = AsyncMock( + side_effect=[CosmosResourceNotFoundError(message="Not found"), existing_batch] + ) + + batch = await cosmos_db_client.create_batch(user_id, batch_id) + + assert batch.batch_id == batch_id + assert batch.user_id == user_id + assert mock_batch_container.read_item.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_batch_conflict_cross_user(cosmos_db_client, mocker): + """Test that a PermissionError is raised when the batch belongs to a different user.""" + user_id = "user_1" + batch_id = uuid4() + + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) + + existing_batch = { + "id": str(batch_id), + "batch_id": str(batch_id), + "user_id": "different_user", + "file_count": 0, + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + "status": ProcessStatus.READY_TO_PROCESS, + } + mock_batch_container.read_item = AsyncMock(return_value=existing_batch) + + with pytest.raises(PermissionError, match="Batch not found"): + await cosmos_db_client.create_batch(user_id, batch_id) + + +@pytest.mark.asyncio +async def test_create_batch_conflict_exhausted_retries(cosmos_db_client, mocker): + """Test that RuntimeError is raised when read_item returns 404 after all retries.""" + user_id = "user_1" + batch_id = uuid4() + + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) + + # All 3 attempts raise 404 + mock_batch_container.read_item = AsyncMock( + side_effect=CosmosResourceNotFoundError(message="Not found") + ) + + with pytest.raises(RuntimeError, match="already exists but could not be read after retries"): + await cosmos_db_client.create_batch(user_id, batch_id) + + assert mock_batch_container.read_item.call_count == 3 @pytest.mark.asyncio @@ -404,7 +485,7 @@ async def test_get_batch(cosmos_db_client, mocker): } # We define the async generator function that will yield the expected batch - async def mock_query_items(query, parameters): + async def mock_query_items(query, parameters, **kwargs): yield expected_batch # Assign the async generator to query_items mock @@ -422,6 +503,7 @@ async def mock_query_items(query, parameters): {"name": "@batch_id", "value": batch_id}, {"name": "@user_id", "value": user_id}, ], + partition_key=batch_id, ) @@ -468,8 +550,8 @@ async def test_get_file(cosmos_db_client, mocker): "blob_path": "/path/to/file" } - # We define the async generator function that will yield the expected batch - async def mock_query_items(query, parameters): + # We define the async generator function that will yield the expected file + async def mock_query_items(query, parameters, **kwargs): yield expected_file # Assign the async generator to query_items mock @@ -483,6 +565,8 @@ async def mock_query_items(query, parameters): assert file["status"] == ProcessStatus.READY_TO_PROCESS mock_file_container.query_items.assert_called_once() + call_kwargs = mock_file_container.query_items.call_args + assert call_kwargs.kwargs.get("partition_key") == file_id @pytest.mark.asyncio @@ -594,7 +678,7 @@ async def test_get_batch_from_id(cosmos_db_client, mocker): } # Define the async generator function that will yield the expected batch - async def mock_query_items(query, parameters): + async def mock_query_items(query, parameters, **kwargs): yield expected_batch # Assign the async generator to query_items mock @@ -608,6 +692,8 @@ async def mock_query_items(query, parameters): assert batch["status"] == ProcessStatus.READY_TO_PROCESS mock_batch_container.query_items.assert_called_once() + call_kwargs = mock_batch_container.query_items.call_args + assert call_kwargs.kwargs.get("partition_key") == batch_id @pytest.mark.asyncio