forked from microsoft/semantic-kernel-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathKernelHooksTest.java
More file actions
103 lines (84 loc) · 4.15 KB
/
KernelHooksTest.java
File metadata and controls
103 lines (84 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.tests;
import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.Kernel.Builder;
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion;
import com.microsoft.semantickernel.hooks.KernelHook.FunctionInvokedHook;
import com.microsoft.semantickernel.hooks.KernelHook.FunctionInvokingHook;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.semanticfunctions.KernelFunction;
import com.microsoft.semantickernel.semanticfunctions.KernelArguments;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionFromPrompt;
import com.microsoft.semantickernel.semanticfunctions.OutputVariable;
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@WireMockTest
public class KernelHooksTest {
private static Builder getKernelBuilder(WireMockRuntimeInfo wmRuntimeInfo) {
final OpenAIAsyncClient client = new OpenAIClientBuilder()
.endpoint("http://localhost:" + wmRuntimeInfo.getHttpPort())
.buildAsyncClient();
ChatCompletionService openAIChatCompletion = OpenAIChatCompletion.builder()
.withModelId("gpt-35-turbo")
.withOpenAIAsyncClient(client)
.build();
return Kernel.builder()
.withAIService(ChatCompletionService.class, openAIChatCompletion);
}
@Test
public void getUsageAsync(WireMockRuntimeInfo wmRuntimeInfo) {
WireMockUtil.mockChatCompletionResponse("Write a random paragraph about", "a-response");
Kernel kernel = getKernelBuilder(wmRuntimeInfo).build();
System.out.println("\n======== Get Usage Data ========\n");
// Initialize prompt
String functionPrompt = "Write a random paragraph about: {{$input}}.";
KernelFunction<String> excuseFunction = KernelFunctionFromPrompt.builder()
.withTemplate(functionPrompt)
.withName("Excuse")
.withDefaultExecutionSettings(PromptExecutionSettings
.builder()
.withMaxTokens(100)
.withTemperature(0.4)
.withTopP(1)
.build())
.withOutputVariable(new OutputVariable<>("result", String.class))
.build();
AtomicBoolean preHookTriggered = new AtomicBoolean(false);
FunctionInvokingHook preHook = event -> {
preHookTriggered.set(true);
return event;
};
AtomicBoolean removedPreExecutionHandlerTriggered = new AtomicBoolean(false);
FunctionInvokingHook removedPreExecutionHandler = event -> {
removedPreExecutionHandlerTriggered.set(true);
return event;
};
AtomicBoolean postExecutionHandlerTriggered = new AtomicBoolean(false);
FunctionInvokedHook postExecutionHandler = event -> {
postExecutionHandlerTriggered.set(true);
return event;
};
kernel.getGlobalKernelHooks().addHook(preHook);
// Demonstrate pattern for removing a handler.
kernel.getGlobalKernelHooks().addHook("pre-invoke-removed", removedPreExecutionHandler);
kernel.getGlobalKernelHooks().removeHook("pre-invoke-removed");
kernel.getGlobalKernelHooks().addHook(postExecutionHandler);
kernel.invokeAsync(
excuseFunction)
.withArguments(
KernelArguments
.builder()
.withVariable("input", "I missed the F1 final race")
.build())
.block();
Assertions.assertTrue(preHookTriggered.get());
Assertions.assertFalse(removedPreExecutionHandlerTriggered.get());
Assertions.assertTrue(postExecutionHandlerTriggered.get());
}
}