Skip to content

Commit c7daa37

Browse files
committed
SelectDimensionKernel
1 parent e2a2b4d commit c7daa37

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

sklearn/gaussian_process/kernels.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,3 +1787,55 @@ def is_stationary(self):
17871787
def __repr__(self):
17881788
return "{0}(gamma={1}, metric={2})".format(
17891789
self.__class__.__name__, self.gamma, self.metric)
1790+
1791+
1792+
class SelectDimensionKernel(Kernel):
1793+
def __init__(self, kernel, active_dims):
1794+
self.kernel = kernel
1795+
if not isinstance(self.kernel, Kernel):
1796+
raise ValueError("Expected kernel to be a Kernel instance, got "
1797+
"%s" % self.kernel)
1798+
self.active_dims = np.asarray(active_dims)
1799+
self.active_kernel = clone(kernel)
1800+
params = self.kernel.get_params()
1801+
1802+
if self.active_dims.ndim != 1:
1803+
raise ValueError("active_dims should be 1-dimensional, got %d"
1804+
% self.active_dims.ndim)
1805+
n_active_dims = self.active_dims.shape[0]
1806+
1807+
new_hyperparameters = []
1808+
for hyperparam in kernel.hyperparameters:
1809+
n_elements = hyperparam.n_elements
1810+
if n_elements != 1 and n_elements != n_active_dims:
1811+
raise ValueError("Expected %d number of elements in "
1812+
"hyperparameter %s of kernel %s, got %d" %
1813+
(n_active_dims, hyperparameter,
1814+
kernel, n_elements))
1815+
if n_elements == 1:
1816+
new_hyperparam = hyperparam
1817+
else:
1818+
hyperparam_value = getattr(
1819+
self.active_kernel, hyperparam.name)
1820+
setattr(self.active_kernel, hyperparam.name, hyperparam_value)
1821+
new_hyperparam = hyperparam._replace(
1822+
bounds=hyperparam.bounds[self.active_dims])
1823+
1824+
new_hyperparameters.append(new_hyperparam)
1825+
self.hyperparameters_ = new_hyperparameters
1826+
1827+
@property
1828+
def hyperparameters(self):
1829+
return self.hyperparameters_
1830+
1831+
def diag(self, X):
1832+
return self.active_kernel.diag(X[self.active_dims])
1833+
1834+
def is_stationary(self):
1835+
return self.active_kernel.is_stationary()
1836+
1837+
def __call__(self, X, Y=None, eval_gradient=False):
1838+
if Y is None:
1839+
return self.active_kernel(X[self.active_dims], X[self.active_dims])
1840+
else:
1841+
return self.active_kernel(X[self.active_dims], Y[self.active_dims])

0 commit comments

Comments
 (0)