From 6be7b1607a1d988a766658d63c6d75170aacf656 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Sun, 25 May 2025 19:51:23 -0500 Subject: [PATCH 01/12] update documentation for module and module files --- Perturbo_reproducibility | 2 +- src/perturbo/models/_model.py | 230 ++++++++++++++++++++++----------- src/perturbo/models/_module.py | 88 ++++++++----- 3 files changed, 216 insertions(+), 104 deletions(-) diff --git a/Perturbo_reproducibility b/Perturbo_reproducibility index 7a20b99..34328f0 160000 --- a/Perturbo_reproducibility +++ b/Perturbo_reproducibility @@ -1 +1 @@ -Subproject commit 7a20b99c570551d7e301724764528a7912ad81e2 +Subproject commit 34328f0716305aaeaffb065668b95bc3ed3f4951 diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 43b6866..1d862e6 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -34,12 +34,30 @@ class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( self, mdata: AnnOrMuData, - control_guides=None, - load_sparse_tensors=False, - dispersion_smoothing="none", - smoothing_factor=0.3, + control_guides: list | None = None, + load_sparse_tensors: bool = False, + dispersion_smoothing: str = "none", + smoothing_factor: float = 0.3, **model_kwargs, ): + """ + Initialize the PERTURBO model. + + Parameters + ---------- + mdata : AnnOrMuData + MuData or AnnData object containing the data. + control_guides : list or None + List of control guide indices (optional, only used for setting initial values). + load_sparse_tensors : bool + Whether to load sparse tensors. + dispersion_smoothing : str + Smoothing method for dispersion estimation ("none", "linear", "isotonic"). + smoothing_factor : float + Smoothing factor for dispersion smoothing. + model_kwargs : dict + Additional keyword arguments for the model. + """ super().__init__(mdata) # data fields that will be loaded/mini-batched into the module @@ -121,7 +139,20 @@ def __init__( logger.info("The model has been initialized") - def read_matrix_from_registry(self, registry_key): + def read_matrix_from_registry(self, registry_key: str) -> torch.Tensor: + """ + Reads a matrix from the AnnDataManager registry and converts it to a torch.Tensor. + + Parameters + ---------- + registry_key : str + Key to retrieve the matrix from the registry. + + Returns + ------- + torch.Tensor + The matrix as a torch tensor. + """ data = self.adata_manager.get_from_registry(registry_key) if isinstance(data, DataFrame): data = data.values @@ -151,44 +182,45 @@ def setup_mudata( guide_by_element_key: str | None = None, library_size_key: str | None = None, size_factor_key: str | None = None, - gene_mean_key: str | None = None, + gene_mean_key: str | None = None, # not used, supported for legacy reasons continuous_covariates_keys: str | None = None, - categorical_covariates_keys: str | None = None, + # categorical_covariates_keys: str | None = None, # not currently supported modalities: dict[str, str] | None = None, **kwargs, ): - """Registers data from a MuData object with the model. + """ + Registers data from a MuData object with the model. Parameters ---------- - mdata + mdata : MuData A MuData object containing the perturbations and observational data. - rna_layer - The key of the MuData modality containing the RNA counts - perturbation_layer - The key of the MuData modality containing the perturbations - batch_key - Key within the RNA AnnData .obs corresponding to the experimental batch - gene_by_element_key - .varm key within the RNA AnnData object containing a mask of which genes can be affected by which genetic elements - guide_by_element_key - .varm key within the perturbation AnnData object containing which perturbations target which genetic elements - rna_element_uns_key - .uns key within the RNA AnnData object containing names of perturbed elements (if using GENE_BY_ELEMENT_KEY), - otherwise automatically inferred from column names if .varm object is a DataFrame - guide_element_uns_key - .uns key within the perturbation AnnData object containing names of perturbed elements - (if using GUIDE_BY_ELEMENT_KEY), otherwise automatically inferred from column names if .varm object is a DataFrame - library_size_key - .obs key within the RNA AnnData object containing raw (not log-scaled) library size factors for each sample - size_factor_key - .obs key within the RNA AnnData object containing library size factors for each sample (e.g. log-library size) - continuous_covariates_keys - list of .obs keys within the RNA AnnData object containing other continuous covariates to be "regressed out" - modalities - A dict containing these same setup arguments - kwargs - Additional keyword arguments + rna_layer : str or None + The key of the MuData modality containing the RNA counts. + perturbation_layer : str or None + The key of the MuData modality containing the perturbations. + batch_key : str or None + Key within the RNA AnnData .obs corresponding to the experimental batch. + gene_by_element_key : str or None + .varm key within the RNA AnnData object containing a mask of which genes can be affected by which genetic elements. + rna_element_uns_key : str or None + .uns key within the RNA AnnData object containing names of perturbed elements. + guide_element_uns_key : str or None + .uns key within the perturbation AnnData object containing names of perturbed elements. + guide_by_element_key : str or None + .varm key within the perturbation AnnData object containing which perturbations target which genetic elements. + library_size_key : str or None + .obs key within the RNA AnnData object containing raw library size factors for each sample. + size_factor_key : str or None + .obs key within the RNA AnnData object containing library size factors for each sample. + gene_mean_key : str or None + .var key for gene mean (legacy, for simulator). + continuous_covariates_keys : str or None + List of .obs keys within the RNA AnnData object containing other continuous covariates. + modalities : dict[str, str] or None + A dict containing these same setup arguments. + kwargs : dict + Additional keyword arguments. """ setup_method_args = cls._get_setup_method_args(**locals()) @@ -326,36 +358,37 @@ def train( Parameters ---------- - max_epochs - Number of passes through the dataset. If `None`, defaults to - `np.min([round((20000 / n_cells) * 400), 400])` - %(param_use_gpu)s - %(param_accelerator)s - %(param_device)s - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the - sequential order of the data according to `validation_size` and `train_size` percentages. - batch_size - Minibatch size to use during training. If `None`, no minibatching occurs and all - data is copied to device (e.g., GPU). - early_stopping - Perform early stopping. Additional arguments can be passed in `**kwargs`. - See :class:`~scvi.train.Trainer` for further options. - lr - Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). - Specifying optimiser via plan_kwargs overrides this choice of lr. - training_plan - Training plan :class:`~scvi.train.PyroTrainingPlan`. - plan_kwargs - Keyword args for :class:`~scvi.train.PyroTrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **trainer_kwargs - Other keyword args for :class:`~scvi.train.Trainer`. + max_epochs : int + Number of passes through the dataset. + accelerator : str + Accelerator type ("cpu", "gpu", etc.). + device : int or str + Device identifier. + train_size : float + Size of training set in the range [0.0, 1.0]. All cells by default. + validation_size : float or None + Size of the validation set. Zero cells by default. + shuffle_set_split : bool + Whether to shuffle indices before splitting. + batch_size : int + Minibatch size to use during training. + early_stopping : bool + Perform early stopping. + lr : float or None + Optimizer learning rate. + training_plan : type + Training plan class. + plan_kwargs : dict or None + Keyword args for the training plan. + data_splitter_kwargs : dict or None + Keyword args for the data splitter. + trainer_kwargs : dict + Other keyword args for the Trainer. + + Returns + ------- + Any + The result of the training runner. """ plan_kwargs = plan_kwargs if plan_kwargs is not None else {} if len(self.module.discrete_sites) > 0: @@ -407,15 +440,30 @@ def train( ) return runner() - def get_element_names(self): + def get_element_names(self) -> list: + """ + Returns the names of the targeted elements. + + Returns + ------- + list + List of element names. + """ if REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY in self.adata_manager.data_registry: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY).column_names else: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.PERTURBATION_KEY).column_names return element_ids - def get_element_effects(self): - """Return a DataFrame summary of the effects for targeted elements on each gene.""" + def get_element_effects(self) -> pd.DataFrame: + """ + Return a DataFrame summary of the effects for targeted elements on each gene. + + Returns + ------- + pd.DataFrame + DataFrame with columns for effect location, scale, element, gene, z-value, and q-value. + """ element_ids = self.get_element_names() gene_ids = self.adata_manager.get_state_registry("X").column_names for guide in self.module.guide: @@ -461,7 +509,20 @@ def make_long_df(mat, value_name): return element_effects.sort_values("z_value") - def get_map_labels(self, indices: list | None = None): + def get_map_labels(self, indices: list | None = None) -> np.ndarray: + """ + Get MAP (maximum a posteriori) labels for the discrete latent variable "perturbed". + + Parameters + ---------- + indices : list or None + Indices of the data subset to use. + + Returns + ------- + np.ndarray + Array of MAP labels. + """ args, kwargs = self._get_data_subset(indices) self.module.guide(*args, **kwargs) guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) # record the globals @@ -474,7 +535,20 @@ def get_map_labels(self, indices: list | None = None): map_labels = serving_model_trace.nodes["perturbed"]["value"].squeeze().cpu().numpy() return map_labels - def _get_data_subset(self, indices: list | None = None): + def _get_data_subset(self, indices: list | None = None) -> tuple: + """ + Get a data subset for inference. + + Parameters + ---------- + indices : list or None + Indices of the data subset. + + Returns + ------- + tuple + Tuple of (args, kwargs) for the model. + """ loader = AnnDataLoader( adata_manager=self.adata_manager, indices=indices, @@ -512,11 +586,23 @@ def _get_data_subset(self, indices: list | None = None): # return samples -def estimate_nb_params(X, smoothing="isotonic"): +def estimate_nb_params( + X: np.ndarray | sp.spmatrix, smoothing: str = "isotonic" +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Estimate NB mean and dispersion (MoM) for each gene (column) in count matrix X. - With optional smoothing: 'linear', 'quadratic', or 'isotonic' (monotonic). + Parameters + ---------- + X : np.ndarray or scipy.sparse.spmatrix + Count matrix (cells x genes). + smoothing : str + Smoothing method: 'linear', 'quadratic', or 'isotonic'. + + Returns + ------- + tuple of np.ndarray + log_means, log_disp, log_disp_smoothed """ if sp.issparse(X): means = np.array(X.mean(axis=0)).flatten() diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index d3fe940..d8496ec 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -33,15 +33,12 @@ def __init__( log_gene_dispersion_init: torch.Tensor | None = None, guide_by_element: torch.Tensor | None = None, gene_by_element: torch.Tensor | None = None, - # guide_noise: bool = False, likelihood: Literal["nb", "lnnb"] = "nb", effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "laplace", n_factors: int | None = None, n_pert_factors: int | None = None, - # use_interactions: bool = False, efficiency_mode: Literal["mixture", "scaled"] = "scaled", - # dispersion_effects: bool = False, - sparse_local_effects=False, + sparse_local_effects: bool = False, fit_guide_efficacy: bool = True, prior_param_dict: Mapping[str, torch.Tensor] | None = None, **module_kwargs, @@ -51,29 +48,45 @@ def __init__( Parameters ---------- - guide_by_element: + n_cells : int or None + Number of cells in the dataset. Provided by model constructor. + n_genes : int or None + Number of genes in the dataset. Provided by model constructor. + n_perturbations : int or None + Number of perturbations (typically CRISPR gRNAs). Provided by model constructor. + n_elements : int or None + Number of targeted elements (typically genomic targets of gRNAs). Provided by model constructor. + n_cont_covariates : int or None + Number of continuous covariates. Provided by model constructor. + n_batches : int or None + Number of batches. Provided by model constructor. + log_gene_mean_init : torch.Tensor or None + Initial values for log gene means. Provided by model constructor. + log_gene_dispersion_init : torch.Tensor or None + Initial values for log gene dispersions. Provided by model constructor. + guide_by_element : torch.Tensor or None Binary array encoding which element(s) are targeted by each guide. - gene_by_element: + gene_by_element : torch.Tensor or None Binary array encoding which element(s) may target each gene *a priori*. - likelihood: + likelihood : Literal["nb", "lnnb"] Observation likelihood, either NegativeBinomial ("nb") or LogNormalNegativeBinomial ("lnnb"). - effect_prior_dist: - Effect size prior, either Cauchy or NormalMixture ("soft" spike & slab) - n_factors: - Number of cell-specific latent factors - n_pert_factors: - Number of perturbation-specific latent factors - use_interactions: - If using cell factors, allow interactions between perturbations and cell factors - efficiency_mode: - Guide efficiency is fraction of cells perturbed ("mixture") or linear scaling of effect size ("scaled"). - "Mixture" mode currently requires (at most) one guide per cell. - dispersion_effects: - Allow for different gene-level dispersion by batch? - prior_params: - dict containing hyperparameter names and tensors to set prior values - merge_guides_mode: - "shared" treats ("partial") across guides targeting same element? + effect_prior_dist : Literal["cauchy", "normal_mixture", "normal", "laplace"] + Effect size prior, either Cauchy, NormalMixture ("soft" spike & slab), Normal, or Laplace. + n_factors : int or None + EXPERIMENTAL: Number of cell-specific latent factors. + n_pert_factors : int or None + EXPERIMENTAL: Number of perturbation-specific latent factors. + efficiency_mode : Literal["mixture", "scaled"] + Guide efficiency is fraction of cells perturbed ("mixture") or fractional scaling of max per-element effect size ("scaled"). + "Mixture" mode currently requires at most one guide observation per cell. + sparse_local_effects : bool + EXPERIMENTAL: If True, use sparse PyTorch matrix for local effects. + fit_guide_efficacy : bool + If True, fit guide efficacy. If False, assume guide efficacy = 1. + prior_param_dict : Mapping[str, torch.Tensor] or None + Dict containing hyperparameter names and tensors to set prior values. + module_kwargs : dict + Additional keyword arguments (unused). """ super().__init__() # set user-defined options for model behavior @@ -87,10 +100,8 @@ def __init__( self.n_factors = n_factors self.n_pert_factors = n_pert_factors self.effect_prior_dist = effect_prior_dist - # self.use_interactions = use_interactions self.efficiency_mode = efficiency_mode self.local_effects = sparse_local_effects - # self.guide_noise = guide_noise # copy data summary stats self.n_cells = n_cells @@ -115,6 +126,7 @@ def __init__( self.n_batches = n_batches + # Sites to approximate with Delta distribution instead of default Normal distribution. self.delta_sites = [] # self.delta_sites = ["cell_factors"] # self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"] @@ -213,7 +225,7 @@ def __init__( self.register_buffer(k, v) @staticmethod - def _get_fn_args_from_batch(tensor_dict): + def _get_fn_args_from_batch(tensor_dict: dict) -> tuple[tuple[torch.Tensor], dict]: fit_size_factor_covariate = False if fit_size_factor_covariate: @@ -231,7 +243,7 @@ def _get_fn_args_from_batch(tensor_dict): # return indices and then the rest of the tensors return (tensor_dict[REGISTRY_KEYS.INDICES_KEY].squeeze(),), tensor_dict - def create_plates(self, idx, **tensor_dict): + def create_plates(self, idx: torch.Tensor, **tensor_dict) -> tuple: # dims = self.infer_data_dims(idx, **tensor_dict) return ( pyro.plate("Cells", self.n_cells, dim=-2, subsample=idx), @@ -246,7 +258,21 @@ def create_plates(self, idx, **tensor_dict): pyro.plate("Pert_factors", self.n_pert_factors, dim=-3), ) - def model(self, idx, **tensor_dict): + def model(self, idx: torch.Tensor, **tensor_dict) -> None: + """ + The probabilistic model definition for perturbo. + + Parameters + ---------- + idx : torch.Tensor + Indices for subsampling cells. + tensor_dict : dict + Dictionary containing all required tensors for the model. + + Returns + ------- + None + """ pyro.module("perturbo", self) ( cell_plate, @@ -255,7 +281,7 @@ def model(self, idx, **tensor_dict): batch_plate, gene_plate, cont_covariate_plate, - element_effects_plate, # sparse mode + element_effects_plate, guide_plate_sparse, cell_factor_plate, pert_factor_plate, @@ -485,5 +511,5 @@ def model(self, idx, **tensor_dict): raise NotImplementedError(f"'{self.likelihood}' likelihood not implemented") @property - def guide(self): + def guide(self) -> AutoGuideList: return self._guide From 11a3e6a327b51779890bd6ccd58060b54f003d20 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Sun, 25 May 2025 21:01:41 -0500 Subject: [PATCH 02/12] disabled categorical covariates --- src/perturbo/models/_model.py | 2 +- tests/test_basic.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 1d862e6..39295a3 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -184,7 +184,7 @@ def setup_mudata( size_factor_key: str | None = None, gene_mean_key: str | None = None, # not used, supported for legacy reasons continuous_covariates_keys: str | None = None, - # categorical_covariates_keys: str | None = None, # not currently supported + categorical_covariates_keys: str | None = None, # not used, supported for legacy reasons modalities: dict[str, str] | None = None, **kwargs, ): diff --git a/tests/test_basic.py b/tests/test_basic.py index 26d76cb..45f0649 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -39,11 +39,10 @@ def test_model_mdata( pyro.clear_param_store() perturbo.PERTURBO.setup_mudata( mdata, - # size_factor_key="lib_size", + library_size_key="lib_size", batch_key="batch_id", guide_element_uns_key="elements" if use_guide_by_element else None, rna_element_uns_key="elements" if use_gene_by_element else None, - categorical_covariates_keys=["lib_size"], guide_by_element_key=guide_by_element_key if use_guide_by_element else None, gene_by_element_key=gene_by_element_key if use_gene_by_element else None, modalities={ From 1bbdae32a204be69905224a2de0da499420450a6 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 09:47:31 -0500 Subject: [PATCH 03/12] add warning for potentially problematic continuous covariates --- src/perturbo/models/_model.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 39295a3..64a36f5 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -271,6 +271,32 @@ def setup_mudata( mod_key=modalities.rna_layer, ) + # Check continuous covariates for potential issues + obs_df = mdata[modalities.rna_layer].obs + if continuous_covariates_keys is not None: + for cov in continuous_covariates_keys: + values = obs_df[cov].values + unique_vals = np.unique(values) + std = np.std(values) + is_all_int = np.all(np.equal(np.mod(values, 1), 0)) + is_all_same = len(unique_vals) == 1 + is_binary = np.array_equal(unique_vals, [0, 1]) or np.array_equal(unique_vals, [1, 0]) + + if is_all_same: + logger.warning( + f"Continuous covariate '{cov}' has the same value for all observations. " + "Consider removing this covariate." + ) + elif is_all_int and len(unique_vals) > 1 and not is_binary: + logger.warning( + f"Continuous covariate '{cov}' contains only discrete counts. " + "Consider applying log1p transform followed by z-scoring." + ) + elif std > 10 or std < 0.1: + logger.warning( + f"Continuous covariate '{cov}' has standard deviation {std:.3g}. Consider z-scoring." + ) + covariates_field = fields.MuDataNumericalJointObsField( REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariates_keys, From 7ed9d07ca70394a697ecdb998e31d0d2f423ea1f Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 09:50:49 -0500 Subject: [PATCH 04/12] add test for continous covariates --- tests/conftest.py | 1 + tests/test_basic.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 3fb8650..5203882 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ def adata(): { "lib_size": np.random.lognormal(10, 1, size=(n_cells)), "batch_id": np.random.choice(["batch_1", "batch_2", "batch_3"], size=(n_cells)), + "cov1": np.random.normal(size=n_cells), } ) rna_counts = np.random.negative_binomial(100, 0.9, size=(n_cells, n_genes)).astype(np.float32) diff --git a/tests/test_basic.py b/tests/test_basic.py index 45f0649..6dd044b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -41,6 +41,7 @@ def test_model_mdata( mdata, library_size_key="lib_size", batch_key="batch_id", + continuous_covariates_keys=["cov1"], guide_element_uns_key="elements" if use_guide_by_element else None, rna_element_uns_key="elements" if use_gene_by_element else None, guide_by_element_key=guide_by_element_key if use_guide_by_element else None, From e16bc1fdb67b7265e4fa5b6ecec1be4b0b5e30e0 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 12:35:37 -0500 Subject: [PATCH 05/12] output factor effects when no cis effects are present --- src/perturbo/models/_model.py | 87 ++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 32 deletions(-) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 64a36f5..12def12 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -492,48 +492,55 @@ def get_element_effects(self) -> pd.DataFrame: """ element_ids = self.get_element_names() gene_ids = self.adata_manager.get_state_registry("X").column_names - for guide in self.module.guide: - if "element_effects" in guide.median(): - loc_values, scale_values = guide._get_loc_and_scale("element_effects") - # loc_values, loc_plus_scale_values = guide.quantiles([0.5, 0.841])["element_effects"] - # scale_values = loc_plus_scale_values - loc_values - # loc_values, scale_values = self.module.guide._get_loc_and_scale("element_effects") - - if hasattr(self.module, "element_by_gene_idx"): - # loc/scale_values are the nonzero elements of a sparse matrix of elements by genes - i, j = self.module.element_by_gene_idx.detach().cpu().numpy().astype(int) - - # pert_ids = self.adata_manager.get_state_registry("perturbations").column_names - element_effects = pd.DataFrame( - { - "loc": loc_values.detach().cpu().numpy(), - "scale": scale_values.detach().cpu().numpy(), - "element": [element_ids[idx] for idx in i], - "gene": [gene_ids[idx] for idx in j], - } + + def make_long_df(mat, value_name): + return ( + pd.DataFrame(data=mat, index=element_ids, columns=gene_ids) + .melt(var_name="gene", value_name=value_name, ignore_index=False) + .reset_index(names="element") ) - else: - # loc/scale_values are dense matrices of elements by genes - def make_long_df(mat, value_name): - return ( - pd.DataFrame(data=mat, index=element_ids, columns=gene_ids) - .melt(var_name="gene", value_name=value_name, ignore_index=False) - .reset_index(names="element") + # Check if all element effects are factorized and raise an error if so + if "element_effects" not in self.module.guide.median(): + logger.warning("All element effects are factorized. Using 'get_factorized_element_effects' instead.") + pert_factors, pert_loadings = self.get_factorized_element_effects() + element_effects = make_long_df(pert_factors.T @ pert_loadings, "loc") + element_effects["scale"] = np.nan + else: + for guide in self.module.guide: + if "element_effects" in guide.median(): + loc_values, scale_values = guide._get_loc_and_scale("element_effects") + # loc_values, loc_plus_scale_values = guide.quantiles([0.5, 0.841])["element_effects"] + # scale_values = loc_plus_scale_values - loc_values + # loc_values, scale_values = self.module.guide._get_loc_and_scale("element_effects") + + if hasattr(self.module, "element_by_gene_idx"): + # loc/scale_values are the nonzero elements of a sparse matrix of elements by genes + i, j = self.module.element_by_gene_idx.detach().cpu().numpy().astype(int) + + # pert_ids = self.adata_manager.get_state_registry("perturbations").column_names + element_effects = pd.DataFrame( + { + "loc": loc_values.detach().cpu().numpy(), + "scale": scale_values.detach().cpu().numpy(), + "element": [element_ids[idx] for idx in i], + "gene": [gene_ids[idx] for idx in j], + } ) + else: + # loc/scale_values are dense matrices of elements by genes - element_effects = pd.merge( - make_long_df(loc_values.detach().cpu().numpy(), "loc"), - make_long_df(scale_values.detach().cpu().numpy(), "scale"), - ) + element_effects = pd.merge( + make_long_df(loc_values.detach().cpu().numpy(), "loc"), + make_long_df(scale_values.detach().cpu().numpy(), "scale"), + ) element_effects = element_effects.assign( z_value=lambda x: x["loc"] / x["scale"], q_value=lambda x: chi2.sf(x["z_value"] * x["z_value"], df=1), - # is_target=lambda x: x["element"].str.split("_", expand=True)[0] == x["gene"], ) - return element_effects.sort_values("z_value") + return element_effects def get_map_labels(self, indices: list | None = None) -> np.ndarray: """ @@ -611,6 +618,22 @@ def _get_data_subset(self, indices: list | None = None) -> tuple: # ) # return samples + def get_factorized_element_effects(self): + element_ids = self.get_element_names() + gene_ids = self.adata_manager.get_state_registry("X").column_names + + medians = self.module.guide.median() + if "pert_factors" not in medians or "pert_loadings" not in medians: + raise RuntimeError("No perturbation factors found. Use get_element_effects instead") + + pert_factors_2d = medians["pert_factors"].squeeze(-1).detach().cpu().numpy() + assert pert_factors_2d.shape == (self.module.n_pert_factors, self.module.n_elements) + pert_factors_df = pd.DataFrame(pert_factors_2d, columns=element_ids) + pert_loadings_2d = medians["pert_loadings"].squeeze(-2).detach().cpu().numpy() + assert pert_loadings_2d.shape == (self.module.n_pert_factors, self.module.n_genes) + pert_loadings_df = pd.DataFrame(pert_loadings_2d, columns=gene_ids) + return pert_factors_df, pert_loadings_df + def estimate_nb_params( X: np.ndarray | sp.spmatrix, smoothing: str = "isotonic" From e27ae6f1707ca6b7d7a06823634cbb61ed680364 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 12:47:25 -0500 Subject: [PATCH 06/12] allow sparse guide x gene efficiency --- src/perturbo/models/_module.py | 195 ++++++++++++++++++++++----------- tests/test_basic.py | 5 +- 2 files changed, 134 insertions(+), 66 deletions(-) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index d8496ec..afd4327 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -38,7 +38,7 @@ def __init__( n_factors: int | None = None, n_pert_factors: int | None = None, efficiency_mode: Literal["mixture", "scaled"] = "scaled", - sparse_local_effects: bool = False, + sparse_effect_tensors: bool | Literal["auto"] = "auto", fit_guide_efficacy: bool = True, prior_param_dict: Mapping[str, torch.Tensor] | None = None, **module_kwargs, @@ -79,7 +79,7 @@ def __init__( efficiency_mode : Literal["mixture", "scaled"] Guide efficiency is fraction of cells perturbed ("mixture") or fractional scaling of max per-element effect size ("scaled"). "Mixture" mode currently requires at most one guide observation per cell. - sparse_local_effects : bool + sparse_effect_tensors : True | False | "auto" EXPERIMENTAL: If True, use sparse PyTorch matrix for local effects. fit_guide_efficacy : bool If True, fit guide efficacy. If False, assume guide efficacy = 1. @@ -101,7 +101,16 @@ def __init__( self.n_pert_factors = n_pert_factors self.effect_prior_dist = effect_prior_dist self.efficiency_mode = efficiency_mode - self.local_effects = sparse_local_effects + self.local_effects = gene_by_element is not None + + if sparse_effect_tensors == "auto": + if gene_by_element is not None: + sparsity = 1.0 - (gene_by_element.count_nonzero().item() / gene_by_element.numel()) + self.sparse_tensors = sparsity > 0.9 + else: + self.sparse_tensors = False + else: + self.sparse_tensors = sparse_effect_tensors and self.local_effects # copy data summary stats self.n_cells = n_cells @@ -121,6 +130,12 @@ def __init__( assert n_elements is not None, "n_elements must be specified if not equal to n_guides" self.n_elements = guide_by_element.shape[1] + # if self.sparse_tensors: + # self.register_buffer("guide_by_element", guide_by_element.to_sparse_coo()) + # else: + # self.register_buffer("guide_by_element", guide_by_element) + self.register_buffer("guide_by_element", guide_by_element) + if n_cont_covariates is not None: self.n_cont_covariates += n_cont_covariates @@ -160,19 +175,21 @@ def __init__( ## register hyperparameters as buffers so they get automatically moved to GPU by scvi-tools # guide_by_element encoding - self.register_buffer("guide_by_element", guide_by_element) - - if gene_by_element is not None: - self.register_buffer("element_by_gene", gene_by_element.T) - else: - self.element_by_gene = None + if self.sparse_tensors: + assert gene_by_element.shape[1] == self.n_elements + self.register_buffer("element_by_gene_idx", gene_by_element.T.to_sparse_coo().indices()) + self.register_buffer("guide_by_gene_idx", (guide_by_element @ gene_by_element.T).to_sparse_coo().indices()) - if self.local_effects: assert gene_by_element.shape[1] == self.n_elements self.register_buffer("element_by_gene_idx", gene_by_element.T.to_sparse_coo().indices()) self.register_buffer("guide_by_gene_idx", (guide_by_element @ gene_by_element.T).to_sparse_coo().indices()) - self.n_element_effects = self.element_by_gene_idx.shape[1] if self.local_effects else 1 - self.n_guide_effects = self.guide_by_gene_idx.shape[1] if self.local_effects else 1 + self.n_element_effects = self.element_by_gene_idx.shape[1] + self.n_guide_effects = self.guide_by_gene_idx.shape[1] + + else: + if self.local_effects: + self.register_buffer("element_by_gene", gene_by_element.T) + self.n_element_effects = self.n_guide_effects = 1 # global hyperparams self.register_buffer("zero", torch.tensor(0.0)) @@ -198,7 +215,7 @@ def __init__( ## element effect size hyperparams # Normal/Laplace/Cauchy prior - effect_prior_scales = {"cauchy": 0.1, "laplace": 0.5, "normal": 2.0} + effect_prior_scales = {"cauchy": 0.1, "laplace": 0.5, "normal": 1.0} model_effect_prior_scale = effect_prior_scales[effect_prior_dist] self.register_buffer("element_effects_prior_scale", torch.tensor(model_effect_prior_scale)) self.register_buffer("guide_effects_prior_scale", torch.tensor(model_effect_prior_scale)) @@ -282,7 +299,7 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: gene_plate, cont_covariate_plate, element_effects_plate, - guide_plate_sparse, + guide_effects_plate, cell_factor_plate, pert_factor_plate, ) = self.create_plates(idx) @@ -304,22 +321,76 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: elif self.effect_prior_dist == "laplace": effects_dist = dist.Laplace(0.0, self.element_effects_prior_scale) - # Sample cis/trans effect sizes - if self.local_effects: - with element_effects_plate: - element_local_effects_values = pyro.sample("element_effects", effects_dist) - element_local_effects = torch.sparse_coo_tensor( - self.element_by_gene_idx, - element_local_effects_values, - size=(self.n_elements, self.n_genes), + # sample pert factors (if using) + if self.n_pert_factors is not None: + with pert_factor_plate, element_plate: + pert_factors = pyro.sample( + "pert_factors", + dist.Laplace(0.0, self.pert_factor_prior_scale), ) - else: + with pert_factor_plate, gene_plate: + pert_loadings = pyro.sample( + "pert_loadings", + dist.Laplace(0.0, self.pert_loading_prior_scale), + ) + + # Sample either sparse cis effects or dense cis/trans effects + + # option 1: sparse cis effects + if self.local_effects: + if self.sparse_tensors: + with element_effects_plate: + element_local_effects_values = pyro.sample("element_effects", effects_dist) + element_local_effects = torch.sparse_coo_tensor( + self.element_by_gene_idx, + element_local_effects_values, + size=(self.n_elements, self.n_genes), + ) + else: + with element_plate, gene_plate: + element_local_effects = pyro.sample("element_effects", effects_dist) + element_local_effects *= self.element_by_gene + + if self.n_pert_factors is None: + element_effects = element_local_effects + # option 1b: cis effects with factorized trans effects + else: + element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) + element_effects = element_factor_effects + element_local_effects + + # option 2: trans effects + elif not self.n_pert_factors: with element_plate, gene_plate: - element_local_effects = pyro.sample("element_effects", effects_dist) + element_effects = pyro.sample("element_effects", effects_dist) + if self.local_effects: + element_effects *= self.element_by_gene + + # option 3: factorized cis + trans effects + else: + element_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) + if self.local_effects: + element_effects *= self.element_by_gene # Pool guide information based on user-specified strategy - if not self.fit_guide_efficacy: - guide_efficacy = self.one.expand((self.n_perturbations, self.n_genes)) + if self.fit_guide_efficacy: + if self.sparse_tensors: + with guide_effects_plate: + guide_efficiency_values = pyro.sample( + "guide_efficacy", dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta) + ) + guide_efficiency = torch.sparse_coo_tensor( + self.guide_by_gene_idx, + guide_efficiency_values, + size=(self.n_perturbations, self.n_genes), + ) + else: + with guide_plate, gene_plate: + guide_efficiency = pyro.sample( + "guide_efficacy", dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta) + ) + else: + guide_efficiency = self.one.expand((self.n_perturbations, self.n_genes)) + # elif self.local_effects: # with guide_plate_sparse: # guide_efficacy_sparse = pyro.sample( @@ -331,26 +402,26 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: # guide_efficacy_sparse, # size=(self.n_guides, self.n_genes), # ) - else: - with guide_plate, gene_plate: - guide_efficacy = pyro.sample( - "guide_efficacy", - dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta), - ) + # else: + # with guide_plate, gene_plate: + # guide_efficacy = pyro.sample( + # "guide_efficacy", + # dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta), + # ) - # Sample dense or factorized perturbation effects - if self.n_pert_factors is None: - element_factor_effects = 0 - else: - with pert_factor_plate, element_plate: - pert_factors = pyro.sample("pert_factors", dist.Laplace(self.zero, self.pert_factor_prior_scale)) - with pert_factor_plate, gene_plate: - pert_loadings = pyro.sample("pert_loadings", dist.Laplace(self.zero, self.pert_loading_prior_scale)) - # pert_factor_scale_term = pyro.sample( - # "pert_factor_scale_term", - # dist.LogNormal(-self.one, self.one), - # ) - element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) + # # Sample dense or factorized perturbation effects + # if self.n_pert_factors is None: + # element_factor_effects = 0 + # else: + # with pert_factor_plate, element_plate: + # pert_factors = pyro.sample("pert_factors", dist.Laplace(self.zero, self.pert_factor_prior_scale)) + # with pert_factor_plate, gene_plate: + # pert_loadings = pyro.sample("pert_loadings", dist.Laplace(self.zero, self.pert_loading_prior_scale)) + # # pert_factor_scale_term = pyro.sample( + # # "pert_factor_scale_term", + # # dist.LogNormal(-self.one, self.one), + # # ) + # element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) # # Sample cell-specific factors (linear unobserved confounders) if using if self.n_factors is not None: @@ -364,10 +435,6 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: "cell_loadings", dist.Laplace(0.0, self.cell_loading_prior_scale), ) - # cell_factor_scale_term = pyro.sample( - # "cell_factor_scale_term", - # dist.LogNormal(self.zero, self.one), - # ) cell_factor_effects = torch.einsum("fci,fjg->cg", cell_factors, cell_loadings) # if self.use_interactions and self.n_pert_factors is not None: @@ -382,19 +449,17 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: else: cell_factor_effects = 0 - if self.local_effects: - # override factor effects - element_effects = (1 - self.element_by_gene) * element_factor_effects + element_local_effects - # guide_factor_efects = self.guide_by_element @ ((1 - self.element_by_gene) * element_factor_effects) - # guide_local_effects = (guide_efficacy * self.guide_by_element) @ element_local_effects - # guide_effects = guide_factor_efects + guide_local_effects - elif self.element_by_gene is not None: - element_effects = ( - self.element_by_gene * element_local_effects + (1 - self.element_by_gene) * element_factor_effects - ) - else: - element_effects = element_factor_effects + element_local_effects - + # if self.local_effects and self.sparse_tensors: + # element_effects = (1 - self.element_by_gene) * element_factor_effects + element_local_effects + # # guide_factor_efects = self.guide_by_element @ ((1 - self.element_by_gene) * element_factor_effects) + # # guide_local_effects = (guide_efficacy * self.guide_by_element) @ element_local_effects + # # guide_effects = guide_factor_efects + guide_local_effects + # if self.local_effects and not self.sparse_tensors: + # element_effects = ( + # self.element_by_gene * element_local_effects + (1 - self.element_by_gene) * element_factor_effects + # ) + # else: + # element_effects = element_factor_effects + element_local_effects guide_effects = self.guide_by_element @ element_effects # # compute/sample guide effects as function of element effects @@ -410,16 +475,16 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: # Account for cell-specific latent "perturbation status" variable(s) if self.efficiency_mode == "scaled": - mean_perturbation_effect = guides_observed @ (guide_efficacy * guide_effects) + mean_perturbation_effect = guides_observed @ (guide_efficiency * guide_effects) elif self.efficiency_mode == "mixture": - pert_prob = guides_observed @ guide_efficacy + pert_prob = guides_observed @ guide_efficiency assert pert_prob.shape[0] == self.n_cells assert (pert_prob.shape[1] == 1) or (pert_prob.shape[1] == self.n_genes) with cell_plate, gene_plate: perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob), infer={"enumerate": "parallel"}) mean_perturbation_effect = perturbed * (guides_observed @ guide_effects) elif self.efficiency_mode == "mixture_high_moi": # for simulation only! - pert_prob = guide_efficacy.expand((self.n_cells, -1, -1)).transpose(-3, -2) + pert_prob = guide_efficiency.expand((self.n_cells, -1, -1)).transpose(-3, -2) assert pert_prob.shape == (self.n_perturbations, self.n_cells, 1) with cell_plate: perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob)).squeeze(-1).T diff --git a/tests/test_basic.py b/tests/test_basic.py index 6dd044b..884d076 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -20,7 +20,8 @@ def test_package_has_version(): @pytest.mark.parametrize("efficiency_mode", ["mixture", "scaled"]) @pytest.mark.parametrize("use_guide_by_element", [True, False]) @pytest.mark.parametrize("use_gene_by_element", [True, False]) -@pytest.mark.parametrize("fit_guide_efficacy", ["partial", "shared"]) +@pytest.mark.parametrize("fit_guide_efficacy", [True, False]) +@pytest.mark.parametrize("sparse_tensors", [True, False]) # @pytest.mark.parametrize("n_factors", [None, 2]) @pytest.mark.parametrize("n_pert_factors", [None, 2]) def test_model_mdata( @@ -30,6 +31,7 @@ def test_model_mdata( use_guide_by_element, use_gene_by_element, fit_guide_efficacy, + sparse_tensors, n_pert_factors, ): """Check that we can register our MuData object with our model and perform training""" @@ -57,6 +59,7 @@ def test_model_mdata( n_pert_factors=n_pert_factors, efficiency_mode=efficiency_mode, fit_guide_efficacy=fit_guide_efficacy, + sparse_effect_tensors=sparse_tensors, ) assert model.summary_stats.n_cells == len(mdata) assert model.summary_stats.n_vars == len(mdata[rna_key].var) From 041046d23f6369876134b0af2d0a88b9cf704ab7 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 13:07:50 -0500 Subject: [PATCH 07/12] allow sparse guide by element --- src/perturbo/models/_module.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index afd4327..136f223 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -130,11 +130,11 @@ def __init__( assert n_elements is not None, "n_elements must be specified if not equal to n_guides" self.n_elements = guide_by_element.shape[1] - # if self.sparse_tensors: - # self.register_buffer("guide_by_element", guide_by_element.to_sparse_coo()) - # else: - # self.register_buffer("guide_by_element", guide_by_element) - self.register_buffer("guide_by_element", guide_by_element) + if self.sparse_tensors: + self.register_buffer("guide_by_element", guide_by_element.to_sparse_coo()) + else: + self.register_buffer("guide_by_element", guide_by_element) + # self.register_buffer("guide_by_element", guide_by_element) if n_cont_covariates is not None: self.n_cont_covariates += n_cont_covariates @@ -349,7 +349,7 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: else: with element_plate, gene_plate: element_local_effects = pyro.sample("element_effects", effects_dist) - element_local_effects *= self.element_by_gene + element_local_effects = self.element_by_gene * element_local_effects if self.n_pert_factors is None: element_effects = element_local_effects @@ -472,10 +472,13 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: # else: # guide_effects = pyro.deterministic("guide_effects", guide_effects) - # Account for cell-specific latent "perturbation status" variable(s) - + # Account for guide efficiency/efficacy if self.efficiency_mode == "scaled": + # Ensure dense for matmul (should only trigger if using factors with sparse cis effects) + if guide_efficiency.is_sparse and not guide_effects.is_sparse: + guide_efficiency = guide_efficiency.to_dense() mean_perturbation_effect = guides_observed @ (guide_efficiency * guide_effects) + elif self.efficiency_mode == "mixture": pert_prob = guides_observed @ guide_efficiency assert pert_prob.shape[0] == self.n_cells From 80edfebb2690e2db53c5a674b4c8e48cc1c95a6b Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 16:11:16 -0500 Subject: [PATCH 08/12] test loading sparse tensors onto gpu --- src/perturbo/models/_model.py | 10 ++++------ src/perturbo/models/_module.py | 16 +++++++++++++--- tests/test_basic.py | 14 ++++++++++++-- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 12def12..e910456 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -12,7 +12,7 @@ from scipy.stats import chi2 from scvi._types import AnnOrMuData from scvi.data import AnnDataManager, fields -from scvi.dataloaders import AnnDataLoader, DeviceBackedDataSplitter +from scvi.dataloaders import AnnDataLoader, DataSplitter, DeviceBackedDataSplitter from scvi.model.base import ( BaseModelClass, PyroJitGuideWarmup, @@ -35,7 +35,6 @@ def __init__( self, mdata: AnnOrMuData, control_guides: list | None = None, - load_sparse_tensors: bool = False, dispersion_smoothing: str = "none", smoothing_factor: float = 0.3, **model_kwargs, @@ -69,8 +68,6 @@ def __init__( REGISTRY_KEYS.INDICES_KEY: np.int64, } - self.load_sparse_tensors = load_sparse_tensors - n_extra_continuous_covs = 0 if "n_extra_continuous_covs" in self.summary_stats: n_extra_continuous_covs = self.summary_stats.n_extra_continuous_covs @@ -374,6 +371,7 @@ def train( batch_size: int = 1024, early_stopping: bool = False, lr: float | None = 0.005, + load_sparse_tensor: bool = False, training_plan: PyroTrainingPlan = PyroTrainingPlan, plan_kwargs: dict | None = None, data_splitter_kwargs: dict | None = None, @@ -437,12 +435,13 @@ def train( **data_splitter_kwargs, ) else: - data_splitter = self._data_splitter_cls( + data_splitter = DataSplitter( self.adata_manager, train_size=train_size, validation_size=validation_size, shuffle_set_split=shuffle_set_split, batch_size=batch_size, + load_sparse_tensor=load_sparse_tensor, **data_splitter_kwargs, ) @@ -587,7 +586,6 @@ def _get_data_subset(self, indices: list | None = None) -> tuple: indices=indices, batch_size=len(indices) if indices is not None else len(self.adata), data_and_attributes=self.data_and_attrs, - load_sparse_tensor=self.load_sparse_tensors, ) return self.module._get_fn_args_from_batch(next(iter(loader))) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index 136f223..4ac8b4f 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -117,7 +117,9 @@ def __init__( self.n_genes = n_genes self.n_perturbations = n_perturbations self.n_cont_covariates = 1 # include (inferred) size factor as covariate always - self.on_load_kwargs = {"max_epochs": 1} # fixes new bug from ipywidgets loading bar on model load + self.on_load_kwargs = { + "max_epochs": 1, # fixes new bug from ipywidgets loading bar on model load + } self.discrete_sites = [] if efficiency_mode == "mixture": @@ -257,6 +259,14 @@ def _get_fn_args_from_batch(tensor_dict: dict) -> tuple[tuple[torch.Tensor], dic else: tensor_dict[REGISTRY_KEYS.CONT_COVS_KEY] = size_factor + X = tensor_dict[REGISTRY_KEYS.X_KEY] + if X is not None and (X.layout == torch.sparse_csc or X.layout == torch.sparse_csr or X.is_sparse): + tensor_dict[REGISTRY_KEYS.X_KEY] = X.to_dense() + + Y = tensor_dict[REGISTRY_KEYS.PERTURBATION_KEY] + if Y is not None and (Y.layout == torch.sparse_csc or Y.layout == torch.sparse_csr or Y.is_sparse): + tensor_dict[REGISTRY_KEYS.PERTURBATION_KEY] = Y.to_dense() + # return indices and then the rest of the tensors return (tensor_dict[REGISTRY_KEYS.INDICES_KEY].squeeze(),), tensor_dict @@ -481,8 +491,8 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: elif self.efficiency_mode == "mixture": pert_prob = guides_observed @ guide_efficiency - assert pert_prob.shape[0] == self.n_cells - assert (pert_prob.shape[1] == 1) or (pert_prob.shape[1] == self.n_genes) + # assert pert_prob.shape[0] == self.n_cells + # assert (pert_prob.shape[1] == 1) or (pert_prob.shape[1] == self.n_genes) with cell_plate, gene_plate: perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob), infer={"enumerate": "parallel"}) mean_perturbation_effect = perturbed * (guides_observed @ guide_effects) diff --git a/tests/test_basic.py b/tests/test_basic.py index 884d076..328190b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -65,10 +65,20 @@ def test_model_mdata( assert model.summary_stats.n_vars == len(mdata[rna_key].var) assert model.summary_stats.n_perturbations == len(mdata[perturb_key].var) - model.train(max_epochs=10, lr=0.1, batch_size=20) + model.train( + max_epochs=5, + lr=0.1, + batch_size=2, + load_sparse_tensor=sparse_tensors, + ) + model.train( + max_epochs=5, + lr=0.1, + batch_size=None, + load_sparse_tensor=sparse_tensors, + ) element_effects = model.get_element_effects() assert isinstance(element_effects, pd.DataFrame) - assert len(model.history["elbo_train"]) == 10 assert isinstance(model.history["elbo_train"], pd.DataFrame) # test model save/load From 2e87be0589730776e22fe3121807c8ceeeb56f62 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 16:20:24 -0500 Subject: [PATCH 09/12] turn on sparse tensor loading by default for gpu --- src/perturbo/models/_model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index e910456..c0e920e 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -48,8 +48,6 @@ def __init__( MuData or AnnData object containing the data. control_guides : list or None List of control guide indices (optional, only used for setting initial values). - load_sparse_tensors : bool - Whether to load sparse tensors. dispersion_smoothing : str Smoothing method for dispersion estimation ("none", "linear", "isotonic"). smoothing_factor : float @@ -371,7 +369,7 @@ def train( batch_size: int = 1024, early_stopping: bool = False, lr: float | None = 0.005, - load_sparse_tensor: bool = False, + load_sparse_tensor: bool = "auto", training_plan: PyroTrainingPlan = PyroTrainingPlan, plan_kwargs: dict | None = None, data_splitter_kwargs: dict | None = None, @@ -400,6 +398,9 @@ def train( Perform early stopping. lr : float or None Optimizer learning rate. + load_sparse_tensor : bool | "auto" + Whether to transfer data to GPU as sparse tensors (may speed up GPU transfer). + On by default for "gpu" accelerator, otherwise off. training_plan : type Training plan class. plan_kwargs : dict or None @@ -423,7 +424,8 @@ def train( data_splitter_kwargs = {} if "data_and_attributes" not in data_splitter_kwargs: data_splitter_kwargs["data_and_attributes"] = self.data_and_attrs - + if load_sparse_tensor == "auto": + load_sparse_tensor = accelerator == "gpu" if batch_size is None: # use data splitter which moves data to GPU once data_splitter = DeviceBackedDataSplitter( From 866a440cc29af2ebfa8c3ef484a5f4e81b49271b Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 16:47:16 -0500 Subject: [PATCH 10/12] use sparse element by gene matrix --- src/perturbo/models/_module.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index 4ac8b4f..d260457 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -176,6 +176,12 @@ def __init__( ## register hyperparameters as buffers so they get automatically moved to GPU by scvi-tools + if self.local_effects: + if self.sparse_tensors: + self.register_buffer("element_by_gene", gene_by_element.T.to_sparse_coo()) + else: + self.register_buffer("element_by_gene", gene_by_element.T.to_sparse_coo()) + # guide_by_element encoding if self.sparse_tensors: assert gene_by_element.shape[1] == self.n_elements @@ -189,9 +195,7 @@ def __init__( self.n_guide_effects = self.guide_by_gene_idx.shape[1] else: - if self.local_effects: - self.register_buffer("element_by_gene", gene_by_element.T) - self.n_element_effects = self.n_guide_effects = 1 + self.n_element_effects = self.n_guide_effects = 1 # for setting plate sizes # global hyperparams self.register_buffer("zero", torch.tensor(0.0)) @@ -366,7 +370,8 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: # option 1b: cis effects with factorized trans effects else: element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) - element_effects = element_factor_effects + element_local_effects + one = self.one.expand(self.n_elements, self.n_genes) + element_effects = (one - self.element_by_gene) * element_factor_effects + element_local_effects # option 2: trans effects elif not self.n_pert_factors: From ac188bef1485c7997c415c3799164932998fdae1 Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 22:11:07 -0500 Subject: [PATCH 11/12] disable efficiency if using n_pert_factors --- src/perturbo/models/_module.py | 11 ++++++++--- tests/test_basic.py | 5 ++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index d260457..e0bf2a6 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -121,6 +121,11 @@ def __init__( "max_epochs": 1, # fixes new bug from ipywidgets loading bar on model load } + if self.n_pert_factors and self.local_effects: + assert not self.fit_guide_efficacy, ( + "fit_guide_efficacy must be false if using n_pert_factors and gene_by_element" + ) + self.discrete_sites = [] if efficiency_mode == "mixture": self.discrete_sites.append("perturbed") @@ -403,8 +408,6 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: guide_efficiency = pyro.sample( "guide_efficacy", dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta) ) - else: - guide_efficiency = self.one.expand((self.n_perturbations, self.n_genes)) # elif self.local_effects: # with guide_plate_sparse: @@ -488,7 +491,9 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: # guide_effects = pyro.deterministic("guide_effects", guide_effects) # Account for guide efficiency/efficacy - if self.efficiency_mode == "scaled": + if not self.fit_guide_efficacy: + mean_perturbation_effect = guides_observed @ guide_effects + elif self.efficiency_mode == "scaled": # Ensure dense for matmul (should only trigger if using factors with sparse cis effects) if guide_efficiency.is_sparse and not guide_effects.is_sparse: guide_efficiency = guide_efficiency.to_dense() diff --git a/tests/test_basic.py b/tests/test_basic.py index 328190b..3369d1a 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -36,7 +36,10 @@ def test_model_mdata( ): """Check that we can register our MuData object with our model and perform training""" if use_gene_by_element and not use_guide_by_element: - pytest.skip("gene_by_element without guide_by_element test not implemented!") + pytest.skip("gene_by_element without guide_by_element not implemented!") + + if n_pert_factors and fit_guide_efficacy: + pytest.skip("cannot fit guide efficacy if using n_pert_factors!") pyro.clear_param_store() perturbo.PERTURBO.setup_mudata( From 33686036dcf85bd819449b11a8bb10653781f75d Mon Sep 17 00:00:00 2001 From: Logan Blaine Date: Mon, 26 May 2025 22:23:45 -0500 Subject: [PATCH 12/12] disable efficiency if using n_pert_factors (for now) --- src/perturbo/models/_module.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index e0bf2a6..d78c881 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -116,15 +116,13 @@ def __init__( self.n_cells = n_cells self.n_genes = n_genes self.n_perturbations = n_perturbations - self.n_cont_covariates = 1 # include (inferred) size factor as covariate always + self.n_cont_covariates = 1 # include size factor as covariate always self.on_load_kwargs = { "max_epochs": 1, # fixes new bug from ipywidgets loading bar on model load } - if self.n_pert_factors and self.local_effects: - assert not self.fit_guide_efficacy, ( - "fit_guide_efficacy must be false if using n_pert_factors and gene_by_element" - ) + if self.n_pert_factors: + assert not self.fit_guide_efficacy, "fit_guide_efficacy must be False if using n_pert_factors" self.discrete_sites = [] if efficiency_mode == "mixture": @@ -560,7 +558,6 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None: ) covariate_effects = cont_covariates @ cont_covariate_effect_size - # nb_log_mean_ctrl = gene_base_log_mean + size_factor + batch_effects + covariate_effects nb_log_mean_ctrl = ( gene_base_log_mean + size_factor + batch_effects + covariate_effects + cell_factor_effects )