diff --git a/Perturbo_reproducibility b/Perturbo_reproducibility index 7a20b99..34328f0 160000 --- a/Perturbo_reproducibility +++ b/Perturbo_reproducibility @@ -1 +1 @@ -Subproject commit 7a20b99c570551d7e301724764528a7912ad81e2 +Subproject commit 34328f0716305aaeaffb065668b95bc3ed3f4951 diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 43b6866..c0e920e 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -12,7 +12,7 @@ from scipy.stats import chi2 from scvi._types import AnnOrMuData from scvi.data import AnnDataManager, fields -from scvi.dataloaders import AnnDataLoader, DeviceBackedDataSplitter +from scvi.dataloaders import AnnDataLoader, DataSplitter, DeviceBackedDataSplitter from scvi.model.base import ( BaseModelClass, PyroJitGuideWarmup, @@ -34,12 +34,27 @@ class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( self, mdata: AnnOrMuData, - control_guides=None, - load_sparse_tensors=False, - dispersion_smoothing="none", - smoothing_factor=0.3, + control_guides: list | None = None, + dispersion_smoothing: str = "none", + smoothing_factor: float = 0.3, **model_kwargs, ): + """ + Initialize the PERTURBO model. + + Parameters + ---------- + mdata : AnnOrMuData + MuData or AnnData object containing the data. + control_guides : list or None + List of control guide indices (optional, only used for setting initial values). + dispersion_smoothing : str + Smoothing method for dispersion estimation ("none", "linear", "isotonic"). + smoothing_factor : float + Smoothing factor for dispersion smoothing. + model_kwargs : dict + Additional keyword arguments for the model. + """ super().__init__(mdata) # data fields that will be loaded/mini-batched into the module @@ -51,8 +66,6 @@ def __init__( REGISTRY_KEYS.INDICES_KEY: np.int64, } - self.load_sparse_tensors = load_sparse_tensors - n_extra_continuous_covs = 0 if "n_extra_continuous_covs" in self.summary_stats: n_extra_continuous_covs = self.summary_stats.n_extra_continuous_covs @@ -121,7 +134,20 @@ def __init__( logger.info("The model has been initialized") - def read_matrix_from_registry(self, registry_key): + def read_matrix_from_registry(self, registry_key: str) -> torch.Tensor: + """ + Reads a matrix from the AnnDataManager registry and converts it to a torch.Tensor. + + Parameters + ---------- + registry_key : str + Key to retrieve the matrix from the registry. + + Returns + ------- + torch.Tensor + The matrix as a torch tensor. + """ data = self.adata_manager.get_from_registry(registry_key) if isinstance(data, DataFrame): data = data.values @@ -151,44 +177,45 @@ def setup_mudata( guide_by_element_key: str | None = None, library_size_key: str | None = None, size_factor_key: str | None = None, - gene_mean_key: str | None = None, + gene_mean_key: str | None = None, # not used, supported for legacy reasons continuous_covariates_keys: str | None = None, - categorical_covariates_keys: str | None = None, + categorical_covariates_keys: str | None = None, # not used, supported for legacy reasons modalities: dict[str, str] | None = None, **kwargs, ): - """Registers data from a MuData object with the model. + """ + Registers data from a MuData object with the model. Parameters ---------- - mdata + mdata : MuData A MuData object containing the perturbations and observational data. - rna_layer - The key of the MuData modality containing the RNA counts - perturbation_layer - The key of the MuData modality containing the perturbations - batch_key - Key within the RNA AnnData .obs corresponding to the experimental batch - gene_by_element_key - .varm key within the RNA AnnData object containing a mask of which genes can be affected by which genetic elements - guide_by_element_key - .varm key within the perturbation AnnData object containing which perturbations target which genetic elements - rna_element_uns_key - .uns key within the RNA AnnData object containing names of perturbed elements (if using GENE_BY_ELEMENT_KEY), - otherwise automatically inferred from column names if .varm object is a DataFrame - guide_element_uns_key - .uns key within the perturbation AnnData object containing names of perturbed elements - (if using GUIDE_BY_ELEMENT_KEY), otherwise automatically inferred from column names if .varm object is a DataFrame - library_size_key - .obs key within the RNA AnnData object containing raw (not log-scaled) library size factors for each sample - size_factor_key - .obs key within the RNA AnnData object containing library size factors for each sample (e.g. log-library size) - continuous_covariates_keys - list of .obs keys within the RNA AnnData object containing other continuous covariates to be "regressed out" - modalities - A dict containing these same setup arguments - kwargs - Additional keyword arguments + rna_layer : str or None + The key of the MuData modality containing the RNA counts. + perturbation_layer : str or None + The key of the MuData modality containing the perturbations. + batch_key : str or None + Key within the RNA AnnData .obs corresponding to the experimental batch. + gene_by_element_key : str or None + .varm key within the RNA AnnData object containing a mask of which genes can be affected by which genetic elements. + rna_element_uns_key : str or None + .uns key within the RNA AnnData object containing names of perturbed elements. + guide_element_uns_key : str or None + .uns key within the perturbation AnnData object containing names of perturbed elements. + guide_by_element_key : str or None + .varm key within the perturbation AnnData object containing which perturbations target which genetic elements. + library_size_key : str or None + .obs key within the RNA AnnData object containing raw library size factors for each sample. + size_factor_key : str or None + .obs key within the RNA AnnData object containing library size factors for each sample. + gene_mean_key : str or None + .var key for gene mean (legacy, for simulator). + continuous_covariates_keys : str or None + List of .obs keys within the RNA AnnData object containing other continuous covariates. + modalities : dict[str, str] or None + A dict containing these same setup arguments. + kwargs : dict + Additional keyword arguments. """ setup_method_args = cls._get_setup_method_args(**locals()) @@ -239,6 +266,32 @@ def setup_mudata( mod_key=modalities.rna_layer, ) + # Check continuous covariates for potential issues + obs_df = mdata[modalities.rna_layer].obs + if continuous_covariates_keys is not None: + for cov in continuous_covariates_keys: + values = obs_df[cov].values + unique_vals = np.unique(values) + std = np.std(values) + is_all_int = np.all(np.equal(np.mod(values, 1), 0)) + is_all_same = len(unique_vals) == 1 + is_binary = np.array_equal(unique_vals, [0, 1]) or np.array_equal(unique_vals, [1, 0]) + + if is_all_same: + logger.warning( + f"Continuous covariate '{cov}' has the same value for all observations. " + "Consider removing this covariate." + ) + elif is_all_int and len(unique_vals) > 1 and not is_binary: + logger.warning( + f"Continuous covariate '{cov}' contains only discrete counts. " + "Consider applying log1p transform followed by z-scoring." + ) + elif std > 10 or std < 0.1: + logger.warning( + f"Continuous covariate '{cov}' has standard deviation {std:.3g}. Consider z-scoring." + ) + covariates_field = fields.MuDataNumericalJointObsField( REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariates_keys, @@ -316,6 +369,7 @@ def train( batch_size: int = 1024, early_stopping: bool = False, lr: float | None = 0.005, + load_sparse_tensor: bool = "auto", training_plan: PyroTrainingPlan = PyroTrainingPlan, plan_kwargs: dict | None = None, data_splitter_kwargs: dict | None = None, @@ -326,36 +380,40 @@ def train( Parameters ---------- - max_epochs - Number of passes through the dataset. If `None`, defaults to - `np.min([round((20000 / n_cells) * 400), 400])` - %(param_use_gpu)s - %(param_accelerator)s - %(param_device)s - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the - sequential order of the data according to `validation_size` and `train_size` percentages. - batch_size - Minibatch size to use during training. If `None`, no minibatching occurs and all - data is copied to device (e.g., GPU). - early_stopping - Perform early stopping. Additional arguments can be passed in `**kwargs`. - See :class:`~scvi.train.Trainer` for further options. - lr - Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). - Specifying optimiser via plan_kwargs overrides this choice of lr. - training_plan - Training plan :class:`~scvi.train.PyroTrainingPlan`. - plan_kwargs - Keyword args for :class:`~scvi.train.PyroTrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **trainer_kwargs - Other keyword args for :class:`~scvi.train.Trainer`. + max_epochs : int + Number of passes through the dataset. + accelerator : str + Accelerator type ("cpu", "gpu", etc.). + device : int or str + Device identifier. + train_size : float + Size of training set in the range [0.0, 1.0]. All cells by default. + validation_size : float or None + Size of the validation set. Zero cells by default. + shuffle_set_split : bool + Whether to shuffle indices before splitting. + batch_size : int + Minibatch size to use during training. + early_stopping : bool + Perform early stopping. + lr : float or None + Optimizer learning rate. + load_sparse_tensor : bool | "auto" + Whether to transfer data to GPU as sparse tensors (may speed up GPU transfer). + On by default for "gpu" accelerator, otherwise off. + training_plan : type + Training plan class. + plan_kwargs : dict or None + Keyword args for the training plan. + data_splitter_kwargs : dict or None + Keyword args for the data splitter. + trainer_kwargs : dict + Other keyword args for the Trainer. + + Returns + ------- + Any + The result of the training runner. """ plan_kwargs = plan_kwargs if plan_kwargs is not None else {} if len(self.module.discrete_sites) > 0: @@ -366,7 +424,8 @@ def train( data_splitter_kwargs = {} if "data_and_attributes" not in data_splitter_kwargs: data_splitter_kwargs["data_and_attributes"] = self.data_and_attrs - + if load_sparse_tensor == "auto": + load_sparse_tensor = accelerator == "gpu" if batch_size is None: # use data splitter which moves data to GPU once data_splitter = DeviceBackedDataSplitter( @@ -378,12 +437,13 @@ def train( **data_splitter_kwargs, ) else: - data_splitter = self._data_splitter_cls( + data_splitter = DataSplitter( self.adata_manager, train_size=train_size, validation_size=validation_size, shuffle_set_split=shuffle_set_split, batch_size=batch_size, + load_sparse_tensor=load_sparse_tensor, **data_splitter_kwargs, ) @@ -407,61 +467,96 @@ def train( ) return runner() - def get_element_names(self): + def get_element_names(self) -> list: + """ + Returns the names of the targeted elements. + + Returns + ------- + list + List of element names. + """ if REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY in self.adata_manager.data_registry: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY).column_names else: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.PERTURBATION_KEY).column_names return element_ids - def get_element_effects(self): - """Return a DataFrame summary of the effects for targeted elements on each gene.""" + def get_element_effects(self) -> pd.DataFrame: + """ + Return a DataFrame summary of the effects for targeted elements on each gene. + + Returns + ------- + pd.DataFrame + DataFrame with columns for effect location, scale, element, gene, z-value, and q-value. + """ element_ids = self.get_element_names() gene_ids = self.adata_manager.get_state_registry("X").column_names - for guide in self.module.guide: - if "element_effects" in guide.median(): - loc_values, scale_values = guide._get_loc_and_scale("element_effects") - # loc_values, loc_plus_scale_values = guide.quantiles([0.5, 0.841])["element_effects"] - # scale_values = loc_plus_scale_values - loc_values - # loc_values, scale_values = self.module.guide._get_loc_and_scale("element_effects") - - if hasattr(self.module, "element_by_gene_idx"): - # loc/scale_values are the nonzero elements of a sparse matrix of elements by genes - i, j = self.module.element_by_gene_idx.detach().cpu().numpy().astype(int) - - # pert_ids = self.adata_manager.get_state_registry("perturbations").column_names - element_effects = pd.DataFrame( - { - "loc": loc_values.detach().cpu().numpy(), - "scale": scale_values.detach().cpu().numpy(), - "element": [element_ids[idx] for idx in i], - "gene": [gene_ids[idx] for idx in j], - } + + def make_long_df(mat, value_name): + return ( + pd.DataFrame(data=mat, index=element_ids, columns=gene_ids) + .melt(var_name="gene", value_name=value_name, ignore_index=False) + .reset_index(names="element") ) - else: - # loc/scale_values are dense matrices of elements by genes - def make_long_df(mat, value_name): - return ( - pd.DataFrame(data=mat, index=element_ids, columns=gene_ids) - .melt(var_name="gene", value_name=value_name, ignore_index=False) - .reset_index(names="element") + # Check if all element effects are factorized and raise an error if so + if "element_effects" not in self.module.guide.median(): + logger.warning("All element effects are factorized. Using 'get_factorized_element_effects' instead.") + pert_factors, pert_loadings = self.get_factorized_element_effects() + element_effects = make_long_df(pert_factors.T @ pert_loadings, "loc") + element_effects["scale"] = np.nan + else: + for guide in self.module.guide: + if "element_effects" in guide.median(): + loc_values, scale_values = guide._get_loc_and_scale("element_effects") + # loc_values, loc_plus_scale_values = guide.quantiles([0.5, 0.841])["element_effects"] + # scale_values = loc_plus_scale_values - loc_values + # loc_values, scale_values = self.module.guide._get_loc_and_scale("element_effects") + + if hasattr(self.module, "element_by_gene_idx"): + # loc/scale_values are the nonzero elements of a sparse matrix of elements by genes + i, j = self.module.element_by_gene_idx.detach().cpu().numpy().astype(int) + + # pert_ids = self.adata_manager.get_state_registry("perturbations").column_names + element_effects = pd.DataFrame( + { + "loc": loc_values.detach().cpu().numpy(), + "scale": scale_values.detach().cpu().numpy(), + "element": [element_ids[idx] for idx in i], + "gene": [gene_ids[idx] for idx in j], + } ) + else: + # loc/scale_values are dense matrices of elements by genes - element_effects = pd.merge( - make_long_df(loc_values.detach().cpu().numpy(), "loc"), - make_long_df(scale_values.detach().cpu().numpy(), "scale"), - ) + element_effects = pd.merge( + make_long_df(loc_values.detach().cpu().numpy(), "loc"), + make_long_df(scale_values.detach().cpu().numpy(), "scale"), + ) element_effects = element_effects.assign( z_value=lambda x: x["loc"] / x["scale"], q_value=lambda x: chi2.sf(x["z_value"] * x["z_value"], df=1), - # is_target=lambda x: x["element"].str.split("_", expand=True)[0] == x["gene"], ) - return element_effects.sort_values("z_value") + return element_effects + + def get_map_labels(self, indices: list | None = None) -> np.ndarray: + """ + Get MAP (maximum a posteriori) labels for the discrete latent variable "perturbed". + + Parameters + ---------- + indices : list or None + Indices of the data subset to use. - def get_map_labels(self, indices: list | None = None): + Returns + ------- + np.ndarray + Array of MAP labels. + """ args, kwargs = self._get_data_subset(indices) self.module.guide(*args, **kwargs) guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) # record the globals @@ -474,13 +569,25 @@ def get_map_labels(self, indices: list | None = None): map_labels = serving_model_trace.nodes["perturbed"]["value"].squeeze().cpu().numpy() return map_labels - def _get_data_subset(self, indices: list | None = None): + def _get_data_subset(self, indices: list | None = None) -> tuple: + """ + Get a data subset for inference. + + Parameters + ---------- + indices : list or None + Indices of the data subset. + + Returns + ------- + tuple + Tuple of (args, kwargs) for the model. + """ loader = AnnDataLoader( adata_manager=self.adata_manager, indices=indices, batch_size=len(indices) if indices is not None else len(self.adata), data_and_attributes=self.data_and_attrs, - load_sparse_tensor=self.load_sparse_tensors, ) return self.module._get_fn_args_from_batch(next(iter(loader))) @@ -511,12 +618,40 @@ def _get_data_subset(self, indices: list | None = None): # ) # return samples + def get_factorized_element_effects(self): + element_ids = self.get_element_names() + gene_ids = self.adata_manager.get_state_registry("X").column_names + + medians = self.module.guide.median() + if "pert_factors" not in medians or "pert_loadings" not in medians: + raise RuntimeError("No perturbation factors found. Use get_element_effects instead") + + pert_factors_2d = medians["pert_factors"].squeeze(-1).detach().cpu().numpy() + assert pert_factors_2d.shape == (self.module.n_pert_factors, self.module.n_elements) + pert_factors_df = pd.DataFrame(pert_factors_2d, columns=element_ids) + pert_loadings_2d = medians["pert_loadings"].squeeze(-2).detach().cpu().numpy() + assert pert_loadings_2d.shape == (self.module.n_pert_factors, self.module.n_genes) + pert_loadings_df = pd.DataFrame(pert_loadings_2d, columns=gene_ids) + return pert_factors_df, pert_loadings_df + -def estimate_nb_params(X, smoothing="isotonic"): +def estimate_nb_params( + X: np.ndarray | sp.spmatrix, smoothing: str = "isotonic" +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Estimate NB mean and dispersion (MoM) for each gene (column) in count matrix X. - With optional smoothing: 'linear', 'quadratic', or 'isotonic' (monotonic). + Parameters + ---------- + X : np.ndarray or scipy.sparse.spmatrix + Count matrix (cells x genes). + smoothing : str + Smoothing method: 'linear', 'quadratic', or 'isotonic'. + + Returns + ------- + tuple of np.ndarray + log_means, log_disp, log_disp_smoothed """ if sp.issparse(X): means = np.array(X.mean(axis=0)).flatten() diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index d3fe940..d78c881 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -33,15 +33,12 @@ def __init__( log_gene_dispersion_init: torch.Tensor | None = None, guide_by_element: torch.Tensor | None = None, gene_by_element: torch.Tensor | None = None, - # guide_noise: bool = False, likelihood: Literal["nb", "lnnb"] = "nb", effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "laplace", n_factors: int | None = None, n_pert_factors: int | None = None, - # use_interactions: bool = False, efficiency_mode: Literal["mixture", "scaled"] = "scaled", - # dispersion_effects: bool = False, - sparse_local_effects=False, + sparse_effect_tensors: bool | Literal["auto"] = "auto", fit_guide_efficacy: bool = True, prior_param_dict: Mapping[str, torch.Tensor] | None = None, **module_kwargs, @@ -51,29 +48,45 @@ def __init__( Parameters ---------- - guide_by_element: + n_cells : int or None + Number of cells in the dataset. Provided by model constructor. + n_genes : int or None + Number of genes in the dataset. Provided by model constructor. + n_perturbations : int or None + Number of perturbations (typically CRISPR gRNAs). Provided by model constructor. + n_elements : int or None + Number of targeted elements (typically genomic targets of gRNAs). Provided by model constructor. + n_cont_covariates : int or None + Number of continuous covariates. Provided by model constructor. + n_batches : int or None + Number of batches. Provided by model constructor. + log_gene_mean_init : torch.Tensor or None + Initial values for log gene means. Provided by model constructor. + log_gene_dispersion_init : torch.Tensor or None + Initial values for log gene dispersions. Provided by model constructor. + guide_by_element : torch.Tensor or None Binary array encoding which element(s) are targeted by each guide. - gene_by_element: + gene_by_element : torch.Tensor or None Binary array encoding which element(s) may target each gene *a priori*. - likelihood: + likelihood : Literal["nb", "lnnb"] Observation likelihood, either NegativeBinomial ("nb") or LogNormalNegativeBinomial ("lnnb"). - effect_prior_dist: - Effect size prior, either Cauchy or NormalMixture ("soft" spike & slab) - n_factors: - Number of cell-specific latent factors - n_pert_factors: - Number of perturbation-specific latent factors - use_interactions: - If using cell factors, allow interactions between perturbations and cell factors - efficiency_mode: - Guide efficiency is fraction of cells perturbed ("mixture") or linear scaling of effect size ("scaled"). - "Mixture" mode currently requires (at most) one guide per cell. - dispersion_effects: - Allow for different gene-level dispersion by batch? - prior_params: - dict containing hyperparameter names and tensors to set prior values - merge_guides_mode: - "shared" treats ("partial") across guides targeting same element? + effect_prior_dist : Literal["cauchy", "normal_mixture", "normal", "laplace"] + Effect size prior, either Cauchy, NormalMixture ("soft" spike & slab), Normal, or Laplace. + n_factors : int or None + EXPERIMENTAL: Number of cell-specific latent factors. + n_pert_factors : int or None + EXPERIMENTAL: Number of perturbation-specific latent factors. + efficiency_mode : Literal["mixture", "scaled"] + Guide efficiency is fraction of cells perturbed ("mixture") or fractional scaling of max per-element effect size ("scaled"). + "Mixture" mode currently requires at most one guide observation per cell. + sparse_effect_tensors : True | False | "auto" + EXPERIMENTAL: If True, use sparse PyTorch matrix for local effects. + fit_guide_efficacy : bool + If True, fit guide efficacy. If False, assume guide efficacy = 1. + prior_param_dict : Mapping[str, torch.Tensor] or None + Dict containing hyperparameter names and tensors to set prior values. + module_kwargs : dict + Additional keyword arguments (unused). """ super().__init__() # set user-defined options for model behavior @@ -87,17 +100,29 @@ def __init__( self.n_factors = n_factors self.n_pert_factors = n_pert_factors self.effect_prior_dist = effect_prior_dist - # self.use_interactions = use_interactions self.efficiency_mode = efficiency_mode - self.local_effects = sparse_local_effects - # self.guide_noise = guide_noise + self.local_effects = gene_by_element is not None + + if sparse_effect_tensors == "auto": + if gene_by_element is not None: + sparsity = 1.0 - (gene_by_element.count_nonzero().item() / gene_by_element.numel()) + self.sparse_tensors = sparsity > 0.9 + else: + self.sparse_tensors = False + else: + self.sparse_tensors = sparse_effect_tensors and self.local_effects # copy data summary stats self.n_cells = n_cells self.n_genes = n_genes self.n_perturbations = n_perturbations - self.n_cont_covariates = 1 # include (inferred) size factor as covariate always - self.on_load_kwargs = {"max_epochs": 1} # fixes new bug from ipywidgets loading bar on model load + self.n_cont_covariates = 1 # include size factor as covariate always + self.on_load_kwargs = { + "max_epochs": 1, # fixes new bug from ipywidgets loading bar on model load + } + + if self.n_pert_factors: + assert not self.fit_guide_efficacy, "fit_guide_efficacy must be False if using n_pert_factors" self.discrete_sites = [] if efficiency_mode == "mixture": @@ -110,11 +135,18 @@ def __init__( assert n_elements is not None, "n_elements must be specified if not equal to n_guides" self.n_elements = guide_by_element.shape[1] + if self.sparse_tensors: + self.register_buffer("guide_by_element", guide_by_element.to_sparse_coo()) + else: + self.register_buffer("guide_by_element", guide_by_element) + # self.register_buffer("guide_by_element", guide_by_element) + if n_cont_covariates is not None: self.n_cont_covariates += n_cont_covariates self.n_batches = n_batches + # Sites to approximate with Delta distribution instead of default Normal distribution. self.delta_sites = [] # self.delta_sites = ["cell_factors"] # self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"] @@ -147,20 +179,26 @@ def __init__( ## register hyperparameters as buffers so they get automatically moved to GPU by scvi-tools - # guide_by_element encoding - self.register_buffer("guide_by_element", guide_by_element) + if self.local_effects: + if self.sparse_tensors: + self.register_buffer("element_by_gene", gene_by_element.T.to_sparse_coo()) + else: + self.register_buffer("element_by_gene", gene_by_element.T.to_sparse_coo()) - if gene_by_element is not None: - self.register_buffer("element_by_gene", gene_by_element.T) - else: - self.element_by_gene = None + # guide_by_element encoding + if self.sparse_tensors: + assert gene_by_element.shape[1] == self.n_elements + self.register_buffer("element_by_gene_idx", gene_by_element.T.to_sparse_coo().indices()) + self.register_buffer("guide_by_gene_idx", (guide_by_element @ gene_by_element.T).to_sparse_coo().indices()) - if self.local_effects: assert gene_by_element.shape[1] == self.n_elements self.register_buffer("element_by_gene_idx", gene_by_element.T.to_sparse_coo().indices()) self.register_buffer("guide_by_gene_idx", (guide_by_element @ gene_by_element.T).to_sparse_coo().indices()) - self.n_element_effects = self.element_by_gene_idx.shape[1] if self.local_effects else 1 - self.n_guide_effects = self.guide_by_gene_idx.shape[1] if self.local_effects else 1 + self.n_element_effects = self.element_by_gene_idx.shape[1] + self.n_guide_effects = self.guide_by_gene_idx.shape[1] + + else: + self.n_element_effects = self.n_guide_effects = 1 # for setting plate sizes # global hyperparams self.register_buffer("zero", torch.tensor(0.0)) @@ -186,7 +224,7 @@ def __init__( ## element effect size hyperparams # Normal/Laplace/Cauchy prior - effect_prior_scales = {"cauchy": 0.1, "laplace": 0.5, "normal": 2.0} + effect_prior_scales = {"cauchy": 0.1, "laplace": 0.5, "normal": 1.0} model_effect_prior_scale = effect_prior_scales[effect_prior_dist] self.register_buffer("element_effects_prior_scale", torch.tensor(model_effect_prior_scale)) self.register_buffer("guide_effects_prior_scale", torch.tensor(model_effect_prior_scale)) @@ -213,7 +251,7 @@ def __init__( self.register_buffer(k, v) @staticmethod - def _get_fn_args_from_batch(tensor_dict): + def _get_fn_args_from_batch(tensor_dict: dict) -> tuple[tuple[torch.Tensor], dict]: fit_size_factor_covariate = False if fit_size_factor_covariate: @@ -228,10 +266,18 @@ def _get_fn_args_from_batch(tensor_dict): else: tensor_dict[REGISTRY_KEYS.CONT_COVS_KEY] = size_factor + X = tensor_dict[REGISTRY_KEYS.X_KEY] + if X is not None and (X.layout == torch.sparse_csc or X.layout == torch.sparse_csr or X.is_sparse): + tensor_dict[REGISTRY_KEYS.X_KEY] = X.to_dense() + + Y = tensor_dict[REGISTRY_KEYS.PERTURBATION_KEY] + if Y is not None and (Y.layout == torch.sparse_csc or Y.layout == torch.sparse_csr or Y.is_sparse): + tensor_dict[REGISTRY_KEYS.PERTURBATION_KEY] = Y.to_dense() + # return indices and then the rest of the tensors return (tensor_dict[REGISTRY_KEYS.INDICES_KEY].squeeze(),), tensor_dict - def create_plates(self, idx, **tensor_dict): + def create_plates(self, idx: torch.Tensor, **tensor_dict) -> tuple: # dims = self.infer_data_dims(idx, **tensor_dict) return ( pyro.plate("Cells", self.n_cells, dim=-2, subsample=idx), @@ -246,7 +292,21 @@ def create_plates(self, idx, **tensor_dict): pyro.plate("Pert_factors", self.n_pert_factors, dim=-3), ) - def model(self, idx, **tensor_dict): + def model(self, idx: torch.Tensor, **tensor_dict) -> None: + """ + The probabilistic model definition for perturbo. + + Parameters + ---------- + idx : torch.Tensor + Indices for subsampling cells. + tensor_dict : dict + Dictionary containing all required tensors for the model. + + Returns + ------- + None + """ pyro.module("perturbo", self) ( cell_plate, @@ -255,8 +315,8 @@ def model(self, idx, **tensor_dict): batch_plate, gene_plate, cont_covariate_plate, - element_effects_plate, # sparse mode - guide_plate_sparse, + element_effects_plate, + guide_effects_plate, cell_factor_plate, pert_factor_plate, ) = self.create_plates(idx) @@ -278,22 +338,75 @@ def model(self, idx, **tensor_dict): elif self.effect_prior_dist == "laplace": effects_dist = dist.Laplace(0.0, self.element_effects_prior_scale) - # Sample cis/trans effect sizes - if self.local_effects: - with element_effects_plate: - element_local_effects_values = pyro.sample("element_effects", effects_dist) - element_local_effects = torch.sparse_coo_tensor( - self.element_by_gene_idx, - element_local_effects_values, - size=(self.n_elements, self.n_genes), + # sample pert factors (if using) + if self.n_pert_factors is not None: + with pert_factor_plate, element_plate: + pert_factors = pyro.sample( + "pert_factors", + dist.Laplace(0.0, self.pert_factor_prior_scale), ) - else: + with pert_factor_plate, gene_plate: + pert_loadings = pyro.sample( + "pert_loadings", + dist.Laplace(0.0, self.pert_loading_prior_scale), + ) + + # Sample either sparse cis effects or dense cis/trans effects + + # option 1: sparse cis effects + if self.local_effects: + if self.sparse_tensors: + with element_effects_plate: + element_local_effects_values = pyro.sample("element_effects", effects_dist) + element_local_effects = torch.sparse_coo_tensor( + self.element_by_gene_idx, + element_local_effects_values, + size=(self.n_elements, self.n_genes), + ) + else: + with element_plate, gene_plate: + element_local_effects = pyro.sample("element_effects", effects_dist) + element_local_effects = self.element_by_gene * element_local_effects + + if self.n_pert_factors is None: + element_effects = element_local_effects + # option 1b: cis effects with factorized trans effects + else: + element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) + one = self.one.expand(self.n_elements, self.n_genes) + element_effects = (one - self.element_by_gene) * element_factor_effects + element_local_effects + + # option 2: trans effects + elif not self.n_pert_factors: with element_plate, gene_plate: - element_local_effects = pyro.sample("element_effects", effects_dist) + element_effects = pyro.sample("element_effects", effects_dist) + if self.local_effects: + element_effects *= self.element_by_gene + + # option 3: factorized cis + trans effects + else: + element_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) + if self.local_effects: + element_effects *= self.element_by_gene # Pool guide information based on user-specified strategy - if not self.fit_guide_efficacy: - guide_efficacy = self.one.expand((self.n_perturbations, self.n_genes)) + if self.fit_guide_efficacy: + if self.sparse_tensors: + with guide_effects_plate: + guide_efficiency_values = pyro.sample( + "guide_efficacy", dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta) + ) + guide_efficiency = torch.sparse_coo_tensor( + self.guide_by_gene_idx, + guide_efficiency_values, + size=(self.n_perturbations, self.n_genes), + ) + else: + with guide_plate, gene_plate: + guide_efficiency = pyro.sample( + "guide_efficacy", dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta) + ) + # elif self.local_effects: # with guide_plate_sparse: # guide_efficacy_sparse = pyro.sample( @@ -305,26 +418,26 @@ def model(self, idx, **tensor_dict): # guide_efficacy_sparse, # size=(self.n_guides, self.n_genes), # ) - else: - with guide_plate, gene_plate: - guide_efficacy = pyro.sample( - "guide_efficacy", - dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta), - ) + # else: + # with guide_plate, gene_plate: + # guide_efficacy = pyro.sample( + # "guide_efficacy", + # dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta), + # ) - # Sample dense or factorized perturbation effects - if self.n_pert_factors is None: - element_factor_effects = 0 - else: - with pert_factor_plate, element_plate: - pert_factors = pyro.sample("pert_factors", dist.Laplace(self.zero, self.pert_factor_prior_scale)) - with pert_factor_plate, gene_plate: - pert_loadings = pyro.sample("pert_loadings", dist.Laplace(self.zero, self.pert_loading_prior_scale)) - # pert_factor_scale_term = pyro.sample( - # "pert_factor_scale_term", - # dist.LogNormal(-self.one, self.one), - # ) - element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) + # # Sample dense or factorized perturbation effects + # if self.n_pert_factors is None: + # element_factor_effects = 0 + # else: + # with pert_factor_plate, element_plate: + # pert_factors = pyro.sample("pert_factors", dist.Laplace(self.zero, self.pert_factor_prior_scale)) + # with pert_factor_plate, gene_plate: + # pert_loadings = pyro.sample("pert_loadings", dist.Laplace(self.zero, self.pert_loading_prior_scale)) + # # pert_factor_scale_term = pyro.sample( + # # "pert_factor_scale_term", + # # dist.LogNormal(-self.one, self.one), + # # ) + # element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) # # Sample cell-specific factors (linear unobserved confounders) if using if self.n_factors is not None: @@ -338,10 +451,6 @@ def model(self, idx, **tensor_dict): "cell_loadings", dist.Laplace(0.0, self.cell_loading_prior_scale), ) - # cell_factor_scale_term = pyro.sample( - # "cell_factor_scale_term", - # dist.LogNormal(self.zero, self.one), - # ) cell_factor_effects = torch.einsum("fci,fjg->cg", cell_factors, cell_loadings) # if self.use_interactions and self.n_pert_factors is not None: @@ -356,19 +465,17 @@ def model(self, idx, **tensor_dict): else: cell_factor_effects = 0 - if self.local_effects: - # override factor effects - element_effects = (1 - self.element_by_gene) * element_factor_effects + element_local_effects - # guide_factor_efects = self.guide_by_element @ ((1 - self.element_by_gene) * element_factor_effects) - # guide_local_effects = (guide_efficacy * self.guide_by_element) @ element_local_effects - # guide_effects = guide_factor_efects + guide_local_effects - elif self.element_by_gene is not None: - element_effects = ( - self.element_by_gene * element_local_effects + (1 - self.element_by_gene) * element_factor_effects - ) - else: - element_effects = element_factor_effects + element_local_effects - + # if self.local_effects and self.sparse_tensors: + # element_effects = (1 - self.element_by_gene) * element_factor_effects + element_local_effects + # # guide_factor_efects = self.guide_by_element @ ((1 - self.element_by_gene) * element_factor_effects) + # # guide_local_effects = (guide_efficacy * self.guide_by_element) @ element_local_effects + # # guide_effects = guide_factor_efects + guide_local_effects + # if self.local_effects and not self.sparse_tensors: + # element_effects = ( + # self.element_by_gene * element_local_effects + (1 - self.element_by_gene) * element_factor_effects + # ) + # else: + # element_effects = element_factor_effects + element_local_effects guide_effects = self.guide_by_element @ element_effects # # compute/sample guide effects as function of element effects @@ -381,19 +488,24 @@ def model(self, idx, **tensor_dict): # else: # guide_effects = pyro.deterministic("guide_effects", guide_effects) - # Account for cell-specific latent "perturbation status" variable(s) + # Account for guide efficiency/efficacy + if not self.fit_guide_efficacy: + mean_perturbation_effect = guides_observed @ guide_effects + elif self.efficiency_mode == "scaled": + # Ensure dense for matmul (should only trigger if using factors with sparse cis effects) + if guide_efficiency.is_sparse and not guide_effects.is_sparse: + guide_efficiency = guide_efficiency.to_dense() + mean_perturbation_effect = guides_observed @ (guide_efficiency * guide_effects) - if self.efficiency_mode == "scaled": - mean_perturbation_effect = guides_observed @ (guide_efficacy * guide_effects) elif self.efficiency_mode == "mixture": - pert_prob = guides_observed @ guide_efficacy - assert pert_prob.shape[0] == self.n_cells - assert (pert_prob.shape[1] == 1) or (pert_prob.shape[1] == self.n_genes) + pert_prob = guides_observed @ guide_efficiency + # assert pert_prob.shape[0] == self.n_cells + # assert (pert_prob.shape[1] == 1) or (pert_prob.shape[1] == self.n_genes) with cell_plate, gene_plate: perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob), infer={"enumerate": "parallel"}) mean_perturbation_effect = perturbed * (guides_observed @ guide_effects) elif self.efficiency_mode == "mixture_high_moi": # for simulation only! - pert_prob = guide_efficacy.expand((self.n_cells, -1, -1)).transpose(-3, -2) + pert_prob = guide_efficiency.expand((self.n_cells, -1, -1)).transpose(-3, -2) assert pert_prob.shape == (self.n_perturbations, self.n_cells, 1) with cell_plate: perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob)).squeeze(-1).T @@ -446,7 +558,6 @@ def model(self, idx, **tensor_dict): ) covariate_effects = cont_covariates @ cont_covariate_effect_size - # nb_log_mean_ctrl = gene_base_log_mean + size_factor + batch_effects + covariate_effects nb_log_mean_ctrl = ( gene_base_log_mean + size_factor + batch_effects + covariate_effects + cell_factor_effects ) @@ -485,5 +596,5 @@ def model(self, idx, **tensor_dict): raise NotImplementedError(f"'{self.likelihood}' likelihood not implemented") @property - def guide(self): + def guide(self) -> AutoGuideList: return self._guide diff --git a/tests/conftest.py b/tests/conftest.py index 3fb8650..5203882 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ def adata(): { "lib_size": np.random.lognormal(10, 1, size=(n_cells)), "batch_id": np.random.choice(["batch_1", "batch_2", "batch_3"], size=(n_cells)), + "cov1": np.random.normal(size=n_cells), } ) rna_counts = np.random.negative_binomial(100, 0.9, size=(n_cells, n_genes)).astype(np.float32) diff --git a/tests/test_basic.py b/tests/test_basic.py index 26d76cb..3369d1a 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -20,7 +20,8 @@ def test_package_has_version(): @pytest.mark.parametrize("efficiency_mode", ["mixture", "scaled"]) @pytest.mark.parametrize("use_guide_by_element", [True, False]) @pytest.mark.parametrize("use_gene_by_element", [True, False]) -@pytest.mark.parametrize("fit_guide_efficacy", ["partial", "shared"]) +@pytest.mark.parametrize("fit_guide_efficacy", [True, False]) +@pytest.mark.parametrize("sparse_tensors", [True, False]) # @pytest.mark.parametrize("n_factors", [None, 2]) @pytest.mark.parametrize("n_pert_factors", [None, 2]) def test_model_mdata( @@ -30,20 +31,24 @@ def test_model_mdata( use_guide_by_element, use_gene_by_element, fit_guide_efficacy, + sparse_tensors, n_pert_factors, ): """Check that we can register our MuData object with our model and perform training""" if use_gene_by_element and not use_guide_by_element: - pytest.skip("gene_by_element without guide_by_element test not implemented!") + pytest.skip("gene_by_element without guide_by_element not implemented!") + + if n_pert_factors and fit_guide_efficacy: + pytest.skip("cannot fit guide efficacy if using n_pert_factors!") pyro.clear_param_store() perturbo.PERTURBO.setup_mudata( mdata, - # size_factor_key="lib_size", + library_size_key="lib_size", batch_key="batch_id", + continuous_covariates_keys=["cov1"], guide_element_uns_key="elements" if use_guide_by_element else None, rna_element_uns_key="elements" if use_gene_by_element else None, - categorical_covariates_keys=["lib_size"], guide_by_element_key=guide_by_element_key if use_guide_by_element else None, gene_by_element_key=gene_by_element_key if use_gene_by_element else None, modalities={ @@ -57,15 +62,26 @@ def test_model_mdata( n_pert_factors=n_pert_factors, efficiency_mode=efficiency_mode, fit_guide_efficacy=fit_guide_efficacy, + sparse_effect_tensors=sparse_tensors, ) assert model.summary_stats.n_cells == len(mdata) assert model.summary_stats.n_vars == len(mdata[rna_key].var) assert model.summary_stats.n_perturbations == len(mdata[perturb_key].var) - model.train(max_epochs=10, lr=0.1, batch_size=20) + model.train( + max_epochs=5, + lr=0.1, + batch_size=2, + load_sparse_tensor=sparse_tensors, + ) + model.train( + max_epochs=5, + lr=0.1, + batch_size=None, + load_sparse_tensor=sparse_tensors, + ) element_effects = model.get_element_effects() assert isinstance(element_effects, pd.DataFrame) - assert len(model.history["elbo_train"]) == 10 assert isinstance(model.history["elbo_train"], pd.DataFrame) # test model save/load