Skip to content

Commit ca64454

Browse files
committed
Add allow/block list
Fix typo Tidy up code Further cleanup
1 parent 5377046 commit ca64454

2 files changed

Lines changed: 124 additions & 13 deletions

File tree

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIFunction.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,20 @@
77
import com.fasterxml.jackson.core.JsonProcessingException;
88
import com.fasterxml.jackson.databind.JsonNode;
99
import com.fasterxml.jackson.databind.ObjectMapper;
10-
import com.microsoft.semantickernel.exceptions.SKException;
1110
import com.microsoft.semantickernel.orchestration.responseformat.ResponseSchemaGenerator;
11+
import com.microsoft.semantickernel.plugin.KernelPluginFactory;
1212
import com.microsoft.semantickernel.semanticfunctions.InputVariable;
1313
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionMetadata;
14-
import org.apache.commons.lang3.StringUtils;
1514
import java.util.ArrayList;
1615
import java.util.Collections;
1716
import java.util.HashMap;
1817
import java.util.List;
1918
import java.util.Locale;
2019
import java.util.Map;
21-
import java.util.Objects;
22-
import java.util.concurrent.ConcurrentHashMap;
2320
import java.util.stream.Collectors;
2421
import javax.annotation.Nonnull;
2522
import javax.annotation.Nullable;
23+
import org.apache.commons.lang3.StringUtils;
2624

