Skip to content

[MRG+1] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types #6846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 21, 2016

Conversation

yenchenlin
Copy link
Contributor

@yenchenlin yenchenlin commented May 31, 2016

This is a follow-up PR for #6430, which use fused types in Cython to work with np.float32 inputs without wasting memory by converting them internally to np.float64.

Since Sparse Func Utils now supports Cython fused types, we can avoid zig-zag memory usage which showed in #6430 .

Updated on 6/19

  • Dense np.float32 Data
    • Before this PR:

master

  • After this PR:

fused_types

  • Sparse np.float32 Data
    • Before this PR:

master_sparse_32

  • After this PR:

fused_types_sparse_32

Here is the test script used to generate above figures (thanks @ssaeger):

import numpy as np
from scipy import sparse as sp
from sklearn.cluster import KMeans

@profile
def fit_est():
    estimator.fit(X)

np.random.seed(5)
X = np.random.rand(200000, 20)
X = np.float32(X)

# X = sp.csr_matrix(X)

estimator = KMeans()
fit_est()

@yenchenlin
Copy link
Contributor Author

It seems that we are facing precision issues again ...

@jnothman
Copy link
Member

Some imprecision in inertia may be tolerable. However, having not looked this over in detail, it seems as if you've changed all the DOUBLEs into floatings. If changing some of these back to DOUBLEs will only increase memory usage in O(n) or O(m) but not O(mn) (n=samples, m=features) I would think it acceptable if it did not also substantially increase runtime. Do you think this is worth investigating?

@yenchenlin
Copy link
Contributor Author

yenchenlin commented Jun 1, 2016

Hello @jnothman , if I understand you correctly, you mean by changing some of the variables back to DOUBLE may only increase memory usage in a O(n) or O(m) without sacrifice precision, right?

@jnothman
Copy link
Member

jnothman commented Jun 1, 2016

Yes.

@yenchenlin
Copy link
Contributor Author

I get it, will do.

However, it's a little bit weird that CI only fails for precision reason in Python2.6 & Python2.7.
Are there any substantial differences in precision between Python2 and Python3?

@jnothman
Copy link
Member

jnothman commented Jun 1, 2016

Are there any substantial differences in precision between Python2 and Python3?

I recall differences in handling of large ints, but don't recall if there was anything widely publicised about float formats. Feel free to investigate a little for your own knowledge, but it's not essential to getting the job done here .

@jnothman
Copy link
Member

jnothman commented Jun 2, 2016

Please ping when you've looked at potential changes, whether or not we can make any beneficial changes.

@yenchenlin
Copy link
Contributor Author

yenchenlin commented Jun 4, 2016

Hello @jnothman and @MechCoder ,

I've fixed the test precision issue.

The only difference I made here is to change the dtype when computing inertia, i.e., np.sum.

Since inertia is the sum of the distances from points to center, using np.float64 accumulator can keep the precision for np.float32 data.

@yenchenlin yenchenlin changed the title [WIP] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types [MRG] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types Jun 4, 2016
@jnothman
Copy link
Member

jnothman commented Jun 4, 2016

It's a bit unfortunate that it involes O(mn) conversions to float64, especially since it's done in every iteration. Can you get an estimate of runtime cost, relative to the lower precision version?

@jnothman
Copy link
Member

jnothman commented Jun 4, 2016

So cumsum(..., dtype=np.float64) is our best option, perhaps?

centers[center_idx] /= counts[center_idx]
# Note: numpy >= 1.10 does not support '/=' for the following
# expression for a mixture of int and float (see numpy issue #6464)
centers[center_idx] = centers[center_idx]/counts[center_idx]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space around / please

@jnothman
Copy link
Member

jnothman commented Jun 4, 2016

Could you please refactor those tests into one? I think the code duplication doesn't give much value.

centers = init
# ensure that the centers have the same dtype as X
# this is a requirement of fused types of cython
centers = np.array(init, dtype=X.dtype)
Copy link
Member

@ogrisel ogrisel Jun 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also trigger a copy of init even if init is already an array with the right dtype while this was not the case before.

I think this is good to systematically copy as it is probably a bad idea to mutate the user provided input array silently. I think you should add a test to check this behavior.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been addressed?

Copy link
Member

@MechCoder MechCoder Jun 17, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a block of code under the function k_means that starts with if hasattr(init, '__array__') that also converts explicitly converts init to higher precision.

