@@ -817,24 +817,28 @@ def _diag_dot(self, D, B):
817
817
D = D [(slice (None ), ) + (np .newaxis , ) * (len (B .shape ) - 1 )]
818
818
return D * B
819
819
820
- def _errors (self , alpha , y , v , Q , QT_y ):
821
- # don't construct matrix G, instead compute action on y & diagonal
820
+ def _errors_and_values_helper (self , alpha , y , v , Q , QT_y ):
821
+ """Helper function to avoid code duplication between self._errors and
822
+ self._values.
823
+
824
+ Notes
825
+ -----
826
+ We don't construct matrix G, instead compute action on y & diagonal.
827
+ """
822
828
w = 1.0 / (v + alpha )
823
829
c = np .dot (Q , self ._diag_dot (w , QT_y ))
824
830
G_diag = self ._decomp_diag (w , Q )
825
831
# handle case where y is 2-d
826
832
if len (y .shape ) != 1 :
827
833
G_diag = G_diag [:, np .newaxis ]
834
+ return G_diag , c
835
+
836
+ def _errors (self , alpha , y , v , Q , QT_y ):
837
+ G_diag , c = self ._errors_and_values_helper (alpha , y , v , Q , QT_y )
828
838
return (c / G_diag ) ** 2 , c
829
839
830
840
def _values (self , alpha , y , v , Q , QT_y ):
831
- # don't construct matrix G, instead compute action on y & diagonal
832
- w = 1.0 / (v + alpha )
833
- c = np .dot (Q , self ._diag_dot (w , QT_y ))
834
- G_diag = self ._decomp_diag (w , Q )
835
- # handle case where y is 2-d
836
- if len (y .shape ) != 1 :
837
- G_diag = G_diag [:, np .newaxis ]
841
+ G_diag , c = self ._errors_and_values_helper (alpha , y , v , Q , QT_y )
838
842
return y - (c / G_diag ), c
839
843
840
844
def _pre_compute_svd (self , X , y ):
@@ -845,22 +849,24 @@ def _pre_compute_svd(self, X, y):
845
849
UT_y = np .dot (U .T , y )
846
850
return v , U , UT_y
847
851
848
- def _errors_svd (self , alpha , y , v , U , UT_y ):
852
+ def _errors_and_values_svd_helper (self , alpha , y , v , U , UT_y ):
853
+ """Helper function to avoid code duplication between self._errors_svd
854
+ and self._values_svd.
855
+ """
849
856
w = ((v + alpha ) ** - 1 ) - (alpha ** - 1 )
850
857
c = np .dot (U , self ._diag_dot (w , UT_y )) + (alpha ** - 1 ) * y
851
858
G_diag = self ._decomp_diag (w , U ) + (alpha ** - 1 )
852
859
if len (y .shape ) != 1 :
853
860
# handle case where y is 2-d
854
861
G_diag = G_diag [:, np .newaxis ]
862
+ return G_diag , c
863
+
864
+ def _errors_svd (self , alpha , y , v , U , UT_y ):
865
+ G_diag , c = self ._errors_and_values_svd_helper (alpha , y , v , U , UT_y )
855
866
return (c / G_diag ) ** 2 , c
856
867
857
868
def _values_svd (self , alpha , y , v , U , UT_y ):
858
- w = ((v + alpha ) ** - 1 ) - (alpha ** - 1 )
859
- c = np .dot (U , self ._diag_dot (w , UT_y )) + (alpha ** - 1 ) * y
860
- G_diag = self ._decomp_diag (w , U ) + (alpha ** - 1 )
861
- if len (y .shape ) != 1 :
862
- # handle case when y is 2-d
863
- G_diag = G_diag [:, np .newaxis ]
869
+ G_diag , c = self ._errors_and_values_svd_helper (alpha , y , v , U , UT_y )
864
870
return y - (c / G_diag ), c
865
871
866
872
def fit (self , X , y , sample_weight = None ):
0 commit comments