Skip to content

Commit c9254ef

Browse files
Enhance Azure credential management in AppConfig
- Updated get_azure_credential and get_azure_credential_async methods to use exclude_environment_credential=True for dev environment. - Refactored MCPEnabledBase to acquire credentials using centralized config method. - Added unit tests for async credential retrieval in both dev and production environments.
1 parent eca6d7b commit c9254ef

4 files changed

Lines changed: 90 additions & 10 deletions

File tree

src/backend/common/config/app_config.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from azure.ai.projects.aio import AIProjectClient
77
from azure.cosmos import CosmosClient
88
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
9+
from azure.identity.aio import (
10+
DefaultAzureCredential as DefaultAzureCredentialAsync,
11+
ManagedIdentityCredential as ManagedIdentityCredentialAsync,
12+
)
913
from dotenv import load_dotenv
1014

1115

@@ -113,7 +117,8 @@ def get_azure_credential(self, client_id=None):
113117
"""
114118
Returns an Azure credential based on the application environment.
115119
116-
If the environment is 'dev', it uses DefaultAzureCredential.
120+
If the environment is 'dev', it uses DefaultAzureCredential with exclude_environment_credential=True
121+
to avoid EnvironmentCredential exceptions in Application Insights traces.
117122
Otherwise, it uses ManagedIdentityCredential.
118123
119124
Args:
@@ -123,10 +128,29 @@ def get_azure_credential(self, client_id=None):
123128
Credential object: Either DefaultAzureCredential or ManagedIdentityCredential.
124129
"""
125130
if self.APP_ENV == "dev":
126-
return DefaultAzureCredential() # CodeQL [SM05139]: DefaultAzureCredential is safe here
131+
return DefaultAzureCredential(exclude_environment_credential=True) # CodeQL [SM05139]: DefaultAzureCredential is safe here
127132
else:
128133
return ManagedIdentityCredential(client_id=client_id)
129134

135+
def get_azure_credential_async(self, client_id=None):
136+
"""
137+
Returns an async Azure credential based on the application environment.
138+
139+
If the environment is 'dev', it uses DefaultAzureCredential (async) with exclude_environment_credential=True
140+
to avoid EnvironmentCredential exceptions in Application Insights traces.
141+
Otherwise, it uses ManagedIdentityCredential (async).
142+
143+
Args:
144+
client_id (str, optional): The client ID for the Managed Identity Credential.
145+
146+
Returns:
147+
Async Credential object: Either DefaultAzureCredentialAsync or ManagedIdentityCredentialAsync.
148+
"""
149+
if self.APP_ENV == "dev":
150+
return DefaultAzureCredentialAsync(exclude_environment_credential=True)
151+
else:
152+
return ManagedIdentityCredentialAsync(client_id=client_id)
153+
130154
def get_azure_credentials(self):
131155
"""Retrieve Azure credentials, either from environment variables or managed identity."""
132156
if self._azure_credentials is None:

