Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: MULTIVI Mudata Minification #3039

Open
wants to merge 96 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
61bbedb
wip
martinkim0 Jun 26, 2024
7929c44
Merge branch 'main' into minified-mode
martinkim0 Jul 2, 2024
1bb2f2f
wip
martinkim0 Jul 3, 2024
245c710
keep empty layers for registry
martinkim0 Jul 3, 2024
b02040d
Merge branch 'main' into minified-mode
canergen Jul 26, 2024
5fea40e
Merge branch 'main' into minified-mode
ori-kron-wis Sep 15, 2024
b02405e
Merge branch 'main' into minified-mode
ori-kron-wis Nov 6, 2024
8d9d012
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
243c206
updated conflics during merge
ori-kron-wis Nov 6, 2024
23c60a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
d0ad5f5
Added mudata support for MULTIVI as well as tests
ori-kron-wis Nov 6, 2024
9450546
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
079faff
needed muon
ori-kron-wis Nov 6, 2024
5e8bc5f
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 6, 2024
b420037
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
815555c
Added ATC/PROTEIN + RNA capability for MultiVI + more tests like in t…
ori-kron-wis Nov 7, 2024
e03b006
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
ea59fd1
small fix
ori-kron-wis Nov 7, 2024
278353f
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 7, 2024
9118af8
small fix
ori-kron-wis Nov 7, 2024
870b0fc
small fix
ori-kron-wis Nov 7, 2024
4f08c1b
small fix
ori-kron-wis Nov 7, 2024
9622f88
small fix
ori-kron-wis Nov 7, 2024
3299355
Added mudata minification models for MULTIVI & TOTALVI as well as tests
ori-kron-wis Nov 12, 2024
e9b72d9
fix typos
ori-kron-wis Nov 12, 2024
987ebd3
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 13, 2024
371ef7a
fixed comments
ori-kron-wis Nov 13, 2024
8f969a9
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-TotalVi-M…
ori-kron-wis Nov 13, 2024
3cf0108
fixed comments
ori-kron-wis Nov 13, 2024
ff82b28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
33cbf84
merge with base branch
ori-kron-wis Nov 13, 2024
1a00720
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 13, 2024
8326895
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 13, 2024
591a9ee
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-TotalVi-M…
ori-kron-wis Nov 13, 2024
b00d439
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-TotalVi-M…
ori-kron-wis Nov 14, 2024
e670511
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 14, 2024
b6d2028
fixed typos
ori-kron-wis Nov 14, 2024
211d344
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 14, 2024
a903359
fix comments
ori-kron-wis Nov 18, 2024
6c08f1c
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 18, 2024
8aecdf8
fix some tests
ori-kron-wis Nov 18, 2024
c957c94
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 19, 2024
7b5f22f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
43b360b
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 19, 2024
5709249
Merge branch 'main' into minified-mode
ori-kron-wis Nov 19, 2024
aab2e09
Merge remote-tracking branch 'origin/main' into minified-mode
ori-kron-wis Nov 19, 2024
fe12fac
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 19, 2024
44e736f
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-TotalVi-M…
ori-kron-wis Nov 19, 2024
951035b
Merge branch 'main' into minified-mode
ori-kron-wis Nov 19, 2024
af8ef83
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 19, 2024
b770de3
Merge remote-tracking branch 'origin/minified-mode' into minified-mode
ori-kron-wis Nov 19, 2024
2a99751
added atac registry field
ori-kron-wis Nov 19, 2024
51077d0
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 19, 2024
3906a6a
merged with the minification refactor branch
ori-kron-wis Nov 19, 2024
a3de92e
Refactored minified models
canergen Nov 19, 2024
bda6b30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
c5678e1
added adata minification for multi/total vi and fixed several tests
ori-kron-wis Nov 19, 2024
108175d
Fixed loss computation for keep count models
canergen Nov 19, 2024
9706947
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
a0ed985
Increase tolerance
canergen Nov 20, 2024
1adc452
Merge branch 'minified-mode' of https://github.com/scverse/scvi-tools…
canergen Nov 20, 2024
0bd790a
Typo
canergen Nov 20, 2024
c971e91
Changed keep adata when keep_counts
canergen Nov 20, 2024
ded463b
Merge remote-tracking branch 'origin/minified-mode' into Ori-MultiVI-…
ori-kron-wis Nov 20, 2024
24177da
Fixed multiVI mudata
canergen Nov 20, 2024
c7b0a3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
101641a
following can's fixes
ori-kron-wis Nov 20, 2024
3b154de
Merge remote-tracking branch 'origin/Ori-MultiVI-TotalVi-Minification…
ori-kron-wis Nov 20, 2024
b32a92f
update branch
ori-kron-wis Nov 20, 2024
effc2f6
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 20, 2024
e369eb1
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 20, 2024
488ed67
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-TotalVi-M…
ori-kron-wis Nov 20, 2024
6629923
Merge remote-tracking branch 'origin/main' into minified-mode
ori-kron-wis Nov 20, 2024
a0cd0bd
fix get_accessibility was using gene indices, should have used region…
ori-kron-wis Nov 20, 2024
711ec54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
93dfda4
Merge remote-tracking branch 'origin/minified-mode' into Ori-MultiVI-…
ori-kron-wis Nov 20, 2024
4a9dca9
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 20, 2024
afaf1b3
some fixes
ori-kron-wis Nov 20, 2024
02a8b8f
update tests
ori-kron-wis Nov 20, 2024
69c06e4
fix tests, save/load
ori-kron-wis Nov 20, 2024
b5eb0a6
more fixes.
ori-kron-wis Nov 20, 2024
cd9c5b2
Refactor code and add tests
canergen Nov 21, 2024
165633c
updated multivi tutorials with mudata and minification
ori-kron-wis Nov 21, 2024
8d91791
Merge branch 'Ori-MultiVI-TotalVi-Minification-MuData' of https://git…
canergen Nov 21, 2024
0143471
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2024
237ff13
Fixed multiVI
canergen Nov 22, 2024
fa7b457
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2024
79c1b96
Merge branch 'main' into Ori-MultiVI-TotalVi-Minification-MuData
ori-kron-wis Nov 27, 2024
a13f29d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
7587475
merged with main
ori-kron-wis Dec 30, 2024
26808d4
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-TotalVi-M…
ori-kron-wis Dec 31, 2024
1a37d8b
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-TotalVi-M…
ori-kron-wis Dec 31, 2024
74ba244
updates
ori-kron-wis Dec 31, 2024
7ada062
updates
ori-kron-wis Dec 31, 2024
35a2d02
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-TotalVi-M…
ori-kron-wis Jan 8, 2025
37d7996
update with main
ori-kron-wis Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added mudata minification models for MULTIVI & TOTALVI as well as tests
  • Loading branch information
