diff --git a/code/create_app.py b/code/create_app.py index 512d29bf5..a0d6c691e 100644 --- a/code/create_app.py +++ b/code/create_app.py @@ -742,15 +742,32 @@ def speech_config(): """Get the speech config for Azure Speech.""" try: logger.info("Method speech_config started") - speech_key = env_helper.AZURE_SPEECH_KEY or get_speech_key(env_helper) - response = requests.post( - f"{env_helper.AZURE_SPEECH_REGION_ENDPOINT}sts/v1.0/issueToken", - headers={ - "Ocp-Apim-Subscription-Key": speech_key, - }, - timeout=5, - ) + if env_helper.AZURE_AUTH_TYPE == "rbac": + credential = get_azure_credential( + env_helper.MANAGED_IDENTITY_CLIENT_ID + ) + token = credential.get_token( + "https://cognitiveservices.azure.com/.default" + ) + response = requests.post( + f"{env_helper.AZURE_SPEECH_REGION_ENDPOINT}sts/v1.0/issueToken", + headers={ + "Authorization": f"Bearer {token.token}", + }, + timeout=5, + ) + else: + speech_key = env_helper.AZURE_SPEECH_KEY or get_speech_key( + env_helper + ) + response = requests.post( + f"{env_helper.AZURE_SPEECH_REGION_ENDPOINT}sts/v1.0/issueToken", + headers={ + "Ocp-Apim-Subscription-Key": speech_key, + }, + timeout=5, + ) if response.status_code == 200: return { diff --git a/code/tests/test_create_app.py b/code/tests/test_create_app.py index 397dae4b5..98a037fdd 100644 --- a/code/tests/test_create_app.py +++ b/code/tests/test_create_app.py @@ -92,6 +92,7 @@ def env_helper_mock(): AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG ) env_helper.SHOULD_STREAM = True + env_helper.AZURE_AUTH_TYPE = "keys" env_helper.is_auth_type_keys.return_value = True env_helper.CONVERSATION_FLOW = ConversationFlow.CUSTOM.value @@ -128,25 +129,24 @@ def test_returns_speech_token_using_keys( timeout=5, ) - @patch("create_app.CognitiveServicesManagementClient") + @patch("create_app.get_azure_credential") @patch("create_app.requests") def test_returns_speech_token_using_rbac( self, requests: MagicMock, - CognitiveServicesManagementClientMock: MagicMock, + get_azure_credential_mock: MagicMock, env_helper_mock: MagicMock, client: FlaskClient, ): """Test that the speech token is returned correctly when using RBAC.""" # given + env_helper_mock.AZURE_AUTH_TYPE = "rbac" env_helper_mock.AZURE_SPEECH_KEY = None + env_helper_mock.MANAGED_IDENTITY_CLIENT_ID = "mock-client-id" - mock_cognitive_services_client_mock = ( - CognitiveServicesManagementClientMock.return_value - ) - mock_cognitive_services_client_mock.accounts.list_keys.return_value = MagicMock( - key1="mock-key1", key2="mock-key2" - ) + mock_credential = MagicMock() + mock_credential.get_token.return_value = MagicMock(token="mock-aad-token") + get_azure_credential_mock.return_value = mock_credential mock_response: MagicMock = requests.post.return_value mock_response.text = "speech-token" @@ -163,10 +163,14 @@ def test_returns_speech_token_using_rbac( "languages": AZURE_SPEECH_RECOGNIZER_LANGUAGES, } + get_azure_credential_mock.assert_called_once_with("mock-client-id") + mock_credential.get_token.assert_called_once_with( + "https://cognitiveservices.azure.com/.default" + ) requests.post.assert_called_once_with( f"{AZURE_SPEECH_REGION_ENDPOINT}sts/v1.0/issueToken", headers={ - "Ocp-Apim-Subscription-Key": "mock-key1", + "Authorization": "Bearer mock-aad-token", }, timeout=5, )