src/backend/v4/magentic_agents/common/lifecycle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# from agent_framework.azure import AzureAIClient
1414
from agent_framework_azure_ai import AzureAIClient
1515
from azure.ai.agents.aio import AgentsClient
16-
from azure.identity.aio import DefaultAzureCredential
16+
from common.config.app_config import config
1717
from common.database.database_base import DatabaseBase
1818
from common.models.messages_af import TeamConfiguration
1919
from common.utils.utils_agents import (
@@ -52,7 +52,7 @@ def __init__(
5252
self.team_config: TeamConfiguration | None = team_config
5353
self.client: Optional[AgentsClient] = None
5454
self.project_endpoint = project_endpoint
55-
self.creds: Optional[DefaultAzureCredential] = None
55+
self.creds = None
5656
self.memory_store: Optional[DatabaseBase] = memory_store
5757
self.agent_name: str | None = agent_name
5858
self.agent_description: str | None = agent_description
@@ -66,8 +66,8 @@ async def open(self) -> "MCPEnabledBase":
6666
return self
6767
self._stack = AsyncExitStack()
6868

69-
# Acquire credential
70-
self.creds = DefaultAzureCredential()
69+
# Acquire credential using centralized config method
70+
self.creds = config.get_azure_credential_async(config.AZURE_CLIENT_ID)
7171
if self._stack:
7272
await self._stack.enter_async_context(self.creds)
7373
# Create AgentsClient

src/tests/backend/common/config/test_app_config.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,15 +251,16 @@ def _get_minimal_env(self):
251251

252252
@patch('backend.common.config.app_config.DefaultAzureCredential')
253253
def test_get_azure_credential_dev_environment(self, mock_default_credential):
254-
"""Test get_azure_credential method in dev environment."""
254+
"""Test get_azure_credential method in dev environment with exclude_environment_credential."""
255255
mock_credential = MagicMock()
256256
mock_default_credential.return_value = mock_credential
257257

258258
with patch.dict(os.environ, self._get_minimal_env()):
259259
config = AppConfig()
260260
result = config.get_azure_credential()
261261

262-
mock_default_credential.assert_called_once()
262+
# Verify it's called with exclude_environment_credential=True in dev
263+
mock_default_credential.assert_called_once_with(exclude_environment_credential=True)
263264
assert result == mock_credential
264265

265266
@patch('backend.common.config.app_config.ManagedIdentityCredential')
@@ -333,6 +334,55 @@ def test_get_access_token_failure(self, mock_default_credential):
333334
with pytest.raises(Exception, match="Token retrieval failed"):
334335
credential.get_token(config.AZURE_COGNITIVE_SERVICES)
335336

337+
@patch('backend.common.config.app_config.DefaultAzureCredentialAsync')
338+
def test_get_azure_credential_async_dev_environment(self, mock_default_credential_async):
339+
"""Test get_azure_credential_async method in dev environment with exclude_environment_credential."""
340+
mock_credential = MagicMock()
341+
mock_default_credential_async.return_value = mock_credential
342+
343+
with patch.dict(os.environ, self._get_minimal_env()):
344+
config = AppConfig()
345+
result = config.get_azure_credential_async()
346+
347+
# Verify it's called with exclude_environment_credential=True in dev
348+
mock_default_credential_async.assert_called_once_with(exclude_environment_credential=True)
349+
assert result == mock_credential
350+
351+
@patch('backend.common.config.app_config.ManagedIdentityCredentialAsync')
352+
def test_get_azure_credential_async_prod_environment(self, mock_managed_credential_async):
353+
"""Test get_azure_credential_async method in production environment."""
354+
mock_credential = MagicMock()
355+
mock_managed_credential_async.return_value = mock_credential
356+
357+
env = self._get_minimal_env()
358+
env["APP_ENV"] = "prod"
359+
env["AZURE_CLIENT_ID"] = "test-client-id"
360+
361+
with patch.dict(os.environ, env):
362+
config = AppConfig()
363+
result = config.get_azure_credential_async("test-client-id")
364+
365+
mock_managed_credential_async.assert_called_once_with(client_id="test-client-id")
366+
assert result == mock_credential
367+
368+
@patch('backend.common.config.app_config.ManagedIdentityCredentialAsync')
369+
def test_get_azure_credential_async_prod_uppercase(self, mock_managed_credential_async):
370+
"""Test get_azure_credential_async handles uppercase Prod environment value."""
371+
mock_credential = MagicMock()
372+
mock_managed_credential_async.return_value = mock_credential
373+
374+
env = self._get_minimal_env()
375+
env["APP_ENV"] = "Prod" # Bicep sets it as "Prod" with capital P
376+
env["AZURE_CLIENT_ID"] = "test-client-id"
377+
378+
with patch.dict(os.environ, env):
379+
config = AppConfig()
380+
result = config.get_azure_credential_async("test-client-id")
381+
382+
# Should use ManagedIdentityCredential even with capital "Prod"
383+
mock_managed_credential_async.assert_called_once_with(client_id="test-client-id")
384+
assert result == mock_credential
385+
336386

337387
class TestAppConfigClientMethods:
338388
"""Test cases for client creation methods in AppConfig class."""

src/tests/backend/v4/magentic_agents/common/test_lifecycle.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ async def test_open_method_success(self):
171171
mock_mcp_tool = AsyncMock()
172172

173173
with patch('backend.v4.magentic_agents.common.lifecycle.AsyncExitStack', return_value=mock_stack):
174-
with patch('backend.v4.magentic_agents.common.lifecycle.DefaultAzureCredential', return_value=mock_creds):
174+
with patch('backend.v4.magentic_agents.common.lifecycle.config') as mock_config:
175+
mock_config.get_azure_credential_async.return_value = mock_creds
176+
mock_config.AZURE_CLIENT_ID = "test-client-id"
175177
with patch('backend.v4.magentic_agents.common.lifecycle.AgentsClient', return_value=mock_client):
176178
with patch('backend.v4.magentic_agents.common.lifecycle.MCPStreamableHTTPTool', return_value=mock_mcp_tool):
177179
with patch.object(base, '_after_open', new_callable=AsyncMock) as mock_after_open:
@@ -182,6 +184,7 @@ async def test_open_method_success(self):
182184
assert base._stack is mock_stack
183185
assert base.creds is mock_creds
184186
assert base.client is mock_client
187+
mock_config.get_azure_credential_async.assert_called_once_with("test-client-id")
185188
mock_after_open.assert_called_once()
186189
mock_agent_registry.register_agent.assert_called_once_with(base)
187190

@@ -207,7 +210,9 @@ async def test_open_method_registration_failure(self):
207210
mock_client = AsyncMock()
208211

209212
with patch('backend.v4.magentic_agents.common.lifecycle.AsyncExitStack', return_value=mock_stack):
210-
with patch('backend.v4.magentic_agents.common.lifecycle.DefaultAzureCredential', return_value=mock_creds):
213+
with patch('backend.v4.magentic_agents.common.lifecycle.config') as mock_config:
214+
mock_config.get_azure_credential_async.return_value = mock_creds
215+
mock_config.AZURE_CLIENT_ID = "test-client-id"
211216
with patch('backend.v4.magentic_agents.common.lifecycle.AgentsClient', return_value=mock_client):
212217
with patch.object(base, '_after_open', new_callable=AsyncMock):
213218
mock_agent_registry.register_agent.side_effect = Exception("Registration failed")
@@ -216,6 +221,7 @@ async def test_open_method_registration_failure(self):
216221
result = await base.open()
217222

218223
assert result is base
224+
mock_config.get_azure_credential_async.assert_called_once_with("test-client-id")
219225
mock_agent_registry.register_agent.assert_called_once_with(base)
220226

221227
@pytest.mark.asyncio

0 commit comments

Comments
 (0)