Source code for ot.gmm

# -*- coding: utf-8 -*-
"""
Optimal transport for Gaussian Mixtures
"""

# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
#         Remi Flamary <remi.flamary@polytechnique.edu>
#         Julie Delon <julie.delon@math.cnrs.fr>
#
# License: MIT License

from .backend import get_backend
from .lp import emd2, emd
import numpy as np
from .utils import dist
from .gaussian import bures_wasserstein_mapping, bures_wasserstein_barycenter


[docs] def gaussian_logpdf(x, m, C): r""" Compute the log of the probability density function of a multivariate Gaussian distribution. Parameters ---------- x : array-like, shape (..., d) The input samples. m : array-like, shape (d,) The mean vector of the Gaussian distribution. C : array-like, shape (d, d) The covariance matrix of the Gaussian distribution. Returns ------- pdf : array-like, shape (...,) The probability density function evaluated at each sample. """ assert ( x.shape[-1] == m.shape[-1] == C.shape[-1] == C.shape[-2] ), "Dimension mismatch" nx = get_backend(x, m, C) d = m.shape[0] diff = x - m inv_C = nx.inv(C) z = nx.sum(diff * (diff @ inv_C), axis=-1) _, log_det_C = nx.slogdet(C) return -0.5 * (d * np.log(2 * np.pi) + log_det_C + z)
[docs] def gaussian_pdf(x, m, C): r""" Compute the probability density function of a multivariate Gaussian distribution. Parameters ---------- x : array-like, shape (..., d) The input samples. m : array-like, shape (d,) The mean vector of the Gaussian distribution. C : array-like, shape (d, d) The covariance matrix of the Gaussian distribution. Returns ------- pdf : array-like, shape (...,) The probability density function evaluated at each sample. """ return get_backend(x, m, C).exp(gaussian_logpdf(x, m, C))
[docs] def gmm_pdf(x, m, C, w): r""" Compute the probability density function (PDF) of a Gaussian Mixture Model (GMM) at given points. Parameters ---------- x : array-like, shape (..., d) The input samples. m : array-like, shape (n_components, d) The means of the Gaussian components. C : array-like, shape (n_components, d, d) The covariance matrices of the Gaussian components. w : array-like, shape (n_components,) The weights of the Gaussian components. Returns ------- out : array-like, shape (...,) The PDF values at the given points. """ assert ( m.shape[0] == C.shape[0] == w.shape[0] ), "All GMM parameters must have the same amount of components" nx = get_backend(x, m, C, w) out = nx.zeros((x.shape[:-1])) for k in range(m.shape[0]): out = out + w[k] * gaussian_pdf(x, m[k], C[k]) return out
[docs] def dist_bures_squared(m_s, m_t, C_s, C_t): r""" Compute the matrix of the squared Bures distances between the components of two Gaussian Mixture Models (GMMs). Used to compute the GMM Optimal Transport distance [69]. Parameters ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) Mean vectors of the target GMM. C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM. C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM. Returns ------- dist : array-like, shape (k_s, k_t) Matrix of squared Bures distances between the components of the source and target GMMs. References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ nx = get_backend(m_s, C_s, m_t, C_t) assert m_s.shape[0] == C_s.shape[0], "Source GMM has different amount of components" assert m_t.shape[0] == C_t.shape[0], "Target GMM has different amount of components" assert ( m_s.shape[-1] == m_t.shape[-1] == C_s.shape[-1] == C_t.shape[-1] ), "All GMMs must have the same dimension" D_means = dist(m_s, m_t, metric="sqeuclidean") # C2[i, j] = Cs12[i] @ C_t[j] @ Cs12[i], shape (k_s, k_t, d, d) Cs12 = nx.sqrtm(C_s) # broadcasts matrix sqrt over (k_s,) C2 = nx.einsum("ikl,jlm,imn->ijkn", Cs12, C_t, Cs12) C = nx.sqrtm(C2) # broadcasts matrix sqrt over (k_s, k_t) # D_covs[i,j] = trace(C_s[i] + C_t[j] - 2C[i,j]) trace_C_s = nx.einsum("ikk->i", C_s)[:, None] # (k_s, 1) trace_C_t = nx.einsum("ikk->i", C_t)[None, :] # (1, k_t) D_covs = trace_C_s + trace_C_t # broadcasts to (k_s, k_t) D_covs -= 2 * nx.einsum("ijkk->ij", C) return nx.maximum(D_means + D_covs, 0)
[docs] def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t, log=False): r""" Compute the Gaussian Mixture Model (GMM) Optimal Transport distance between two GMMs introduced in [69]. Parameters ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) Mean vectors of the target GMM. C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM. C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM. w_s : array-like, shape (k_s,) Weights of the source GMM components. w_t : array-like, shape (k_t,) Weights of the target GMM components. log: bool, optional (default=False) If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the GMM optimal transportation cost. Returns ------- loss : float or array-like The GMM-OT loss. log : dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) assert m_s.shape[0] == w_s.shape[0], "Source GMM has different amount of components" assert m_t.shape[0] == w_t.shape[0], "Target GMM has different amount of components" D = dist_bures_squared(m_s, m_t, C_s, C_t) return emd2(w_s, w_t, D, log=log)
[docs] def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t, log=False): r""" Compute the Gaussian Mixture Model (GMM) Optimal Transport plan between two GMMs introduced in [69]. Parameters ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) Mean vectors of the target GMM. C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM. C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM. w_s : array-like, shape (k_s,) Weights of the source GMM components. w_t : array-like, shape (k_t,) Weights of the target GMM components. log : bool, optional (default=False) If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the GMM optimal transportation matrix. Returns ------- plan : array-like, shape (k_s, k_t) The GMM-OT plan. log : dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) assert m_s.shape[0] == w_s.shape[0], "Source GMM has different amount of components" assert m_t.shape[0] == w_t.shape[0], "Target GMM has different amount of components" D = dist_bures_squared(m_s, m_t, C_s, C_t) return emd(w_s, w_t, D, log=log)
[docs] def gmm_ot_apply_map( x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, method="bary", seed=None ): r""" Apply Gaussian Mixture Model (GMM) optimal transport (OT) mapping to input data. The 'barycentric' mapping corresponds to the barycentric projection of the GMM-OT plan, and is called T_bary in [69]. The 'random' mapping takes for each input point a random pair (i,j) of components of the GMMs and applied the Gaussian map, it is called T_rand in [69]. Parameters ---------- x : array-like, shape (n_samples, d) Input data points. m_s : array-like, shape (k_s, d) Mean vectors of the source GMM components. m_t : array-like, shape (k_t, d) Mean vectors of the target GMM components. C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM components. C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM components. w_s : array-like, shape (k_s,) Weights of the source GMM components. w_t : array-like, shape (k_t,) Weights of the target GMM components. plan : array-like, shape (k_s, k_t), optional Optimal transport plan between the source and target GMM components. If not provided, it will be computed internally. method : {'bary', 'rand'}, optional Method for applying the GMM OT mapping. 'bary' uses barycentric mapping, while 'rand' uses random sampling. Default is 'bary'. seed : int, optional Seed for the random number generator. Only used when method='rand'. Returns ------- out : array-like, shape (n_samples, d) Output data points after applying the GMM OT mapping. References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t) else: nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t, plan) k_s, k_t = m_s.shape[0], m_t.shape[0] d = m_s.shape[1] n_samples = x.shape[0] if method == "bary": out = nx.zeros(x.shape) logpdf = nx.stack( [gaussian_logpdf(x, m_s[k], C_s[k])[:, None] for k in range(k_s)] ) # only need to compute for non-zero plan entries for i, j in zip(*nx.where(plan > 0)): Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) A = Cs12inv @ M0 @ Cs12inv b = m_t[j] - A @ m_s[i] # gaussian mapping between components i and j applied to x T_ij_x = x @ A + b z = w_s[:, None, None] * nx.exp(logpdf - logpdf[i][None, :, :]) denom = nx.sum(z, axis=0) out = out + plan[i, j] * T_ij_x / denom return out else: # rand # A[i, j] is the linear part of the gaussian mapping between components # i and j, b[i, j] is the translation part rng = np.random.RandomState(seed) A = nx.zeros((k_s, k_t, d, d)) b = nx.zeros((k_s, k_t, d)) # only need to compute for non-zero plan entries for i, j in zip(*nx.where(plan > 0)): Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) A[i, j] = Cs12inv @ M0 @ Cs12inv b[i, j] = m_t[j] - A[i, j] @ m_s[i] logpdf = nx.stack( [gaussian_logpdf(x, m_s[k], C_s[k]) for k in range(k_s)], axis=-1 ) # (n_samples, k_s) out = nx.zeros(x.shape) for i_sample in range(n_samples): log_g = logpdf[i_sample] log_diff = log_g[:, None] - log_g[None, :] weighted_exp = w_s[:, None] * nx.exp(log_diff) denom = nx.sum(weighted_exp, axis=0)[:, None] * nx.ones(plan.shape[1]) p_mat = plan / denom p = p_mat.reshape(k_s * k_t) # stack line-by-line # sample between 0 and k_s * k_t - 1 ij_mat = rng.choice(k_s * k_t, p=p) i = ij_mat // k_t j = ij_mat % k_t out[i_sample] = A[i, j] @ x[i_sample] + b[i, j] return out
[docs] def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-2): """ Compute the density of the Gaussian Mixture Model - Optimal Transport coupling between GMMS at given points, as introduced in [69]. Given two arrays of points x and y, the function computes the density at each point `(x[i], y[i])` of the product space. Parameters ---------- x : array-like, shape (n, d) Entry points in source space for density computation. y : array-like, shape (m, d) Entry points in target space for density computation. m_s : array-like, shape (k_s, d) The means of the source GMM components. m_t : array-like, shape (k_t, d) The means of the target GMM components. C_s : array-like, shape (k_s, d, d) The covariance matrices of the source GMM components. C_t : array-like, shape (k_t, d, d) The covariance matrices of the target GMM components. w_s : array-like, shape (k_s,) The weights of the source GMM components. w_t : array-like, shape (k_t,) The weights of the target GMM components. plan : array-like, shape (k_s, k_t), optional The optimal transport plan between the source and target GMMs. If not provided, it will be computed using `gmm_ot_plan`. atol : float, optional The absolute tolerance used to determine the support of the GMM-OT coupling. Returns ------- density : array-like, shape (n, m) The density of the GMM-OT coupling between the two GMMs. References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ assert ( x.shape[-1] == y.shape[-1] ), "x (n, d) and y (m, d) must have the same dimension d" n, m = x.shape[0], y.shape[0] nx = get_backend(x, y, m_s, m_t, C_s, C_t, w_s, w_t) # hand-made d-variate meshgrid in ij indexing xx = x[:, None, :] * nx.ones((1, m, 1)) # shapes (n, m, d) yy = y[None, :, :] * nx.ones((n, 1, 1)) # shapes (n, m, d) if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) def Tk0k1(k0, k1): A, b = bures_wasserstein_mapping(m_s[k0], m_t[k1], C_s[k0], C_t[k1]) Tx = xx @ A + b g = gaussian_pdf(xx, m_s[k0], C_s[k0]) out = plan[k0, k1] * g norms = nx.norm(Tx - yy, axis=-1) out = out * ((norms < atol) * 1.0) return out mat = nx.stack( [ nx.stack([Tk0k1(k0, k1) for k1 in range(m_t.shape[0])]) for k0 in range(m_s.shape[0]) ] ) return nx.sum(mat, axis=(0, 1))
[docs] def gmm_barycenter_fixed_point( means_list, covs_list, w_list, means_init, covs_init, weights, w_bar=None, iterations=100, log=False, barycentric_proj_method="euclidean", ): r""" Solves the Gaussian Mixture Model OT barycenter problem (defined in [69]) using the fixed point algorithm (proposed in [77]). The weights of the barycenter are not optimized, and stay the same as the input `w_list` or are initialized to uniform. The algorithm uses barycentric projections of GMM-OT plans, and these can be computed either through Bures Barycenters (slow but accurate, barycentric_proj_method='bures') or by convex combination (fast, barycentric_proj_method='euclidean', default). This is a special case of the generic free-support barycenter solver `ot.lp.free_support_barycenter_generic_costs`. Parameters ---------- means_list : list of array-like List of K (m_k, d) GMM means. covs_list : list of array-like List of K (m_k, d, d) GMM covariances. w_list : list of array-like List of K (m_k) arrays of weights. means_init : array-like Initial (n, d) GMM means. covs_init : array-like Initial (n, d, d) GMM covariances. weights : array-like Array (K,) of the barycentre coefficients. w_bar : array-like, optional Initial weights (n) of the barycentre GMM. If None, initialized to uniform. iterations : int, optional Number of iterations (default is 100). log : bool, optional Whether to return the list of iterations (default is False). barycentric_proj_method : str, optional Method to project the barycentre weights: 'euclidean' (default) or 'bures'. Returns ------- means : array-like (n, d) barycentre GMM means. covs : array-like (n, d, d) barycentre GMM covariances. log_dict : dict, optional Dictionary containing the list of iterations if log is True. References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. .. [77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaƫl (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) See Also -------- ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs. """ nx = get_backend( means_init, covs_init, means_list[0], covs_list[0], w_list[0], weights ) K = len(means_list) n = means_init.shape[0] d = means_init.shape[1] means_its = [nx.copy(means_init)] covs_its = [nx.copy(covs_init)] means, covs = means_init, covs_init if w_bar is None: w_bar = nx.ones(n, type_as=means) / n for _ in range(iterations): pi_list = [ gmm_ot_plan(means, means_list[k], covs, covs_list[k], w_bar, w_list[k]) for k in range(K) ] # filled in the euclidean case means_selection, covs_selection = None, None # in the euclidean case, the selection of Gaussians from each K sources # comes from a barycentric projection: it is a convex combination of the # selected means and covariances, which can be computed without a # for loop on i = 0, ..., n -1 if barycentric_proj_method == "euclidean": means_selection = nx.zeros((n, K, d), type_as=means) covs_selection = nx.zeros((n, K, d, d), type_as=means) for k in range(K): means_selection[:, k, :] = n * pi_list[k] @ means_list[k] covs_selection[:, k, :, :] = ( nx.einsum("ij,jab->iab", pi_list[k], covs_list[k]) * n ) # each component i of the barycentre will be a Bures barycentre of the # selected components of the K GMMs. In the 'bures' barycentric # projection option, the selected components are also Bures barycentres. for i in range(n): # means_selection_i (K, d) is the selected means, each comes from a # Gaussian barycentre along the disintegration of pi_k at i # covs_selection_i (K, d, d) are the selected covariances means_selection_i = None covs_selection_i = None # use previous computation (convex combination) if barycentric_proj_method == "euclidean": means_selection_i = means_selection[i] covs_selection_i = covs_selection[i] # compute Bures barycentre of certain components to get the # selection at i elif barycentric_proj_method == "bures": means_selection_i = nx.zeros((K, d), type_as=means) covs_selection_i = nx.zeros((K, d, d), type_as=means) for k in range(K): w = (1 / w_bar[i]) * pi_list[k][i, :] m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w) means_selection_i[k] = m covs_selection_i[k] = C else: raise ValueError("Unknown barycentric_proj_method") means[i], covs[i] = bures_wasserstein_barycenter( means_selection_i, covs_selection_i, weights ) if log: means_its.append(nx.copy(means)) covs_its.append(nx.copy(covs)) if log: return means, covs, {"means_its": means_its, "covs_its": covs_its} return means, covs