That should also be fixed and we should test the attributes cluster_centers_ and inertia_ for precision when init is a ndarray ...

Copy link
Contributor Author

@yenchenlin yenchenlin Jun 17, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test to make sure that we do copy init here.

@MechCoder , since function _init_centroids will always be called after the code block you mentioned and thus make sure centers has the same dtype as input, I think maybe we don't need to touch that part of code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch, that means a copy is already triggered before _init_centroids is called right?

Correct me if I'm wrong but in other words, if X is of dtype np.float32, there are 2 copies made of init.

  1. One in kmeans which makes init of dtype np.float64
  2. And then one in _init_centroids that converts it back to np.float32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right!

But I think it is reasonable to copy the data twice because in step 1.,
init = check_array(init, dtype=np.float64, copy=True) is used to make sure KMeans.init is a copy of the array provided by users.

And in step 2., centers = np.array(init, dtype=X.dtype)
is used to make sure the centers we compute won't alter its argument init.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but I still change this line to init = check_array(init, dtype=X.dtype.type, copy=True) since it is more consistent to keep it as the same datatype as X, and also it makes no harm to precision! 😃

@ogrisel
Copy link
Member

ogrisel commented Jun 4, 2016

Great work @yenchenlin. I added some further comments to address on top of @jnothman's. Looking forward to merge this once they are addressed.

@ogrisel ogrisel added this to the 0.18 milestone Jun 4, 2016
@yenchenlin
Copy link
Contributor Author

yenchenlin commented Jun 4, 2016

Thanks a lot for you guys comments!
I've started addressing your comments.

@yenchenlin yenchenlin force-pushed the fused_types branch 2 times, most recently from 60fb242 to 5f5a81e Compare June 14, 2016 00:58
@jnothman
Copy link
Member

Please avoid amending commits and force-pushing your changes. It's much easier to review changes incrementally, especially weeks after the previous change, if the commits show precisely what has changed. There are other reasons too, but commits with clear messages were designed that way for a reason.

