-
-
Notifications
You must be signed in to change notification settings - Fork 26k
PERF Avoid repetitively allocating large temporary arrays when fitting GaussianMixture
#30614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
GaussianMixture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ogrisel This looks nice. Would it make sense to separate the splitting logic into its own PR and keep all the other changes that you have made in this one? I think the error results because of the splitting part only, is that correct?
@ogrisel Do you think we should update this PR? Currently there are conflicts and the build is too old so I can't see the errors that occurred. As far as the tests in the mixture module are concerned, all of them passed on my local windows system. |
I don't have a plan to work on it soon. Feel free to takeover or extract easy to merge parts in a new PR. |
# to convert it to bytes | ||
bytes_per_sample = max(X.dtype.itemsize * X.shape[1], 1) | ||
batch_size = max(int(get_config()["working_memory"] * 1e6) // bytes_per_sample, 1) | ||
float_dtype = precisions_chol.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: For now, we need to extract the dtype from precisions_chol as that is used below in the in-place computation of squared_diff which requires that dtypes should match. This is required because BayesianGaussianMixture does not currently support float32 so directly using X.dtype causes issues plus we also check in the common tests for cases where X has an integer dtype
CC: @lesteve @betatim @antoinebaker @jeremiedbb for reviews |
@ogrisel Do you think we should close this PR, now that the PR incorporating the array API is close to being finalized? |
Once the array API support is merged, this PR will have to be adapted to optimize either the numpy case or the generic array API case when possible. I think the individual optims will have to be reviewed individually maybe by splitting the PR into sub PRs. |
I agree. I think it would be better to open new PRs though instead of trying to adjust this one. What do you think? |
While profiling the memory usage of
GaussianMixture
as part of #30415 (comment) I realized that there were many other possible improvements, independently of the use of float32 data.So here is a WIP PR with a snapshot of the things I found with the help of scalene and memray.
On float64 data, chunking and more liberal use of in-place operations + the a-posteriori covariance matrix centering trick make it possible to reduce fit time by ~40% and trim peak memory usage by 60% on a 400 MB dataset.
TODO:
ValueError: output array is not acceptable (must have the right datatype, number of dimensions, and be a C-Array)