|
27 | 27 | import java.nio.file.Path; |
28 | 28 | import java.util.ArrayList; |
29 | 29 | import java.util.Arrays; |
| 30 | +import java.util.Collections; |
30 | 31 | import java.util.HashMap; |
31 | 32 | import java.util.HashSet; |
32 | 33 | import java.util.List; |
| 34 | +import java.util.Locale; |
33 | 35 | import java.util.Map; |
| 36 | +import java.util.Set; |
| 37 | +import java.util.function.Predicate; |
34 | 38 | import java.util.stream.Collectors; |
35 | 39 | import javax.annotation.Nullable; |
36 | 40 | import org.reactivestreams.Publisher; |
@@ -67,6 +71,12 @@ public class KernelPluginFactory { |
67 | 71 | COMMON_CLASS_NAMES.put("map", HashMap.class); |
68 | 72 | COMMON_CLASS_NAMES.put("set", HashSet.class); |
69 | 73 |
|
| 74 | + COMMON_CLASS_NAMES.put(Integer.class.getName(), int.class); |
| 75 | + COMMON_CLASS_NAMES.put(String.class.getName(), String.class); |
| 76 | + COMMON_CLASS_NAMES.put(List.class.getName(), ArrayList.class); |
| 77 | + COMMON_CLASS_NAMES.put(Map.class.getName(), HashMap.class); |
| 78 | + COMMON_CLASS_NAMES.put(Set.class.getName(), HashSet.class); |
| 79 | + |
70 | 80 | BOXED_FROM_PRIMITIVE.put(void.class, Void.class); |
71 | 81 | BOXED_FROM_PRIMITIVE.put(int.class, Integer.class); |
72 | 82 | BOXED_FROM_PRIMITIVE.put(double.class, Double.class); |
@@ -240,28 +250,27 @@ public static Class<?> getTypeForName(String className) { |
240 | 250 | return clazz; |
241 | 251 | } |
242 | 252 |
|
| 253 | + if (!checkClassName(className)) { |
| 254 | + throw new SKException("Requested type is not allowed: " + className); |
| 255 | + } |
| 256 | + |
243 | 257 | try { |
244 | 258 | clazz = Thread.currentThread().getContextClassLoader().loadClass(className); |
245 | 259 | } catch (ClassNotFoundException e) { |
246 | 260 | // ignore |
247 | 261 | } |
248 | 262 |
|
249 | | - if (clazz == null) { |
250 | | - try { |
251 | | - // Seems that in tests specifically we need to use the class loader of the class itself |
252 | | - clazz = KernelPluginFactory.class.getClassLoader().loadClass(className); |
253 | | - } catch (ClassNotFoundException e) { |
254 | | - // ignore |
255 | | - } |
256 | | - } |
257 | | - |
258 | 263 | if (clazz == null) { |
259 | 264 | throw new SKException("Requested type could not be found: " + className |
260 | 265 | + ". This needs to be a fully qualified class name, e.g. 'java.lang.String'."); |
261 | 266 | } |
262 | 267 | return clazz; |
263 | 268 | } |
264 | 269 |
|
| 270 | + public static boolean checkClassName(String className) { |
| 271 | + return ClassFilter.CLASS_CHECKER.test(className); |
| 272 | + } |
| 273 | + |
265 | 274 | /** |
266 | 275 | * Creates a plugin from the provided name and function collection. |
267 | 276 | * |
@@ -429,6 +438,7 @@ private static <T> KernelFunction<T> getKernelFunction( |
429 | 438 |
|
430 | 439 | /** |
431 | 440 | * Imports a plugin from a resource directory on the filesystem. |
| 441 | + * |
432 | 442 | * @param parentDirectory The parent directory containing the plugin directories. |
433 | 443 | * @param pluginDirectoryName The name of the plugin directory. |
434 | 444 | * @param functionName The name of the function to import. |
@@ -552,4 +562,107 @@ private static PromptTemplateConfig getPromptTemplateConfig( |
552 | 562 | return null; |
553 | 563 | } |
554 | 564 | } |
| 565 | + |
| 566 | + // Filters allowed classes that can be used as types in plugins |
| 567 | + public static class ClassFilter { |
| 568 | + |
| 569 | + // Selects which filter type to use, allow list or ban list |
| 570 | + public static final String CLASS_BLOCK_TYPE_PROPERTY_NAME = "semantic-kernel.class-block-type"; |
| 571 | + public static final String CLASS_BLOCK_LIST_PROPERTY_NAME = "semantic-kernel.class-block-list"; |
| 572 | + public static final String CLASS_ALLOW_LIST_PROPERTY_NAME = "semantic-kernel.class-allow-list"; |
| 573 | + |
| 574 | + // allow nothing by default (other than java primitives and collections) |
| 575 | + private static final List<String> CLASS_ALLOW_LIST; |
| 576 | + private static final List<String> CLASS_ALLOW_LIST_DEFAULT = Collections.emptyList(); |
| 577 | + |
| 578 | + // block Java classes by default (other than java primitives and collections) |
| 579 | + private static final List<String> CLASS_BLOCK_LIST; |
| 580 | + private static final List<String> CLASS_BLOCK_LIST_DEFAULT = Arrays.asList( |
| 581 | + "java\\..*", |
| 582 | + "com\\.sun\\..*", |
| 583 | + "javax\\..*", |
| 584 | + "jdk\\..*", |
| 585 | + "org\\.xml\\..*", |
| 586 | + "org\\.w3c\\..*" |
| 587 | + ); |
| 588 | + |
| 589 | + static Predicate<String> CLASS_CHECKER; |
| 590 | + |
| 591 | + private enum BlockType { |
| 592 | + BLOCK, |
| 593 | + ALLOW |
| 594 | + } |
| 595 | + |
| 596 | + static { |
| 597 | + // Default to blocking type |
| 598 | + String classFilterType = System.getProperty(CLASS_BLOCK_TYPE_PROPERTY_NAME, |
| 599 | + BlockType.BLOCK.name()); |
| 600 | + CLASS_BLOCK_LIST = getList(CLASS_BLOCK_LIST_PROPERTY_NAME, CLASS_BLOCK_LIST_DEFAULT); |
| 601 | + CLASS_ALLOW_LIST = getList(CLASS_ALLOW_LIST_PROPERTY_NAME, CLASS_ALLOW_LIST_DEFAULT); |
| 602 | + |
| 603 | + BlockType type; |
| 604 | + |
| 605 | + try { |
| 606 | + type = BlockType.valueOf(classFilterType.toUpperCase(Locale.ROOT)); |
| 607 | + } catch (IllegalArgumentException e) { |
| 608 | + type = BlockType.BLOCK; |
| 609 | + } |
| 610 | + |
| 611 | + switch (type) { |
| 612 | + case ALLOW: |
| 613 | + CLASS_CHECKER = ClassFilter::evaluateAllow; |
| 614 | + break; |
| 615 | + case BLOCK: |
| 616 | + default: |
| 617 | + CLASS_CHECKER = ClassFilter::evaluateBlock; |
| 618 | + break; |
| 619 | + } |
| 620 | + } |
| 621 | + |
| 622 | + private static List<String> getList(String propertyName, List<String> defaultList) { |
| 623 | + String blockList = System.getProperty(propertyName); |
| 624 | + |
| 625 | + if (blockList != null) { |
| 626 | + return Arrays.asList(blockList.split(",")); |
| 627 | + } else { |
| 628 | + return defaultList; |
| 629 | + } |
| 630 | + } |
| 631 | + |
| 632 | + // Block classes/packages classes (other than common Java primitives and collections) |
| 633 | + private static boolean evaluateBlock(String className) { |
| 634 | + if (className == null || className.isEmpty()) { |
| 635 | + return false; |
| 636 | + } |
| 637 | + |
| 638 | + for (String ban : CLASS_BLOCK_LIST) { |
| 639 | + if (className.matches(ban)) { |
| 640 | + LOGGER.warn( |
| 641 | + "Skipping class not allowed by class block list {}, if you wish to unblock this class update the property: {}", |
| 642 | + className, CLASS_BLOCK_LIST_PROPERTY_NAME); |
| 643 | + return false; |
| 644 | + } |
| 645 | + } |
| 646 | + |
| 647 | + return true; |
| 648 | + } |
| 649 | + |
| 650 | + // Only allow explicitly allowed classes/packages (other than common Java primitives and collections) |
| 651 | + private static boolean evaluateAllow(String className) { |
| 652 | + if (className == null || className.isEmpty()) { |
| 653 | + return false; |
| 654 | + } |
| 655 | + |
| 656 | + for (String allow : CLASS_ALLOW_LIST) { |
| 657 | + if (className.matches(allow)) { |
| 658 | + return true; |
| 659 | + } |
| 660 | + } |
| 661 | + |
| 662 | + LOGGER.warn( |
| 663 | + "Skipping class not allowed by class allow list {}, if you wish to allow this class update the property: {}", |
| 664 | + className, CLASS_ALLOW_LIST_DEFAULT); |
| 665 | + return false; |
| 666 | + } |
| 667 | + } |
555 | 668 | } |
0 commit comments