for dtype in [np.int32, np.int64, np.float32, np.float64]:
X_test = dtype(X_small)
init_centers_test = dtype(init_centers)
assert_equal(X_test.dtype, init_centers_test.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line does not really test anything, does it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it yesterday so as to emphasize that they are of same types, but it looks stupid to me overnight. Thanks!

@MechCoder
Copy link
Member

That should be it from me as well. lgtm pending comments. great work!

@MechCoder MechCoder changed the title [MRG] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types [MRG+1] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types Jun 18, 2016
@yenchenlin
Copy link
Contributor Author

Thanks @MechCoder for the inputs, please have a look when you have time!

@jnothman
Copy link
Member

In terms of benchmarking, I mostly meant to make sure that we're actually reducing memory consumption in the sparse and dense cases from what it is at master...

@yenchenlin
Copy link
Contributor Author

yenchenlin commented Jun 19, 2016

Hello @jnothman , sorry for the misunderstanding ...
But what I want to emphasize in this comment is that we can keep precision while not increasing overhead.

About the benchmarking, I've updated the results of sparse input in the main description of this PR,
while the results remain the same in the dense case.

@@ -305,7 +305,7 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto',
X -= X_mean

if hasattr(init, '__array__'):
init = check_array(init, dtype=np.float64, copy=True)
init = check_array(init, dtype=X.dtype.type, copy=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm surprised that we need this .type here and perhaps check_array should allow a dtype object to be passed.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so.

@jnothman
Copy link
Member

Let us know when you've dealt with those cosmetic things, after which I think this looks good for merge.

@yenchenlin
Copy link
Contributor Author

@jnothman Thanks for the check, done!

@MechCoder
Copy link
Member

Thanks for the updates. My +1 still holds.

@jnothman
Copy link
Member

Good work, @yenchenlin !

@jnothman jnothman merged commit 6bea11b into scikit-learn:master Jun 21, 2016
@yenchenlin
Copy link
Contributor Author

yenchenlin commented Jun 21, 2016

Thanks for your review!
Also thanks @ssaeger !

@ssaeger
Copy link

ssaeger commented Jun 21, 2016

Thanks @yenchenlin , great work!

imaculate pushed a commit to imaculate/scikit-learn that referenced this pull request Jun 23, 2016
agramfort pushed a commit that referenced this pull request Jun 23, 2016
… and documentation. Fixes #6862 (#6907)

* Make KernelCenterer a _pairwise operation

Replicate solution to 9a52077 except that `_pairwise` should always be `True` for `KernelCenterer` because it's supposed to receive a Gram matrix. This should make `KernelCenterer` usable in `Pipeline`s.

Happy to add tests, just tell me what should be covered.

* Adding test for PR #6900

* Simplifying imports and test

* updating changelog links on homepage (#6901)

* first commit

* changed binary average back to macro

* changed binomialNB to multinomialNB

* emphasis on "higher return values are better..." (#6909)

* fix typo in comment of hierarchical clustering (#6912)

* [MRG] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types (#6846)

* Fix sklearn.base.clone for all scipy.sparse formats (#6910)

* DOC If git is not installed, need to catch OSError

Fixes #6860

* DOC add what's new for clone fix

* fix a typo in ridge.py (#6917)

* pep8

* TST: Speed up: cv=2

This is a smoke test. Hence there is no point having cv=4

* Added support for sample_weight in linearSVR, including tests and documentation

* Changed assert to assert_allclose and assert_almost_equal, reduced the test tolerance

* Fixed pep8 violations and sampleweight format

* rebased with upstream
olologin pushed a commit to olologin/scikit-learn that referenced this pull request Aug 24, 2016
olologin pushed a commit to olologin/scikit-learn that referenced this pull request Aug 24, 2016
… and documentation. Fixes scikit-learn#6862 (scikit-learn#6907)

* Make KernelCenterer a _pairwise operation

Replicate solution to scikit-learn@9a52077 except that `_pairwise` should always be `True` for `KernelCenterer` because it's supposed to receive a Gram matrix. This should make `KernelCenterer` usable in `Pipeline`s.

Happy to add tests, just tell me what should be covered.

* Adding test for PR scikit-learn#6900

* Simplifying imports and test

* updating changelog links on homepage (scikit-learn#6901)

* first commit

* changed binary average back to macro

* changed binomialNB to multinomialNB

* emphasis on "higher return values are better..." (scikit-learn#6909)

* fix typo in comment of hierarchical clustering (scikit-learn#6912)

* [MRG] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types (scikit-learn#6846)

* Fix sklearn.base.clone for all scipy.sparse formats (scikit-learn#6910)

* DOC If git is not installed, need to catch OSError

Fixes scikit-learn#6860

* DOC add what's new for clone fix

* fix a typo in ridge.py (scikit-learn#6917)

* pep8

* TST: Speed up: cv=2

This is a smoke test. Hence there is no point having cv=4

* Added support for sample_weight in linearSVR, including tests and documentation

* Changed assert to assert_allclose and assert_almost_equal, reduced the test tolerance

* Fixed pep8 violations and sampleweight format

* rebased with upstream
TomDLT pushed a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016
TomDLT pushed a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016
… and documentation. Fixes scikit-learn#6862 (scikit-learn#6907)

* Make KernelCenterer a _pairwise operation

Replicate solution to scikit-learn@9a52077 except that `_pairwise` should always be `True` for `KernelCenterer` because it's supposed to receive a Gram matrix. This should make `KernelCenterer` usable in `Pipeline`s.

Happy to add tests, just tell me what should be covered.

* Adding test for PR scikit-learn#6900

* Simplifying imports and test

* updating changelog links on homepage (scikit-learn#6901)

* first commit

* changed binary average back to macro

* changed binomialNB to multinomialNB

* emphasis on "higher return values are better..." (scikit-learn#6909)

* fix typo in comment of hierarchical clustering (scikit-learn#6912)

* [MRG] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types (scikit-learn#6846)

* Fix sklearn.base.clone for all scipy.sparse formats (scikit-learn#6910)

* DOC If git is not installed, need to catch OSError

Fixes scikit-learn#6860

* DOC add what's new for clone fix

* fix a typo in ridge.py (scikit-learn#6917)

* pep8

* TST: Speed up: cv=2

This is a smoke test. Hence there is no point having cv=4

* Added support for sample_weight in linearSVR, including tests and documentation

* Changed assert to assert_allclose and assert_almost_equal, reduced the test tolerance

* Fixed pep8 violations and sampleweight format

* rebased with upstream
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants