-
Notifications
You must be signed in to change notification settings - Fork 627
Expand file tree
/
Copy pathsemantic_kernel.py
More file actions
157 lines (131 loc) · 6.24 KB
/
semantic_kernel.py
File metadata and controls
157 lines (131 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import json
import logging
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.contents import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.utils.finish_reason import FinishReason
from ..common.answer import Answer
from ..helpers.llm_helper import LLMHelper
from ..helpers.env_helper import EnvHelper
from ..helpers.prompt_utils import get_current_date_suffix
from ..plugins.chat_plugin import ChatPlugin
from ..plugins.post_answering_plugin import PostAnsweringPlugin
from .orchestrator_base import OrchestratorBase
logger = logging.getLogger(__name__)
class SemanticKernelOrchestrator(OrchestratorBase):
def __init__(self) -> None:
super().__init__()
self.kernel = Kernel()
self.llm_helper = LLMHelper()
self.env_helper = EnvHelper()
# Add the Azure OpenAI service to the kernel
self.chat_service = self.llm_helper.get_sk_chat_completion_service("cwyd")
self.kernel.add_service(self.chat_service)
self.kernel.add_plugin(
plugin=PostAnsweringPlugin(), plugin_name="PostAnswering"
)
async def orchestrate(
self, user_message: str, chat_history: list[dict], **kwargs: dict
) -> list[dict]:
logger.info("Method orchestrate of semantic_kernel started")
# Call Content Safety tool
if self.config.prompts.enable_content_safety:
if response := self.call_content_safety_input(user_message):
return response
system_message = self.env_helper.SEMANTIC_KERNEL_SYSTEM_PROMPT
if not system_message:
system_message = """You help employees to navigate only private information sources.
You **must always** call the search_documents function first for any user question before deciding if the information is available or not. Never decide a question is out of domain without searching first.
You should prioritize the function call over your general knowledge for any question by calling the search_documents function.
Call the text_processing function when the user requests an operation on the current context, such as translate, summarize, or paraphrase. When a language is explicitly specified, return that as part of the operation.
When directly replying to the user, always reply in the language the user is speaking.
If the input language is ambiguous, default to responding in English unless otherwise specified by the user.
Do not list all documents in your repository if asked.
"""
# Append current date so the LLM is aware of today's date
system_message += get_current_date_suffix()
self.kernel.add_plugin(
plugin=ChatPlugin(question=user_message, chat_history=chat_history),
plugin_name="Chat",
)
settings = self.llm_helper.get_sk_service_settings(self.chat_service)
settings.function_choice_behavior = FunctionChoiceBehavior.Auto(
auto_invoke=False,
filters={"included_plugins": ["Chat"]}
)
orchestrate_function = self.kernel.add_function(
plugin_name="Main",
function_name="orchestrate",
prompt="{{$chat_history}}{{$user_message}}",
prompt_execution_settings=settings,
)
history = ChatHistory(system_message=system_message)
for message in chat_history.copy():
history.add_message(message)
chat_history_str = ""
for message in history.messages:
chat_history_str += f"{message.role}: {message.content}\n"
result: ChatMessageContent = (
await self.kernel.invoke(
function=orchestrate_function,
chat_history=chat_history_str,
user_message=user_message,
)
).value[0]
self.log_tokens(
prompt_tokens=result.metadata["usage"].prompt_tokens,
completion_tokens=result.metadata["usage"].completion_tokens,
)
if result.finish_reason == FinishReason.TOOL_CALLS:
logger.info("Semantic Kernel function call detected")
function_name = result.items[0].name
logger.info(f"{function_name} function detected")
function = self.kernel.get_function_from_fully_qualified_function_name(
function_name
)
arguments = json.loads(result.items[0].arguments)
answer: Answer = (
await self.kernel.invoke(function=function, **arguments)
).value
self.log_tokens(
prompt_tokens=answer.prompt_tokens,
completion_tokens=answer.completion_tokens,
)
# Run post prompt if needed
if (
self.config.prompts.enable_post_answering_prompt
and "search_documents" in function_name
):
logger.debug("Running post answering prompt")
answer: Answer = (
await self.kernel.invoke(
function_name="validate_answer",
plugin_name="PostAnswering",
answer=answer,
)
).value
self.log_tokens(
prompt_tokens=answer.prompt_tokens,
completion_tokens=answer.completion_tokens,
)
else:
logger.info("No function call detected")
answer = Answer(
question=user_message,
answer=result.content,
prompt_tokens=result.metadata["usage"].prompt_tokens,
completion_tokens=result.metadata["usage"].completion_tokens,
)
# Call Content Safety tool
if self.config.prompts.enable_content_safety:
if response := self.call_content_safety_output(user_message, answer.answer):
return response
# Format the output for the UI
messages = self.output_parser.parse(
question=answer.question,
answer=answer.answer,
source_documents=answer.source_documents,
)
logger.info("Method orchestrate of semantic_kernel ended")
return messages