Skip to content

Commit c1fe208

Browse files
mlahariyaYoric
andauthored
[Features] Add Fit/Transform and Set_params/Get_params Methods to QEK (#27)
* Update QEK with fit and transform methods * Update tutorial with fit_transform method * Cleanup * Cleanup * Update get/set_params * Update fit transform methods with X and y * Restore two methods for creating kernal matrix * Update get_params function * Clean up * Update qek/kernel/kernel.py Co-authored-by: David Teller <david.teller@pasqal.com> * Update add docs --------- Co-authored-by: David Teller <david.teller@pasqal.com>
1 parent 91eb3c3 commit c1fe208

2 files changed

Lines changed: 83 additions & 3 deletions

File tree

examples/tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@
771771
"metadata": {},
772772
"outputs": [],
773773
"source": [
774-
"train_kernel = kernel.create_train_kernel_matrix(processed_dataset)\n",
774+
"train_kernel = kernel.fit_transform(processed_dataset)\n",
775775
"y_tot = [data.target for data in processed_dataset]"
776776
]
777777
},

qek/kernel/kernel.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
from typing import Any
34
import collections
5+
import copy
46
from collections.abc import Sequence
57

68
import numpy as np
@@ -10,8 +12,25 @@
1012

1113

1214
class QuantumEvolutionKernel:
15+
"""QuantumEvolutionKernel class.
16+
17+
Attributes:
18+
- params (dict): Dictionary of training parameters.
19+
- X (Sequence[ProcessedData]): Training data used for fitting the kernel
20+
- kernel_matrix (np.ndarray): Kernel matrix. This is assigned in the `fit()` method
21+
22+
23+
"""
24+
1325
def __init__(self, mu: float):
14-
self.mu = mu
26+
"""Initialize the QuantumEvolutionKernel.
27+
28+
Args:
29+
mu (float): Scaling factor for the Jensen-Shannon divergence
30+
"""
31+
self.params: dict[str, Any] = {"mu": mu}
32+
self.X: Sequence[ProcessedData]
33+
self.kernel_matrix: np.ndarray
1534

1635
def __call__(
1736
self, graph_1: ProcessedData, graph_2: ProcessedData, size_max: int = 100
@@ -49,7 +68,46 @@ class from qek_os.data_io.dataset. The size_max parameter controls the maximum
4968
js = (
5069
jensenshannon(p=dist_graph_1, q=dist_graph_2) ** 2
5170
) # Because the divergence is the square root of the distance
52-
return float(np.exp(-self.mu * js))
71+
return float(np.exp(-self.params["mu"] * js))
72+
73+
def fit(self, X: Sequence[ProcessedData], y: list = None) -> None:
74+
"""Fit the kernel to the training dataset by storing the dataset.
75+
76+
Args:
77+
X (Sequence[ProcessedData]): The training dataset.
78+
y: list: Target variable for the dataset sequence. defaults to None.
79+
"""
80+
self.X = X
81+
self.kernel_matrix = self.create_train_kernel_matrix(self.X)
82+
83+
def transform(self, X_test: Sequence[ProcessedData], y_test: list = None) -> np.ndarray:
84+
"""Transform the dataset into the kernel space with respect to the training dataset.
85+
86+
Args:
87+
X_test (Sequence[ProcessedData]): The dataset to transform.
88+
y_test: list: Target variable for the dataset sequence. defaults to None.
89+
90+
Returns:
91+
np.ndarray: Kernel matrix where each entry represents the similarity between
92+
the given dataset and the training dataset.
93+
"""
94+
if self.X is None:
95+
raise ValueError("The kernel must be fit to a training dataset before transforming.")
96+
97+
return self.create_test_kernel_matrix(X_test, self.X)
98+
99+
def fit_transform(self, X: Sequence[ProcessedData], y: list = None) -> np.ndarray:
100+
"""Fit the kernel to the training dataset and transform it.
101+
102+
Args:
103+
X (Sequence[ProcessedData]): The dataset to fit and transform.
104+
y: list: Target variable for the dataset sequence. defaults to None.
105+
106+
Returns:
107+
np.ndarray: Kernel matrix for the training dataset.
108+
"""
109+
self.fit(X)
110+
return self.kernel_matrix
53111

54112
def create_train_kernel_matrix(self, train_dataset: Sequence[ProcessedData]) -> np.ndarray:
55113
"""Compute a kernel matrix for a given training dataset.
@@ -102,6 +160,28 @@ def create_test_kernel_matrix(
102160
kernel_mat[i][j] = self(test_dataset[i], train_dataset[j])
103161
return kernel_mat
104162

163+
def set_params(self, **kwargs: dict[str, Any]) -> None:
164+
"""Set multiple parameters for the kernel.
165+
166+
Args:
167+
**kwargs: Arbitrary keyword dictionary where keys are attribute names
168+
and values are their respective values
169+
"""
170+
for key, value in kwargs.items():
171+
self.params[key] = value
172+
173+
def get_params(self, deep: bool = True) -> dict:
174+
"""Retrieve the value of all parameters.
175+
176+
Args:
177+
deep (bool): Ignored. Added for compatibility with various machine learning libraries,
178+
such as scikit-learn.
179+
180+
Returns
181+
dict: A dictionary of parameters and their respective values.
182+
"""
183+
return copy.deepcopy(self.params)
184+
105185

106186
def count_occupation_from_bitstring(bitstring: str) -> int:
107187
"""Counts the number of '1' bits in a binary string.

0 commit comments

Comments
 (0)