Skip to content

Commit 1c53c31

Browse files
Unit test cases
1 parent add28a4 commit 1c53c31

5 files changed

Lines changed: 375 additions & 0 deletions

File tree

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
from sql_agents.agents.semantic_verifier.agent import SemanticVerifierAgent
6+
from sql_agents.agents.semantic_verifier.response import SemanticVerifierResponse
7+
from sql_agents.helpers.models import AgentType
8+
9+
10+
@pytest.fixture
11+
def mock_config():
12+
"""Fixture to create a mock configuration."""
13+
mock_config = MagicMock()
14+
mock_config.model_type = {
15+
AgentType.SEMANTIC_VERIFIER: "semantic_verifier_model"
16+
}
17+
return mock_config
18+
19+
20+
@pytest.fixture
21+
def semantic_verifier_agent(mock_config):
22+
"""Fixture to create a SemanticVerifierAgent instance."""
23+
agent = SemanticVerifierAgent(
24+
agent_type=AgentType.SEMANTIC_VERIFIER,
25+
config=mock_config
26+
)
27+
return agent
28+
29+
30+
def test_response_object(semantic_verifier_agent):
31+
"""Test that the response_object property returns SemanticVerifierResponse."""
32+
assert semantic_verifier_agent.response_object == SemanticVerifierResponse
33+
34+
35+
def test_deployment_name(semantic_verifier_agent):
36+
"""Test that the deployment_name property returns the correct model name."""
37+
assert semantic_verifier_agent.deployment_name == "semantic_verifier_model"
38+
39+
40+
def test_missing_deployment_name(mock_config):
41+
"""Test that accessing deployment_name raises a KeyError if the model type is missing."""
42+
mock_config.model_type = {}
43+
agent = SemanticVerifierAgent(
44+
agent_type=AgentType.SEMANTIC_VERIFIER,
45+
config=mock_config
46+
)
47+
with pytest.raises(KeyError):
48+
_ = agent.deployment_name
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
3+
from sql_agents.agents.semantic_verifier.response import SemanticVerifierResponse
4+
5+
6+
def test_semantic_verifier_response_initialization():
7+
"""Test initializing SemanticVerifierResponse with valid data."""
8+
response = SemanticVerifierResponse(
9+
judgement="valid",
10+
differences=["difference1", "difference2"],
11+
summary="This is a summary."
12+
)
13+
assert response.judgement == "valid"
14+
assert response.differences == ["difference1", "difference2"]
15+
assert response.summary == "This is a summary."
16+
17+
18+
def test_semantic_verifier_response_empty_fields():
19+
"""Test initializing SemanticVerifierResponse with empty fields."""
20+
response = SemanticVerifierResponse(
21+
judgement="",
22+
differences=[],
23+
summary=""
24+
)
25+
assert response.judgement == ""
26+
assert response.differences == []
27+
assert response.summary == ""
28+
29+
30+
def test_semantic_verifier_response_invalid_data():
31+
"""Test initializing SemanticVerifierResponse with invalid data."""
32+
with pytest.raises(ValueError):
33+
SemanticVerifierResponse(
34+
judgement=123, # Invalid type
35+
differences="not a list", # Invalid type
36+
summary=None # Invalid type
37+
)
38+
39+
40+
def test_semantic_verifier_response_large_differences():
41+
"""Test initializing SemanticVerifierResponse with a large number of differences."""
42+
differences = [f"difference{i}" for i in range(1000)] # Large list of differences
43+
response = SemanticVerifierResponse(
44+
judgement="valid",
45+
differences=differences,
46+
summary="This is a summary."
47+
)
48+
assert len(response.differences) == 1000
49+
assert response.judgement == "valid"
50+
assert response.summary == "This is a summary."
51+
52+
53+
def test_semantic_verifier_response_special_characters():
54+
"""Test initializing SemanticVerifierResponse with special characters."""
55+
response = SemanticVerifierResponse(
56+
judgement="valid!@#$%^&*()",
57+
differences=["difference1", "difference2"],
58+
summary="This is a summary with special characters!@#$%^&*()"
59+
)
60+
assert response.judgement == "valid!@#$%^&*()"
61+
assert response.differences == ["difference1", "difference2"]
62+
assert response.summary == "This is a summary with special characters!@#$%^&*()"
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
from sql_agents.agents.syntax_checker.agent import SyntaxCheckerAgent
6+
from sql_agents.agents.syntax_checker.plug_ins import SyntaxCheckerPlugin
7+
from sql_agents.agents.syntax_checker.response import SyntaxCheckerResponse
8+
from sql_agents.helpers.models import AgentType
9+
10+
11+
@pytest.fixture
12+
def mock_config():
13+
"""Fixture to create a mock configuration."""
14+
mock_config = MagicMock()
15+
mock_config.model_type = {
16+
AgentType.SYNTAX_CHECKER: "syntax_checker_model"
17+
}
18+
return mock_config
19+
20+
21+
@pytest.fixture
22+
def syntax_checker_agent(mock_config):
23+
"""Fixture to create a SyntaxCheckerAgent instance."""
24+
agent = SyntaxCheckerAgent(
25+
agent_type=AgentType.SYNTAX_CHECKER,
26+
config=mock_config
27+
)
28+
return agent
29+
30+
31+
def test_response_object(syntax_checker_agent):
32+
"""Test that the response_object property returns SyntaxCheckerResponse."""
33+
assert syntax_checker_agent.response_object == SyntaxCheckerResponse
34+
35+
36+
def test_plugins(syntax_checker_agent):
37+
"""Test that the plugins property returns the correct plugins."""
38+
plugins = syntax_checker_agent.plugins
39+
assert isinstance(plugins, list)
40+
assert plugins[0] == "check_syntax"
41+
assert isinstance(plugins[1], SyntaxCheckerPlugin)
42+
43+
44+
def test_deployment_name(syntax_checker_agent):
45+
"""Test that the deployment_name property returns the correct model name."""
46+
assert syntax_checker_agent.deployment_name == "syntax_checker_model"
47+
48+
49+
def test_missing_deployment_name(mock_config):
50+
"""Test that accessing deployment_name raises a KeyError if the model type is missing."""
51+
mock_config.model_type = {} # Simulate missing AgentType in model_type
52+
agent = SyntaxCheckerAgent(
53+
agent_type=AgentType.SYNTAX_CHECKER,
54+
config=mock_config
55+
)
56+
with pytest.raises(KeyError):
57+
_ = agent.deployment_name
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pytest
4+
5+
from sql_agents.agents.syntax_checker.plug_ins import SyntaxCheckerPlugin
6+
7+
8+
@pytest.fixture
9+
def syntax_checker_plugin():
10+
"""Fixture to create a SyntaxCheckerPlugin instance."""
11+
return SyntaxCheckerPlugin()
12+
13+
14+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
15+
def test_check_syntax_windows_path(mock_subprocess_run, syntax_checker_plugin):
16+
"""Test the _call_tsqlparser method on Windows."""
17+
with patch("platform.system", return_value="Windows"):
18+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
19+
candidate_sql = "SELECT * FROM table"
20+
result = syntax_checker_plugin.check_syntax(candidate_sql)
21+
assert result == "[]"
22+
mock_subprocess_run.assert_called_once_with(
23+
[r".\sql_agents\tools\win-x64\tsqlParser.exe", "--string", candidate_sql],
24+
capture_output=True,
25+
text=True,
26+
check=True,
27+
)
28+
29+
30+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
31+
def test_check_syntax_linux_path(mock_subprocess_run, syntax_checker_plugin):
32+
"""Test the _call_tsqlparser method on Linux."""
33+
with patch("platform.system", return_value="Linux"):
34+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
35+
candidate_sql = "SELECT * FROM table"
36+
result = syntax_checker_plugin.check_syntax(candidate_sql)
37+
assert result == "[]"
38+
mock_subprocess_run.assert_called_once_with(
39+
["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql],
40+
capture_output=True,
41+
text=True,
42+
check=True,
43+
)
44+
45+
46+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
47+
def test_check_syntax_other_os(mock_subprocess_run, syntax_checker_plugin):
48+
"""Test the _call_tsqlparser method on other OS."""
49+
with patch("platform.system", return_value="Other"):
50+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
51+
candidate_sql = "SELECT * FROM table"
52+
result = syntax_checker_plugin.check_syntax(candidate_sql)
53+
assert result == "[]"
54+
mock_subprocess_run.assert_called_once_with(
55+
["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql],
56+
capture_output=True,
57+
text=True,
58+
check=True,
59+
)
60+
61+
62+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
63+
def test_check_syntax_empty_string(mock_subprocess_run, syntax_checker_plugin):
64+
"""Test the _call_tsqlparser method with an empty string."""
65+
with patch("platform.system", return_value="Windows"):
66+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
67+
candidate_sql = ""
68+
result = syntax_checker_plugin.check_syntax(candidate_sql)
69+
assert result == "[]"
70+
mock_subprocess_run.assert_called_once_with(
71+
[r".\sql_agents\tools\win-x64\tsqlParser.exe", "--string", candidate_sql],
72+
capture_output=True,
73+
text=True,
74+
check=True,
75+
)
76+
77+
78+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
79+
def test_check_syntax_empty_string_linux(mock_subprocess_run, syntax_checker_plugin):
80+
"""Test the _call_tsqlparser method with an empty string."""
81+
with patch("platform.system", return_value="Linux"):
82+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
83+
candidate_sql = ""
84+
result = syntax_checker_plugin.check_syntax(candidate_sql)
85+
assert result == "[]"
86+
mock_subprocess_run.assert_called_once_with(
87+
["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql],
88+
capture_output=True,
89+
text=True,
90+
check=True,
91+
)
92+
93+
94+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
95+
def test_check_syntax_empty_string_other_os(mock_subprocess_run, syntax_checker_plugin):
96+
"""Test the _call_tsqlparser method with an empty string."""
97+
with patch("platform.system", return_value="Other"):
98+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
99+
candidate_sql = ""
100+
result = syntax_checker_plugin.check_syntax(candidate_sql)
101+
assert result == "[]"
102+
mock_subprocess_run.assert_called_once_with(
103+
["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql],
104+
capture_output=True,
105+
text=True,
106+
check=True,
107+
)
108+
109+
110+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
111+
def test_check_syntax_invalid_sql(mock_subprocess_run, syntax_checker_plugin):
112+
"""Test the _call_tsqlparser method with invalid SQL."""
113+
with patch("platform.system", return_value="Windows"):
114+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
115+
candidate_sql = "INVALID SQL"
116+
result = syntax_checker_plugin.check_syntax(candidate_sql)
117+
assert result == "[]"
118+
mock_subprocess_run.assert_called_once_with(
119+
[r".\sql_agents\tools\win-x64\tsqlParser.exe", "--string", candidate_sql],
120+
capture_output=True,
121+
text=True,
122+
check=True,
123+
)
124+
125+
126+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
127+
def test_check_syntax_invalid_sql_linux(mock_subprocess_run, syntax_checker_plugin):
128+
"""Test the _call_tsqlparser method with invalid SQL."""
129+
with patch("platform.system", return_value="Linux"):
130+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
131+
candidate_sql = "INVALID SQL"
132+
result = syntax_checker_plugin.check_syntax(candidate_sql)
133+
assert result == "[]"
134+
mock_subprocess_run.assert_called_once_with(
135+
["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql],
136+
capture_output=True,
137+
text=True,
138+
check=True,
139+
)
140+
141+
142+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
143+
def test_check_syntax_invalid_sql_other_os(mock_subprocess_run, syntax_checker_plugin):
144+
"""Test the _call_tsqlparser method with invalid SQL."""
145+
with patch("platform.system", return_value="Other"):
146+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
147+
candidate_sql = "INVALID SQL"
148+
result = syntax_checker_plugin.check_syntax(candidate_sql)
149+
assert result == "[]"
150+
mock_subprocess_run.assert_called_once_with(
151+
["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql],
152+
capture_output=True,
153+
text=True,
154+
check=True,
155+
)
156+
157+
158+
@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run")
159+
def test_check_syntax_valid_sql(mock_subprocess_run, syntax_checker_plugin):
160+
"""Test the _call_tsqlparser method with valid SQL."""
161+
with patch("platform.system", return_value="Windows"):
162+
mock_subprocess_run.return_value = MagicMock(stdout="[]")
163+
candidate_sql = "SELECT * FROM table"
164+
result = syntax_checker_plugin.check_syntax(candidate_sql)
165+
assert result == "[]"
166+
mock_subprocess_run.assert_called_once_with(
167+
[r".\sql_agents\tools\win-x64\tsqlParser.exe", "--string", candidate_sql],
168+
capture_output=True,
169+
text=True,
170+
check=True,
171+
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from sql_agents.agents.syntax_checker.response import SyntaxCheckerResponse, SyntaxErrorInt
2+
3+
4+
def test_syntax_error_int_initialization():
5+
"""Test initializing SyntaxErrorInt with valid data."""
6+
syntax_error = SyntaxErrorInt(line=1, column=5, error="Syntax error")
7+
assert syntax_error.line == 1
8+
assert syntax_error.column == 5
9+
assert syntax_error.error == "Syntax error"
10+
11+
12+
def test_syntax_checker_response_initialization():
13+
"""Test initializing SyntaxCheckerResponse with valid data."""
14+
syntax_error = SyntaxErrorInt(line=1, column=5, error="Syntax error")
15+
response = SyntaxCheckerResponse(
16+
thought="Analyzing SQL query",
17+
syntax_errors=[syntax_error],
18+
summary="1 syntax error found"
19+
)
20+
assert response.thought == "Analyzing SQL query"
21+
assert len(response.syntax_errors) == 1
22+
assert response.syntax_errors[0].line == 1
23+
assert response.syntax_errors[0].column == 5
24+
assert response.syntax_errors[0].error == "Syntax error"
25+
assert response.summary == "1 syntax error found"
26+
27+
28+
def test_syntax_checker_response_empty_fields():
29+
"""Test initializing SyntaxCheckerResponse with empty fields."""
30+
response = SyntaxCheckerResponse(
31+
thought="",
32+
syntax_errors=[],
33+
summary=""
34+
)
35+
assert response.thought == ""
36+
assert response.syntax_errors == []
37+
assert response.summary == ""

0 commit comments

Comments
 (0)