|
1 | | -from azure.identity import DefaultAzureCredential, get_bearer_token_provider |
2 | | -from openai import AzureOpenAI |
| 1 | +from azure.identity import DefaultAzureCredential |
| 2 | +from azure.ai.inference import EmbeddingsClient |
| 3 | +from urllib.parse import urlparse |
3 | 4 | import re |
4 | 5 | import time |
5 | 6 | from pypdf import PdfReader |
|
10 | 11 | import requests |
11 | 12 |
|
12 | 13 | search_endpoint = os.getenv("SEARCH_ENDPOINT") |
13 | | -openai_endpoint = os.getenv("OPEN_AI_ENDPOINT_URL") |
| 14 | +ai_project_endpoint = os.getenv("AZURE_AI_AGENT_ENDPOINT") # AI Foundry Project endpoint |
14 | 15 | embedding_model_name = os.getenv("EMBEDDING_MODEL_NAME") |
15 | 16 | embedding_model_api_version = os.getenv("EMBEDDING_MODEL_API_VERSION") |
16 | 17 | use_local_files = (os.getenv("USE_LOCAL_FILES") == "true") |
17 | 18 | index_name = "ai_app_index" |
18 | 19 |
|
19 | 20 | print(f"Creating search index at {search_endpoint} with index name {index_name}") |
20 | | -print(f"Using OpenAI endpoint: {openai_endpoint}") |
| 21 | +print(f"Using AI Foundry Project endpoint: {ai_project_endpoint}") |
21 | 22 | print(f"Using embedding model: {embedding_model_name} with API version: {embedding_model_api_version}") |
22 | 23 |
|
23 | | -# Function: Get Embeddings |
24 | | -def get_embeddings(text: str, openai_endpoint: str, embedding_model_api_version: str): |
| 24 | +# Function: Get Embeddings using Azure AI Inference SDK with Foundry endpoint |
| 25 | +def get_embeddings(text: str, ai_project_endpoint: str, embedding_model_api_version: str): |
25 | 26 | credential = DefaultAzureCredential() |
26 | | - token_provider = get_bearer_token_provider(credential, |
27 | | - "https://cognitiveservices.azure.com/.default") |
28 | | - client = AzureOpenAI( |
29 | | - api_version=embedding_model_api_version, |
30 | | - azure_endpoint=openai_endpoint, |
31 | | - azure_ad_token_provider=token_provider |
| 27 | + |
| 28 | + # Construct inference endpoint with /models path for Azure AI Foundry |
| 29 | + inference_endpoint = f"https://{urlparse(ai_project_endpoint).netloc}/models" |
| 30 | + |
| 31 | + # Create embeddings client using Azure AI Inference SDK |
| 32 | + embeddings_client = EmbeddingsClient( |
| 33 | + endpoint=inference_endpoint, |
| 34 | + credential=credential, |
| 35 | + credential_scopes=["https://cognitiveservices.azure.com/.default"] |
| 36 | + ) |
| 37 | + |
| 38 | + # Create embeddings using the model name from environment |
| 39 | + response = embeddings_client.embed( |
| 40 | + model=embedding_model_name, |
| 41 | + input=[text] |
32 | 42 | ) |
33 | 43 |
|
34 | | - embedding = client.embeddings.create(input=text, model=embedding_model_name).data[0].embedding |
| 44 | + embedding = response.data[0].embedding |
35 | 45 | return embedding |
36 | 46 |
|
37 | 47 | # Function: Clean Spaces with Regex - |
@@ -92,12 +102,12 @@ def prepare_search_doc(content, document_id, filename): |
92 | 102 | chunk_id = document_id + '_' + str(chunk_num).zfill(2) |
93 | 103 |
|
94 | 104 | try: |
95 | | - v_contentVector = get_embeddings(str(chunk), openai_endpoint, "2023-05-15") |
| 105 | + v_contentVector = get_embeddings(str(chunk), ai_project_endpoint, embedding_model_api_version) |
96 | 106 | except Exception as e: |
97 | 107 | print(f"Error occurred: {e}. Retrying after 30 seconds...") |
98 | 108 | time.sleep(30) |
99 | 109 | try: |
100 | | - v_contentVector = get_embeddings(str(chunk), openai_endpoint, "1") |
| 110 | + v_contentVector = get_embeddings(str(chunk), ai_project_endpoint, embedding_model_api_version) |
101 | 111 | except Exception as e: |
102 | 112 | print(f"Retry failed: {e}. Setting v_contentVector to an empty list.") |
103 | 113 | v_contentVector = [] |
|
0 commit comments