diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3f245b24..34a2f24d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -95,7 +95,7 @@ jobs: - name: Run Backend Tests with Coverage if: env.skip_backend_tests == 'false' run: | - cd src/tests/backend + cd src pytest --cov=. --cov-report=term-missing --cov-report=xml diff --git a/src/backend/common/database/database_base.py b/src/backend/common/database/database_base.py index 156d6bdf..66d36f42 100644 --- a/src/backend/common/database/database_base.py +++ b/src/backend/common/database/database_base.py @@ -16,55 +16,55 @@ class DatabaseBase(ABC): @abstractmethod async def initialize_cosmos(self) -> None: - """Initialize the cosmosdb client and create container if needed.""" - pass + """Initialize the cosmosdb client and create container if needed""" + pass # pragma: no cover @abstractmethod async def create_batch(self, user_id: str, batch_id: uuid.UUID) -> BatchRecord: - """Create a new conversion batch.""" - pass + """Create a new conversion batch""" + pass # pragma: no cover @abstractmethod async def get_file_logs(self, file_id: str) -> Dict: - """Retrieve all logs for a file.""" - pass + """Retrieve all logs for a file""" + pass # pragma: no cover @abstractmethod async def get_batch_from_id(self, batch_id: str) -> Dict: - """Retrieve all logs for a file.""" - pass + """Retrieve all logs for a file""" + pass # pragma: no cover @abstractmethod async def get_batch_files(self, batch_id: str) -> List[Dict]: - """Retrieve all files for a batch.""" - pass + """Retrieve all files for a batch""" + pass # pragma: no cover @abstractmethod async def delete_file_logs(self, file_id: str) -> None: - """Delete all logs for a file.""" - pass + """Delete all logs for a file""" + pass # pragma: no cover @abstractmethod async def get_user_batches(self, user_id: str) -> Dict: - """Retrieve all batches for a user.""" - pass + """Retrieve all batches for a user""" + pass # pragma: no cover @abstractmethod async def add_file( self, batch_id: uuid.UUID, file_id: uuid.UUID, file_name: str, storage_path: str ) -> FileRecord: - """Add a file entry to the database.""" - pass + """Add a file entry to the database""" + pass # pragma: no cover @abstractmethod async def get_batch(self, user_id: str, batch_id: str) -> Optional[Dict]: - """Retrieve a batch and its associated files.""" - pass + """Retrieve a batch and its associated files""" + pass # pragma: no cover @abstractmethod async def get_file(self, file_id: str) -> Optional[Dict]: - """Retrieve a file entry along with its logs.""" - pass + """Retrieve a file entry along with its logs""" + pass # pragma: no cover @abstractmethod async def add_file_log( @@ -76,39 +76,40 @@ async def add_file_log( agent_type: AgentType, author_role: AuthorRole, ) -> None: - """Log a file status update.""" - pass + """Log a file status update""" + pass # pragma: no cover @abstractmethod async def update_file(self, file_record: FileRecord) -> None: - """Update file record.""" - pass + """Update file record""" + pass # pragma: no cover @abstractmethod async def update_batch(self, batch_record: BatchRecord) -> BatchRecord: """Update a batch record""" + pass # pragma: no cover @abstractmethod async def delete_all(self, user_id: str) -> None: - """Delete all batches, files, and logs for a user.""" - pass + """Delete all batches, files, and logs for a user""" + pass # pragma: no cover @abstractmethod async def delete_batch(self, user_id: str, batch_id: str) -> None: - """Delete a batch along with its files and logs.""" - pass + """Delete a batch along with its files and logs""" + pass # pragma: no cover @abstractmethod async def delete_file(self, user_id: str, batch_id: str, file_id: str) -> None: - """Delete a file and its logs, and update batch file count.""" - pass + """Delete a file and its logs, and update batch file count""" + pass # pragma: no cover @abstractmethod async def get_batch_history(self, user_id: str, batch_id: str) -> List[Dict]: - """Retrieve all logs for a batch.""" - pass + """Retrieve all logs for a batch""" + pass # pragma: no cover @abstractmethod async def close(self) -> None: - """Close database connection.""" - pass + """Close database connection""" + pass # pragma: no cover diff --git a/src/backend/common/database/database_factory.py b/src/backend/common/database/database_factory.py index ee92677f..c2f7de9d 100644 --- a/src/backend/common/database/database_factory.py +++ b/src/backend/common/database/database_factory.py @@ -1,3 +1,4 @@ +import asyncio from typing import Optional from common.config.config import Config @@ -33,25 +34,20 @@ async def get_database(): # Note that you have to assign yourself data plane access to Cosmos in script for this to work locally. See # https://learn.microsoft.com/en-us/azure/cosmos-db/table/security/how-to-grant-data-plane-role-based-access?tabs=built-in-definition%2Ccsharp&pivots=azure-interface-cli # Note that your principal id is your entra object id for your user account. -if __name__ == "__main__": - # Example usage - import asyncio - - async def main(): - database = await DatabaseFactory.get_database() - # Use the database instance... - await database.initialize_cosmos() - await database.create_batch("mark1", "123e4567-e89b-12d3-a456-426614174000") - await database.add_file( - "123e4567-e89b-12d3-a456-426614174000", - "123e4567-e89b-12d3-a456-426614174001", - "q1_informix.sql", - "https://cmsamarktaylstor.blob.core.windows.net/cmsablob", - ) - tstbatch = await database.get_batch( - "mark1", "123e4567-e89b-12d3-a456-426614174000" - ) - print(tstbatch) - await database.close() +async def main(): + database = await DatabaseFactory.get_database() + await database.initialize_cosmos() + await database.create_batch("mark1", "123e4567-e89b-12d3-a456-426614174000") + await database.add_file( + "123e4567-e89b-12d3-a456-426614174000", + "123e4567-e89b-12d3-a456-426614174001", + "q1_informix.sql", + "https://cmsamarktaylstor.blob.core.windows.net/cmsablob", + ) + tstbatch = await database.get_batch("mark1", "123e4567-e89b-12d3-a456-426614174000") + print(tstbatch) + await database.close() + +if __name__ == "__main__": asyncio.run(main()) diff --git a/src/backend/common/storage/blob_factory.py b/src/backend/common/storage/blob_factory.py index fc855635..d20c2de8 100644 --- a/src/backend/common/storage/blob_factory.py +++ b/src/backend/common/storage/blob_factory.py @@ -1,3 +1,4 @@ +import asyncio from typing import Optional from common.config.config import Config # Load config @@ -31,15 +32,14 @@ async def close_storage() -> None: # Local testing of config and code -if __name__ == "__main__": - # Example usage - import asyncio +async def main(): + storage = await BlobStorageFactory.get_storage() + + # Use the storage instance + blob = await storage.get_file("q1_informix.sql") + print("Blob content:", blob) - async def main(): - storage = await BlobStorageFactory.get_storage() - # Use the storage instance... - blob = await storage.get_file("q1_informix.sql") - print(blob) - await BlobStorageFactory.close_storage() + await BlobStorageFactory.close_storage() +if __name__ == "__main__": asyncio.run(main()) diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index 9b9a37c0..c5d6b636 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -20,6 +20,7 @@ azure-functions # Development tools pytest +pytest-mock black pylint flake8 diff --git a/src/backend/sql_agents/helpers/agents_manager.py b/src/backend/sql_agents/helpers/agents_manager.py index 8767b796..af5d6365 100644 --- a/src/backend/sql_agents/helpers/agents_manager.py +++ b/src/backend/sql_agents/helpers/agents_manager.py @@ -2,7 +2,7 @@ import logging -from semantic_kernel.agents import AzureAIAgent # pylint: disable=E0611 +from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent # pylint: disable=E0611 from sql_agents.agents.agent_config import AgentBaseConfig from sql_agents.agents.fixer.setup import setup_fixer_agent diff --git a/src/backend/sql_agents/process_batch.py b/src/backend/sql_agents/process_batch.py index 132c574f..1434fba5 100644 --- a/src/backend/sql_agents/process_batch.py +++ b/src/backend/sql_agents/process_batch.py @@ -23,11 +23,10 @@ from fastapi import HTTPException -from semantic_kernel.agents import AzureAIAgent # pylint: disable=E0611 +from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent # pylint: disable=E0611 from semantic_kernel.contents import AuthorRole from semantic_kernel.exceptions.service_exceptions import ServiceResponseException - from sql_agents.agents.agent_config import AgentBaseConfig from sql_agents.convert_script import convert_script from sql_agents.helpers.agents_manager import SqlAgents diff --git a/src/tests/backend/app_test.py b/src/tests/backend/app_test.py new file mode 100644 index 00000000..610e36c3 --- /dev/null +++ b/src/tests/backend/app_test.py @@ -0,0 +1,33 @@ +from backend.app import create_app + +from fastapi import FastAPI + +from httpx import ASGITransport +from httpx import AsyncClient + +import pytest + + +@pytest.fixture +def app() -> FastAPI: + """Fixture to create a test app instance.""" + return create_app() + + +@pytest.mark.asyncio +async def test_health_check(app: FastAPI): + """Test the /health endpoint returns a healthy status.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + response = await ac.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +@pytest.mark.asyncio +async def test_backend_routes_exist(app: FastAPI): + """Ensure /api routes are available (smoke test).""" + # Check available routes include /api prefix from backend_router + routes = [route.path for route in app.router.routes] + backend_routes = [r for r in routes if r.startswith("/api")] + assert backend_routes, "No backend routes found under /api prefix" diff --git a/src/tests/backend/common/config/config_test.py b/src/tests/backend/common/config/config_test.py index 87531bbc..6984ae8f 100644 --- a/src/tests/backend/common/config/config_test.py +++ b/src/tests/backend/common/config/config_test.py @@ -1,62 +1,67 @@ -import unittest -from unittest.mock import patch - -# from config import Config -from common.config.config import Config - - -class TestConfigInitialization(unittest.TestCase): - @patch.dict( - "os.environ", - { - "AZURE_TENANT_ID": "test-tenant-id", - "AZURE_CLIENT_ID": "test-client-id", - "AZURE_CLIENT_SECRET": "test-client-secret", - "COSMOSDB_DATABASE": "test-database", - "COSMOSDB_BATCH_CONTAINER": "test-batch-container", - "COSMOSDB_FILE_CONTAINER": "test-file-container", - "COSMOSDB_LOG_CONTAINER": "test-log-container", - "AZURE_BLOB_CONTAINER_NAME": "test-blob-container-name", - "AZURE_BLOB_ACCOUNT_NAME": "test-blob-account-name", - }, - clear=True, - ) - def test_config_initialization(self): - """Test if all attributes are correctly assigned from environment variables.""" - config = Config() - - # Ensure every attribute is accessed - self.assertEqual(config.azure_tenant_id, "test-tenant-id") - self.assertEqual(config.azure_client_id, "test-client-id") - self.assertEqual(config.azure_client_secret, "test-client-secret") - - self.assertEqual(config.cosmosdb_endpoint, "test-cosmosdb-endpoint") - self.assertEqual(config.cosmosdb_database, "test-database") - self.assertEqual(config.cosmosdb_batch_container, "test-batch-container") - self.assertEqual(config.cosmosdb_file_container, "test-file-container") - self.assertEqual(config.cosmosdb_log_container, "test-log-container") - - self.assertEqual(config.azure_blob_container_name, "test-blob-container-name") - self.assertEqual(config.azure_blob_account_name, "test-blob-account-name") - - @patch.dict( - "os.environ", - { - "COSMOSDB_ENDPOINT": "test-cosmosdb-endpoint", - "COSMOSDB_DATABASE": "test-database", - "COSMOSDB_BATCH_CONTAINER": "test-batch-container", - "COSMOSDB_FILE_CONTAINER": "test-file-container", - "COSMOSDB_LOG_CONTAINER": "test-log-container", - }, - ) - def test_cosmosdb_config_initialization(self): - config = Config() - self.assertEqual(config.cosmosdb_endpoint, "test-cosmosdb-endpoint") - self.assertEqual(config.cosmosdb_database, "test-database") - self.assertEqual(config.cosmosdb_batch_container, "test-batch-container") - self.assertEqual(config.cosmosdb_file_container, "test-file-container") - self.assertEqual(config.cosmosdb_log_container, "test-log-container") - - -if __name__ == "__main__": - unittest.main() +import pytest + + +@pytest.fixture(autouse=True) +def clear_env(monkeypatch): + # Clear environment variables that might affect tests. + keys = [ + "AZURE_TENANT_ID", + "AZURE_CLIENT_ID", + "AZURE_CLIENT_SECRET", + "COSMOSDB_ENDPOINT", + "COSMOSDB_DATABASE", + "COSMOSDB_BATCH_CONTAINER", + "COSMOSDB_FILE_CONTAINER", + "COSMOSDB_LOG_CONTAINER", + "AZURE_BLOB_CONTAINER_NAME", + "AZURE_BLOB_ACCOUNT_NAME", + ] + for key in keys: + monkeypatch.delenv(key, raising=False) + + +def test_config_initialization(monkeypatch): + # Set the full configuration environment variables. + monkeypatch.setenv("AZURE_TENANT_ID", "test-tenant-id") + monkeypatch.setenv("AZURE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("AZURE_CLIENT_SECRET", "test-client-secret") + monkeypatch.setenv("COSMOSDB_ENDPOINT", "test-cosmosdb-endpoint") + monkeypatch.setenv("COSMOSDB_DATABASE", "test-database") + monkeypatch.setenv("COSMOSDB_BATCH_CONTAINER", "test-batch-container") + monkeypatch.setenv("COSMOSDB_FILE_CONTAINER", "test-file-container") + monkeypatch.setenv("COSMOSDB_LOG_CONTAINER", "test-log-container") + monkeypatch.setenv("AZURE_BLOB_CONTAINER_NAME", "test-blob-container-name") + monkeypatch.setenv("AZURE_BLOB_ACCOUNT_NAME", "test-blob-account-name") + + # Local import to avoid triggering circular imports during module collection. + from common.config.config import Config + config = Config() + + assert config.azure_tenant_id == "test-tenant-id" + assert config.azure_client_id == "test-client-id" + assert config.azure_client_secret == "test-client-secret" + assert config.cosmosdb_endpoint == "test-cosmosdb-endpoint" + assert config.cosmosdb_database == "test-database" + assert config.cosmosdb_batch_container == "test-batch-container" + assert config.cosmosdb_file_container == "test-file-container" + assert config.cosmosdb_log_container == "test-log-container" + assert config.azure_blob_container_name == "test-blob-container-name" + assert config.azure_blob_account_name == "test-blob-account-name" + + +def test_cosmosdb_config_initialization(monkeypatch): + # Set only cosmosdb-related environment variables. + monkeypatch.setenv("COSMOSDB_ENDPOINT", "test-cosmosdb-endpoint") + monkeypatch.setenv("COSMOSDB_DATABASE", "test-database") + monkeypatch.setenv("COSMOSDB_BATCH_CONTAINER", "test-batch-container") + monkeypatch.setenv("COSMOSDB_FILE_CONTAINER", "test-file-container") + monkeypatch.setenv("COSMOSDB_LOG_CONTAINER", "test-log-container") + + from common.config.config import Config + config = Config() + + assert config.cosmosdb_endpoint == "test-cosmosdb-endpoint" + assert config.cosmosdb_database == "test-database" + assert config.cosmosdb_batch_container == "test-batch-container" + assert config.cosmosdb_file_container == "test-file-container" + assert config.cosmosdb_log_container == "test-log-container" diff --git a/src/tests/backend/common/database/cosmosdb_test.py b/src/tests/backend/common/database/cosmosdb_test.py index 7ef364a6..df53fde1 100644 --- a/src/tests/backend/common/database/cosmosdb_test.py +++ b/src/tests/backend/common/database/cosmosdb_test.py @@ -1,618 +1,1117 @@ -import asyncio -import enum -import uuid -from datetime import datetime +import os +import sys +# Add backend directory to sys.path +sys.path.insert( + 0, + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..", "backend")), +) +from datetime import datetime, timezone # noqa: E402 +from unittest import mock # noqa: E402 +from unittest.mock import AsyncMock # noqa: E402 +from uuid import uuid4 # noqa: E402 + +from azure.cosmos.aio import CosmosClient # noqa: E402 +from azure.cosmos.exceptions import CosmosResourceExistsError # noqa: E402 + +from common.database.cosmosdb import ( # noqa: E402 + CosmosDBClient, +) +from common.models.api import ( # noqa: E402 + AgentType, + AuthorRole, + BatchRecord, + FileRecord, + LogType, + ProcessStatus, +) # noqa: E402 + +import pytest # noqa: E402 + +# Mocked data for the test +endpoint = "https://fake.cosmosdb.azure.com" +credential = "fake_credential" +database_name = "test_database" +batch_container = "batch_container" +file_container = "file_container" +log_container = "log_container" -from azure.cosmos import PartitionKey, exceptions -from common.database.cosmosdb import CosmosDBClient -from common.logger.app_logger import AppLogger -from common.models.api import ProcessStatus +@pytest.fixture +def cosmos_db_client(): + return CosmosDBClient( + endpoint=endpoint, + credential=credential, + database_name=database_name, + batch_container=batch_container, + file_container=file_container, + log_container=log_container, + ) -import pytest +@pytest.mark.asyncio +async def test_initialize_cosmos(cosmos_db_client, mocker): + # Mocking CosmosClient and its methods + mock_client = mocker.patch.object(CosmosClient, 'get_database_client', return_value=mock.MagicMock()) + mock_database = mock_client.return_value + + # Use AsyncMock for asynchronous methods + mock_batch_container = mock.MagicMock() + mock_file_container = mock.MagicMock() + mock_log_container = mock.MagicMock() + + # Use AsyncMock to mock asynchronous container creation + mock_database.create_container = AsyncMock(side_effect=[ + mock_batch_container, + mock_file_container, + mock_log_container + ]) + + # Call the initialize_cosmos method + await cosmos_db_client.initialize_cosmos() + + # Assert that the containers were created or fetched successfully + mock_database.create_container.assert_any_call(id=batch_container, partition_key=mock.ANY) + mock_database.create_container.assert_any_call(id=file_container, partition_key=mock.ANY) + mock_database.create_container.assert_any_call(id=log_container, partition_key=mock.ANY) + + # Check the client and containers were set + assert cosmos_db_client.client is not None + assert cosmos_db_client.batch_container == mock_batch_container + assert cosmos_db_client.file_container == mock_file_container + assert cosmos_db_client.log_container == mock_log_container -# --- Enums for Testing --- -class DummyProcessStatus(enum.Enum): - READY_TO_PROCESS = "READY" - PROCESSING = "PROCESSING" +@pytest.mark.asyncio +async def test_initialize_cosmos_with_error(cosmos_db_client, mocker): + # Mocking CosmosClient and its methods + mock_client = mocker.patch.object(CosmosClient, 'get_database_client', return_value=mock.MagicMock()) + mock_database = mock_client.return_value + + # Simulate a general exception during container creation + mock_database.create_container = AsyncMock(side_effect=Exception("Failed to create container")) -class DummyLogType(enum.Enum): - INFO = "INFO" - ERROR = "ERROR" + # Call the initialize_cosmos method and expect it to raise an error + with pytest.raises(Exception) as exc_info: + await cosmos_db_client.initialize_cosmos() + # Assert that the exception message matches the expected message + assert str(exc_info.value) == "Failed to create container" -@pytest.fixture(autouse=True) -def patch_enums(monkeypatch): - monkeypatch.setattr("common.models.api.ProcessStatus", DummyProcessStatus) - monkeypatch.setattr("common.models.api.LogType", DummyLogType) - - -# --- implementations to simulate Cosmos DB behavior --- -async def async_query_generator(items): - for item in items: - yield item - - -async def async_query_error_generator(*args, **kwargs): - raise Exception("Error in query") - if False: - yield - - -class DummyContainerClient: - def __init__(self, container_name): - self.container_name = container_name - self.created_items = [] - self.deleted_items = [] - self._query_items_func = None - - async def create_item(self, body): - self.created_items.append(body) - - async def replace_item(self, item, body): - return body - - async def delete_item(self, item, partition_key=None): - self.deleted_items.append((item, partition_key)) - - async def delete_items(self, key): - self.deleted_items.append(key) - - async def query_items(self, query, parameters): - if self._query_items_func: - async for item in self._query_items_func(query, parameters): - yield item - else: - if False: - yield - - def set_query_items(self, func): - self._query_items_func = func - - -class DummyDatabase: - def __init__(self, database_name): - self.database_name = database_name - self.containers = {} - - async def create_container(self, id, partition_key): - if id in self.containers: - raise exceptions.CosmosResourceExistsError(404, "Container exists") - container = DummyContainerClient(id) - self.containers[id] = container - return container - - def get_container_client(self, container_name): - return self.containers.get(container_name, DummyContainerClient(container_name)) - - -class DummyCosmosClient: - def __init__(self, url, credential): - self.url = url - self.credential = credential - self._database = DummyDatabase("dummy_db") - self.closed = False - - def get_database_client(self, database_name): - return self._database - - def close(self): - self.closed = True - - -class FakeCosmosDBClient(CosmosDBClient): - async def _async_init( - self, - endpoint: str, - credential: any, - database_name: str, - batch_container: str, - file_container: str, - log_container: str, - ): - self.endpoint = endpoint - self.credential = credential - self.database_name = database_name - self.batch_container_name = batch_container - self.file_container_name = file_container - self.log_container_name = log_container - self.logger = AppLogger("CosmosDB") - self.client = DummyCosmosClient(endpoint, credential) - db = self.client.get_database_client(database_name) - self.batch_container = await db.create_container( - batch_container, PartitionKey(path="/batch_id") - ) - self.file_container = await db.create_container( - file_container, PartitionKey(path="/file_id") - ) - self.log_container = await db.create_container( - log_container, PartitionKey(path="/log_id") - ) - - @classmethod - async def create( - cls, - endpoint, - credential, - database_name, - batch_container, - file_container, - log_container, - ): - instance = cls.__new__(cls) - await instance._async_init( - endpoint, - credential, - database_name, - batch_container, - file_container, - log_container, - ) - return instance - - # Minimal implementations for abstract methods not under test. - async def delete_file_logs(self, file_id: str) -> None: - await self.log_container.delete_items(file_id) - - async def log_batch_status( - self, batch_id: str, status: ProcessStatus, processed_files: int - ) -> None: - return - - -# --- Fixture --- -@pytest.fixture -def cosmosdb_client(event_loop): - client = event_loop.run_until_complete( - FakeCosmosDBClient.create( - endpoint="dummy_endpoint", - credential="dummy_credential", - database_name="dummy_db", - batch_container="batch", - file_container="file", - log_container="log", - ) - ) - return client +@pytest.mark.asyncio +async def test_initialize_cosmos_container_exists_error(cosmos_db_client, mocker): + # Mocking CosmosClient and its methods + mock_client = mocker.patch.object(CosmosClient, 'get_database_client', return_value=mock.MagicMock()) + mock_database = mock_client.return_value -# --- Test Cases --- + # Simulating CosmosResourceExistsError for container creation + mock_database.create_container = AsyncMock(side_effect=CosmosResourceExistsError) + # Use AsyncMock for asynchronous methods + mock_batch_container = mock.MagicMock() + mock_file_container = mock.MagicMock() + mock_log_container = mock.MagicMock() -@pytest.mark.asyncio -async def test_initialization_success(cosmosdb_client): - assert cosmosdb_client.client is not None - assert cosmosdb_client.batch_container is not None - assert cosmosdb_client.file_container is not None - assert cosmosdb_client.log_container is not None + # Use AsyncMock to mock asynchronous container creation + mock_database.create_container = AsyncMock(side_effect=[ + mock_batch_container, + mock_file_container, + mock_log_container + ]) + + # Call the initialize_cosmos method + await cosmos_db_client.initialize_cosmos() + + # Assert that the container creation method was called with the correct arguments + mock_database.create_container.assert_any_call(id='batch_container', partition_key=mock.ANY) + mock_database.create_container.assert_any_call(id='file_container', partition_key=mock.ANY) + mock_database.create_container.assert_any_call(id='log_container', partition_key=mock.ANY) + + # Check that existing containers are returned (mocked containers) + assert cosmos_db_client.batch_container == mock_batch_container + assert cosmos_db_client.file_container == mock_file_container + assert cosmos_db_client.log_container == mock_log_container @pytest.mark.asyncio -async def test_init_error(monkeypatch): - async def fake_async_init(*args, **kwargs): - raise Exception("client error") +async def test_create_batch_new(cosmos_db_client, mocker): + user_id = "user_1" + batch_id = uuid4() - monkeypatch.setattr(FakeCosmosDBClient, "_async_init", fake_async_init) - with pytest.raises(Exception) as exc_info: - await FakeCosmosDBClient.create("dummy", "dummy", "dummy", "a", "b", "c") - assert "client error" in str(exc_info.value) + # Mock container creation + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + + # Mock the method to return the batch + mock_batch_container.create_item = AsyncMock(return_value=None) + + # Call the method + batch = await cosmos_db_client.create_batch(user_id, batch_id) + + # Assert that the batch is created + assert batch.batch_id == batch_id + assert batch.user_id == user_id + assert batch.status == ProcessStatus.READY_TO_PROCESS + + mock_batch_container.create_item.assert_called_once_with(body=batch.dict()) @pytest.mark.asyncio -async def test_get_or_create_container_existing(monkeypatch, cosmosdb_client): - db = DummyDatabase("dummy_db") - existing = DummyContainerClient("existing") - db.containers["existing"] = existing +async def test_create_batch_exists(cosmos_db_client, mocker): + user_id = "user_1" + batch_id = uuid4() + + # Mock container creation and get_batch + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.create_item = AsyncMock(side_effect=CosmosResourceExistsError) - async def fake_create_container(id, partition_key): - raise exceptions.CosmosResourceExistsError(404, "Container exists") + # Mock the get_batch method + mock_get_batch = AsyncMock(return_value=BatchRecord( + batch_id=batch_id, + user_id=user_id, + file_count=0, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + status=ProcessStatus.READY_TO_PROCESS + )) + mocker.patch.object(cosmos_db_client, 'get_batch', mock_get_batch) - monkeypatch.setattr(db, "create_container", fake_create_container) - monkeypatch.setattr(db, "get_container_client", lambda name: existing) + # Call the method + batch = await cosmos_db_client.create_batch(user_id, batch_id) - # Directly call _get_or_create_container on a new instance. - instance = FakeCosmosDBClient.__new__(FakeCosmosDBClient) - instance.logger = AppLogger("CosmosDB") - result = await instance._get_or_create_container(db, "existing", "/id") - assert result is existing + # Assert that batch was fetched (not created) due to already existing + assert batch.batch_id == batch_id + assert batch.user_id == user_id + assert batch.status == ProcessStatus.READY_TO_PROCESS + + mock_get_batch.assert_called_once_with(user_id, str(batch_id)) @pytest.mark.asyncio -async def test_create_batch_success(monkeypatch, cosmosdb_client): - called = False +async def test_create_batch_exception(cosmos_db_client, mocker): + user_id = "user_1" + batch_id = uuid4() - async def fake_create_item(body): - nonlocal called - called = True + # Mock the batch_container and make create_item raise a general Exception + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.create_item = AsyncMock(side_effect=Exception("Unexpected Error")) - monkeypatch.setattr( - cosmosdb_client.batch_container, "create_item", fake_create_item - ) - bid = uuid.uuid4() - batch = await cosmosdb_client.create_batch("user1", bid) - assert batch.batch_id == bid - assert batch.user_id == "user1" - assert called + # Mock the logger to verify logging + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Call the method and assert it raises the exception + with pytest.raises(Exception, match="Unexpected Error"): + await cosmos_db_client.create_batch(user_id, batch_id) + + # Ensure logger.error was called with expected message and error + mock_logger.error.assert_called_once() + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to create batch" + assert "error" in called_kwargs + assert "Unexpected Error" in called_kwargs["error"] @pytest.mark.asyncio -async def test_create_batch_error(monkeypatch, cosmosdb_client): - async def fake_create_item(body): - raise Exception("Batch creation error") +async def test_add_file(cosmos_db_client, mocker): + batch_id = uuid4() + file_id = uuid4() + file_name = "file.txt" + storage_path = "/path/to/storage" - monkeypatch.setattr( - cosmosdb_client.batch_container, "create_item", fake_create_item - ) - with pytest.raises(Exception) as exc_info: - await cosmosdb_client.create_batch("user1", uuid.uuid4()) - assert "Batch creation error" in str(exc_info.value) + # Mock file container creation + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + + # Mock the create_item method + mock_file_container.create_item = AsyncMock(return_value=None) + + # Call the method + file_record = await cosmos_db_client.add_file(batch_id, file_id, file_name, storage_path) + + # Assert that the file record is created + assert file_record.file_id == file_id + assert file_record.batch_id == batch_id + assert file_record.original_name == file_name + assert file_record.blob_path == storage_path + assert file_record.status == ProcessStatus.READY_TO_PROCESS + + mock_file_container.create_item.assert_called_once_with(body=file_record.dict()) @pytest.mark.asyncio -async def test_add_file_success(monkeypatch, cosmosdb_client): - called = False +async def test_add_file_exception(cosmos_db_client, mocker): + batch_id = uuid4() + file_id = uuid4() + file_name = "document.pdf" + storage_path = "/files/document.pdf" - async def fake_create_item(body): - nonlocal called - called = True + # Mock file_container.create_item to raise a general exception + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + mock_file_container.create_item = AsyncMock(side_effect=Exception("Insert failed")) - monkeypatch.setattr(cosmosdb_client.file_container, "create_item", fake_create_item) - bid = uuid.uuid4() - fid = uuid.uuid4() - fs = await cosmosdb_client.add_file(bid, fid, "test.txt", "path/to/blob") - assert fs.file_id == fid - assert fs.original_name == "test.txt" - assert fs.blob_path == "path/to/blob" - assert called + # Mock logger to capture error logs + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Expect an exception when calling add_file + with pytest.raises(Exception, match="Insert failed"): + await cosmos_db_client.add_file(batch_id, file_id, file_name, storage_path) + + # Check that logger.error was called properly + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to add file" + assert "error" in called_kwargs + assert "Insert failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_add_file_error(monkeypatch, cosmosdb_client): - async def fake_create_item(body): - raise Exception("Add file error") +async def test_update_file(cosmos_db_client, mocker): + file_id = uuid4() + file_record = FileRecord( + file_id=file_id, + batch_id=uuid4(), + original_name="file.txt", + blob_path="/path/to/storage", + translated_path="", + status=ProcessStatus.READY_TO_PROCESS, + error_count=0, + syntax_count=0, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + # Mock file container replace_item method + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + mock_file_container.replace_item = AsyncMock(return_value=None) + + # Call the method + updated_file_record = await cosmos_db_client.update_file(file_record) - monkeypatch.setattr( - cosmosdb_client.file_container, - "create_item", - lambda *args, **kwargs: fake_create_item(*args, **kwargs), + # Assert that the file record is updated + assert updated_file_record.file_id == file_id + + mock_file_container.replace_item.assert_called_once_with(item=str(file_id), body=file_record.dict()) + + +@pytest.mark.asyncio +async def test_update_file_exception(cosmos_db_client, mocker): + # Create a sample FileRecord + file_record = FileRecord( + file_id=uuid4(), + batch_id=uuid4(), + original_name="file.txt", + blob_path="/storage/file.txt", + translated_path="", + status=ProcessStatus.READY_TO_PROCESS, + error_count=0, + syntax_count=0, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), ) - with pytest.raises(Exception) as exc_info: - await cosmosdb_client.add_file( - uuid.uuid4(), uuid.uuid4(), "test.txt", "path/to/blob" - ) - assert "Add file error" in str(exc_info.value) + + # Mock file_container.replace_item to raise an exception + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + mock_file_container.replace_item = AsyncMock(side_effect=Exception("Update failed")) + + # Mock logger + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Expect an exception when update_file is called + with pytest.raises(Exception, match="Update failed"): + await cosmos_db_client.update_file(file_record) + + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to update file" + assert "error" in called_kwargs + assert "Update failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_get_batch_success(monkeypatch, cosmosdb_client): - batch_item = { - "id": "batch1", - "user_id": "user1", - "created_at": datetime.utcnow().isoformat(), - } - file_item = {"file_id": "file1", "batch_id": "batch1"} +async def test_update_batch(cosmos_db_client, mocker): + batch_record = BatchRecord( + batch_id=uuid4(), + user_id="user_1", + file_count=0, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + status=ProcessStatus.READY_TO_PROCESS + ) + + # Mock batch container replace_item method + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.replace_item = AsyncMock(return_value=None) - async def fake_query_items_batch(*args, **kwargs): - for item in [batch_item]: - yield item + # Call the method + updated_batch_record = await cosmos_db_client.update_batch(batch_record) - async def fake_query_items_files(*args, **kwargs): - for item in [file_item]: - yield item + # Assert that the batch record is updated + assert updated_batch_record.batch_id == batch_record.batch_id - cosmosdb_client.batch_container.set_query_items(fake_query_items_batch) - cosmosdb_client.file_container.set_query_items(fake_query_items_files) - result = await cosmosdb_client.get_batch("user1", "batch1") - assert result is not None - assert result.get("id") == "batch1" + mock_batch_container.replace_item.assert_called_once_with(item=str(batch_record.batch_id), body=batch_record.dict()) @pytest.mark.asyncio -async def test_get_batch_not_found(monkeypatch, cosmosdb_client): - async def fake_query_items(*args, **kwargs): - if False: - yield +async def test_update_batch_exception(cosmos_db_client, mocker): + # Create a sample BatchRecord + batch_record = BatchRecord( + batch_id=uuid4(), + user_id="user_1", + file_count=3, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + status=ProcessStatus.READY_TO_PROCESS, + ) + + # Mock batch_container.replace_item to raise an exception + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.replace_item = AsyncMock(side_effect=Exception("Update batch failed")) - cosmosdb_client.batch_container.set_query_items(fake_query_items) - result = await cosmosdb_client.get_batch("user1", "nonexistent") - assert result is None + # Mock logger to verify logging + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Expect an exception when update_batch is called + with pytest.raises(Exception, match="Update batch failed"): + await cosmos_db_client.update_batch(batch_record) + + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to update batch" + assert "error" in called_kwargs + assert "Update batch failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_get_batch_error(monkeypatch, cosmosdb_client): - async def fake_query_items(*args, **kwargs): - raise Exception("Query batch error") - if False: - yield +async def test_get_batch(cosmos_db_client, mocker): + user_id = "user_1" + batch_id = str(uuid4()) + + # Mock batch container query_items method + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, "batch_container", mock_batch_container) + + # Simulate the query result + expected_batch = { + "batch_id": batch_id, + "user_id": user_id, + "file_count": 0, + "status": ProcessStatus.READY_TO_PROCESS, + } - monkeypatch.setattr( - cosmosdb_client.batch_container, - "query_items", - lambda *args, **kwargs: fake_query_items(*args, **kwargs), + # We define the async generator function that will yield the expected batch + async def mock_query_items(query, parameters): + yield expected_batch + + # Assign the async generator to query_items mock + mock_batch_container.query_items.side_effect = mock_query_items + # Call the method + batch = await cosmos_db_client.get_batch(user_id, batch_id) + + # Assert the batch is returned correctly + assert batch["batch_id"] == batch_id + assert batch["user_id"] == user_id + + mock_batch_container.query_items.assert_called_once_with( + query="SELECT * FROM c WHERE c.batch_id = @batch_id and c.user_id = @user_id", + parameters=[ + {"name": "@batch_id", "value": batch_id}, + {"name": "@user_id", "value": user_id}, + ], ) - with pytest.raises(Exception) as exc_info: - await cosmosdb_client.get_batch("user1", "batch1") - assert "Query batch error" in str(exc_info.value) @pytest.mark.asyncio -async def test_get_file_success(monkeypatch, cosmosdb_client): - file_item = {"file_id": "file1", "original_name": "test.txt"} +async def test_get_batch_exception(cosmos_db_client, mocker): + user_id = "user_1" + batch_id = str(uuid4()) + + # Mock batch_container.query_items to raise an exception + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.query_items = mock.MagicMock( + side_effect=Exception("Get batch failed") + ) + + # Patch logger + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) - async def fake_query_items(*args, **kwargs): - for item in [file_item]: - yield item + # Call get_batch and expect it to raise an exception + with pytest.raises(Exception, match="Get batch failed"): + await cosmos_db_client.get_batch(user_id, batch_id) - cosmosdb_client.file_container.set_query_items(fake_query_items) - result = await cosmosdb_client.get_file("file1") - assert result == file_item + # Ensure logger.error was called with the expected error message + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to get batch" + assert "error" in called_kwargs + assert "Get batch failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_get_file_error(monkeypatch, cosmosdb_client): - async def fake_query_items(*args, **kwargs): - raise Exception("Query file error") - if False: - yield +async def test_get_file(cosmos_db_client, mocker): + file_id = str(uuid4()) + + # Mock file container query_items method + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + + # Simulate the query result + expected_file = { + "file_id": file_id, + "status": ProcessStatus.READY_TO_PROCESS, + "original_name": "file.txt", + "blob_path": "/path/to/file" + } - monkeypatch.setattr( - cosmosdb_client.file_container, - "query_items", - lambda *args, **kwargs: fake_query_items(*args, **kwargs), - ) - with pytest.raises(Exception) as exc_info: - await cosmosdb_client.get_file("file1") - assert "Query file error" in str(exc_info.value) + # We define the async generator function that will yield the expected batch + async def mock_query_items(query, parameters): + yield expected_file + # Assign the async generator to query_items mock + mock_file_container.query_items.side_effect = mock_query_items -@pytest.mark.asyncio -async def test_get_batch_files_success(monkeypatch, cosmosdb_client): - file_item = {"file_id": "file1", "batch_id": "batch1"} + # Call the method + file = await cosmos_db_client.get_file(file_id) - async def fake_query_items(*args, **kwargs): - for item in [file_item]: - yield item + # Assert the file is returned correctly + assert file["file_id"] == file_id + assert file["status"] == ProcessStatus.READY_TO_PROCESS - cosmosdb_client.file_container.set_query_items(fake_query_items) - files = await cosmosdb_client.get_batch_files("user1", "batch1") - assert files == [file_item] + mock_file_container.query_items.assert_called_once() @pytest.mark.asyncio -async def test_get_user_batches_success(monkeypatch, cosmosdb_client): - batch_item1 = {"id": "batch1", "user_id": "user1"} - batch_item2 = {"id": "batch2", "user_id": "user1"} +async def test_get_file_exception(cosmos_db_client, mocker): + file_id = str(uuid4()) + + # Mock file_container.query_items to raise an exception + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + mock_file_container.query_items = mock.MagicMock( + side_effect=Exception("Get file failed") + ) - async def fake_query_items(*args, **kwargs): - for item in [batch_item1, batch_item2]: - yield item + # Mock logger to verify logging + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) - cosmosdb_client.batch_container.set_query_items(fake_query_items) - result = await cosmosdb_client.get_user_batches("user1") - assert result == [batch_item1, batch_item2] + # Call get_file and expect an exception + with pytest.raises(Exception, match="Get file failed"): + await cosmos_db_client.get_file(file_id) + + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to get file" + assert "error" in called_kwargs + assert "Get file failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_get_user_batches_error(monkeypatch, cosmosdb_client): - async def fake_query_items(*args, **kwargs): - raise Exception("User batches error") - if False: - yield +async def test_get_batch_files(cosmos_db_client, mocker): + batch_id = str(uuid4()) + + # Mock file container query_items method + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + + # Simulate the query result for multiple files + expected_files = [ + { + "file_id": str(uuid4()), + "status": ProcessStatus.READY_TO_PROCESS, + "original_name": "file1.txt", + "blob_path": "/path/to/file1" + }, + { + "file_id": str(uuid4()), + "status": ProcessStatus.IN_PROGRESS, + "original_name": "file2.txt", + "blob_path": "/path/to/file2" + } + ] + + # Define the async generator function to yield the expected files + async def mock_query_items(query, parameters): + for file in expected_files: + yield file + + # Set the side_effect of query_items to simulate async iteration + mock_file_container.query_items.side_effect = mock_query_items + + # Call the method + files = await cosmos_db_client.get_batch_files(batch_id) + + # Assert the files list contains the correct files + assert len(files) == len(expected_files) + assert files[0]["file_id"] == expected_files[0]["file_id"] + assert files[1]["file_id"] == expected_files[1]["file_id"] + + mock_file_container.query_items.assert_called_once() + - monkeypatch.setattr( - cosmosdb_client.batch_container, - "query_items", - lambda *args, **kwargs: fake_query_items(*args, **kwargs), +@pytest.mark.asyncio +async def test_get_batch_files_exception(cosmos_db_client, mocker): + batch_id = str(uuid4()) + + # Mock file_container.query_items to raise an exception + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + mock_file_container.query_items = mock.MagicMock( + side_effect=Exception("Get batch file failed") ) - with pytest.raises(Exception) as exc_info: - await cosmosdb_client.get_user_batches("user1") - assert "User batches error" in str(exc_info.value) + + # Mock logger to verify logging + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Expect the exception to be raised + with pytest.raises(Exception, match="Get batch file failed"): + await cosmos_db_client.get_batch_files(batch_id) + + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to get files" + assert "error" in called_kwargs + assert "Get batch file failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_get_file_logs_success(monkeypatch, cosmosdb_client): - log_item = { - "file_id": "file1", - "description": "log", - "timestamp": datetime.utcnow().isoformat(), +async def test_get_batch_from_id(cosmos_db_client, mocker): + batch_id = str(uuid4()) + + # Mock batch container query_items method + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + + # Simulate the query result + expected_batch = { + "batch_id": batch_id, + "status": ProcessStatus.READY_TO_PROCESS, + "user_id": "user_123", } - async def fake_query_items(*args, **kwargs): - for item in [log_item]: - yield item + # Define the async generator function that will yield the expected batch + async def mock_query_items(query, parameters): + yield expected_batch - cosmosdb_client.log_container.set_query_items(fake_query_items) - result = await cosmosdb_client.get_file_logs("file1") - assert result == [log_item] + # Assign the async generator to query_items mock + mock_batch_container.query_items.side_effect = mock_query_items + # Call the method + batch = await cosmos_db_client.get_batch_from_id(batch_id) + + # Assert the batch is returned correctly + assert batch["batch_id"] == batch_id + assert batch["status"] == ProcessStatus.READY_TO_PROCESS + + mock_batch_container.query_items.assert_called_once() -@pytest.mark.asyncio -async def test_get_file_logs_error(monkeypatch, cosmosdb_client): - async def fake_query_items(*args, **kwargs): - raise Exception("Log query error") - if False: - yield - monkeypatch.setattr( - cosmosdb_client.log_container, - "query_items", - lambda *args, **kwargs: fake_query_items(*args, **kwargs), +@pytest.mark.asyncio +async def test_get_batch_from_id_exception(cosmos_db_client, mocker): + batch_id = str(uuid4()) + + # Mock batch_container.query_items to raise an exception + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.query_items = mock.MagicMock( + side_effect=Exception("Get batch from id failed") ) - with pytest.raises(Exception) as exc_info: - await cosmosdb_client.get_file_logs("file1") - assert "Log query error" in str(exc_info.value) + + # Mock logger to verify logging + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Call the method and expect it to raise an exception + with pytest.raises(Exception, match="Get batch from id failed"): + await cosmos_db_client.get_batch_from_id(batch_id) + + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to get batch from ID" + assert "error" in called_kwargs + assert "Get batch from id failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_delete_all_success(monkeypatch, cosmosdb_client): - async def fake_delete_items(key): - return +async def test_get_user_batches(cosmos_db_client, mocker): + user_id = "user_123" - monkeypatch.setattr( - cosmosdb_client.batch_container, "delete_items", fake_delete_items - ) - monkeypatch.setattr( - cosmosdb_client.file_container, "delete_items", fake_delete_items - ) - monkeypatch.setattr( - cosmosdb_client.log_container, "delete_items", fake_delete_items + # Mock batch container query_items method + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + + # Simulate the query result + expected_batches = [ + {"batch_id": str(uuid4()), "status": ProcessStatus.READY_TO_PROCESS, "user_id": user_id}, + {"batch_id": str(uuid4()), "status": ProcessStatus.IN_PROGRESS, "user_id": user_id} + ] + + # Define the async generator function that will yield the expected batches + async def mock_query_items(query, parameters): + for batch in expected_batches: + yield batch + + # Assign the async generator to query_items mock + mock_batch_container.query_items.side_effect = mock_query_items + + # Call the method + batches = await cosmos_db_client.get_user_batches(user_id) + + # Assert the batches are returned correctly + assert len(batches) == 2 + assert batches[0]["status"] == ProcessStatus.READY_TO_PROCESS + assert batches[1]["status"] == ProcessStatus.IN_PROGRESS + + mock_batch_container.query_items.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_user_batches_exception(cosmos_db_client, mocker): + user_id = "user_" + str(uuid4()) + + # Mock batch_container.query_items to raise an exception + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.query_items = mock.MagicMock( + side_effect=Exception("Get user batch failed") ) - await cosmosdb_client.delete_all("user1") + + # Mock logger to capture the error + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Call the method and expect it to raise the exception + with pytest.raises(Exception, match="Get user batch failed"): + await cosmos_db_client.get_user_batches(user_id) + + # Ensure logger.error was called with the expected message and error + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to get user batches" + assert "error" in called_kwargs + assert "Get user batch failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_delete_all_error(monkeypatch, cosmosdb_client): - async def fake_delete_items(key): - raise Exception("Delete all error") +async def test_get_file_logs(cosmos_db_client, mocker): + file_id = str(uuid4()) + + # Mock log container query_items method + mock_log_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container) + + # Simulate the query result with new log structure + expected_logs = [ + { + "log_id": str(uuid4()), + "file_id": file_id, + "description": "Log entry 1", + "last_candidate": "candidate_1", + "log_type": LogType.INFO, + "agent_type": AgentType.FIXER, + "author_role": AuthorRole.ASSISTANT, + "timestamp": datetime(2025, 4, 7, 12, 0, 0) + }, + { + "log_id": str(uuid4()), + "file_id": file_id, + "description": "Log entry 2", + "last_candidate": "candidate_2", + "log_type": LogType.ERROR, + "agent_type": AgentType.HUMAN, + "author_role": AuthorRole.USER, + "timestamp": datetime(2025, 4, 7, 12, 5, 0) + } + ] + + # Define the async generator function that will yield the expected logs + async def mock_query_items(query, parameters): + for log in expected_logs: + yield log + + # Assign the async generator to query_items mock + mock_log_container.query_items.side_effect = mock_query_items + + # Call the method + logs = await cosmos_db_client.get_file_logs(file_id) + + # Assert the logs are returned correctly + assert len(logs) == 2 + assert logs[0]["description"] == "Log entry 1" + assert logs[1]["description"] == "Log entry 2" + assert logs[0]["log_type"] == LogType.INFO + assert logs[1]["log_type"] == LogType.ERROR + assert logs[0]["timestamp"] == datetime(2025, 4, 7, 12, 0, 0) + assert logs[1]["timestamp"] == datetime(2025, 4, 7, 12, 5, 0) + + mock_log_container.query_items.assert_called_once() - monkeypatch.setattr( - cosmosdb_client.batch_container, "delete_items", fake_delete_items + +@pytest.mark.asyncio +async def test_get_file_logs_exception(cosmos_db_client, mocker): + file_id = str(uuid4()) + + # Mock log_container.query_items to raise an exception + mock_log_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container) + mock_log_container.query_items = mock.MagicMock( + side_effect=Exception("Get file log failed") ) - with pytest.raises(Exception) as exc_info: - await cosmosdb_client.delete_all("user1") - assert "Delete all error" in str(exc_info.value) + + # Mock logger to verify error logging + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Call the method and expect it to raise the exception + with pytest.raises(Exception, match="Get file log failed"): + await cosmos_db_client.get_file_logs(file_id) + + # Assert logger.error was called with correct arguments + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to get file logs" + assert "error" in called_kwargs + assert "Get file log failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_delete_logs_success(monkeypatch, cosmosdb_client): - async def fake_delete_items(key): - return +async def test_delete_all(cosmos_db_client, mocker): + user_id = str(uuid4()) - monkeypatch.setattr( - cosmosdb_client.log_container, "delete_items", fake_delete_items + # Mock containers with AsyncMock + mock_batch_container = AsyncMock() + mock_file_container = AsyncMock() + mock_log_container = AsyncMock() + + # Patching the containers with mock objects + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container) + + # Mock the delete_item method for all containers + mock_batch_container.delete_item = AsyncMock(return_value=None) + mock_file_container.delete_item = AsyncMock(return_value=None) + mock_log_container.delete_item = AsyncMock(return_value=None) + + # Call the delete_all method + await cosmos_db_client.delete_all(user_id) + + mock_batch_container.delete_item.assert_called_once() + mock_file_container.delete_item.assert_called_once() + mock_log_container.delete_item.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_all_exception(cosmos_db_client, mocker): + user_id = f"user_{uuid4()}" + + # Mock batch_container to raise an exception on delete + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.delete_item = mock.AsyncMock( + side_effect=Exception("Delete failed") ) - await cosmosdb_client.delete_logs("file1") + + # Also mock file_container and log_container to avoid accidental execution + mock_file_container = mock.MagicMock() + mock_log_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container) + + # Mock logger to verify error handling + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Call the method and expect it to raise the exception + with pytest.raises(Exception, match="Delete failed"): + await cosmos_db_client.delete_all(user_id) + + # Check that logger.error was called with expected error message + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to delete all user data" + assert "error" in called_kwargs + assert "Delete failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_delete_batch_success(monkeypatch, cosmosdb_client): - delete_calls = [] +async def test_delete_logs(cosmos_db_client, mocker): + file_id = str(uuid4()) - async def fake_delete_items(key): - delete_calls.append(key) + # Mock the log container with AsyncMock + mock_log_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container) - async def fake_delete_item(item, partition_key): - delete_calls.append((item, partition_key)) + # Simulate the query result for logs + log_ids = [str(uuid4()), str(uuid4())] - monkeypatch.setattr( - cosmosdb_client.file_container, "delete_items", fake_delete_items - ) - monkeypatch.setattr( - cosmosdb_client.log_container, "delete_items", fake_delete_items - ) - monkeypatch.setattr( - cosmosdb_client.batch_container, "delete_item", fake_delete_item + # Define the async generator function to simulate query result + async def mock_query_items(query, parameters): + for log_id in log_ids: + yield {"id": log_id} + + # Assign the async generator to query_items mock + mock_log_container.query_items.side_effect = mock_query_items + + # Mock delete_item method for log_container + mock_log_container.delete_item = AsyncMock(return_value=None) + + # Call the delete_logs method + await cosmos_db_client.delete_logs(file_id) + + # Assert delete_item is called for each log id + for log_id in log_ids: + mock_log_container.delete_item.assert_any_call(log_id, partition_key=log_id) + + mock_log_container.query_items.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_logs_exception(cosmos_db_client, mocker): + file_id = str(uuid4()) + + # Mock log_container.query_items to raise an exception + mock_log_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container) + mock_log_container.query_items = mock.MagicMock( + side_effect=Exception("Query failed") ) - await cosmosdb_client.delete_batch("user1", "batch1") - assert len(delete_calls) == 3 + + # Mock logger to verify error handling + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Call the method and expect it to raise the exception + with pytest.raises(Exception, match="Query failed"): + await cosmos_db_client.delete_logs(file_id) + + # Check that logger.error was called with expected error message + called_args, called_kwargs = mock_logger.error.call_args + assert called_args[0] == "Failed to delete all user data" + assert "error" in called_kwargs + assert "Query failed" in called_kwargs["error"] @pytest.mark.asyncio -async def test_delete_file_success(monkeypatch, cosmosdb_client): - calls = [] +async def test_delete_batch(cosmos_db_client, mocker): + user_id = str(uuid4()) + batch_id = str(uuid4()) - async def fake_delete_items(key): - calls.append(("log_delete", key)) + # Mock the batch container with AsyncMock + mock_batch_container = AsyncMock() + mocker.patch.object(cosmos_db_client, "batch_container", mock_batch_container) - async def fake_delete_item(file_id): - calls.append(("file_delete", file_id)) + # Call the delete_batch method + await cosmos_db_client.delete_batch(user_id, batch_id) - monkeypatch.setattr( - cosmosdb_client.log_container, "delete_items", fake_delete_items + mock_batch_container.delete_item.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_batch_exception(cosmos_db_client, mocker): + user_id = f"user_{uuid4()}" + batch_id = str(uuid4()) + + # Mock batch_container.delete_item to raise an exception + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + mock_batch_container.delete_item = mock.AsyncMock( + side_effect=Exception("Delete failed") ) - monkeypatch.setattr(cosmosdb_client.file_container, "delete_item", fake_delete_item) - await cosmosdb_client.delete_file("user1", "batch1", "file1") - assert ("log_delete", "file1") in calls - assert ("file_delete", "file1") in calls + + # Mock logger to verify error logging + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Expect the exception to be raised from the inner try block + with pytest.raises(Exception, match="Delete failed"): + await cosmos_db_client.delete_batch(user_id, batch_id) + + # Check that both error logs were triggered + assert mock_logger.error.call_count == 2 + + # First log: failed to delete the specific batch + first_call_args, first_call_kwargs = mock_logger.error.call_args_list[0] + assert f"Failed to delete batch with ID: {batch_id}" in first_call_args[0] + assert "error" in first_call_kwargs + assert "Delete failed" in first_call_kwargs["error"] + + # Second log: higher-level operation failed + second_call_args, second_call_kwargs = mock_logger.error.call_args_list[1] + assert second_call_args[0] == "Failed to perform delete batch operation" + assert "error" in second_call_kwargs + assert "Delete failed" in second_call_kwargs["error"] @pytest.mark.asyncio -async def test_log_file_status_success(monkeypatch, cosmosdb_client): - called = False +async def test_delete_file(cosmos_db_client, mocker): + user_id = str(uuid4()) + file_id = str(uuid4()) - async def fake_create_item(body): - nonlocal called - called = True + # Mock containers with AsyncMock + mock_file_container = AsyncMock() + mock_log_container = AsyncMock() - monkeypatch.setattr(cosmosdb_client.log_container, "create_item", fake_create_item) - await cosmosdb_client.log_file_status( - "file1", DummyProcessStatus.READY_TO_PROCESS, "desc", DummyLogType.INFO - ) - assert called + # Patching the containers with mock objects + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container) + + # Mock the delete_logs method (since it's called in delete_file) + mocker.patch.object(cosmos_db_client, 'delete_logs', return_value=None) + + # Call the delete_file method + await cosmos_db_client.delete_file(user_id, file_id) + + cosmos_db_client.delete_logs.assert_called_once_with(file_id) + + mock_file_container.delete_item.assert_called_once_with(file_id, partition_key=file_id) @pytest.mark.asyncio -async def test_log_file_status_error(monkeypatch, cosmosdb_client): - async def fake_create_item(body): - raise Exception("Log error") +async def test_delete_file_exception(cosmos_db_client, mocker): + user_id = f"user_{uuid4()}" + file_id = str(uuid4()) + + # Mock delete_logs to raise an exception + mocker.patch.object( + cosmos_db_client, + 'delete_logs', + mock.AsyncMock(side_effect=Exception("Delete file failed")) + ) - monkeypatch.setattr( - cosmosdb_client.log_container, - "create_item", - lambda *args, **kwargs: fake_create_item(*args, **kwargs), + # Mock file_container to ensure delete_item is not accidentally called + mock_file_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'file_container', mock_file_container) + + # Mock logger to verify error logging + mock_logger = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'logger', mock_logger) + + # Expect an exception to be raised from delete_logs + with pytest.raises(Exception, match="Delete file failed"): + await cosmos_db_client.delete_file(user_id, file_id) + + mock_logger.error.assert_called_once() + called_args, _ = mock_logger.error.call_args + assert f"Failed to delete file and logs for file_id {file_id}" in called_args[0] + + +@pytest.mark.asyncio +async def test_add_file_log(cosmos_db_client, mocker): + file_id = uuid4() + description = "File processing started" + last_candidate = "candidate_123" + log_type = LogType.INFO + agent_type = AgentType.MIGRATOR + author_role = AuthorRole.ASSISTANT + + # Mock log container create_item method + mock_log_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'log_container', mock_log_container) + + # Mock the create_item method + mock_log_container.create_item = AsyncMock(return_value=None) + + # Call the method + await cosmos_db_client.add_file_log( + file_id, description, last_candidate, log_type, agent_type, author_role ) - with pytest.raises(Exception) as exc_info: - await cosmosdb_client.log_file_status( - "file1", DummyProcessStatus.READY_TO_PROCESS, "desc", DummyLogType.INFO - ) - assert "Log error" in str(exc_info.value) + + mock_log_container.create_item.assert_called_once() @pytest.mark.asyncio -async def test_update_batch_entry_success(monkeypatch, cosmosdb_client): - dummy_batch = { - "id": "batch1", - "user_id": "user1", - "status": DummyProcessStatus.READY_TO_PROCESS, - "updated_at": datetime.utcnow().isoformat(), +async def test_update_batch_entry(cosmos_db_client, mocker): + batch_id = "batch_123" + user_id = "user_123" + status = ProcessStatus.IN_PROGRESS + file_count = 5 + + # Mock batch container replace_item method + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + + # Mock the get_batch method + mocker.patch.object(cosmos_db_client, 'get_batch', return_value={ + "batch_id": batch_id, + "status": ProcessStatus.READY_TO_PROCESS.value, + "user_id": user_id, "file_count": 0, - } + "updated_at": "2025-04-07T00:00:00Z" + }) - async def fake_get_batch(user_id, batch_id): - return dummy_batch.copy() + # Mock the replace_item method + mock_batch_container.replace_item = AsyncMock(return_value=None) - monkeypatch.setattr(cosmosdb_client, "get_batch", fake_get_batch) - updated_body = None + # Call the method + updated_batch = await cosmos_db_client.update_batch_entry(batch_id, user_id, status, file_count) - async def fake_replace_item(item, body): - nonlocal updated_body - updated_body = body - return body + # Assert that replace_item was called with the correct arguments + mock_batch_container.replace_item.assert_called_once_with(item=batch_id, body={ + "batch_id": batch_id, + "status": status.value, + "user_id": user_id, + "file_count": file_count, + "updated_at": updated_batch["updated_at"] + }) - monkeypatch.setattr( - cosmosdb_client.batch_container, "replace_item", fake_replace_item - ) - new_status = DummyProcessStatus.PROCESSING - file_count = 5 - result = await cosmosdb_client.update_batch_entry( - "batch1", "user1", new_status, file_count - ) - assert result["file_count"] == file_count - assert result["status"] == new_status.value - assert updated_body is not None + # Assert the returned batch matches expected values + assert updated_batch["batch_id"] == batch_id + assert updated_batch["status"] == status.value + assert updated_batch["file_count"] == file_count @pytest.mark.asyncio -async def test_update_batch_entry_not_found(monkeypatch, cosmosdb_client): - monkeypatch.setattr( - cosmosdb_client, "get_batch", lambda u, b: asyncio.sleep(0, result=None) - ) - with pytest.raises(ValueError, match="Batch not found"): - await cosmosdb_client.update_batch_entry( - "nonexistent", "user1", DummyProcessStatus.READY_TO_PROCESS, 0 - ) +async def test_close(cosmos_db_client, mocker): + # Mock the client and logger + mock_client = mock.MagicMock() + mock_logger = mock.MagicMock() + cosmos_db_client.client = mock_client + cosmos_db_client.logger = mock_logger + # Call the method + await cosmos_db_client.close() -@pytest.mark.asyncio -async def test_close(monkeypatch, cosmosdb_client): - closed = False + # Assert that the client was closed + mock_client.close.assert_called_once() - def fake_close(): - nonlocal closed - closed = True + # Assert that logger's info method was called + mock_logger.info.assert_called_once_with("Closed Cosmos DB connection") - monkeypatch.setattr(cosmosdb_client.client, "close", fake_close) - await cosmosdb_client.close() - assert closed + +@pytest.mark.asyncio +async def test_get_batch_history(cosmos_db_client, mocker): + user_id = "user_123" + limit = 5 + offset = 0 + sort_order = "DESC" + + # Mock batch container query_items method + mock_batch_container = mock.MagicMock() + mocker.patch.object(cosmos_db_client, 'batch_container', mock_batch_container) + + # Simulate the query result for batches + expected_batches = [ + {"batch_id": "batch_1", "status": ProcessStatus.IN_PROGRESS.value, "user_id": user_id, "file_count": 5}, + {"batch_id": "batch_2", "status": ProcessStatus.COMPLETED.value, "user_id": user_id, "file_count": 3}, + ] + + # Define the async generator function to simulate query result + async def mock_query_items(query, parameters): + for batch in expected_batches: + yield batch + + # Assign the async generator to query_items mock + mock_batch_container.query_items.side_effect = mock_query_items + + # Call the method + batches = await cosmos_db_client.get_batch_history(user_id, limit, sort_order, offset) + + # Assert the returned batches are correct + assert len(batches) == len(expected_batches) + assert batches[0]["batch_id"] == expected_batches[0]["batch_id"] + + mock_batch_container.query_items.assert_called_once() diff --git a/src/tests/backend/common/database/database_base_test.py b/src/tests/backend/common/database/database_base_test.py index 0e9d1fec..325cf7e9 100644 --- a/src/tests/backend/common/database/database_base_test.py +++ b/src/tests/backend/common/database/database_base_test.py @@ -1,59 +1,61 @@ import uuid from enum import Enum -# Import the abstract base class and related models/enums. + from common.database.database_base import DatabaseBase from common.models.api import ProcessStatus import pytest + +# Allow instantiation of the abstract base class by clearing its abstract methods. DatabaseBase.__abstractmethods__ = set() @pytest.fixture def db_instance(): - # Instantiate the DatabaseBase directly. + # Create a concrete implementation of DatabaseBase using async methods. class ConcreteDatabase(DatabaseBase): - def create_batch(self, user_id, batch_id): + async def create_batch(self, user_id, batch_id): pass - def get_file_logs(self, file_id): + async def get_file_logs(self, file_id): pass - def get_batch_files(self, user_id, batch_id): + async def get_batch_files(self, user_id, batch_id): pass - def delete_file_logs(self, file_id): + async def delete_file_logs(self, file_id): pass - def get_user_batches(self, user_id): + async def get_user_batches(self, user_id): pass - def add_file(self, batch_id, file_id, file_name, file_path): + async def add_file(self, batch_id, file_id, file_name, file_path): pass - def get_batch(self, user_id, batch_id): + async def get_batch(self, user_id, batch_id): pass - def get_file(self, file_id): + async def get_file(self, file_id): pass - def log_file_status(self, file_id, status, description, log_type): + async def log_file_status(self, file_id, status, description, log_type): pass - def log_batch_status(self, batch_id, status, file_count): + async def log_batch_status(self, batch_id, status, file_count): pass - def delete_all(self, user_id): + async def delete_all(self, user_id): pass - def delete_batch(self, user_id, batch_id): + async def delete_batch(self, user_id, batch_id): pass - def delete_file(self, user_id, batch_id, file_id): + async def delete_file(self, user_id, batch_id, file_id): pass - def close(self): + async def close(self): pass return ConcreteDatabase() @@ -62,7 +64,6 @@ def close(self): def get_dummy_status(): """ Try to use a specific ProcessStatus value (e.g. PROCESSING). - If that member is not available, just return the first member in the enum. """ try: @@ -71,7 +72,7 @@ def get_dummy_status(): members = list(ProcessStatus) if members: return members[0] - # If the enum is empty, create a dummy one + # If the enum is empty, create a dummy one. DummyStatus = Enum("DummyStatus", {"DUMMY": "dummy"}) return DummyStatus.DUMMY @@ -79,7 +80,7 @@ def get_dummy_status(): @pytest.mark.asyncio async def test_create_batch(db_instance): result = await db_instance.create_batch("user1", uuid.uuid4()) - # Since the method is abstract (and implemented as pass), result is None. + # Since the method is implemented as pass, result is None. assert result is None @@ -109,9 +110,7 @@ async def test_get_user_batches(db_instance): @pytest.mark.asyncio async def test_add_file(db_instance): - result = await db_instance.add_file( - uuid.uuid4(), uuid.uuid4(), "test.txt", "/dummy/path" - ) + result = await db_instance.add_file(uuid.uuid4(), uuid.uuid4(), "test.txt", "/dummy/path") assert result is None @@ -129,10 +128,8 @@ async def test_get_file(db_instance): @pytest.mark.asyncio async def test_log_file_status(db_instance): - # Use an existing member for file status—here we use COMPLETED. - result = await db_instance.log_file_status( - "file1", ProcessStatus.COMPLETED, "desc", "log_type" - ) + # Using ProcessStatus.COMPLETED as an example. + result = await db_instance.log_file_status("file1", ProcessStatus.COMPLETED, "desc", "log_type") assert result is None diff --git a/src/tests/backend/common/database/database_factory_test.py b/src/tests/backend/common/database/database_factory_test.py index bdf99d35..27d98105 100644 --- a/src/tests/backend/common/database/database_factory_test.py +++ b/src/tests/backend/common/database/database_factory_test.py @@ -1,63 +1,79 @@ -from common.config.config import Config -from common.database.database_factory import DatabaseFactory +from unittest.mock import AsyncMock, patch + import pytest -class DummyConfig: - cosmosdb_endpoint = "dummy_endpoint" - cosmosdb_database = "dummy_database" - cosmosdb_batch_container = "dummy_batch" - cosmosdb_file_container = "dummy_file" - cosmosdb_log_container = "dummy_log" +@pytest.fixture(autouse=True) +def patch_config(monkeypatch): + """Patch Config class to use dummy values.""" + from common.config.config import Config + def dummy_init(self): + """Mocked __init__ method for Config to set dummy values.""" + self.cosmosdb_endpoint = "dummy_endpoint" + self.cosmosdb_database = "dummy_database" + self.cosmosdb_batch_container = "dummy_batch" + self.cosmosdb_file_container = "dummy_file" + self.cosmosdb_log_container = "dummy_log" + self.get_azure_credentials = lambda: "dummy_credential" -class DummyCosmosDBClient: - def __init__(self, endpoint, credential, database_name, batch_container, file_container, log_container): - self.endpoint = endpoint - self.credential = credential - self.database_name = database_name - self.batch_container = batch_container - self.file_container = file_container - self.log_container = log_container + monkeypatch.setattr(Config, "__init__", dummy_init) # Replace the init method -def dummy_config_init(self): - self.cosmosdb_endpoint = DummyConfig.cosmosdb_endpoint - self.cosmosdb_database = DummyConfig.cosmosdb_database - self.cosmosdb_batch_container = DummyConfig.cosmosdb_batch_container - self.cosmosdb_file_container = DummyConfig.cosmosdb_file_container - self.cosmosdb_log_container = DummyConfig.cosmosdb_log_container - # Provide a dummy method for credentials. - self.get_azure_credentials = lambda: "dummy_credential" +@pytest.fixture(autouse=True) +def patch_cosmosdb_client(monkeypatch): + """Patch CosmosDBClient to use a dummy implementation.""" + class DummyCosmosDBClient: + def __init__(self, endpoint, credential, database_name, batch_container, file_container, log_container): + self.endpoint = endpoint + self.credential = credential + self.database_name = database_name + self.batch_container = batch_container + self.file_container = file_container + self.log_container = log_container -@pytest.fixture(autouse=True) -def patch_config(monkeypatch): - # Patch the __init__ of Config so that an instance will have the required attributes. - monkeypatch.setattr(Config, "__init__", dummy_config_init) + async def initialize_cosmos(self): + pass + async def create_batch(self, *args, **kwargs): + pass + + async def add_file(self, *args, **kwargs): + pass + + async def get_batch(self, *args, **kwargs): + return "mock_batch" + + async def close(self): + pass -@pytest.fixture(autouse=True) -def patch_cosmosdb_client(monkeypatch): - # Patch CosmosDBClient in the module under test to use our dummy client. monkeypatch.setattr("common.database.database_factory.CosmosDBClient", DummyCosmosDBClient) -def test_get_database(): - """ - Test that DatabaseFactory.get_database() correctly returns an instance of the. +@pytest.mark.asyncio +async def test_get_database(): + """Test database retrieval using the factory.""" + from common.database.database_factory import DatabaseFactory - dummy CosmosDB client with the expected configuration values. - """ - # When get_database() is called, it creates a new Config() instance. - db_instance = DatabaseFactory.get_database() + db_instance = await DatabaseFactory.get_database() - # Verify that the returned instance is our dummy client with the expected attributes. - assert isinstance(db_instance, DummyCosmosDBClient) - assert db_instance.endpoint == DummyConfig.cosmosdb_endpoint + assert db_instance.endpoint == "dummy_endpoint" assert db_instance.credential == "dummy_credential" - assert db_instance.database_name == DummyConfig.cosmosdb_database - assert db_instance.batch_container == DummyConfig.cosmosdb_batch_container - assert db_instance.file_container == DummyConfig.cosmosdb_file_container - assert db_instance.log_container == DummyConfig.cosmosdb_log_container + assert db_instance.database_name == "dummy_database" + assert db_instance.batch_container == "dummy_batch" + assert db_instance.file_container == "dummy_file" + assert db_instance.log_container == "dummy_log" + + +@pytest.mark.asyncio +async def test_main_function(): + """Test the main function in database factory.""" + with patch("common.database.database_factory.DatabaseFactory.get_database", new_callable=AsyncMock, return_value=AsyncMock()) as mock_get_database, patch("builtins.print") as mock_print: + + from common.database.database_factory import main + await main() + + mock_get_database.assert_called_once() + mock_print.assert_called() # Ensures print is executed diff --git a/src/tests/backend/common/logger/app_logger_test.py b/src/tests/backend/common/logger/app_logger_test.py new file mode 100644 index 00000000..9301eb30 --- /dev/null +++ b/src/tests/backend/common/logger/app_logger_test.py @@ -0,0 +1,94 @@ +import json +import logging +from unittest.mock import MagicMock, patch + +from common.logger.app_logger import AppLogger, LogLevel # Adjust the import based on your actual path + +import pytest + + +@pytest.fixture +def logger_name(): + return "test_logger" + + +@pytest.fixture +def logger_instance(logger_name): + """Fixture to return AppLogger with mocked handler""" + with patch("common.logger.app_logger.logging.getLogger") as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + yield AppLogger(logger_name) + + +def test_log_levels(): + """Ensure log levels are set correctly""" + assert LogLevel.NONE == logging.NOTSET + assert LogLevel.DEBUG == logging.DEBUG + assert LogLevel.INFO == logging.INFO + assert LogLevel.WARNING == logging.WARNING + assert LogLevel.ERROR == logging.ERROR + assert LogLevel.CRITICAL == logging.CRITICAL + + +def test_format_message_basic(logger_instance): + result = logger_instance._format_message("Test message") + parsed = json.loads(result) + assert parsed["message"] == "Test message" + assert "context" not in parsed + + +def test_format_message_with_context(logger_instance): + result = logger_instance._format_message("Contextual message", key1="value1", key2="value2") + parsed = json.loads(result) + assert parsed["message"] == "Contextual message" + assert parsed["context"] == {"key1": "value1", "key2": "value2"} + + +def test_debug_log(logger_instance): + with patch.object(logger_instance.logger, "debug") as mock_debug: + logger_instance.debug("Debug log", user="tester") + mock_debug.assert_called_once() + log_json = json.loads(mock_debug.call_args[0][0]) + assert log_json["message"] == "Debug log" + assert log_json["context"]["user"] == "tester" + + +def test_info_log(logger_instance): + with patch.object(logger_instance.logger, "info") as mock_info: + logger_instance.info("Info log", module="log_module") + mock_info.assert_called_once() + log_json = json.loads(mock_info.call_args[0][0]) + assert log_json["message"] == "Info log" + assert log_json["context"]["module"] == "log_module" + + +def test_warning_log(logger_instance): + with patch.object(logger_instance.logger, "warning") as mock_warning: + logger_instance.warning("Warning log") + mock_warning.assert_called_once() + + +def test_error_log(logger_instance): + with patch.object(logger_instance.logger, "error") as mock_error: + logger_instance.error("Error log", error_code=500) + mock_error.assert_called_once() + log_json = json.loads(mock_error.call_args[0][0]) + assert log_json["message"] == "Error log" + assert log_json["context"]["error_code"] == 500 + + +def test_critical_log(logger_instance): + with patch.object(logger_instance.logger, "critical") as mock_critical: + logger_instance.critical("Critical log") + mock_critical.assert_called_once() + + +def test_set_min_log_level(): + with patch("common.logger.app_logger.logging.getLogger") as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + AppLogger.set_min_log_level(LogLevel.ERROR) + + mock_logger.setLevel.assert_called_once_with(LogLevel.ERROR) diff --git a/src/tests/backend/common/models/api_test.py b/src/tests/backend/common/models/api_test.py new file mode 100644 index 00000000..b338efc0 --- /dev/null +++ b/src/tests/backend/common/models/api_test.py @@ -0,0 +1,123 @@ +from datetime import datetime +from uuid import uuid4 + +from backend.common.models.api import AgentType, BatchRecord, FileLog, FileProcessUpdate, FileProcessUpdateJSONEncoder, FileRecord, FileResult, ProcessStatus, QueueBatch, TranslateType + +import pytest + + +@pytest.fixture +def common_datetime(): + return datetime.now() + + +@pytest.fixture +def uuid_pair(): + return str(uuid4()), str(uuid4()) + + +def test_filelog_fromdb_and_dict(uuid_pair, common_datetime): + log_id, file_id = uuid_pair + data = { + "log_id": log_id, + "file_id": file_id, + "description": "test log", + "last_candidate": "some_candidate", + "log_type": "SUCCESS", + "agent_type": "migrator", + "author_role": "user", + "timestamp": common_datetime.isoformat(), + } + log = FileLog.fromdb(data) + assert log.log_id.hex == log_id.replace("-", "") + assert log.dict()["log_type"] == "info" + + assert log.dict()["author_role"] == "user" + + +def test_filerecord_fromdb_and_dict(uuid_pair, common_datetime): + file_id, batch_id = uuid_pair + data = { + "file_id": file_id, + "batch_id": batch_id, + "original_name": "file.sql", + "blob_path": "/blob/file.sql", + "translated_path": "/translated/file.sql", + "status": "in_progress", + "file_result": "warning", + "error_count": 2, + "syntax_count": 5, + "created_at": common_datetime.isoformat(), + "updated_at": common_datetime.isoformat(), + } + record = FileRecord.fromdb(data) + assert record.file_id.hex == file_id.replace("-", "") + assert record.dict()["status"] == "ready_to_process" + assert record.dict()["file_result"] == "warning" + + +def test_fileprocessupdate_dict(uuid_pair): + file_id, batch_id = uuid_pair + update = FileProcessUpdate( + file_id=file_id, + batch_id=batch_id, + process_status=ProcessStatus.COMPLETED, + file_result=FileResult.SUCCESS, + agent_type=AgentType.FIXER, + agent_message="Translation done", + ) + result = update.dict() + assert result["process_status"] == "completed" + assert result["file_result"] == "success" + assert result["agent_type"] == "fixer" + assert result["agent_message"] == "Translation done" + + +def test_fileprocessupdate_json_encoder(uuid_pair): + file_id, batch_id = uuid_pair + update = FileProcessUpdate( + file_id=file_id, + batch_id=batch_id, + process_status=ProcessStatus.FAILED, + file_result=FileResult.ERROR, + agent_type=AgentType.HUMAN, + agent_message="Something failed", + ) + json_string = FileProcessUpdateJSONEncoder().encode(update) + assert "failed" in json_string + assert "human" in json_string + + +def test_queuebatch_dict(uuid_pair, common_datetime): + batch_id, _ = uuid_pair + batch = QueueBatch( + batch_id=batch_id, + user_id="user123", + translate_from="en", + translate_to="tsql", + created_at=common_datetime, + updated_at=common_datetime, + status=ProcessStatus.IN_PROGRESS, + ) + result = batch.dict() + assert result["status"] == "in_process" + assert result["user_id"] == "user123" + + +def test_batchrecord_fromdb_and_dict(uuid_pair, common_datetime): + batch_id, _ = uuid_pair + data = { + "batch_id": batch_id, + "user_id": "user123", + "file_count": 3, + "created_at": common_datetime.isoformat(), + "updated_at": common_datetime.isoformat(), + "status": "completed", + "from_language": "Informix", + "to_language": "T-SQL" + } + record = BatchRecord.fromdb(data) + assert record.status == ProcessStatus.COMPLETED + assert record.from_language == TranslateType.INFORMIX + assert record.to_language == TranslateType.TSQL + assert record.dict()["status"] == "completed" diff --git a/src/tests/backend/common/services/__init__.py b/src/tests/backend/common/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tests/backend/common/services/batch_service_test.py b/src/tests/backend/common/services/batch_service_test.py new file mode 100644 index 00000000..21fd3a67 --- /dev/null +++ b/src/tests/backend/common/services/batch_service_test.py @@ -0,0 +1,785 @@ +from io import BytesIO +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +from common.models.api import AgentType, AuthorRole, BatchRecord, FileResult, LogType, ProcessStatus +from common.services.batch_service import BatchService + +from fastapi import HTTPException, UploadFile + +import pytest + +import pytest_asyncio + + +@pytest.fixture +def mock_service(mocker): + service = BatchService() + service.logger = mocker.Mock() + service.database = MagicMock() + + return service + + +@pytest_asyncio.fixture +async def service(): + svc = BatchService() + svc.logger = MagicMock() + return svc + + +def batch_service(): + service = BatchService() # Correct constructor + service.database = MagicMock() # Inject mock database + return service + + +@pytest.mark.asyncio +@patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock) +async def test_initialize_database(mock_get_db, service): + mock_db = AsyncMock() + mock_get_db.return_value = mock_db + await service.initialize_database() + assert service.database == mock_db + + +@pytest.mark.asyncio +async def test_get_batch_found(service): + service.database = AsyncMock() + batch_id = uuid4() + user_id = "user123" + service.database.get_batch.return_value = {"id": str(batch_id)} + service.database.get_batch_files.return_value = [{"file_id": "f1"}] + result = await service.get_batch(batch_id, user_id) + assert result["batch"] == {"id": str(batch_id)} + assert result["files"] == [{"file_id": "f1"}] + + +@pytest.mark.asyncio +async def test_get_batch_not_found(service): + service.database = AsyncMock() + batch_id = uuid4() + user_id = "user123" + service.database.get_batch.return_value = None + result = await service.get_batch(batch_id, user_id) + assert result is None + + +@pytest.mark.asyncio +async def test_get_file_found(service): + service.database = AsyncMock() + service.database.get_file.return_value = {"file_id": "file123"} + result = await service.get_file("file123") + assert result == {"file": {"file_id": "file123"}} + + +@pytest.mark.asyncio +async def test_get_file_not_found(service): + service.database = AsyncMock() + service.database.get_file.return_value = None + result = await service.get_file("notfound") + assert result is None + + +@pytest.mark.asyncio +@patch("common.services.batch_service.BlobStorageFactory.get_storage", new_callable=AsyncMock) +@patch("common.models.api.FileRecord.fromdb") +@patch("common.models.api.BatchRecord.fromdb") +async def test_get_file_report_success(mock_batch_fromdb, mock_file_fromdb, mock_get_storage, service): + service.database = AsyncMock() + file_id = "file123" + mock_file = {"batch_id": uuid4(), "translated_path": "some/path"} + mock_batch = {"batch_id": "batch123"} + mock_logs = [{"log": "log1"}] + mock_translated = "translated content" + service.database.get_file.return_value = mock_file + service.database.get_batch_from_id.return_value = mock_batch + service.database.get_file_logs.return_value = mock_logs + mock_file_fromdb.return_value = MagicMock(dict=lambda: mock_file, batch_id=mock_file["batch_id"], translated_path="some/path") + mock_batch_fromdb.return_value = MagicMock(dict=lambda: mock_batch) + mock_storage = AsyncMock() + mock_storage.get_file.return_value = mock_translated + mock_get_storage.return_value = mock_storage + result = await service.get_file_report(file_id) + assert result["file"] == mock_file + assert result["batch"] == mock_batch + assert result["logs"] == mock_logs + assert result["translated_content"] == mock_translated + + +@pytest.mark.asyncio +@patch("common.services.batch_service.BlobStorageFactory.get_storage", new_callable=AsyncMock) +async def test_get_file_translated_success(mock_get_storage, service): + file = {"translated_path": "some/path"} + mock_storage = AsyncMock() + mock_storage.get_file.return_value = "translated" + mock_get_storage.return_value = mock_storage + result = await service.get_file_translated(file) + assert result == "translated" + + +@pytest.mark.asyncio +@patch("common.services.batch_service.BlobStorageFactory.get_storage", new_callable=AsyncMock) +async def test_get_file_translated_error(mock_get_storage, service): + file = {"translated_path": "some/path"} + mock_storage = AsyncMock() + mock_storage.get_file.side_effect = IOError("Failed to download") + mock_get_storage.return_value = mock_storage + result = await service.get_file_translated(file) + assert result == "" + + +@pytest.mark.asyncio +async def test_get_batch_for_zip(service): + service.database = AsyncMock() + service.get_file_translated = AsyncMock(return_value="file-content") + service.database.get_batch_files.return_value = [ + {"original_name": "doc1.txt", "translated_path": "path1"}, + {"original_name": "doc2.txt", "translated_path": "path2"}, + ] + result = await service.get_batch_for_zip("batch1") + assert len(result) == 2 + assert result[0][0] == "rslt_doc1.txt" + assert result[0][1] == "file-content" + + +@pytest.mark.asyncio +@patch("common.models.api.BatchRecord.fromdb") +async def test_get_batch_summary_success(mock_batch_fromdb, service): + service.database = AsyncMock() + mock_batch = {"batch_id": "batch1"} + mock_batch_record = MagicMock(dict=lambda: {"batch_id": "batch1"}) + mock_batch_fromdb.return_value = mock_batch_record + service.database.get_batch.return_value = mock_batch + service.database.get_batch_files.return_value = [ + {"file_id": "file1", "translated_path": "path1"}, + {"file_id": "file2", "translated_path": None}, + ] + service.database.get_file_logs.return_value = ["log1"] + service.get_file_translated = AsyncMock(return_value="translated") + result = await service.get_batch_summary("batch1", "user1") + assert "files" in result + assert "batch" in result + assert result["files"][0]["logs"] == ["log1"] + assert result["files"][0]["translated_content"] == "translated" + + +@pytest.mark.asyncio +async def test_batch_zip_with_no_files(service): + service.database = AsyncMock() + service.database.get_batch_files.return_value = [] + service.get_file_translated = AsyncMock() + result = await service.get_batch_for_zip("batch_empty") + assert result == [] + + +def test_is_valid_uuid(): + service = BatchService() + valid = str(uuid4()) + invalid = "not-a-uuid" + assert service.is_valid_uuid(valid) + assert not service.is_valid_uuid(invalid) + + +def test_generate_file_path(): + service = BatchService() + path = service.generate_file_path("batch1", "user1", "file1", "test@file.pdf") + assert path == "user1/batch1/file1/test_file.pdf" + + +@pytest.mark.asyncio +async def test_delete_batch_existing(): + service = BatchService() + service.database = AsyncMock() + batch_id = uuid4() + service.database.get_batch.return_value = {"id": str(batch_id)} + service.database.delete_batch.return_value = None + result = await service.delete_batch(batch_id, "user1") + assert result["message"] == "Batch deleted successfully" + assert result["batch_id"] == str(batch_id) + + +@pytest.mark.asyncio +async def test_delete_file_success(): + service = BatchService() + service.database = AsyncMock() + file_id = uuid4() + batch_id = uuid4() + mock_file = MagicMock() + mock_file.batch_id = batch_id + mock_file.blob_path = "some/path/file.pdf" + mock_file.translated_path = "some/path/file_translated.pdf" + with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage: + mock_storage.return_value.delete_file.return_value = True + service.database.get_file.return_value = mock_file + service.database.get_batch.return_value = {"id": str(batch_id)} + service.database.get_batch_files.return_value = [1, 2] + with patch("common.models.api.FileRecord.fromdb", return_value=mock_file), \ + patch("common.models.api.BatchRecord.fromdb") as mock_batch_record: + mock_record = MagicMock() + mock_record.file_count = 1 + service.database.update_batch.return_value = None + mock_batch_record.return_value = mock_record + result = await service.delete_file(file_id, "user1") + assert result["message"] == "File deleted successfully" + assert result["file_id"] == str(file_id) + + +@pytest.mark.asyncio +async def test_upload_file_to_batch_dict_batch(): + service = BatchService() + service.database = AsyncMock() + file = UploadFile(filename="hello@file.txt", file=BytesIO(b"test content")) + batch_id = str(uuid4()) + file_id = str(uuid4()) + with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \ + patch("uuid.uuid4", return_value=file_id), \ + patch("common.models.api.FileRecord.fromdb", return_value={"blob_path": "path"}): + + mock_storage.return_value.upload_file.return_value = None + service.database.get_batch.side_effect = [None, {"file_count": 0}] + service.database.create_batch.return_value = {} + service.database.get_batch_files.return_value = ["file1", "file2"] + service.database.get_file.return_value = {"filename": file.filename} + service.database.update_batch_entry.return_value = {"batch_id": batch_id, "file_count": 2} + result = await service.upload_file_to_batch(batch_id, "user1", file) + assert "batch" in result + assert "file" in result + + +@pytest.mark.asyncio +async def test_upload_file_to_batch_invalid_storage(): + service = BatchService() + service.database = AsyncMock() + file = UploadFile(filename="file.txt", file=BytesIO(b"data")) + with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", return_value=None): + with pytest.raises(RuntimeError) as exc_info: + await service.upload_file_to_batch(str(uuid4()), "user1", file) + # Check outer exception message + assert str(exc_info.value) == "File upload failed" + + # Check original cause of the exception + assert isinstance(exc_info.value.__cause__, RuntimeError) + assert str(exc_info.value.__cause__) == "Storage service not initialized" + + +def test_generate_file_path_only_filename(): + service = BatchService() + path = service.generate_file_path(None, None, None, "weird@name!.txt") + assert path.endswith("weird_name_.txt") + + +def test_is_valid_uuid_empty_string(): + service = BatchService() + assert not service.is_valid_uuid("") + + +def test_is_valid_uuid_partial_uuid(): + service = BatchService() + assert not service.is_valid_uuid("1234abcd") + + +@pytest.mark.asyncio +async def test_delete_file_file_not_found(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + + service.database.get_file.return_value = None + result = await service.delete_file(file_id, "user1") + assert result is None + + +@pytest.mark.asyncio +async def test_upload_file_to_batch_storage_upload_fails(): + service = BatchService() + service.database = AsyncMock() + file = UploadFile(filename="test.txt", file=BytesIO(b"abc")) + file_id = str(uuid4()) + + with patch("common.storage.blob_factory.BlobStorageFactory.get_storage") as mock_get_storage, \ + patch("uuid.uuid4", return_value=file_id): + mock_storage = AsyncMock() + mock_storage.upload_file.side_effect = RuntimeError("upload failed") + mock_get_storage.return_value = mock_storage + + service.database.get_batch.side_effect = [None, {"file_count": 0}] + service.database.create_batch.return_value = {} + service.database.get_batch_files.return_value = [] + service.database.update_batch_entry.return_value = {} + + with pytest.raises(RuntimeError, match="File upload failed"): + await service.upload_file_to_batch("batch123", "user1", file) + + @pytest.mark.asyncio + async def test_update_file_counts_success(service): + service.database = AsyncMock() + file_id = str(uuid4()) + mock_file = {"file_id": file_id} + mock_logs = [ + {"log_type": LogType.ERROR.value}, + {"log_type": LogType.WARNING.value}, + {"log_type": LogType.WARNING.value}, + ] + service.database.get_file.return_value = mock_file + service.database.get_file_logs.return_value = mock_logs + with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock()) as mock_file_record: + await service.update_file_counts(file_id) + mock_file_record.assert_called_once() + service.database.update_file.assert_called_once() + + @pytest.mark.asyncio + async def test_update_file_counts_no_logs(service): + service.database = AsyncMock() + file_id = str(uuid4()) + mock_file = {"file_id": file_id} + service.database.get_file.return_value = mock_file + service.database.get_file_logs.return_value = [] + with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock()) as mock_file_record: + await service.update_file_counts(file_id) + mock_file_record.assert_called_once() + service.database.update_file.assert_called_once() + + @pytest.mark.asyncio + async def test_get_file_counts_success(service): + service.database = AsyncMock() + file_id = str(uuid4()) + mock_logs = [ + {"log_type": LogType.ERROR.value}, + {"log_type": LogType.WARNING.value}, + {"log_type": LogType.WARNING.value}, + ] + service.database.get_file_logs.return_value = mock_logs + error_count, syntax_count = await service.get_file_counts(file_id) + assert error_count == 1 + assert syntax_count == 2 + + @pytest.mark.asyncio + async def test_get_file_counts_no_logs(service): + service.database = AsyncMock() + file_id = str(uuid4()) + service.database.get_file_logs.return_value = [] + error_count, syntax_count = await service.get_file_counts(file_id) + assert error_count == 0 + assert syntax_count == 0 + + @pytest.mark.asyncio + async def test_get_batch_history_success(service): + service.database = AsyncMock() + user_id = "user123" + mock_history = [{"batch_id": "batch1"}, {"batch_id": "batch2"}] + service.database.get_batch_history.return_value = mock_history + result = await service.get_batch_history(user_id, limit=10, offset=0) + assert result == mock_history + + @pytest.mark.asyncio + async def test_get_batch_history_no_history(service): + service.database = AsyncMock() + user_id = "user123" + service.database.get_batch_history.return_value = [] + result = await service.get_batch_history(user_id, limit=10, offset=0) + assert result == [] + + @pytest.mark.asyncio + @patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock) + async def test_initialize_database_success(mock_get_database, service): + # Arrange + mock_database = AsyncMock() + mock_get_database.return_value = mock_database + + # Act + await service.initialize_database() + + # Assert + assert service.database == mock_database + mock_get_database.assert_called_once() + + @pytest.mark.asyncio + @patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock) + async def test_initialize_database_failure(mock_get_database, service): + # Arrange + mock_get_database.side_effect = RuntimeError("Database initialization failed") + + # Act & Assert + with pytest.raises(RuntimeError, match="Database initialization failed"): + await service.initialize_database() + mock_get_database.assert_called_once() + + @pytest.mark.asyncio + @patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock) + async def test_initialize_database_success(mock_get_database, service): + # Arrange + mock_database = AsyncMock() + mock_get_database.return_value = mock_database + + # Act + await service.initialize_database() + + # Assert + assert service.database == mock_database + mock_get_database.assert_called_once() + + @pytest.mark.asyncio + @patch("common.services.batch_service.DatabaseFactory.get_database", new_callable=AsyncMock) + async def test_initialize_database_failure(mock_get_database, service): + # Arrange + mock_get_database.side_effect = RuntimeError("Database initialization failed") + + # Act & Assert + with pytest.raises(RuntimeError, match="Database initialization failed"): + await service.initialize_database() + mock_get_database.assert_called_once() + + +@pytest.mark.asyncio +async def test_update_file_success(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + mock_file = {"file_id": file_id} + mock_record = MagicMock() + mock_record.error_count = 0 + mock_record.syntax_count = 0 + + service.database.get_file.return_value = mock_file + with patch("common.models.api.FileRecord.fromdb", return_value=mock_record): + await service.update_file(file_id, ProcessStatus.COMPLETED, FileResult.SUCCESS, 1, 2) + assert mock_record.error_count == 1 + assert mock_record.syntax_count == 2 + service.database.update_file.assert_called_once() + + +@pytest.mark.asyncio +async def test_update_file_record(): + service = BatchService() + service.database = AsyncMock() + mock_file_record = MagicMock() + await service.update_file_record(mock_file_record) + service.database.update_file.assert_called_once_with(mock_file_record) + + +@pytest.mark.asyncio +async def test_create_file_log(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + await service.create_file_log( + file_id=file_id, + description="test log", + last_candidate="candidate", + log_type=LogType.SUCCESS, + agent_type=AgentType.HUMAN, + author_role=AuthorRole.USER + ) + service.database.add_file_log.assert_called_once() + + +@pytest.mark.asyncio +async def test_update_batch_success(): + service = BatchService() + service.database = AsyncMock() + batch_id = str(uuid4()) + mock_batch = {"batch_id": batch_id} + mock_batch_record = MagicMock() + service.database.get_batch_from_id.return_value = mock_batch + with patch("common.models.api.BatchRecord.fromdb", return_value=mock_batch_record): + await service.update_batch(batch_id, ProcessStatus.COMPLETED) + service.database.update_batch.assert_called_once_with(mock_batch_record) + + +@pytest.mark.asyncio +async def test_delete_batch_and_files_success(): + service = BatchService() + service.database = AsyncMock() + batch_id = str(uuid4()) + user_id = "user" + mock_file = MagicMock() + mock_file.file_id = uuid4() + mock_file.blob_path = "blob/file" + mock_file.translated_path = "blob/translated" + service.database.get_batch.return_value = {"batch_id": batch_id} + service.database.get_batch_files.return_value = [mock_file] + + with patch("common.models.api.FileRecord.fromdb", return_value=mock_file), \ + patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage: + mock_storage.return_value.delete_file.return_value = True + result = await service.delete_batch_and_files(batch_id, user_id) + assert result["message"] == "Files deleted successfully" + + +@pytest.mark.asyncio +async def test_batch_files_final_update(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + file = { + "file_id": file_id, + "translated_path": "", + "status": "IN_PROGRESS" + } + service.database.get_batch_files.return_value = [file] + with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock(file_id=file_id, translated_path="", status=None)), \ + patch.object(service, "get_file_counts", return_value=(1, 1)), \ + patch.object(service, "create_file_log", new_callable=AsyncMock), \ + patch.object(service, "update_file_record", new_callable=AsyncMock): + await service.batch_files_final_update("batch1") + + +@pytest.mark.asyncio +async def test_delete_all_from_storage_cosmos_success(): + service = BatchService() + service.database = AsyncMock() + user_id = "user123" + file_id = str(uuid4()) + batch_id = str(uuid4()) + mock_file = { + "translated_path": "translated/path" + } + + service.get_all_batches = AsyncMock(return_value=[{"batch_id": batch_id}]) + service.database.get_file.return_value = mock_file + service.database.list_files = AsyncMock(return_value=[{"name": f"user/{batch_id}/{file_id}/file.txt"}]) + + with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage: + mock_storage.return_value.list_files.return_value = [{"name": f"user/{batch_id}/{file_id}/file.txt"}] + mock_storage.return_value.delete_file.return_value = True + result = await service.delete_all_from_storage_cosmos(user_id) + assert result["message"] == "All user data deleted successfully" + + +@pytest.mark.asyncio +async def test_create_candidate_success(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + batch_id = str(uuid4()) + user_id = "user123" + mock_file = {"batch_id": batch_id, "original_name": "doc.txt"} + mock_batch = {"user_id": user_id} + + with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock(original_name="doc.txt", batch_id=batch_id)), \ + patch("common.models.api.BatchRecord.fromdb", return_value=MagicMock(user_id=user_id)), \ + patch.object(service, "get_file_counts", return_value=(0, 1)), \ + patch.object(service, "update_file_record", new_callable=AsyncMock), \ + patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage: + + mock_storage.return_value.upload_file.return_value = None + service.database.get_file.return_value = mock_file + service.database.get_batch_from_id.return_value = mock_batch + await service.create_candidate(file_id, "Some content") + + +@pytest.mark.asyncio +async def test_batch_files_final_update_success_path(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + file = { + "file_id": file_id, + "translated_path": "some/path", + "status": "IN_PROGRESS" + } + + mock_file_record = MagicMock(translated_path="some/path", file_id=file_id) + service.database.get_batch_files.return_value = [file] + + with patch("common.models.api.FileRecord.fromdb", return_value=mock_file_record), \ + patch.object(service, "update_file_record", new_callable=AsyncMock): + await service.batch_files_final_update("batch123") + + +@pytest.mark.asyncio +async def test_get_file_counts_logs_none(): + service = BatchService() + service.database = AsyncMock() + service.database.get_file_logs.return_value = None + error_count, syntax_count = await service.get_file_counts("file_id") + assert error_count == 0 + assert syntax_count == 0 + + +@pytest.mark.asyncio +async def test_create_candidate_upload_error(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + mock_file = {"batch_id": str(uuid4()), "original_name": "doc.txt"} + mock_batch = {"user_id": "user1"} + + with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock(original_name="doc.txt", batch_id=mock_file["batch_id"])), \ + patch("common.models.api.BatchRecord.fromdb", return_value=MagicMock(user_id="user1")), \ + patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \ + patch.object(service, "get_file_counts", return_value=(1, 1)), \ + patch.object(service, "update_file_record", new_callable=AsyncMock): + + mock_storage.return_value.upload_file.side_effect = Exception("Upload fail") + service.database.get_file.return_value = mock_file + service.database.get_batch_from_id.return_value = mock_batch + + await service.create_candidate(file_id, "candidate content") + + +@pytest.mark.asyncio +async def test_get_batch_history_failure(): + service = BatchService() + service.logger = MagicMock() + service.database = AsyncMock() + + service.database.get_batch_history.side_effect = RuntimeError("DB failure") + + with pytest.raises(RuntimeError, match="Error retrieving batch history"): + await service.get_batch_history("user1", limit=5, offset=0) + + +@pytest.mark.asyncio +async def test_delete_file_logs_exception(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + batch_id = str(uuid4()) + mock_file = MagicMock() + mock_file.batch_id = batch_id + mock_file.blob_path = "blob" + mock_file.translated_path = "translated" + with patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage: + mock_storage.return_value.delete_file.return_value = True + service.database.get_file.return_value = mock_file + service.database.get_batch.return_value = {"id": str(batch_id)} + service.database.get_batch_files.return_value = [1, 2] + + with patch("common.models.api.FileRecord.fromdb", return_value=mock_file), \ + patch("common.models.api.BatchRecord.fromdb") as mock_batch_record: + mock_record = MagicMock() + mock_record.file_count = 2 + mock_batch_record.return_value = mock_record + service.database.update_batch.side_effect = Exception("Update failed") + + result = await service.delete_file(file_id, "user1") + assert result["message"] == "File deleted successfully" + + +@pytest.mark.asyncio +async def test_upload_file_to_batch_batchrecord(): + service = BatchService() + service.database = AsyncMock() + file = UploadFile(filename="test.txt", file=BytesIO(b"test content")) + batch_id = str(uuid4()) + file_id = str(uuid4()) + + # Create a mock BatchRecord instance + mock_batch_record = MagicMock(spec=BatchRecord) + mock_batch_record.file_count = 0 + mock_batch_record.updated_at = None + + with patch("uuid.uuid4", return_value=file_id), \ + patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \ + patch("common.models.api.FileRecord.fromdb", return_value={"blob_path": "blob/path"}), \ + patch("common.models.api.BatchRecord.fromdb", return_value=mock_batch_record): + + mock_storage.return_value.upload_file.return_value = None + # This will trigger the BatchRecord path + service.database.get_batch.side_effect = [mock_batch_record] + service.database.get_batch_files.return_value = ["file1", "file2"] + service.database.get_file.return_value = {"file_id": file_id} + service.database.update_batch_entry.return_value = mock_batch_record + + result = await service.upload_file_to_batch(batch_id, "user1", file) + assert "batch" in result + assert "file" in result + + +@pytest.mark.asyncio +async def test_upload_file_to_batch_unknown_type(): + service = BatchService() + service.database = AsyncMock() + file = UploadFile(filename="file.txt", file=BytesIO(b"data")) + file_id = str(uuid4()) + + with patch("uuid.uuid4", return_value=file_id), \ + patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \ + patch("common.models.api.FileRecord.fromdb", return_value={"blob_path": "path"}): + + mock_storage.return_value.upload_file.return_value = None + service.database.get_batch.side_effect = [object()] # Unknown type + service.database.get_batch_files.return_value = [] + service.database.get_file.return_value = {"file_id": file_id} + + with pytest.raises(RuntimeError, match="File upload failed"): + await service.upload_file_to_batch("batch123", "user1", file) + + +@pytest.mark.asyncio +@patch("common.services.batch_service.BlobStorageFactory.get_storage", new_callable=AsyncMock) +@patch("common.models.api.FileRecord.fromdb") +@patch("common.models.api.BatchRecord.fromdb") +async def test_get_file_report_ioerror(mock_batch_fromdb, mock_file_fromdb, mock_get_storage): + service = BatchService() + service.database = AsyncMock() + file_id = "file123" + mock_file = {"batch_id": uuid4(), "translated_path": "some/path"} + mock_batch = {"batch_id": "batch123"} + mock_logs = [{"log": "log1"}] + + mock_file_fromdb.return_value = MagicMock(dict=lambda: mock_file, batch_id=mock_file["batch_id"], translated_path="some/path") + mock_batch_fromdb.return_value = MagicMock(dict=lambda: mock_batch) + service.database.get_file.return_value = mock_file + service.database.get_batch_from_id.return_value = mock_batch + service.database.get_file_logs.return_value = mock_logs + + mock_storage = AsyncMock() + mock_storage.get_file.side_effect = IOError("Boom") + mock_get_storage.return_value = mock_storage + + result = await service.get_file_report(file_id) + assert result["translated_content"] == "" + + +@pytest.mark.asyncio +@patch("common.models.api.BatchRecord.fromdb") +async def test_get_batch_summary_log_exception(mock_batch_fromdb): + service = BatchService() + service.database = AsyncMock() + mock_batch = {"batch_id": "batch1"} + mock_batch_record = MagicMock(dict=lambda: {"batch_id": "batch1"}) + mock_batch_fromdb.return_value = mock_batch_record + + service.database.get_batch.return_value = mock_batch + service.database.get_batch_files.return_value = [{"file_id": "file1", "translated_path": None}] + service.database.get_file_logs.side_effect = Exception("DB log fail") + + result = await service.get_batch_summary("batch1", "user1") + assert result["files"][0]["logs"] == [] + + +@pytest.mark.asyncio +async def test_update_file_not_found(): + service = BatchService() + service.database = AsyncMock() + service.database.get_file.return_value = None + with pytest.raises(HTTPException) as exc: + await service.update_file("invalid_id", ProcessStatus.COMPLETED, FileResult.SUCCESS, 0, 0) + assert exc.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_create_candidate_success_flow(): + service = BatchService() + service.database = AsyncMock() + file_id = str(uuid4()) + batch_id = str(uuid4()) + user_id = "user1" + + mock_file = {"batch_id": batch_id, "original_name": "test.txt"} + mock_batch = {"user_id": user_id} + + with patch("common.models.api.FileRecord.fromdb", return_value=MagicMock(original_name="test.txt", batch_id=batch_id)), \ + patch("common.models.api.BatchRecord.fromdb", return_value=MagicMock(user_id=user_id)), \ + patch("common.storage.blob_factory.BlobStorageFactory.get_storage", new_callable=AsyncMock) as mock_storage, \ + patch.object(service, "get_file_counts", return_value=(0, 0)), \ + patch.object(service, "update_file_record", new_callable=AsyncMock): + + service.database.get_file.return_value = mock_file + service.database.get_batch_from_id.return_value = mock_batch + mock_storage.return_value.upload_file.return_value = None + + await service.create_candidate(file_id, "candidate content") diff --git a/src/tests/backend/common/storage/blob_azure_test.py b/src/tests/backend/common/storage/blob_azure_test.py index 2f743020..68e5ad0d 100644 --- a/src/tests/backend/common/storage/blob_azure_test.py +++ b/src/tests/backend/common/storage/blob_azure_test.py @@ -1,10 +1,7 @@ -# blob_azure_test.py +import json +from io import BytesIO +from unittest.mock import MagicMock, patch -from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch - -# Import the class under test -from azure.core.exceptions import ResourceExistsError from common.storage.blob_azure import AzureBlobStorage @@ -12,217 +9,217 @@ import pytest -class DummyBlob: - """A dummy blob item returned by list_blobs.""" - - def __init__(self, name, size, creation_time, content_type, metadata): - self.name = name - self.size = size - self.creation_time = creation_time - self.content_settings = MagicMock(content_type=content_type) - self.metadata = metadata +@pytest.fixture +def mock_blob_service(): + """Fixture to mock Azure Blob Storage service client""" + with patch("common.storage.blob_azure.BlobServiceClient") as mock_service: + mock_service_instance = MagicMock() + mock_container_client = MagicMock() + mock_blob_client = MagicMock() + # Set up mock methods + mock_service.return_value = mock_service_instance + mock_service_instance.get_container_client.return_value = mock_container_client + mock_container_client.get_blob_client.return_value = mock_blob_client -class DummyAsyncIterator: - """A dummy async iterator that yields the given items.""" + yield mock_service_instance, mock_container_client, mock_blob_client - def __init__(self, items): - self.items = items - self.index = 0 - def __aiter__(self): - return self +@pytest.fixture +def blob_storage(mock_blob_service): + """Fixture to initialize AzureBlobStorage with mocked dependencies""" + service_client, container_client, blob_client = mock_blob_service + return AzureBlobStorage(account_name="test_account", container_name="test_container") - async def __anext__(self): - if self.index >= len(self.items): - raise StopAsyncIteration - item = self.items[self.index] - self.index += 1 - return item +@pytest.mark.asyncio +async def test_upload_file(blob_storage, mock_blob_service): + """Test uploading a file""" + _, _, mock_blob_client = mock_blob_service + mock_blob_client.upload_blob.return_value = MagicMock() + mock_blob_client.get_blob_properties.return_value = MagicMock( + size=1024, + content_settings=MagicMock(content_type="text/plain"), + creation_time="2024-03-15T12:00:00Z", + etag="dummy_etag", + ) + + file_content = BytesIO(b"dummy data") + + result = await blob_storage.upload_file(file_content, "test_blob.txt", "text/plain") + + assert result["path"] == "test_blob.txt" + assert result["size"] == 1024 + assert result["content_type"] == "text/plain" + assert result["created_at"] == "2024-03-15T12:00:00Z" + assert result["etag"] == "dummy_etag" + assert "url" in result -class DummyDownloadStream: - """A dummy download stream whose content_as_bytes method returns a fixed byte string.""" - async def content_as_bytes(self): - return b"file content" +@pytest.mark.asyncio +async def test_upload_file_exception(blob_storage, mock_blob_service): + """Test upload_file when an exception occurs""" + _, _, mock_blob_client = mock_blob_service + mock_blob_client.upload_blob.side_effect = Exception("Upload failed") -# --- Fixtures --- + with pytest.raises(Exception, match="Upload failed"): + await blob_storage.upload_file(BytesIO(b"dummy data"), "test_blob.txt") -@pytest.fixture -def dummy_storage(): - # Create an instance with dummy connection string and container name. - return AzureBlobStorage("dummy_connection_string", "dummy_container") +@pytest.mark.asyncio +async def test_get_file(blob_storage, mock_blob_service): + """Test downloading a file""" + _, _, mock_blob_client = mock_blob_service + mock_blob_client.download_blob.return_value.readall.return_value = b"dummy data" + result = await blob_storage.get_file("test_blob.txt") -@pytest.fixture -def dummy_container_client(): - container = MagicMock() - container.create_container = AsyncMock() - container.list_blobs = MagicMock() # Will be overridden per test. - container.get_blob_client = MagicMock() - return container + assert result == "dummy data" -@pytest.fixture -def dummy_service_client(dummy_container_client): - service = MagicMock() - service.get_container_client.return_value = dummy_container_client - return service +@pytest.mark.asyncio +async def test_get_file_exception(blob_storage, mock_blob_service): + """Test get_file when an exception occurs""" + _, _, mock_blob_client = mock_blob_service + mock_blob_client.download_blob.side_effect = Exception("Download failed") + with pytest.raises(Exception, match="Download failed"): + await blob_storage.get_file("test_blob.txt") -@pytest.fixture -def dummy_blob_client(): - blob_client = MagicMock() - blob_client.upload_blob = AsyncMock() - blob_client.get_blob_properties = AsyncMock() - blob_client.download_blob = AsyncMock() - blob_client.delete_blob = AsyncMock() - blob_client.url = "https://dummy.blob.core.windows.net/dummy_container/dummy_blob" - return blob_client -# --- Tests for AzureBlobStorage methods --- +@pytest.mark.asyncio +async def test_delete_file(blob_storage, mock_blob_service): + """Test deleting a file""" + _, _, mock_blob_client = mock_blob_service + mock_blob_client.delete_blob.return_value = None + result = await blob_storage.delete_file("test_blob.txt") -@pytest.mark.asyncio -async def test_initialize_creates_container(dummy_storage, dummy_service_client, dummy_container_client): - with patch("common.storage.blob_azure.BlobServiceClient.from_connection_string", return_value=dummy_service_client) as mock_from_conn: - # Simulate normal container creation. - dummy_container_client.create_container = AsyncMock() - await dummy_storage.initialize() - mock_from_conn.assert_called_once_with("dummy_connection_string") - dummy_service_client.get_container_client.assert_called_once_with("dummy_container") - dummy_container_client.create_container.assert_awaited_once() + assert result is True @pytest.mark.asyncio -async def test_initialize_container_already_exists(dummy_storage, dummy_service_client, dummy_container_client): - with patch("common.storage.blob_azure.BlobServiceClient.from_connection_string", return_value=dummy_service_client): - # Simulate container already existing. - dummy_container_client.create_container = AsyncMock(side_effect=ResourceExistsError("Container exists")) - with patch.object(dummy_storage.logger, "debug") as mock_debug: - await dummy_storage.initialize() - dummy_container_client.create_container.assert_awaited_once() - mock_debug.assert_called_with("Container dummy_container already exists") +async def test_delete_file_exception(blob_storage, mock_blob_service): + """Test delete_file when an exception occurs""" + _, _, mock_blob_client = mock_blob_service + mock_blob_client.delete_blob.side_effect = Exception("Delete failed") + result = await blob_storage.delete_file("test_blob.txt") -@pytest.mark.asyncio -async def test_initialize_failure(dummy_storage): - # Simulate failure during initialization. - with patch("common.storage.blob_azure.BlobServiceClient.from_connection_string", side_effect=Exception("Init error")): - with patch.object(dummy_storage.logger, "error") as mock_error: - with pytest.raises(Exception, match="Init error"): - await dummy_storage.initialize() - mock_error.assert_called() + assert result is False @pytest.mark.asyncio -async def test_upload_file_success(dummy_storage, dummy_blob_client): - # Patch get_blob_client to return our dummy blob client. - dummy_storage.container_client = MagicMock() - dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client - - # Create a dummy properties object. - dummy_properties = MagicMock() - dummy_properties.size = 1024 - dummy_properties.content_settings = MagicMock(content_type="text/plain") - dummy_properties.creation_time = datetime(2023, 1, 1) - dummy_properties.etag = "dummy_etag" - dummy_blob_client.get_blob_properties = AsyncMock(return_value=dummy_properties) - - file_content = b"Hello, world!" - result = await dummy_storage.upload_file(file_content, "dummy_blob.txt", "text/plain", {"key": "value"}) - dummy_storage.container_client.get_blob_client.assert_called_once_with("dummy_blob.txt") - dummy_blob_client.upload_blob.assert_awaited_with(file_content, content_type="text/plain", metadata={"key": "value"}, overwrite=True) - dummy_blob_client.get_blob_properties.assert_awaited() - assert result["path"] == "dummy_blob.txt" - assert result["size"] == 1024 - assert result["content_type"] == "text/plain" - assert result["url"] == dummy_blob_client.url - assert result["etag"] == "dummy_etag" +async def test_list_files(blob_storage, mock_blob_service): + """Test listing files in a container""" + _, mock_container_client, _ = mock_blob_service + class AsyncIterator: + """Helper class to create an async iterator""" -@pytest.mark.asyncio -async def test_upload_file_error(dummy_storage, dummy_blob_client): - dummy_storage.container_client = MagicMock() - dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client - dummy_blob_client.upload_blob = AsyncMock(side_effect=Exception("Upload failed")) - with pytest.raises(Exception, match="Upload failed"): - await dummy_storage.upload_file(b"data", "blob.txt", "text/plain", {}) + def __init__(self, items): + self._items = items + def __aiter__(self): + self._iter = iter(self._items) + return self -@pytest.mark.asyncio -async def test_get_file_success(dummy_storage, dummy_blob_client): - dummy_storage.container_client = MagicMock() - dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client - # Make download_blob return a DummyDownloadStream (not wrapped in extra coroutine) - dummy_blob_client.download_blob = AsyncMock(return_value=DummyDownloadStream()) - result = await dummy_storage.get_file("blob.txt") - dummy_storage.container_client.get_blob_client.assert_called_once_with("blob.txt") - dummy_blob_client.download_blob.assert_awaited() - assert result == b"file content" + async def __anext__(self): + try: + return next(self._iter) + except StopIteration: + raise StopAsyncIteration + mock_blobs = [ + MagicMock(name="file1.txt"), + MagicMock(name="file2.txt"), + ] -@pytest.mark.asyncio -async def test_get_file_error(dummy_storage, dummy_blob_client): - dummy_storage.container_client = MagicMock() - dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client - dummy_blob_client.download_blob = AsyncMock(side_effect=Exception("Download error")) - with pytest.raises(Exception, match="Download error"): - await dummy_storage.get_file("nonexistent.txt") + # Explicitly set attributes to avoid MagicMock issues + mock_blobs[0].name = "file1.txt" + mock_blobs[0].size = 123 + mock_blobs[0].creation_time = "2024-03-15T12:00:00Z" + mock_blobs[0].content_settings = MagicMock(content_type="text/plain") + mock_blobs[0].metadata = {} + mock_blobs[1].name = "file2.txt" + mock_blobs[1].size = 456 + mock_blobs[1].creation_time = "2024-03-16T12:00:00Z" + mock_blobs[1].content_settings = MagicMock(content_type="application/json") + mock_blobs[1].metadata = {} -@pytest.mark.asyncio -async def test_delete_file_success(dummy_storage, dummy_blob_client): - dummy_storage.container_client = MagicMock() - dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client - dummy_blob_client.delete_blob = AsyncMock() - result = await dummy_storage.delete_file("blob.txt") - dummy_storage.container_client.get_blob_client.assert_called_once_with("blob.txt") - dummy_blob_client.delete_blob.assert_awaited() - assert result is True + mock_container_client.list_blobs = MagicMock(return_value=AsyncIterator(mock_blobs)) + result = await blob_storage.list_files() -@pytest.mark.asyncio -async def test_delete_file_error(dummy_storage, dummy_blob_client): - dummy_storage.container_client = MagicMock() - dummy_storage.container_client.get_blob_client.return_value = dummy_blob_client - dummy_blob_client.delete_blob = AsyncMock(side_effect=Exception("Delete error")) - result = await dummy_storage.delete_file("blob.txt") - assert result is False + assert len(result) == 2 + assert result[0]["name"] == "file1.txt" + assert result[0]["size"] == 123 + assert result[0]["created_at"] == "2024-03-15T12:00:00Z" + assert result[0]["content_type"] == "text/plain" + assert result[0]["metadata"] == {} + + assert result[1]["name"] == "file2.txt" + assert result[1]["size"] == 456 + assert result[1]["created_at"] == "2024-03-16T12:00:00Z" + assert result[1]["content_type"] == "application/json" + assert result[1]["metadata"] == {} @pytest.mark.asyncio -async def test_list_files_success(dummy_storage): - dummy_storage.container_client = MagicMock() - # Create two dummy blobs. - blob1 = DummyBlob("file1.txt", 100, datetime(2023, 1, 1), "text/plain", {"a": "1"}) - blob2 = DummyBlob("file2.txt", 200, datetime(2023, 1, 2), "text/plain", {"b": "2"}) - async_iterator = DummyAsyncIterator([blob1, blob2]) - dummy_storage.container_client.list_blobs.return_value = async_iterator - result = await dummy_storage.list_files("file") - assert len(result) == 2 - names = {item["name"] for item in result} - assert names == {"file1.txt", "file2.txt"} +async def test_list_files_exception(blob_storage, mock_blob_service): + """Test list_files when an exception occurs""" + _, mock_container_client, _ = mock_blob_service + mock_container_client.list_blobs.side_effect = Exception("List failed") + + with pytest.raises(Exception, match="List failed"): + await blob_storage.list_files() @pytest.mark.asyncio -async def test_list_files_failure(dummy_storage): - dummy_storage.container_client = MagicMock() - # Define list_blobs to return an invalid object (simulate error) +async def test_close(blob_storage, mock_blob_service): + """Test closing the storage client""" + service_client, _, _ = mock_blob_service + + await blob_storage.close() - async def invalid_list_blobs(*args, **kwargs): - # Return a plain string (which does not implement __aiter__) - return "invalid" - dummy_storage.container_client.list_blobs = invalid_list_blobs - with pytest.raises(Exception): # noqa B017 - await dummy_storage.list_files("") + service_client.close.assert_called_once() @pytest.mark.asyncio -async def test_close(dummy_storage): - dummy_storage.service_client = MagicMock() - dummy_storage.service_client.close = AsyncMock() - await dummy_storage.close() - dummy_storage.service_client.close.assert_awaited() +async def test_blob_storage_init_exception(): + """Test that an exception during initialization logs the error message""" + with patch("common.storage.blob_azure.BlobServiceClient") as mock_service, \ + patch("logging.getLogger") as mock_logger: # Patch logging globally + + # Mock logger instance + mock_logger_instance = MagicMock() + mock_logger.return_value = mock_logger_instance + + # Simulate an exception when creating BlobServiceClient + mock_service.side_effect = Exception("Connection failed") + + # Try to initialize AzureBlobStorage + try: + AzureBlobStorage(account_name="test_account", container_name="test_container") + except Exception: + pass # Prevent test failure due to the exception + + # Construct the expected JSON log format + expected_error_log = json.dumps({ + "message": "Failed to initialize Azure Blob Storage", + "context": { + "error": "Connection failed", + "account_name": "test_account" + } + }) + + expected_debug_log = json.dumps({ + "message": "Container test_container already exists" + }) + + # Assert that error logging happened with the expected JSON string + mock_logger_instance.error.assert_called_once_with(expected_error_log) + + # Assert that debug log is written for container existence + mock_logger_instance.debug.assert_called_once_with(expected_debug_log) diff --git a/src/tests/backend/common/storage/blob_base_test.py b/src/tests/backend/common/storage/blob_base_test.py index 561007ed..d7e2383d 100644 --- a/src/tests/backend/common/storage/blob_base_test.py +++ b/src/tests/backend/common/storage/blob_base_test.py @@ -1,128 +1,86 @@ -from datetime import datetime -from typing import Any, BinaryIO, Dict +from io import BytesIO +from typing import Any, BinaryIO, Dict, Optional + + +from common.storage.blob_base import BlobStorageBase # Adjust import path as needed -# Import the abstract base class from the production code. -from common.storage.blob_base import BlobStorageBase import pytest -# Create a dummy concrete subclass of BlobStorageBase that calls the parent's abstract methods. -class DummyBlobStorage(BlobStorageBase): - async def initialize(self) -> None: - # Call the parent (which is just a pass) - await super().initialize() - # Return a dummy value so we can verify our override is called. - return "initialized" +class MockBlobStorage(BlobStorageBase): + """Mock implementation of BlobStorageBase for testing""" async def upload_file( self, file_content: BinaryIO, blob_path: str, - content_type: str = None, - metadata: Dict[str, str] = None, + content_type: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: - await super().upload_file(file_content, blob_path, content_type, metadata) - # Return a dummy dictionary that simulates upload details. return { - "url": "https://dummy.blob.core.windows.net/dummy_container/" + blob_path, - "size": len(file_content), - "etag": "dummy_etag", + "path": blob_path, + "size": len(file_content.read()), + "content_type": content_type or "application/octet-stream", + "metadata": metadata or {}, + "url": f"https://mockstorage.com/{blob_path}", } async def get_file(self, blob_path: str) -> BinaryIO: - await super().get_file(blob_path) - # Return dummy binary content. - return b"dummy content" + return BytesIO(b"mock data") async def delete_file(self, blob_path: str) -> bool: - await super().delete_file(blob_path) - # Simulate a successful deletion. return True - async def list_files(self, prefix: str = None) -> list[Dict[str, Any]]: - await super().list_files(prefix) + async def list_files(self, prefix: Optional[str] = None) -> list[Dict[str, Any]]: return [ - { - "name": "dummy.txt", - "size": 123, - "created_at": datetime.now(), - "content_type": "text/plain", - "metadata": {"dummy": "value"}, - } + {"name": "file1.txt", "size": 100, "content_type": "text/plain"}, + {"name": "file2.jpg", "size": 200, "content_type": "image/jpeg"}, ] -# tests cases with each method. +@pytest.fixture +def mock_blob_storage(): + """Fixture to provide a MockBlobStorage instance""" + return MockBlobStorage() @pytest.mark.asyncio -async def test_initialize(): - storage = DummyBlobStorage() - result = await storage.initialize() - # Since the dummy override returns "initialized" after calling super(), - # we assert that the result equals that string. - assert result == "initialized" +async def test_upload_file(mock_blob_storage): + """Test upload_file method""" + file_content = BytesIO(b"dummy data") + result = await mock_blob_storage.upload_file(file_content, "test_blob.txt", "text/plain") - -@pytest.mark.asyncio -async def test_upload_file(): - storage = DummyBlobStorage() - content = b"hello world" - blob_path = "folder/hello.txt" - content_type = "text/plain" - metadata = {"key": "value"} - result = await storage.upload_file(content, blob_path, content_type, metadata) - # Verify that our dummy return value is as expected. - assert ( - result["url"] - == "https://dummy.blob.core.windows.net/dummy_container/" + blob_path - ) - assert result["size"] == len(content) - assert result["etag"] == "dummy_etag" + assert result["path"] == "test_blob.txt" + assert result["size"] == len(b"dummy data") + assert result["content_type"] == "text/plain" + assert "url" in result @pytest.mark.asyncio -async def test_get_file(): - storage = DummyBlobStorage() - result = await storage.get_file("folder/hello.txt") - # Verify that we get the dummy binary content. - assert result == b"dummy content" +async def test_get_file(mock_blob_storage): + """Test get_file method""" + result = await mock_blob_storage.get_file("test_blob.txt") - -@pytest.mark.asyncio -async def test_delete_file(): - storage = DummyBlobStorage() - result = await storage.delete_file("folder/hello.txt") - # Verify that deletion returns True. - assert result is True + assert isinstance(result, BytesIO) + assert result.read() == b"mock data" @pytest.mark.asyncio -async def test_list_files(): - storage = DummyBlobStorage() - result = await storage.list_files("dummy") - # Verify that we receive a list with one item having a 'name' key. - assert isinstance(result, list) - assert len(result) == 1 - assert "dummy.txt" in result[0]["name"] - assert result[0]["size"] == 123 - assert result[0]["content_type"] == "text/plain" - assert result[0]["metadata"] == {"dummy": "value"} +async def test_delete_file(mock_blob_storage): + """Test delete_file method""" + result = await mock_blob_storage.delete_file("test_blob.txt") + + assert result is True @pytest.mark.asyncio -async def test_smoke_all_methods(): - storage = DummyBlobStorage() - init_val = await storage.initialize() - assert init_val == "initialized" - upload_val = await storage.upload_file( - b"data", "file.txt", "text/plain", {"a": "b"} - ) - assert upload_val["size"] == 4 - file_val = await storage.get_file("file.txt") - assert file_val == b"dummy content" - delete_val = await storage.delete_file("file.txt") - assert delete_val is True - list_val = await storage.list_files("file") - assert isinstance(list_val, list) +async def test_list_files(mock_blob_storage): + """Test list_files method""" + result = await mock_blob_storage.list_files() + + assert len(result) == 2 + assert result[0]["name"] == "file1.txt" + assert result[1]["name"] == "file2.jpg" + assert result[0]["size"] == 100 + assert result[1]["size"] == 200 diff --git a/src/tests/backend/common/storage/blob_factory_test.py b/src/tests/backend/common/storage/blob_factory_test.py index 47e344ff..70ed7ecf 100644 --- a/src/tests/backend/common/storage/blob_factory_test.py +++ b/src/tests/backend/common/storage/blob_factory_test.py @@ -1,284 +1,78 @@ -import asyncio -import os -import sys -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock, patch -# Adjust sys.path so that the project root is found. -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) -# Set required environment variables (dummy values) -os.environ["COSMOSDB_ENDPOINT"] = "https://dummy-endpoint" -os.environ["COSMOSDB_KEY"] = "dummy-key" -os.environ["COSMOSDB_DATABASE"] = "dummy-database" -os.environ["COSMOSDB_CONTAINER"] = "dummy-container" -os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "dummy-deployment" -os.environ["AZURE_OPENAI_API_VERSION"] = "2023-01-01" -os.environ["AZURE_OPENAI_ENDPOINT"] = "https://dummy-openai-endpoint" +from common.storage.blob_factory import BlobStorageFactory -# Patch missing azure module so that event_utils imports without error. -sys.modules["azure.monitor.events.extension"] = MagicMock() -# --- Import the module under test --- -from common.storage.blob_base import BlobStorageBase # noqa: E402 -from common.storage.blob_factory import BlobStorageFactory # noqa: E402 - -import pytest # noqa: E402 - -# --- Dummy configuration for testing --- - - -class DummyConfig: - azure_blob_connection_string = "dummy_connection_string" - azure_blob_container_name = "dummy_container" - -# --- Fixture to patch Config in our tests --- - - -@pytest.fixture(autouse=True) -def patch_config(monkeypatch): - # Import the real Config from your project. - from common.config.config import Config - - def dummy_init(self): - self.azure_blob_connection_string = DummyConfig.azure_blob_connection_string - self.azure_blob_container_name = DummyConfig.azure_blob_container_name - monkeypatch.setattr(Config, "__init__", dummy_init) - # Reset the BlobStorageFactory singleton before each test. - BlobStorageFactory._instance = None - - -class DummyAzureBlobStorage(BlobStorageBase): - def __init__(self, connection_string: str, container_name: str): - self.connection_string = connection_string - self.container_name = container_name - self.initialized = False - self.files = {} # maps blob_path to tuple(file_content, content_type, metadata) - - async def initialize(self): - self.initialized = True - - async def upload_file(self, file_content: bytes, blob_path: str, content_type: str, metadata: dict): - self.files[blob_path] = (file_content, content_type, metadata) - return { - "url": f"https://dummy.blob.core.windows.net/{self.container_name}/{blob_path}", - "size": len(file_content), - "etag": "dummy_etag" - } - - async def get_file(self, blob_path: str): - if blob_path in self.files: - return self.files[blob_path][0] - else: - raise FileNotFoundError(f"File {blob_path} not found") - - async def delete_file(self, blob_path: str): - if blob_path in self.files: - del self.files[blob_path] - # No error if file does not exist. - - async def list_files(self, prefix: str = ""): - return [path for path in self.files if path.startswith(prefix)] - - async def close(self): - self.initialized = False - -# --- Fixture to patch AzureBlobStorage --- - - -@pytest.fixture(autouse=True) -def patch_azure_blob_storage(monkeypatch): - monkeypatch.setattr("common.storage.blob_factory.AzureBlobStorage", DummyAzureBlobStorage) - BlobStorageFactory._instance = None - -# -------------------- Tests for BlobStorageFactory -------------------- +import pytest @pytest.mark.asyncio -async def test_get_storage_success(): - """Test that get_storage returns an initialized DummyAzureBlobStorage instance and is a singleton.""" - storage = await BlobStorageFactory.get_storage() - assert isinstance(storage, DummyAzureBlobStorage) - assert storage.initialized is True - - # Call get_storage again; it should return the same instance. - storage2 = await BlobStorageFactory.get_storage() - assert storage is storage2 +async def test_get_storage_logs_on_init(): + """Test that logger logs on initialization""" + # Force reset the singleton before test + BlobStorageFactory._instance = None + mock_storage_instance = MagicMock() -@pytest.mark.asyncio -async def test_get_storage_missing_config(monkeypatch): - """ - Test that get_storage raises a ValueError when configuration is missing. + with patch("common.storage.blob_factory.AzureBlobStorage", return_value=mock_storage_instance), \ + patch("common.storage.blob_factory.Config") as mock_config, \ + patch.object(BlobStorageFactory, "_logger") as mock_logger: - We simulate missing connection string and container name. - """ - from common.config.config import Config + mock_config_instance = MagicMock() + mock_config_instance.azure_blob_account_name = "account" + mock_config_instance.azure_blob_container_name = "container" + mock_config.return_value = mock_config_instance - def dummy_init_missing(self): - self.azure_blob_connection_string = "" - self.azure_blob_container_name = "" - monkeypatch.setattr(Config, "__init__", dummy_init_missing) - with pytest.raises(ValueError, match="Azure Blob Storage configuration is missing"): await BlobStorageFactory.get_storage() - -@pytest.mark.asyncio -async def test_close_storage_success(): - """Test that close_storage calls close() on the storage instance and resets the singleton.""" - storage = await BlobStorageFactory.get_storage() - # Patch close() method with an async mock. - storage.close = AsyncMock() - await BlobStorageFactory.close_storage() - storage.close.assert_called_once() - assert BlobStorageFactory._instance is None - -# -------------------- File Upload Tests -------------------- - - -@pytest.mark.asyncio -async def test_upload_file_success(): - """Test that upload_file successfully uploads a file and returns metadata.""" - storage = DummyAzureBlobStorage("dummy", "container") - await storage.initialize() - file_content = b"Hello, Blob!" - blob_path = "folder/blob.txt" - content_type = "text/plain" - metadata = {"meta": "data"} - result = await storage.upload_file(file_content, blob_path, content_type, metadata) - assert "url" in result - assert result["size"] == len(file_content) - assert blob_path in storage.files - - -@pytest.mark.asyncio -async def test_upload_file_error(monkeypatch): - """Test that an exception during file upload is propagated.""" - storage = DummyAzureBlobStorage("dummy", "container") - await storage.initialize() - monkeypatch.setattr(storage, "upload_file", AsyncMock(side_effect=Exception("Upload failed"))) - with pytest.raises(Exception, match="Upload failed"): - await storage.upload_file(b"data", "file.txt", "text/plain", {}) - -# -------------------- File Retrieval Tests -------------------- + mock_logger.info.assert_called_once_with("Initialized Azure Blob Storage: container") @pytest.mark.asyncio -async def test_get_file_success(): - """Test that get_file retrieves the correct file content.""" - storage = DummyAzureBlobStorage("dummy", "container") - await storage.initialize() - blob_path = "folder/data.bin" - file_content = b"BinaryData" - storage.files[blob_path] = (file_content, "application/octet-stream", {}) - result = await storage.get_file(blob_path) - assert result == file_content - +async def test_close_storage_resets_instance(): + """Test that close_storage resets the singleton instance""" + # Setup instance first + mock_storage_instance = MagicMock() -@pytest.mark.asyncio -async def test_get_file_not_found(): - """Test that get_file raises FileNotFoundError when file does not exist.""" - storage = DummyAzureBlobStorage("dummy", "container") - await storage.initialize() - with pytest.raises(FileNotFoundError): - await storage.get_file("nonexistent.file") + with patch("common.storage.blob_factory.AzureBlobStorage", return_value=mock_storage_instance), \ + patch("common.storage.blob_factory.Config") as mock_config: -# -------------------- File Deletion Tests -------------------- + mock_config_instance = MagicMock() + mock_config_instance.azure_blob_account_name = "account" + mock_config_instance.azure_blob_container_name = "container" + mock_config.return_value = mock_config_instance + instance = await BlobStorageFactory.get_storage() + assert instance is not None -@pytest.mark.asyncio -async def test_delete_file_success(): - """Test that delete_file removes an existing file.""" - storage = DummyAzureBlobStorage("dummy", "container") - await storage.initialize() - blob_path = "folder/remove.txt" - storage.files[blob_path] = (b"To remove", "text/plain", {}) - await storage.delete_file(blob_path) - assert blob_path not in storage.files - - -@pytest.mark.asyncio -async def test_delete_file_nonexistent(): - """Test that deleting a non-existent file does not raise an error.""" - storage = DummyAzureBlobStorage("dummy", "container") - await storage.initialize() - # Should not raise any exception. - await storage.delete_file("nonexistent.file") - assert True + await BlobStorageFactory.close_storage() -# -------------------- File Listing Tests -------------------- + assert BlobStorageFactory._instance is None @pytest.mark.asyncio -async def test_list_files_with_prefix(): - """Test that list_files returns files that match the given prefix.""" - storage = DummyAzureBlobStorage("dummy", "container") - await storage.initialize() - storage.files = { - "folder/a.txt": (b"A", "text/plain", {}), - "folder/b.txt": (b"B", "text/plain", {}), - "other/c.txt": (b"C", "text/plain", {}), - } - result = await storage.list_files("folder/") - assert set(result) == {"folder/a.txt", "folder/b.txt"} - - -@pytest.mark.asyncio -async def test_list_files_no_files(): - """Test that list_files returns an empty list when no files match the prefix.""" - storage = DummyAzureBlobStorage("dummy", "container") - await storage.initialize() - storage.files = {} - result = await storage.list_files("prefix/") - assert result == [] - -# -------------------- Additional Basic Tests -------------------- - - -@pytest.mark.asyncio -async def test_dummy_azure_blob_storage_initialize(): - """Test that initializing DummyAzureBlobStorage sets the initialized flag.""" - storage = DummyAzureBlobStorage("dummy_conn", "dummy_container") - assert storage.initialized is False - await storage.initialize() - assert storage.initialized is True - - -@pytest.mark.asyncio -async def test_dummy_azure_blob_storage_upload_and_retrieve(): - """Test that a file uploaded to DummyAzureBlobStorage can be retrieved.""" - storage = DummyAzureBlobStorage("dummy_conn", "dummy_container") - await storage.initialize() - content = b"Sample file content" - blob_path = "folder/sample.txt" - metadata = {"author": "tester"} - result = await storage.upload_file(content, blob_path, "text/plain", metadata) - assert "url" in result - assert result["size"] == len(content) - retrieved = await storage.get_file(blob_path) - assert retrieved == content - +async def test_get_storage_after_close_reinitializes(): + """Test that get_storage reinitializes after close_storage is called""" + # Force reset before test + BlobStorageFactory._instance = None -@pytest.mark.asyncio -async def test_dummy_azure_blob_storage_close(): - """Test that close() sets initialized to False.""" - storage = DummyAzureBlobStorage("dummy_conn", "dummy_container") - await storage.initialize() - await storage.close() - assert storage.initialized is False + with patch("common.storage.blob_factory.AzureBlobStorage") as mock_storage, \ + patch("common.storage.blob_factory.Config") as mock_config: -# -------------------- Test for BlobStorageFactory Singleton Usage -------------------- + mock_storage.side_effect = [MagicMock(name="instance1"), MagicMock(name="instance2")] + mock_config_instance = MagicMock() + mock_config_instance.azure_blob_account_name = "account" + mock_config_instance.azure_blob_container_name = "container" + mock_config.return_value = mock_config_instance -def test_common_usage_of_blob_factory(): - """Test that manually setting the singleton in BlobStorageFactory works as expected.""" - # Create a dummy storage instance. - dummy_storage = DummyAzureBlobStorage("dummy", "container") - dummy_storage.initialized = True - BlobStorageFactory._instance = dummy_storage - storage = asyncio.run(BlobStorageFactory.get_storage()) - assert storage is dummy_storage + # First init + instance1 = await BlobStorageFactory.get_storage() + await BlobStorageFactory.close_storage() + # Re-init + instance2 = await BlobStorageFactory.get_storage() -if __name__ == "__main__": - # Run tests when this file is executed directly. - asyncio.run(pytest.main()) + assert instance1 is not instance2 + assert mock_storage.call_count == 2 diff --git a/src/tests/backend/sql_agents/agents/agent_config_test.py b/src/tests/backend/sql_agents/agents/agent_config_test.py new file mode 100644 index 00000000..8250a235 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/agent_config_test.py @@ -0,0 +1,42 @@ +import importlib +from unittest.mock import AsyncMock, patch + +import pytest + + +@pytest.fixture +def mock_project_client(): + return AsyncMock() + + +@patch.dict("os.environ", { + "MIGRATOR_AGENT_MODEL_DEPLOY": "migrator-model", + "PICKER_AGENT_MODEL_DEPLOY": "picker-model", + "FIXER_AGENT_MODEL_DEPLOY": "fixer-model", + "SEMANTIC_VERIFIER_AGENT_MODEL_DEPLOY": "semantic-verifier-model", + "SYNTAX_CHECKER_AGENT_MODEL_DEPLOY": "syntax-checker-model", + "SELECTION_MODEL_DEPLOY": "selection-model", + "TERMINATION_MODEL_DEPLOY": "termination-model", +}) +def test_agent_model_type_mapping_and_instance(mock_project_client): + # Re-import to re-evaluate class variable with patched env + from sql_agents.agents import agent_config + importlib.reload(agent_config) + + AgentType = agent_config.AgentType + AgentBaseConfig = agent_config.AgentBaseConfig + + # Test model_type mapping + assert AgentBaseConfig.model_type[AgentType.MIGRATOR] == "migrator-model" + assert AgentBaseConfig.model_type[AgentType.PICKER] == "picker-model" + assert AgentBaseConfig.model_type[AgentType.FIXER] == "fixer-model" + assert AgentBaseConfig.model_type[AgentType.SEMANTIC_VERIFIER] == "semantic-verifier-model" + assert AgentBaseConfig.model_type[AgentType.SYNTAX_CHECKER] == "syntax-checker-model" + assert AgentBaseConfig.model_type[AgentType.SELECTION] == "selection-model" + assert AgentBaseConfig.model_type[AgentType.TERMINATION] == "termination-model" + + # Test __init__ stores params correctly + config = AgentBaseConfig(mock_project_client, sql_from="sql1", sql_to="sql2") + assert config.ai_project_client == mock_project_client + assert config.sql_from == "sql1" + assert config.sql_to == "sql2" diff --git a/src/tests/conftest.py b/src/tests/conftest.py new file mode 100644 index 00000000..cad4e268 --- /dev/null +++ b/src/tests/conftest.py @@ -0,0 +1,12 @@ +import os +import sys + +# Determine the project root relative to this conftest.py file. +# This file is at: /src/tests/conftest.py +# We want to add: /src/backend to sys.path. +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..")) # Goes from tests to src +backend_path = os.path.join(project_root, "backend") +sys.path.insert(0, backend_path) + +print("Adjusted sys.path:", sys.path)