Skip to content

[MRG+2] Use fused types in sparse mean variance functions #6593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

yenchenlin
Copy link
Contributor

This is a follow up PR from #6588 , which try to make functions in utils/sparsefuncs_fast.pyx support Cython fused types.

In this PR, I focus on functions listed below:

  • csr_mean_variance_axis0
  • csc_mean_variance_axis0
  • incr_mean_variance_axis0

EDIT:
I called mean_variance_axis function on a np.float32 array with shape (5*10^6, 20).
Here is the memory usage over time

  • master:

figure_1

  • this branch:

figure_2

memory usage surrounded by the bracket indeed decrease.

is_CSC = isinstance(X, sp.csc_matrix)
return _incr_mean_variance_axis0(X.data, X.shape, X.indices, X.indptr,
last_mean, last_var, last_n,
is_CSR, is_CSC)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there is a better solution to distinguish X as a csr_matrix or csc_matrix?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use X.format, but the present is better.

In any case, why did you move the logic to check the type of sparse matrix here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see why.

This might be nitpicking but you need not pass is_CSR or is_CSC. You can get whatever information you want from len(X_indptr)

if len(X_indptr) == shape[0] + 1:
    # Then CSR
else:
   # Then CSC

@yenchenlin yenchenlin force-pushed the use-fused-types-in-mean-variance branch 2 times, most recently from a3831e7 to 0405f80 Compare April 2, 2016 03:02
@yenchenlin yenchenlin changed the title [WIP] Use fused types in sparse mean variance functions [MRG] Use fused types in sparse mean variance functions Apr 2, 2016
@yenchenlin
Copy link
Contributor Author

Hello @MechCoder , I've addressed the issues you pointed out.
Also, I've refactored the tests for mean variance functions and make sure it will output np.float32 result when X passed in is np.float32.

@yenchenlin
Copy link
Contributor Author

Would @jnothman please have a look when you have time?
Thanks!

