Skip to content

PERF Avoid repetitively allocating large temporary arrays when fitting GaussianMixture #30614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions sklearn/mixture/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def predict_proba(self, X):
check_is_fitted(self)
X = validate_data(self, X, reset=False)
_, log_resp = self._estimate_log_prob_resp(X)
return np.exp(log_resp)
return np.exp(log_resp, out=log_resp)

def sample(self, n_samples=1):
"""Generate random samples from the fitted Gaussian distribution.
Expand Down Expand Up @@ -482,7 +482,9 @@ def _estimate_weighted_log_prob(self, X):
-------
weighted_log_prob : array, shape (n_samples, n_component)
"""
return self._estimate_log_prob(X) + self._estimate_log_weights()
result = self._estimate_log_prob(X)
result += self._estimate_log_weights()
return result

@abstractmethod
def _estimate_log_weights(self):
Expand Down Expand Up @@ -529,11 +531,13 @@ def _estimate_log_prob_resp(self, X):
log_responsibilities : array, shape (n_samples, n_components)
logarithm of the responsibilities
"""
weighted_log_prob = self._estimate_weighted_log_prob(X)
log_prob_norm = logsumexp(weighted_log_prob, axis=1)
# Inplace normalize the weighted log probabilities into log
# responsibilities to avoid a memory copy.
log_resp = self._estimate_weighted_log_prob(X)
log_prob_norm = logsumexp(log_resp, axis=1)
with np.errstate(under="ignore"):
# ignore underflow
log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
# Ignore underflow
log_resp -= log_prob_norm[:, np.newaxis]
return log_prob_norm, log_resp

def _print_verbose_msg_init_beg(self, n_init):
Expand Down
75 changes: 63 additions & 12 deletions sklearn/mixture/_gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
from scipy import linalg

from .._config import get_config
from ..utils import check_array
from ..utils._chunking import gen_batches
from ..utils._param_validation import StrOptions
from ..utils.extmath import row_norms
from ._base import BaseMixture, _check_shape
Expand Down Expand Up @@ -173,8 +175,26 @@ def _estimate_gaussian_covariances_full(resp, X, nk, means, reg_covar):
n_components, n_features = means.shape
covariances = np.empty((n_components, n_features, n_features), dtype=X.dtype)
for k in range(n_components):
diff = X - means[k]
covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k]
if X.dtype == np.float32:
diff = X - means[k]
covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k]
else:
# Compute the covariance matrix of the k-th component by first forming
# the responsibilities-weighted Gram matrix using the uncentered data.
# Then, we subtract the outer product of the responsibilities-weighted
# mean. This avoids an explicit centering of the data, which would
# cause a significant waste of memory.
#
# XXX: We could further optimize memory usage and computation speed if
# we had access to a fused-implementation of the sandwich product
# kernel. A similar pattern occurs in the computation of the Hessian
# matrix in the "newton-cholesky" solver of the LogisticRegression
# class.
np.dot((resp[:, k] / nk[k]) * X.T, X, out=covariances[k])
covariances[k] -= np.outer(means[k], means[k])

# Apply covariance regularization on the diagonal of the covariance
# matrix:
covariances[k].flat[:: n_features + 1] += reg_covar
return covariances

Expand Down Expand Up @@ -480,17 +500,38 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type):
# In short: det(precision_chol) = - det(precision) / 2
log_det = _compute_log_det_cholesky(precisions_chol, covariance_type, n_features)

# Chunk the input X to avoid allocating too large temporary arrays when
# n_samples is large, yet we want to use large enough chunks to benefit
# from BLAS-level parallelism. "working_memory" is expressed in MB, we need
# to convert it to bytes
bytes_per_sample = max(X.dtype.itemsize * X.shape[1], 1)
batch_size = max(int(get_config()["working_memory"] * 1e6) // bytes_per_sample, 1)
float_dtype = precisions_chol.dtype
Copy link
Contributor

@OmarManzoor OmarManzoor Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: For now, we need to extract the dtype from precisions_chol as that is used below in the in-place computation of squared_diff which requires that dtypes should match. This is required because BayesianGaussianMixture does not currently support float32 so directly using X.dtype causes issues plus we also check in the common tests for cases where X has an integer dtype


# Pre-allocate a reusable buffer to store feature-wise squared diff.
if covariance_type in ["full", "tied"]:
squared_diff = np.empty((batch_size, n_features), dtype=float_dtype)

if covariance_type == "full":
log_prob = np.empty((n_samples, n_components), dtype=X.dtype)
for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)):
y = np.dot(X, prec_chol) - np.dot(mu, prec_chol)
log_prob[:, k] = np.sum(np.square(y), axis=1)
log_prob = np.empty((n_samples, n_components), dtype=float_dtype)
for k, (mean_k, prec_chol) in enumerate(zip(means, precisions_chol)):
mean_k_prec_chol = mean_k @ prec_chol
for batch_slice in gen_batches(X.shape[0], batch_size):
X_batch = X[batch_slice]
np.dot(X_batch, prec_chol, out=squared_diff[: len(X_batch)])
squared_diff[: len(X_batch)] -= mean_k_prec_chol
squared_diff[: len(X_batch)] **= 2
log_prob[batch_slice, k] = np.sum(squared_diff[: len(X_batch)], axis=1)

elif covariance_type == "tied":
log_prob = np.empty((n_samples, n_components), dtype=X.dtype)
for k, mu in enumerate(means):
y = np.dot(X, precisions_chol) - np.dot(mu, precisions_chol)
log_prob[:, k] = np.sum(np.square(y), axis=1)
log_prob = np.empty((n_samples, n_components), dtype=float_dtype)
for k, mean_k in enumerate(means):
mean_k_precisions_chol = mean_k @ precisions_chol
for batch_slice in gen_batches(X.shape[0], batch_size):
squared_diff = X[batch_slice] @ precisions_chol
squared_diff -= mean_k_precisions_chol
squared_diff **= 2
log_prob[batch_slice, k] = np.sum(squared_diff, axis=1)

elif covariance_type == "diag":
precisions = precisions_chol**2
Expand All @@ -509,7 +550,11 @@ def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type):
)
# Since we are using the precision of the Cholesky decomposition,
# `- 0.5 * log_det_precision` becomes `+ log_det_precision_chol`
return -0.5 * (n_features * np.log(2 * np.pi).astype(X.dtype) + log_prob) + log_det
result = log_prob
result += n_features * np.log(2 * np.pi).astype(float_dtype)
result *= -0.5
result += log_det
return result


class GaussianMixture(BaseMixture):
Expand Down Expand Up @@ -824,8 +869,14 @@ def _m_step(self, X, log_resp):
Logarithm of the posterior probabilities (or responsibilities) of
the point of each sample in X.
"""
# XXX: inplace mutation of the input argument. This is done to reduce
# the number of large memory allocations of temporary arrays. We know
# that log_resp is not used after this function but it would probably
# be better to refactor the code to make that explicit in the caller by
# passing the result of the exponentiation to the _m_step.
resp = np.exp(log_resp, out=log_resp)
self.weights_, self.means_, self.covariances_ = _estimate_gaussian_parameters(
X, np.exp(log_resp), self.reg_covar, self.covariance_type
X, resp, self.reg_covar, self.covariance_type
)
self.weights_ /= self.weights_.sum()
self.precisions_cholesky_ = _compute_precision_cholesky(
Expand Down