Skip to content

[MRG] FIX: fix MLKR cost and gradient #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

wdevazelhes
Copy link
Member

Running scipy.check_grad on MLKR loss function returned an error.
This PR fixes the gradient and adds a non regression test. I also modified the cost function (filling K diagonal with zeros), see comments in the code below

denom = np.maximum(K.sum(axis=0), EPS)
yhat = K.dot(y) / denom
ydiff = yhat - y
cost = (ydiff**2).sum()

# also compute the gradient
np.fill_diagonal(K, 1)
Copy link
Member Author

@wdevazelhes wdevazelhes Aug 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was the diagonal of K set to 1 here ?

@@ -98,14 +98,14 @@ def _loss(flatA, X, y, dX):
A = flatA.reshape((-1, X.shape[1]))
dist = pdist(X, metric='mahalanobis', VI=A.T.dot(A))
K = squareform(np.exp(-dist**2))
np.fill_diagonal(K, 0)
Copy link
Member Author

@wdevazelhes wdevazelhes Aug 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to set K's diagonal to 0 here to prevent taking into account Kii in the sum called denom, and in the dot product with y (for the cost function) (and though it is involved in the gradient, it should still be correct since whatever the diagonal of K is in the gradient, it will be canceled out by zeros in dX)

@wdevazelhes
Copy link
Member Author