ori-kron-wis committed Nov 12, 2024
commit 3299355e7a9884ad170d20f77bf7ce65bef7bf93
1 change: 1 addition & 0 deletions CHANGELOG.md
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will need to update the changlog

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Add MuData Minification option to {class}`~scvi.model.MULTIVI` and {class}`~scvi.model.TOTALVI` {pr}`30XX`.
- Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method
{meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`.

Expand Down
4 changes: 3 additions & 1 deletion src/scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,12 @@ def _get_adata_minify_type(adata: AnnData) -> MinifiedDataType | None:
return adata.uns.get(_constants._ADATA_MINIFY_TYPE_UNS_KEY, None)


def _is_minified(adata: AnnData | str) -> bool:
def _is_minified(adata: AnnOrMuData | str) -> bool:
uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY
if isinstance(adata, AnnData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, MuData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, str):
with h5py.File(adata) as fp:
return uns_key in read_elem(fp["uns"]).keys()
Expand Down
121 changes: 117 additions & 4 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager, fields
from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
ProteinObsmField,
StringUnsField,
)
from scvi.model._utils import (
_get_batch_code_from_category,
Expand All @@ -28,11 +32,12 @@
)
from scvi.model.base import (
ArchesMixin,
BaseModelClass,
BaseMudataMinifiedModeModelClass,
UnsupervisedTrainingMixin,
VAEMixin,
)
from scvi.model.base._de_core import _de_core
from scvi.model.utils import get_minified_mudata
from scvi.module import MULTIVAE
from scvi.train import AdversarialTrainingPlan
from scvi.train._callbacks import SaveBestState
Expand All @@ -45,12 +50,19 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import AnnOrMuData, Number
from scvi._types import AnnOrMuData, MinifiedDataType, Number
from scvi.data.fields import (
BaseAnnDataField,
)

_MULTIVI_LATENT_QZM = "_multivi_latent_qzm"
_MULTIVI_LATENT_QZV = "_multivi_latent_qzv"
_MULTIVI_OBSERVED_LIB_SIZE = "_multivi_observed_lib_size"

logger = logging.getLogger(__name__)


class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):
class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, ArchesMixin, BaseMudataMinifiedModeModelClass):
"""Integration of multi-modal and single-modality data :cite:p:`AshuachGabitto21`.

MultiVI is used to integrate multiomic datasets with single-modality (expression
Expand Down Expand Up @@ -174,6 +186,10 @@ def __init__(

use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry

# TODO: ADD MINIFICATION CONSIDERATION HERE?
# if not use_size_factor_key and self.minified_data_type is None:
# library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

if "n_proteins" in self.summary_stats:
n_proteins = self.summary_stats.n_proteins
else:
Expand Down Expand Up @@ -224,6 +240,7 @@ def __init__(
self.n_genes = n_genes
self.n_regions = n_regions
self.n_proteins = n_proteins
self.module.minified_data_type = self.minified_data_type

@devices_dsp.dedent
def train(
Expand Down Expand Up @@ -414,6 +431,7 @@ def get_latent_representation(
indices: Sequence[int] | None = None,
give_mean: bool = True,
batch_size: int | None = None,
return_dist: bool = False,
) -> np.ndarray:
r"""Return the latent representation for each cell.

Expand All @@ -430,6 +448,9 @@ def get_latent_representation(
Give mean of distribution or sample from it.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_dist
If ``True``, returns the mean and variance of the latent distribution. Otherwise,
returns the mean of the latent distribution.

Returns
-------
Expand Down Expand Up @@ -457,6 +478,8 @@ def get_latent_representation(
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent = []
qz_means = []
qz_vars = []
for tensors in scdl:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
Expand All @@ -473,8 +496,17 @@ def get_latent_representation(
else:
z = qz_m

if return_dist:
qz_means.append(qz_m.cpu())
qz_vars.append(qz_v.cpu())
continue

latent += [z.cpu()]
return torch.cat(latent).numpy()

if return_dist:
return torch.cat(qz_means).numpy(), torch.cat(qz_vars).numpy()
else:
return torch.cat(latent).numpy()

@torch.inference_mode()
def get_accessibility_estimates(
Expand Down Expand Up @@ -1113,6 +1145,87 @@ def setup_mudata(
mod_required=True,
)
)
# TODO: register new fields if the adata is minified
mdata_minify_type = _get_adata_minify_type(mdata)
if mdata_minify_type is not None:
mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_mudata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_MULTIVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_MULTIVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_MULTIVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_mudata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
) -> None:
"""Minifies the model's mudata.

Minifies the mudata, and registers new mudata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the mudata is
minified.

Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:

* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored

Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
# without removing the original counts.
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

# if self.module.use_observed_lib_size is False:
# raise ValueError("Cannot minify the data if `use_observed_lib_size` is False")

minified_adata = get_minified_mudata(self.adata, minified_data_type)
minified_adata.obsm[_MULTIVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_MULTIVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_MULTIVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1)))
self._update_mudata_and_manager_post_minification(minified_adata, minified_data_type)
self.module.minified_data_type = minified_data_type
103 changes: 98 additions & 5 deletions src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager, fields
from scvi.data._utils import _check_nonnegative_integers
from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE
from scvi.data._utils import _check_nonnegative_integers, _get_adata_minify_type
from scvi.data.fields import NumericalObsField, ObsmField, StringUnsField
from scvi.dataloaders import DataSplitter
from scvi.model._utils import (
_get_batch_code_from_category,
Expand All @@ -22,11 +24,12 @@
get_max_epochs_heuristic,
)
from scvi.model.base._de_core import _de_core
from scvi.model.utils import get_minified_mudata
from scvi.module import TOTALVAE
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin
from .base import ArchesMixin, BaseMudataMinifiedModeModelClass, RNASeqMixin, VAEMixin

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand All @@ -35,12 +38,19 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import AnnOrMuData, Number
from scvi._types import AnnOrMuData, MinifiedDataType, Number
from scvi.data.fields import (
BaseAnnDataField,
)

_TOTALVI_LATENT_QZM = "_totalvi_latent_qzm"
_TOTALVI_LATENT_QZV = "_totalvi_latent_qzv"
_TOTALVI_OBSERVED_LIB_SIZE = "_totalvi_observed_lib_size"

logger = logging.getLogger(__name__)


class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMudataMinifiedModeModelClass):
"""total Variational Inference :cite:p:`GayosoSteier21`.

Parameters
Expand Down Expand Up @@ -162,7 +172,8 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key:
# TODO: ADD MINIFICATION CONSIDERATION
if not use_size_factor_key and self.minified_data_type is None:
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
Expand All @@ -184,6 +195,7 @@ def __init__(
library_log_vars=library_log_vars,
**model_kwargs,
)
self.module.minified_data_type = self.minified_data_type
self._model_summary_string = (
f"TotalVI Model with the following params: \nn_latent: {n_latent}, "
f"gene_dispersion: {gene_dispersion}, protein_dispersion: {protein_dispersion}, "
Expand Down Expand Up @@ -1331,6 +1343,87 @@ def setup_mudata(
mod_required=True,
),
]
# TODO: register new fields if the mudata is minified
mdata_minify_type = _get_adata_minify_type(mdata)
if mdata_minify_type is not None:
mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_mudata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the fields required for mudata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_TOTALVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_TOTALVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_TOTALVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_mudata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
) -> None:
"""Minifies the model's mudata.

Minifies the mudata, and registers new mudata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the mudata is
minified.

Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:

* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored

Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
# without removing the original counts.
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

if self.module.use_observed_lib_size is False:
raise ValueError("Cannot minify the data if `use_observed_lib_size` is False")

minified_adata = get_minified_mudata(self.adata, minified_data_type)
minified_adata.obsm[_TOTALVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_TOTALVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_TOTALVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1)))
self._update_mudata_and_manager_post_minification(minified_adata, minified_data_type)
self.module.minified_data_type = minified_data_type
7 changes: 6 additions & 1 deletion src/scvi/model/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from ._archesmixin import ArchesMixin
from ._base_model import BaseMinifiedModeModelClass, BaseModelClass
from ._base_model import (
BaseMinifiedModeModelClass,
BaseModelClass,
BaseMudataMinifiedModeModelClass,
)
from ._differential import DifferentialComputation
from ._embedding_mixin import EmbeddingMixin
from ._jaxmixin import JaxTrainingMixin
Expand All @@ -26,5 +30,6 @@
"DifferentialComputation",
"JaxTrainingMixin",
"BaseMinifiedModeModelClass",
"BaseMudataMinifiedModeModelClass",
"EmbeddingMixin",
]
Loading
Loading