2725
class OpenAIFunction {
2826

semantickernel-api/src/main/java/com/microsoft/semantickernel/plugin/KernelPluginFactory.java

Lines changed: 122 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,14 @@
2727
import java.nio.file.Path;
2828
import java.util.ArrayList;
2929
import java.util.Arrays;
30+
import java.util.Collections;
3031
import java.util.HashMap;
3132
import java.util.HashSet;
3233
import java.util.List;
34+
import java.util.Locale;
3335
import java.util.Map;
36+
import java.util.Set;
37+
import java.util.function.Predicate;
3438
import java.util.stream.Collectors;
3539
import javax.annotation.Nullable;
3640
import org.reactivestreams.Publisher;
@@ -67,6 +71,12 @@ public class KernelPluginFactory {
6771
COMMON_CLASS_NAMES.put("map", HashMap.class);
6872
COMMON_CLASS_NAMES.put("set", HashSet.class);
6973

74+
COMMON_CLASS_NAMES.put(Integer.class.getName(), int.class);
75+
COMMON_CLASS_NAMES.put(String.class.getName(), String.class);
76+
COMMON_CLASS_NAMES.put(List.class.getName(), ArrayList.class);
77+
COMMON_CLASS_NAMES.put(Map.class.getName(), HashMap.class);
78+
COMMON_CLASS_NAMES.put(Set.class.getName(), HashSet.class);
79+
7080
BOXED_FROM_PRIMITIVE.put(void.class, Void.class);
7181
BOXED_FROM_PRIMITIVE.put(int.class, Integer.class);
7282
BOXED_FROM_PRIMITIVE.put(double.class, Double.class);
@@ -240,28 +250,27 @@ public static Class<?> getTypeForName(String className) {
240250
return clazz;
241251
}
242252

253+
if (!checkClassName(className)) {
254+
throw new SKException("Requested type is not allowed: " + className);
255+
}
256+
243257
try {
244258
clazz = Thread.currentThread().getContextClassLoader().loadClass(className);
245259
} catch (ClassNotFoundException e) {
246260
// ignore
247261
}
248262

249-
if (clazz == null) {
250-
try {
251-
// Seems that in tests specifically we need to use the class loader of the class itself
252-
clazz = KernelPluginFactory.class.getClassLoader().loadClass(className);
253-
} catch (ClassNotFoundException e) {
254-
// ignore
255-
}
256-
}
257-
258263
if (clazz == null) {
259264
throw new SKException("Requested type could not be found: " + className
260265
+ ". This needs to be a fully qualified class name, e.g. 'java.lang.String'.");
261266
}
262267
return clazz;
263268
}
264269

270+
public static boolean checkClassName(String className) {
271+
return ClassFilter.CLASS_CHECKER.test(className);
272+
}
273+
265274
/**
266275
* Creates a plugin from the provided name and function collection.
267276
*
@@ -429,6 +438,7 @@ private static <T> KernelFunction<T> getKernelFunction(
429438

430439
/**
431440
* Imports a plugin from a resource directory on the filesystem.
441+
*
432442
* @param parentDirectory The parent directory containing the plugin directories.
433443
* @param pluginDirectoryName The name of the plugin directory.
434444
* @param functionName The name of the function to import.
@@ -552,4 +562,107 @@ private static PromptTemplateConfig getPromptTemplateConfig(
552562
return null;
553563
}
554564
}
565+
566+
// Filters allowed classes that can be used as types in plugins
567+
public static class ClassFilter {
568+
569+
// Selects which filter type to use, allow list or ban list
570+
public static final String CLASS_BLOCK_TYPE_PROPERTY_NAME = "semantic-kernel.class-block-type";
571+
public static final String CLASS_BLOCK_LIST_PROPERTY_NAME = "semantic-kernel.class-block-list";
572+
public static final String CLASS_ALLOW_LIST_PROPERTY_NAME = "semantic-kernel.class-allow-list";
573+
574+
// allow nothing by default (other than java primitives and collections)
575+
private static final List<String> CLASS_ALLOW_LIST;
576+
private static final List<String> CLASS_ALLOW_LIST_DEFAULT = Collections.emptyList();
577+
578+
// block Java classes by default (other than java primitives and collections)
579+
private static final List<String> CLASS_BLOCK_LIST;
580+
private static final List<String> CLASS_BLOCK_LIST_DEFAULT = Arrays.asList(
581+
"java\\..*",
582+
"com\\.sun\\..*",
583+
"javax\\..*",
584+
"jdk\\..*",
585+
"org\\.xml\\..*",
586+
"org\\.w3c\\..*"
587+
);
588+
589+
static Predicate<String> CLASS_CHECKER;
590+
591+
private enum BlockType {
592+
BLOCK,
593+
ALLOW
594+
}
595+
596+
static {
597+
// Default to blocking type
598+
String classFilterType = System.getProperty(CLASS_BLOCK_TYPE_PROPERTY_NAME,
599+
BlockType.BLOCK.name());
600+
CLASS_BLOCK_LIST = getList(CLASS_BLOCK_LIST_PROPERTY_NAME, CLASS_BLOCK_LIST_DEFAULT);
601+
CLASS_ALLOW_LIST = getList(CLASS_ALLOW_LIST_PROPERTY_NAME, CLASS_ALLOW_LIST_DEFAULT);
602+
603+
BlockType type;
604+
605+
try {
606+
type = BlockType.valueOf(classFilterType.toUpperCase(Locale.ROOT));
607+
} catch (IllegalArgumentException e) {
608+
type = BlockType.BLOCK;
609+
}
610+
611+
switch (type) {
612+
case ALLOW:
613+
CLASS_CHECKER = ClassFilter::evaluateAllow;
614+
break;
615+
case BLOCK:
616+
default:
617+
CLASS_CHECKER = ClassFilter::evaluateBlock;
618+
break;
619+
}
620+
}
621+
622+
private static List<String> getList(String propertyName, List<String> defaultList) {
623+
String blockList = System.getProperty(propertyName);
624+
625+
if (blockList != null) {
626+
return Arrays.asList(blockList.split(","));
627+
} else {
628+
return defaultList;
629+
}
630+
}
631+
632+
// Block classes/packages classes (other than common Java primitives and collections)
633+
private static boolean evaluateBlock(String className) {
634+
if (className == null || className.isEmpty()) {
635+
return false;
636+
}
637+
638+
for (String ban : CLASS_BLOCK_LIST) {
639+
if (className.matches(ban)) {
640+
LOGGER.warn(
641+
"Skipping class not allowed by class block list {}, if you wish to unblock this class update the property: {}",
642+
className, CLASS_BLOCK_LIST_PROPERTY_NAME);
643+
return false;
644+
}
645+
}
646+
647+
return true;
648+
}
649+
650+
// Only allow explicitly allowed classes/packages (other than common Java primitives and collections)
651+
private static boolean evaluateAllow(String className) {
652+
if (className == null || className.isEmpty()) {
653+
return false;
654+
}
655+
656+
for (String allow : CLASS_ALLOW_LIST) {
657+
if (className.matches(allow)) {
658+
return true;
659+
}
660+
}
661+
662+
LOGGER.warn(
663+
"Skipping class not allowed by class allow list {}, if you wish to allow this class update the property: {}",
664+
className, CLASS_ALLOW_LIST_DEFAULT);
665+
return false;
666+
}
667+
}
555668
}

0 commit comments

Comments
 (0)