I just pushed a version that uses logsumexp instead of adding a small constant in exponential, which seems to be more stable numerically. Indeed, on this following example (the example from #104), after the commit the loss is better than before the commit (found by adding disp=1 in mlkr minimize). It seems this numerical error caused MLKR to think it is optimized, when it is in fact not (maybe a plateau ?)

from sklearn.datasets import make_regression
from metric_learn import MLKR
import numpy as np
np.random.seed(0)
X, y = make_regression()
mlkr = MLKR()
mlkr.fit(X, y)

Before commit 1143e59

Optimization terminated successfully.
         Current function value: 1753134.478289
         Iterations: 0
         Function evaluations: 1
         Gradient evaluations: 1

After commit 1143e59

Warning: Maximum number of iterations has been exceeded.
         Current function value: 1215299.266080
         Iterations: 1000
         Function evaluations: 1792
         Gradient evaluations: 1792

@wdevazelhes wdevazelhes changed the title FIX: fix mlkr cost and gradient [MRG] FIX: fix MLKR cost and gradient Aug 10, 2018
@wdevazelhes
Copy link
Member Author

I just added a commit that makes MLKR more memory efficient, in a way very similar to NCA. Running on the previous example, it might introduce some small numerical errors (see below: the objective at 1000 iterations seems worse than before), but the algorithm still runs (while before it stopped after 0 iterations), and the memory gain is big, so I guess it is still fine

Code:

from sklearn.datasets import make_regression
from metric_learn import MLKR
import numpy as np
np.random.seed(0)
X, y = make_regression()
mlkr = MLKR()
mlkr.fit(X, y)

Result (with setting disp=1 inside minimization):

Warning: Maximum number of iterations has been exceeded.
         Current function value: 1248855.833869
         Iterations: 1000
         Function evaluations: 1791
         Gradient evaluations: 1791

@wdevazelhes wdevazelhes added this to the v0.4.0 milestone Aug 13, 2018
@bellet
Copy link
Member

bellet commented Aug 14, 2018

Can you explain a bit more the last change and why it introduces some numerical differences with the previous case? Did you track down the extent of these differences? Maybe it would be useful to compare both versions by showing the objective value at each iteration, and maybe for several random seeds to check that the behavior is consistent. Maybe also use a standard small regression dataset (make_regression introduces noisy features which can make the problem difficult for a method based on the Gaussian kernel).

@wdevazelhes
Copy link
Member Author

The change amounts to expanding some squared differences, in a similar way to (a-b)*(a-b)=a**2 + b**2 - 2*a*b, and this can introduce some numerical errors (see: scikit-learn/scikit-learn#9354 (comment), or https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2): here we expand a sum of outer products of differences Xi - Xj (with some coefficients for each term, the Wij), into a sum of outer product of Xii (i.e. the dot product between X.T and X with the elementwise multiplication with W.sum(axis=0) + W.sum(axis=1)) and a sum of outer products of Xij (giving the triple dot product between X.T, W + W.T, and X).

Besides, similarly to NCA (see comment scikit-learn/scikit-learn#10058 (comment)) W.sum(axis=1) is in fact only zeros so no need to include it in the sum. But I don't think removing this zero (in practice almost zero) term would introduce numerical errors, on the contrary, it should remove some very small terms that should be zero in theory but are not in practice (except if those terms compensate some other errors but I don't see why they would)

Finally, we transform the A.dot(X.T) into X_emb_t like in NCA, because I think the point was that for large number of features and for a small output space it can help to do (A.dot(X.T)).dot(X) than A.dot(X.T.dot(X)) (complexity 2*n_samples*n_features*n_components vs n_features^2 * (n_components + n_samples). But now that I think about it, it really depends, if we have n_features=n_components and n_samples >> n_features this can be less efficient, so maybe we should use something like np.linalg.multi_dot ? And also, maybe the order of dot products can influence the precision ?

In any case I agree, I'll run a quick benchmark to see the extend of these numerical errors.

@bellet
Copy link
Member

bellet commented Aug 16, 2018

Yes, what I meant to ask is whether you verified that for several matrices A, the cost and gradient you find is the same (up to some small difference) with/without your optimization.

Regarding the last point: you could use an if condition to choose between the two ways of computing the gradient depending on the relative values of n_features, n_components and n_samples. But if that complicates the code too much, maybe we can just skip this for now and stick to the option you have now.

@wdevazelhes
Copy link
Member Author

Maybe it would be useful to compare both versions by showing the objective value at each iteration, and maybe for several random seeds to check that the behavior is consistent. Maybe also use a standard small regression dataset (make_regression introduces noisy features which can make the problem difficult for a method based on the Gaussian kernel).

I didn't run an intensive benchmark yet, but I runned again the snippet above with L-BFGS-B instead of CG (where printing the cost function and gradient value at each iteration is easier (see #105 (comment))). We see that around iteration 20, the gradient seems to diverge between the two implementations. I agree the make_regression dataset is probably not a good idea, I'll try the same with a more standard dataset like boston (even make_regression but with a small number of features seemed to work better)

With expanded form (new form)

Machine precision = 2.220D-16
 N =        10000     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  2.78046D+06    |proj g|=  1.08130D+06

At iterate    1    f=  2.13809D+06    |proj g|=  1.16530D+06

At iterate    2    f=  1.67627D+06    |proj g|=  2.83815D+05

At iterate    3    f=  1.61780D+06    |proj g|=  1.47791D+05

At iterate    4    f=  1.56950D+06    |proj g|=  1.12790D+05

At iterate    5    f=  1.55094D+06    |proj g|=  7.18991D+04

At iterate    6    f=  1.55073D+06    |proj g|=  1.60675D+05

At iterate    7    f=  1.54365D+06    |proj g|=  1.70773D+05

At iterate    8    f=  1.53857D+06    |proj g|=  7.51888D+04

At iterate    9    f=  1.53701D+06    |proj g|=  2.99976D+04

At iterate   10    f=  1.53659D+06    |proj g|=  4.10903D+04

At iterate   11    f=  1.53611D+06    |proj g|=  2.62526D+04

At iterate   12    f=  1.52952D+06    |proj g|=  2.48966D+04

At iterate   13    f=  1.52891D+06    |proj g|=  1.62781D+04

At iterate   14    f=  1.52844D+06    |proj g|=  3.02756D+04

At iterate   15    f=  1.52797D+06    |proj g|=  2.98060D+04

At iterate   16    f=  1.52775D+06    |proj g|=  1.97663D+04

At iterate   17    f=  1.52757D+06    |proj g|=  7.49897D+03

At iterate   18    f=  1.52741D+06    |proj g|=  1.59429D+04

At iterate   19    f=  1.36466D+06    |proj g|=  9.50598D+04

At iterate   20    f=  1.36221D+06    |proj g|=  7.87909D+04

At iterate   21    f=  1.17814D+06    |proj g|=  1.96582D+05

At iterate   22    f=  1.13251D+06    |proj g|=  3.27058D+05

At iterate   23    f=  1.08014D+06    |proj g|=  1.62554D+05

At iterate   24    f=  1.06896D+06    |proj g|=  1.40173D+05

At iterate   25    f=  1.06061D+06    |proj g|=  3.99786D+04

At iterate   26    f=  1.04681D+06    |proj g|=  3.36043D+04

At iterate   27    f=  1.04636D+06    |proj g|=  2.11968D+04

At iterate   28    f=  1.04507D+06    |proj g|=  4.72693D+03

At iterate   29    f=  1.04470D+06    |proj g|=  5.06310D+03

At iterate   30    f=  1.03635D+06    |proj g|=  3.29010D+03

At iterate   31    f=  1.00658D+06    |proj g|=  1.98405D+04

At iterate   32    f=  1.00500D+06    |proj g|=  1.66036D+04

At iterate   33    f=  1.00483D+06    |proj g|=  1.29732D+04

At iterate   34    f=  1.00434D+06    |proj g|=  5.97226D+03

At iterate   35    f=  1.00427D+06    |proj g|=  4.32432D+03

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
10000     35     57      1     0     0   4.324D+03   1.004D+06
  F =   1004266.4009399187     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             

 Cauchy                time 0.000E+00 seconds.
 Subspace minimization time 0.000E+00 seconds.
 Line search           time 0.000E+00 seconds.

 Total User time 0.000E+00 seconds.

With factorized form (old form)

Machine precision = 2.220D-16
 N =        10000     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  2.78046D+06    |proj g|=  1.08130D+06

At iterate    1    f=  2.13809D+06    |proj g|=  1.16530D+06

At iterate    2    f=  1.67627D+06    |proj g|=  2.83815D+05

At iterate    3    f=  1.61780D+06    |proj g|=  1.47791D+05

At iterate    4    f=  1.56950D+06    |proj g|=  1.12790D+05

At iterate    5    f=  1.55094D+06    |proj g|=  7.18991D+04

At iterate    6    f=  1.55073D+06    |proj g|=  1.60675D+05

At iterate    7    f=  1.54365D+06    |proj g|=  1.70773D+05

At iterate    8    f=  1.53857D+06    |proj g|=  7.51888D+04

At iterate    9    f=  1.53701D+06    |proj g|=  2.99976D+04

At iterate   10    f=  1.53659D+06    |proj g|=  4.10903D+04

At iterate   11    f=  1.53611D+06    |proj g|=  2.62526D+04

At iterate   12    f=  1.52952D+06    |proj g|=  2.48966D+04
      fun: 1004266.4009399187
 hess_inv: <10000x10000 LbfgsInvHessProduct with dtype=float64>
      jac: array([-1800.17330295,  -349.61076155,   898.33241171, ...,
        -124.22813722,   623.87363675,   337.57684747])
  message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 57
      nit: 35
   status: 0
  success: True
        x: array([-0.16364048,  0.00618834,  0.02217811, ...,  0.32247086,
       -0.03682855, -0.1361809 ])
1004266.4009399187
      fun: 1006543.529038562
 hess_inv: <10000x10000 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 642.20747991, -265.95252268, -946.2479427 , ...,  231.40068126,
       -227.59339703, -122.922117  ])
  message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 59
      nit: 38
   status: 0
  success: True
        x: array([-0.16552844,  0.00397891,  0.02523658, ...,  0.32019185,
       -0.03875091, -0.13636222])
1006543.529038562

At iterate   13    f=  1.52891D+06    |proj g|=  1.62780D+04

At iterate   14    f=  1.52844D+06    |proj g|=  3.02755D+04

At iterate   15    f=  1.52797D+06    |proj g|=  2.98061D+04

At iterate   16    f=  1.52775D+06    |proj g|=  1.97664D+04

At iterate   17    f=  1.52757D+06    |proj g|=  7.49893D+03

At iterate   18    f=  1.52741D+06    |proj g|=  1.59428D+04

At iterate   19    f=  1.36466D+06    |proj g|=  9.50575D+04

At iterate   20    f=  1.36222D+06    |proj g|=  7.88659D+04

At iterate   21    f=  1.17375D+06    |proj g|=  2.05227D+05

At iterate   22    f=  1.12573D+06    |proj g|=  3.35049D+05

At iterate   23    f=  1.10346D+06    |proj g|=  1.93905D+05

At iterate   24    f=  1.06300D+06    |proj g|=  2.89046D+04

At iterate   25    f=  1.05854D+06    |proj g|=  1.69905D+04

At iterate   26    f=  1.04497D+06    |proj g|=  6.35560D+04

At iterate   27    f=  1.04478D+06    |proj g|=  5.18243D+04

At iterate   28    f=  1.04413D+06    |proj g|=  1.42317D+04

At iterate   29    f=  1.04381D+06    |proj g|=  9.17901D+03

At iterate   30    f=  1.04360D+06    |proj g|=  6.25093D+03

At iterate   31    f=  1.01181D+06    |proj g|=  4.76286D+04

At iterate   32    f=  1.01020D+06    |proj g|=  4.33296D+04

At iterate   33    f=  1.00973D+06    |proj g|=  3.74927D+04

At iterate   34    f=  1.00848D+06    |proj g|=  2.17611D+04

At iterate   35    f=  1.00768D+06    |proj g|=  3.95831D+04

At iterate   36    f=  1.00668D+06    |proj g|=  1.14577D+04

At iterate   37    f=  1.00656D+06    |proj g|=  3.96136D+03

At iterate   38    f=  1.00654D+06    |proj g|=  1.60814D+03

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
10000     38     59      1     0     0   1.608D+03   1.007D+06
  F =   1006543.5290385620     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             

 Cauchy                time 0.000E+00 seconds.
 Subspace minimization time 0.000E+00 seconds.
 Line search           time 0.000E+00 seconds.

 Total User time 0.000E+00 seconds.

@wdevazelhes
Copy link
Member Author

Yes, what I meant to ask is whether you verified that for several matrices A, the cost and gradient you find is the same (up to some small difference) with/without your optimization.

No but indeed that makes sense, I'll run a quick snippet to see this

Regarding the last point: you could use an if condition to choose between the two ways of computing the gradient depending on the relative values of n_features, n_components and n_samples. But if that complicates the code too much, maybe we can just skip this for now and stick to the option you have now.

Why not use np.linalg.multi_dot (it is quite recent though: it was introduced in numpy 1.13) ?

@bellet
Copy link
Member

bellet commented Aug 16, 2018

Yes, I think comparing for the same matrix A makes more sense (in your comparison above, the small differences accumulate along iterations). Maybe you can run the new code, and along the iterations also compute obj/grad at the current point to compare. Or simply generate a bunch of random PSD matrices A and compare.

Indeed np.linalg.multi_dot sounds like the way to go, I did not know about this

@wdevazelhes
Copy link
Member Author

Here is a gist to compare gradients/loss between optimized/not optimized versions: https://gist.github.com/wdevazelhes/8c9f5fdc53ed6e7bbe8d8f958351db85
I ran it for a bunch of matrices A (not necessarily PSD since A is the embedding matrix here no ?)

I tried with make_regression, to see the influence of the number of features, and a higher dimensional feature space appears to make more difference between the two implems

@wdevazelhes
Copy link
Member Author

Here is also another snippet, this time with load_boston. Here the difference seems quite significative isn't it ? :

from metric_learn import MLKR
from sklearn.utils import check_random_state
import numpy as np
from losses import _loss_non_optimized, _loss_optimized
from collections import defaultdict
from sklearn.datasets import load_boston

X, y = load_boston(return_X_y=True)
for seed in range(5):
    rng = check_random_state(seed)
    A  = rng.randn(X.shape[1], X.shape[1])
    print('gradient differences:')
    print(np.linalg.norm(_loss_optimized(A, X, y)[1] 
        - _loss_non_optimized(A, X, y)[1]))
    print('loss differences:')
    print(_loss_optimized(A, X, y)[0] 
            - _loss_non_optimized(A, X, y)[0])

Results:

gradient differences:
0.0002306125088543176
loss differences:
0.0
gradient differences:
0.00022796999902744875
loss differences:
0.0
gradient differences:
0.0001696791546658249
loss differences:
0.0
gradient differences:
0.0004592041477002949
loss differences:
0.0
gradient differences:
0.00042782581177926474
loss differences:
0.0

@wdevazelhes
Copy link
Member Author

I tried the above snippet putting back the supposedly null W.sum(axis=1) in the sum (so the gradient expression becomes

grad = (4 * (X_emb_t * (W.sum(axis=0) + W.sum(axis=1))  - X_emb_t.dot(W + W.T)).dot(X))

And this time the difference is really lower (see below), so I guess in fact maybe the W.sum(axis=1) was maybe compensating some errors ? I will look into it further but I thought it was supposed to be 0...

2.021742238365209e-09
loss differences:
0.0
gradient differences:
6.28459815257491e-10
loss differences:
0.0
gradient differences:
6.62783153933958e-10
loss differences:
0.0
gradient differences:
1.4157154963156227e-09
loss differences:
0.0
gradient differences:
2.5947199373104065e-08
loss differences:
0.0

@bellet
Copy link
Member

bellet commented Aug 16, 2018

You are looking at absolute difference between the two gradients, right? This can be misleading as it ignores the norm of the gradient. You should divide the difference by the norm of the first gradient to get a relative error

@bellet
Copy link
Member

bellet commented Aug 16, 2018

About W.sum(axis=1): I guess the old version uses it so it makes sense that it reduces the difference. It seems to have some impact in the difference you observe with the old code, but it does not mean that this gives a "more correct" gradient.

@wdevazelhes
Copy link
Member Author

You are looking at absolute difference between the two gradients, right? This can be misleading as it ignores the norm of the gradient. You should divide the difference by the norm of the first gradient to get a relative error

True, here is the new result with relative differences:

from metric_learn import MLKR
from sklearn.utils import check_random_state
import numpy as np
from losses import _loss_non_optimized, _loss_optimized
from collections import defaultdict
from sklearn.datasets import load_boston

X, y = load_boston(return_X_y=True)
for seed in range(5):
    rng = check_random_state(seed)
    A  = rng.randn(X.shape[1], X.shape[1])
    print('gradient differences:')
    print(np.linalg.norm((_loss_optimized(A, X, y)[1] 
        - _loss_non_optimized(A, X, y)[1])/np.linalg.norm(_loss_optimized(A, X, y)[1])))
    print('loss differences:')
    print(_loss_optimized(A, X, y)[0] 
            - _loss_non_optimized(A, X, y)[0])
gradient differences:
8.359448158866233e-08
loss differences:
0.0
gradient differences:
1.0784946143039721e-07
loss differences:
0.0
gradient differences:
8.204760940504895e-07
loss differences:
0.0
gradient differences:
1.4688860756792364e-06
loss differences:
0.0
gradient differences:
1.2904076343770319e-08
loss differences:
0.0

@wdevazelhes
Copy link
Member Author

About W.sum(axis=1): I guess the old version uses it so it makes sense that it reduces the difference. It seems to have some impact in the difference you observe with the old code, but it does not mean that this gives a "more correct" gradient.

You're right, I'll try to compare the two computations of the gradient with the finite approximation of it (considered as the "true" gradient), maybe it can give more insight

@wdevazelhes
Copy link
Member Author

Here is a snippets that prints the relative difference of the gradient compared with the finite differences approximation, for optimized and non optimized versions:

from sklearn.datasets import load_boston
from scipy.optimize import check_grad
import numpy as np
from losses import _loss_optimized, _loss_non_optimized
from collections import defaultdict
from sklearn.utils import check_random_state



X, y = load_boston(return_X_y=True)

def finite_diff(loss, seed):
    rng = check_random_state(seed)
    M = rng.randn(rng.randint(1, X.shape[1] + 1), X.shape[1])

    def fun(M):
      return loss(M, X, y)[0]

    def grad(M):
      return loss(M, X, y)[1].ravel()

    rel_diff = check_grad(fun, grad, M.ravel()) / np.linalg.norm(grad(M))
    return rel_diff

if __name__ == '__main__':
    differences = defaultdict(list)
    for seed in range(3):
        for loss in [_loss_optimized, _loss_non_optimized]:
            differences[loss.__name__].append(finite_diff(loss, seed))
    
    means = dict()
    variances = dict()
    for loss in [_loss_optimized, _loss_non_optimized]:
        means[loss.__name__] = np.mean(differences[loss.__name__])
        variances[loss.__name__] = np.var(differences[loss.__name__])
    print(differences)
    print('means: {}'.format(means))
    print('variances: {}'.format(variances))

Results:

defaultdict(<class 'list'>, {'_loss_optimized': [3.992401292465381e-05, 7.768899993233868e-05, 0.0001933318430358699], '_loss_non_optimized': [3.992009297948192e-05, 7.766764656230469e-05, 0.00019337770049358163]})
means: {'_loss_optimized': 0.0001036482852976208, '_loss_non_optimized': 0.00010365514667845608}
variances: {'_loss_optimized': 4.2592693049100326e-09, '_loss_non_optimized': 4.2625479651353965e-09}

Copy link
Member

@bellet bellet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These results support the fact that your optimized method computes a correct gradient (the variation with respect to the previous version is negligible). I think we can merge

@perimosocordiae perimosocordiae merged commit b60b72b into scikit-learn-contrib:master Aug 16, 2018
@perimosocordiae
Copy link
Contributor

Looks good, especially with the new test case. Merged.

@wdevazelhes wdevazelhes deleted the fix/mlkr_loss_grad branch August 22, 2018 06:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants