@@ -1787,3 +1787,55 @@ def is_stationary(self):
1787
1787
def __repr__ (self ):
1788
1788
return "{0}(gamma={1}, metric={2})" .format (
1789
1789
self .__class__ .__name__ , self .gamma , self .metric )
1790
+
1791
+
1792
+ class SelectDimensionKernel (Kernel ):
1793
+ def __init__ (self , kernel , active_dims ):
1794
+ self .kernel = kernel
1795
+ if not isinstance (self .kernel , Kernel ):
1796
+ raise ValueError ("Expected kernel to be a Kernel instance, got "
1797
+ "%s" % self .kernel )
1798
+ self .active_dims = np .asarray (active_dims )
1799
+ self .active_kernel = clone (kernel )
1800
+ params = self .kernel .get_params ()
1801
+
1802
+ if self .active_dims .ndim != 1 :
1803
+ raise ValueError ("active_dims should be 1-dimensional, got %d"
1804
+ % self .active_dims .ndim )
1805
+ n_active_dims = self .active_dims .shape [0 ]
1806
+
1807
+ new_hyperparameters = []
1808
+ for hyperparam in kernel .hyperparameters :
1809
+ n_elements = hyperparam .n_elements
1810
+ if n_elements != 1 and n_elements != n_active_dims :
1811
+ raise ValueError ("Expected %d number of elements in "
1812
+ "hyperparameter %s of kernel %s, got %d" %
1813
+ (n_active_dims , hyperparameter ,
1814
+ kernel , n_elements ))
1815
+ if n_elements == 1 :
1816
+ new_hyperparam = hyperparam
1817
+ else :
1818
+ hyperparam_value = getattr (
1819
+ self .active_kernel , hyperparam .name )
1820
+ setattr (self .active_kernel , hyperparam .name , hyperparam_value )
1821
+ new_hyperparam = hyperparam ._replace (
1822
+ bounds = hyperparam .bounds [self .active_dims ])
1823
+
1824
+ new_hyperparameters .append (new_hyperparam )
1825
+ self .hyperparameters_ = new_hyperparameters
1826
+
1827
+ @property
1828
+ def hyperparameters (self ):
1829
+ return self .hyperparameters_
1830
+
1831
+ def diag (self , X ):
1832
+ return self .active_kernel .diag (X [self .active_dims ])
1833
+
1834
+ def is_stationary (self ):
1835
+ return self .active_kernel .is_stationary ()
1836
+
1837
+ def __call__ (self , X , Y = None , eval_gradient = False ):
1838
+ if Y is None :
1839
+ return self .active_kernel (X [self .active_dims ], X [self .active_dims ])
1840
+ else :
1841
+ return self .active_kernel (X [self .active_dims ], Y [self .active_dims ])
0 commit comments