Skip to content

Commit 61517c0

Browse files
author
Harmanpreet Kaur
committed
added agent_config file
1 parent 9ed8584 commit 61517c0

1 file changed

Lines changed: 42 additions & 0 deletions

File tree

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import importlib
2+
from unittest.mock import AsyncMock, patch
3+
4+
import pytest
5+
6+
7+
@pytest.fixture
8+
def mock_project_client():
9+
return AsyncMock()
10+
11+
12+
@patch.dict("os.environ", {
13+
"MIGRATOR_AGENT_MODEL_DEPLOY": "migrator-model",
14+
"PICKER_AGENT_MODEL_DEPLOY": "picker-model",
15+
"FIXER_AGENT_MODEL_DEPLOY": "fixer-model",
16+
"SEMANTIC_VERIFIER_AGENT_MODEL_DEPLOY": "semantic-verifier-model",
17+
"SYNTAX_CHECKER_AGENT_MODEL_DEPLOY": "syntax-checker-model",
18+
"SELECTION_MODEL_DEPLOY": "selection-model",
19+
"TERMINATION_MODEL_DEPLOY": "termination-model",
20+
})
21+
def test_agent_model_type_mapping_and_instance(mock_project_client):
22+
# Re-import to re-evaluate class variable with patched env
23+
from sql_agents.agents import agent_config
24+
importlib.reload(agent_config)
25+
26+
AgentType = agent_config.AgentType
27+
AgentBaseConfig = agent_config.AgentBaseConfig
28+
29+
# Test model_type mapping
30+
assert AgentBaseConfig.model_type[AgentType.MIGRATOR] == "migrator-model"
31+
assert AgentBaseConfig.model_type[AgentType.PICKER] == "picker-model"
32+
assert AgentBaseConfig.model_type[AgentType.FIXER] == "fixer-model"
33+
assert AgentBaseConfig.model_type[AgentType.SEMANTIC_VERIFIER] == "semantic-verifier-model"
34+
assert AgentBaseConfig.model_type[AgentType.SYNTAX_CHECKER] == "syntax-checker-model"
35+
assert AgentBaseConfig.model_type[AgentType.SELECTION] == "selection-model"
36+
assert AgentBaseConfig.model_type[AgentType.TERMINATION] == "termination-model"
37+
38+
# Test __init__ stores params correctly
39+
config = AgentBaseConfig(mock_project_client, sql_from="sql1", sql_to="sql2")
40+
assert config.ai_project_client == mock_project_client
41+
assert config.sql_from == "sql1"
42+
assert config.sql_to == "sql2"

0 commit comments

Comments
 (0)