|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from typing import Any |
3 | 4 | import collections |
| 5 | +import copy |
4 | 6 | from collections.abc import Sequence |
5 | 7 |
|
6 | 8 | import numpy as np |
|
10 | 12 |
|
11 | 13 |
|
12 | 14 | class QuantumEvolutionKernel: |
| 15 | + """QuantumEvolutionKernel class. |
| 16 | +
|
| 17 | + Attributes: |
| 18 | + - params (dict): Dictionary of training parameters. |
| 19 | + - X (Sequence[ProcessedData]): Training data used for fitting the kernel |
| 20 | + - kernel_matrix (np.ndarray): Kernel matrix. This is assigned in the `fit()` method |
| 21 | +
|
| 22 | +
|
| 23 | + """ |
| 24 | + |
13 | 25 | def __init__(self, mu: float): |
14 | | - self.mu = mu |
| 26 | + """Initialize the QuantumEvolutionKernel. |
| 27 | +
|
| 28 | + Args: |
| 29 | + mu (float): Scaling factor for the Jensen-Shannon divergence |
| 30 | + """ |
| 31 | + self.params: dict[str, Any] = {"mu": mu} |
| 32 | + self.X: Sequence[ProcessedData] |
| 33 | + self.kernel_matrix: np.ndarray |
15 | 34 |
|
16 | 35 | def __call__( |
17 | 36 | self, graph_1: ProcessedData, graph_2: ProcessedData, size_max: int = 100 |
@@ -49,7 +68,46 @@ class from qek_os.data_io.dataset. The size_max parameter controls the maximum |
49 | 68 | js = ( |
50 | 69 | jensenshannon(p=dist_graph_1, q=dist_graph_2) ** 2 |
51 | 70 | ) # Because the divergence is the square root of the distance |
52 | | - return float(np.exp(-self.mu * js)) |
| 71 | + return float(np.exp(-self.params["mu"] * js)) |
| 72 | + |
| 73 | + def fit(self, X: Sequence[ProcessedData], y: list = None) -> None: |
| 74 | + """Fit the kernel to the training dataset by storing the dataset. |
| 75 | +
|
| 76 | + Args: |
| 77 | + X (Sequence[ProcessedData]): The training dataset. |
| 78 | + y: list: Target variable for the dataset sequence. defaults to None. |
| 79 | + """ |
| 80 | + self.X = X |
| 81 | + self.kernel_matrix = self.create_train_kernel_matrix(self.X) |
| 82 | + |
| 83 | + def transform(self, X_test: Sequence[ProcessedData], y_test: list = None) -> np.ndarray: |
| 84 | + """Transform the dataset into the kernel space with respect to the training dataset. |
| 85 | +
|
| 86 | + Args: |
| 87 | + X_test (Sequence[ProcessedData]): The dataset to transform. |
| 88 | + y_test: list: Target variable for the dataset sequence. defaults to None. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + np.ndarray: Kernel matrix where each entry represents the similarity between |
| 92 | + the given dataset and the training dataset. |
| 93 | + """ |
| 94 | + if self.X is None: |
| 95 | + raise ValueError("The kernel must be fit to a training dataset before transforming.") |
| 96 | + |
| 97 | + return self.create_test_kernel_matrix(X_test, self.X) |
| 98 | + |
| 99 | + def fit_transform(self, X: Sequence[ProcessedData], y: list = None) -> np.ndarray: |
| 100 | + """Fit the kernel to the training dataset and transform it. |
| 101 | +
|
| 102 | + Args: |
| 103 | + X (Sequence[ProcessedData]): The dataset to fit and transform. |
| 104 | + y: list: Target variable for the dataset sequence. defaults to None. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + np.ndarray: Kernel matrix for the training dataset. |
| 108 | + """ |
| 109 | + self.fit(X) |
| 110 | + return self.kernel_matrix |
53 | 111 |
|
54 | 112 | def create_train_kernel_matrix(self, train_dataset: Sequence[ProcessedData]) -> np.ndarray: |
55 | 113 | """Compute a kernel matrix for a given training dataset. |
@@ -102,6 +160,28 @@ def create_test_kernel_matrix( |
102 | 160 | kernel_mat[i][j] = self(test_dataset[i], train_dataset[j]) |
103 | 161 | return kernel_mat |
104 | 162 |
|
| 163 | + def set_params(self, **kwargs: dict[str, Any]) -> None: |
| 164 | + """Set multiple parameters for the kernel. |
| 165 | +
|
| 166 | + Args: |
| 167 | + **kwargs: Arbitrary keyword dictionary where keys are attribute names |
| 168 | + and values are their respective values |
| 169 | + """ |
| 170 | + for key, value in kwargs.items(): |
| 171 | + self.params[key] = value |
| 172 | + |
| 173 | + def get_params(self, deep: bool = True) -> dict: |
| 174 | + """Retrieve the value of all parameters. |
| 175 | +
|
| 176 | + Args: |
| 177 | + deep (bool): Ignored. Added for compatibility with various machine learning libraries, |
| 178 | + such as scikit-learn. |
| 179 | +
|
| 180 | + Returns |
| 181 | + dict: A dictionary of parameters and their respective values. |
| 182 | + """ |
| 183 | + return copy.deepcopy(self.params) |
| 184 | + |
105 | 185 |
|
106 | 186 | def count_occupation_from_bitstring(bitstring: str) -> int: |
107 | 187 | """Counts the number of '1' bits in a binary string. |
|
0 commit comments