cdef unsigned int n_samples = X.shape[0]
cdef unsigned int n_features = X.shape[1]
if X.dtype == np.int32 or X.dtype == np.int64:
X = X.astype(np.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can cast X to np.float32 if X.dtype is np.int32 right?

@MechCoder
Copy link
Member

Thanks for the refactoring!

Can you use memory profiler to do some quick benchmarks and confirm the lessened memory usage when dtype is np.float32 between this branch and master?

@yenchenlin yenchenlin force-pushed the use-fused-types-in-mean-variance branch 2 times, most recently from acf2eba to 12739e7 Compare April 7, 2016 17:07
@yenchenlin
Copy link
Contributor Author

Hello @MechCoder ,
I've fixed the problems you mentioned.

About memory profiling, I called mean_variance_axis function on a np.float32 array with shape (5*10^6, 20).

Here is the memory usage over time

  • master:

figure_1

  • this branch:

figure_2

As you can see, memory usage surrounded by the bracket drastically decrease.

@MechCoder
Copy link
Member

I would expect the peak memory usage to drastically reduce. Did you forget to rebuild sklearn?

@yenchenlin
Copy link
Contributor Author

Here is my test script:

import numpy as np
import scipy.sparse as sp
from sklearn.utils.sparsefuncs import mean_variance_axis

X = np.random.rand(5000000, 20)
X = X.astype(np.float32)
X_csr = sp.csr_matrix(X)

@profile
def test():
    X_means, X_vars = mean_variance_axis(X_csr, axis=0)
    print X_means.dtype

test()

I think peak memory usage appear when I initialize
X = np.random.rand(5000000, 20), and it should not change between the two branch.

@MechCoder
Copy link
Member

In that case, you can use the -m flag to verify.

python -m memory_profiler test.py

which will provide line by line output.

@MechCoder
Copy link
Member

Oh the graph, looks good, I did not read it properly :/

@yenchenlin
Copy link
Contributor Author

Oh okay 😄

@MechCoder
Copy link
Member

LGTM. cc: @jnothman

@MechCoder MechCoder changed the title [MRG] Use fused types in sparse mean variance functions [MRG+1] Use fused types in sparse mean variance functions Apr 8, 2016
cdef np.ndarray[DOUBLE, ndim=1, mode="c"] X_data
X_data = np.asarray(X.data, dtype=np.float64) # might copy!
cdef np.ndarray[int, ndim=1] X_indices = X.indices
if X.dtype == np.int32:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what I think of these rules... The question being, I guess: what level of precision do we need in a mean?

However, for instance, an int32 converted to float32 loses precision. So we're going to be providing less precise answers than we used to for integer input.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is there a reason we don't support other integer sparse matrix types (including bool, although mean and variance are silly functions to run in that case)? If the code supported those dtypes before (even if untested), I think we'd better maintain support.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have test coverage?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is there a reason we don't support other integer sparse matrix types

I have to admit that I don't know the reason 😢

an int32 converted to float32 loses precision

Yeah you are right, but there's a workaround:
Still convert int32 and int64 into float64 just like before.

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're playing at the margins, but that keeping it as float64 (at least for int32) is safest. Most important of my comments is that other types remain supported.

Indeed, I think we've actually lost some precision in calculating the mean of float32s in float32 rather than float64, but in that case I think the difference is really marginal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what that question means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not being clear, I mean the workaround:

Still convert dtypes except float32 into float64.

can already solve related issues that arguing in some cases we only need float32 precision.
So I think maybe we can first merge the workaround solution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that's the right solution, "workaround" or otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman thanks!

🐝 ping @MechCoder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should revert this back. Sorry about that!

@yenchenlin yenchenlin force-pushed the use-fused-types-in-mean-variance branch from 12739e7 to 61ebeb2 Compare April 12, 2016 07:45
@yenchenlin
Copy link
Contributor Author

Hello @jnothman @MechCoder ,

I've updated the code, please have a look. Thanks!

@MechCoder
Copy link
Member

I will let @jnothman do the honours.

@jnothman
Copy link
Member

It's not really an issue with this PR, but I suspect we should have a test there to ensure the result is sensible for integer dtypes. You could just have something like:

for input_dtype, expected_dtype in [(np.float32, np.float32), (np.int32, np.float64), ...]

In fact, the expected_dtype is not really of interest: the enhancement we're making here is not really testable except via memory profiling.

Thinking further about it, this change still copies data for integers, which we could avoid with a more generic fused type, while we still having float output. Can we assume that the mean of explicitly integer features is not actually something we're interested in often, so not worth the additional compilation time?

We also need a what's new entry boasting what we've enhanced.

@yenchenlin
Copy link
Contributor Author

Hello @jnothman , thanks for your review.
I'll add the test and what's new.

Can we assume that the mean of explicitly integer features is not actually something we're interested in often, so not worth the additional compilation time

Maybe it's my lack of knowledge, but why mean of integer features is often not we are interesred in?

@jnothman
Copy link
Member

I guess integer features, or at least binary, are common enough to many problem spaces. So the current code is still performing a copy for integer input. Do we mind?

@ogrisel
Copy link
Member

ogrisel commented Apr 14, 2016

I guess integer features, or at least binary, are common enough to many problem spaces. So the current code is still performing a copy for integer input. Do we mind?

+1 for not delaying this PR because of integer handling.

# Implement the function here since variables using fused types
# cannot be declared directly and can only be passed as function arguments
cdef unsigned int n_samples = shape[0]
cdef unsigned int n_features =shape[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: spacing around =

@ogrisel
Copy link
Member

ogrisel commented Apr 14, 2016

+1 for not delaying this PR because of integer handling.

But +1 for adding a non-regression test to check that integer dtypes are still supported.

@yenchenlin yenchenlin force-pushed the use-fused-types-in-mean-variance branch from 61ebeb2 to 043d2be Compare April 16, 2016 03:17
@yenchenlin yenchenlin force-pushed the use-fused-types-in-mean-variance branch 2 times, most recently from d98b23e to c9d6ece Compare April 16, 2016 04:30
@yenchenlin
Copy link
Contributor Author

yenchenlin commented Apr 16, 2016

Hi @jnothman @ogrisel @MechCoder ,
I've fixed the issues you guys mentioned.

Would you please check again?
Thanks a lot!

@@ -131,6 +131,9 @@ Enhancements
- Add option to show ``indicator features`` in the output of Imputer.
By `Mani Teja`_.

- Reduce the memory usage of :func:`utils.mean_variance_axis` and :func:`utils.incr_mean_variance_axis`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I was hoping we could say that this reduced the memory usage of some estimators, rather than talking about utils here. Is that wrong? Or do we need to do more to reduce estimator memory usage?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about addressing that separately (checking all estimators that use this function)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind mentioning utils. Saying "for 32-bit float arrays" might be worthwhile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I still need to make assign_rows_csr to support fused types, which can benefit #6430 a lot.
And after that, we may add something about estimators in whats_new?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jnothman , are you talking about something like the following?

Reduce the memory usage for 32-bit float input arrays of :func:utils.mean_variance_axis and :func:utils.incr_mean_variance_axis

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done! 🔨

@yenchenlin yenchenlin force-pushed the use-fused-types-in-mean-variance branch from c9d6ece to 5bc8d1d Compare April 16, 2016 14:21
@jnothman
Copy link
Member

LGTM!

@jnothman jnothman changed the title [MRG+1] Use fused types in sparse mean variance functions [MRG+2] Use fused types in sparse mean variance functions Apr 16, 2016
@MechCoder MechCoder merged commit 28758cc into scikit-learn:master Apr 16, 2016
@MechCoder
Copy link
Member

Thanks, Yen!

mannby pushed a commit to mannby/scikit-learn that referenced this pull request Apr 22, 2016
…rn#6593)

* Use fused types in mean variance functions

* Add test for mean-variance functions using fused trpes

* Add whats_new
TomDLT pushed a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016
…rn#6593)

* Use fused types in mean variance functions

* Add test for mean-variance functions using fused trpes

* Add whats_new
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants