diff --git a/Perturbo_reproducibility b/Perturbo_reproducibility index 7a20b99..34328f0 160000 --- a/Perturbo_reproducibility +++ b/Perturbo_reproducibility @@ -1 +1 @@ -Subproject commit 7a20b99c570551d7e301724764528a7912ad81e2 +Subproject commit 34328f0716305aaeaffb065668b95bc3ed3f4951 diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 43b6866..39295a3 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -34,12 +34,30 @@ class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( self, mdata: AnnOrMuData, - control_guides=None, - load_sparse_tensors=False, - dispersion_smoothing="none", - smoothing_factor=0.3, + control_guides: list | None = None, + load_sparse_tensors: bool = False, + dispersion_smoothing: str = "none", + smoothing_factor: float = 0.3, **model_kwargs, ): + """ + Initialize the PERTURBO model. + + Parameters + ---------- + mdata : AnnOrMuData + MuData or AnnData object containing the data. + control_guides : list or None + List of control guide indices (optional, only used for setting initial values). + load_sparse_tensors : bool + Whether to load sparse tensors. + dispersion_smoothing : str + Smoothing method for dispersion estimation ("none", "linear", "isotonic"). + smoothing_factor : float + Smoothing factor for dispersion smoothing. + model_kwargs : dict + Additional keyword arguments for the model. + """ super().__init__(mdata) # data fields that will be loaded/mini-batched into the module @@ -121,7 +139,20 @@ def __init__( logger.info("The model has been initialized") - def read_matrix_from_registry(self, registry_key): + def read_matrix_from_registry(self, registry_key: str) -> torch.Tensor: + """ + Reads a matrix from the AnnDataManager registry and converts it to a torch.Tensor. + + Parameters + ---------- + registry_key : str + Key to retrieve the matrix from the registry. + + Returns + ------- + torch.Tensor + The matrix as a torch tensor. + """ data = self.adata_manager.get_from_registry(registry_key) if isinstance(data, DataFrame): data = data.values @@ -151,44 +182,45 @@ def setup_mudata( guide_by_element_key: str | None = None, library_size_key: str | None = None, size_factor_key: str | None = None, - gene_mean_key: str | None = None, + gene_mean_key: str | None = None, # not used, supported for legacy reasons continuous_covariates_keys: str | None = None, - categorical_covariates_keys: str | None = None, + categorical_covariates_keys: str | None = None, # not used, supported for legacy reasons modalities: dict[str, str] | None = None, **kwargs, ): - """Registers data from a MuData object with the model. + """ + Registers data from a MuData object with the model. Parameters ---------- - mdata + mdata : MuData A MuData object containing the perturbations and observational data. - rna_layer - The key of the MuData modality containing the RNA counts - perturbation_layer - The key of the MuData modality containing the perturbations - batch_key - Key within the RNA AnnData .obs corresponding to the experimental batch - gene_by_element_key - .varm key within the RNA AnnData object containing a mask of which genes can be affected by which genetic elements - guide_by_element_key - .varm key within the perturbation AnnData object containing which perturbations target which genetic elements - rna_element_uns_key - .uns key within the RNA AnnData object containing names of perturbed elements (if using GENE_BY_ELEMENT_KEY), - otherwise automatically inferred from column names if .varm object is a DataFrame - guide_element_uns_key - .uns key within the perturbation AnnData object containing names of perturbed elements - (if using GUIDE_BY_ELEMENT_KEY), otherwise automatically inferred from column names if .varm object is a DataFrame - library_size_key - .obs key within the RNA AnnData object containing raw (not log-scaled) library size factors for each sample - size_factor_key - .obs key within the RNA AnnData object containing library size factors for each sample (e.g. log-library size) - continuous_covariates_keys - list of .obs keys within the RNA AnnData object containing other continuous covariates to be "regressed out" - modalities - A dict containing these same setup arguments - kwargs - Additional keyword arguments + rna_layer : str or None + The key of the MuData modality containing the RNA counts. + perturbation_layer : str or None + The key of the MuData modality containing the perturbations. + batch_key : str or None + Key within the RNA AnnData .obs corresponding to the experimental batch. + gene_by_element_key : str or None + .varm key within the RNA AnnData object containing a mask of which genes can be affected by which genetic elements. + rna_element_uns_key : str or None + .uns key within the RNA AnnData object containing names of perturbed elements. + guide_element_uns_key : str or None + .uns key within the perturbation AnnData object containing names of perturbed elements. + guide_by_element_key : str or None + .varm key within the perturbation AnnData object containing which perturbations target which genetic elements. + library_size_key : str or None + .obs key within the RNA AnnData object containing raw library size factors for each sample. + size_factor_key : str or None + .obs key within the RNA AnnData object containing library size factors for each sample. + gene_mean_key : str or None + .var key for gene mean (legacy, for simulator). + continuous_covariates_keys : str or None + List of .obs keys within the RNA AnnData object containing other continuous covariates. + modalities : dict[str, str] or None + A dict containing these same setup arguments. + kwargs : dict + Additional keyword arguments. """ setup_method_args = cls._get_setup_method_args(**locals()) @@ -326,36 +358,37 @@ def train( Parameters ---------- - max_epochs - Number of passes through the dataset. If `None`, defaults to - `np.min([round((20000 / n_cells) * 400), 400])` - %(param_use_gpu)s - %(param_accelerator)s - %(param_device)s - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the - sequential order of the data according to `validation_size` and `train_size` percentages. - batch_size - Minibatch size to use during training. If `None`, no minibatching occurs and all - data is copied to device (e.g., GPU). - early_stopping - Perform early stopping. Additional arguments can be passed in `**kwargs`. - See :class:`~scvi.train.Trainer` for further options. - lr - Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). - Specifying optimiser via plan_kwargs overrides this choice of lr. - training_plan - Training plan :class:`~scvi.train.PyroTrainingPlan`. - plan_kwargs - Keyword args for :class:`~scvi.train.PyroTrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **trainer_kwargs - Other keyword args for :class:`~scvi.train.Trainer`. + max_epochs : int + Number of passes through the dataset. + accelerator : str + Accelerator type ("cpu", "gpu", etc.). + device : int or str + Device identifier. + train_size : float + Size of training set in the range [0.0, 1.0]. All cells by default. + validation_size : float or None + Size of the validation set. Zero cells by default. + shuffle_set_split : bool + Whether to shuffle indices before splitting. + batch_size : int + Minibatch size to use during training. + early_stopping : bool + Perform early stopping. + lr : float or None + Optimizer learning rate. + training_plan : type + Training plan class. + plan_kwargs : dict or None + Keyword args for the training plan. + data_splitter_kwargs : dict or None + Keyword args for the data splitter. + trainer_kwargs : dict + Other keyword args for the Trainer. + + Returns + ------- + Any + The result of the training runner. """ plan_kwargs = plan_kwargs if plan_kwargs is not None else {} if len(self.module.discrete_sites) > 0: @@ -407,15 +440,30 @@ def train( ) return runner() - def get_element_names(self): + def get_element_names(self) -> list: + """ + Returns the names of the targeted elements. + + Returns + ------- + list + List of element names. + """ if REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY in self.adata_manager.data_registry: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY).column_names else: element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.PERTURBATION_KEY).column_names return element_ids - def get_element_effects(self): - """Return a DataFrame summary of the effects for targeted elements on each gene.""" + def get_element_effects(self) -> pd.DataFrame: + """ + Return a DataFrame summary of the effects for targeted elements on each gene. + + Returns + ------- + pd.DataFrame + DataFrame with columns for effect location, scale, element, gene, z-value, and q-value. + """ element_ids = self.get_element_names() gene_ids = self.adata_manager.get_state_registry("X").column_names for guide in self.module.guide: @@ -461,7 +509,20 @@ def make_long_df(mat, value_name): return element_effects.sort_values("z_value") - def get_map_labels(self, indices: list | None = None): + def get_map_labels(self, indices: list | None = None) -> np.ndarray: + """ + Get MAP (maximum a posteriori) labels for the discrete latent variable "perturbed". + + Parameters + ---------- + indices : list or None + Indices of the data subset to use. + + Returns + ------- + np.ndarray + Array of MAP labels. + """ args, kwargs = self._get_data_subset(indices) self.module.guide(*args, **kwargs) guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) # record the globals @@ -474,7 +535,20 @@ def get_map_labels(self, indices: list | None = None): map_labels = serving_model_trace.nodes["perturbed"]["value"].squeeze().cpu().numpy() return map_labels - def _get_data_subset(self, indices: list | None = None): + def _get_data_subset(self, indices: list | None = None) -> tuple: + """ + Get a data subset for inference. + + Parameters + ---------- + indices : list or None + Indices of the data subset. + + Returns + ------- + tuple + Tuple of (args, kwargs) for the model. + """ loader = AnnDataLoader( adata_manager=self.adata_manager, indices=indices, @@ -512,11 +586,23 @@ def _get_data_subset(self, indices: list | None = None): # return samples -def estimate_nb_params(X, smoothing="isotonic"): +def estimate_nb_params( + X: np.ndarray | sp.spmatrix, smoothing: str = "isotonic" +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Estimate NB mean and dispersion (MoM) for each gene (column) in count matrix X. - With optional smoothing: 'linear', 'quadratic', or 'isotonic' (monotonic). + Parameters + ---------- + X : np.ndarray or scipy.sparse.spmatrix + Count matrix (cells x genes). + smoothing : str + Smoothing method: 'linear', 'quadratic', or 'isotonic'. + + Returns + ------- + tuple of np.ndarray + log_means, log_disp, log_disp_smoothed """ if sp.issparse(X): means = np.array(X.mean(axis=0)).flatten() diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index d3fe940..d8496ec 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -33,15 +33,12 @@ def __init__( log_gene_dispersion_init: torch.Tensor | None = None, guide_by_element: torch.Tensor | None = None, gene_by_element: torch.Tensor | None = None, - # guide_noise: bool = False, likelihood: Literal["nb", "lnnb"] = "nb", effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "laplace", n_factors: int | None = None, n_pert_factors: int | None = None, - # use_interactions: bool = False, efficiency_mode: Literal["mixture", "scaled"] = "scaled", - # dispersion_effects: bool = False, - sparse_local_effects=False, + sparse_local_effects: bool = False, fit_guide_efficacy: bool = True, prior_param_dict: Mapping[str, torch.Tensor] | None = None, **module_kwargs, @@ -51,29 +48,45 @@ def __init__( Parameters ---------- - guide_by_element: + n_cells : int or None + Number of cells in the dataset. Provided by model constructor. + n_genes : int or None + Number of genes in the dataset. Provided by model constructor. + n_perturbations : int or None + Number of perturbations (typically CRISPR gRNAs). Provided by model constructor. + n_elements : int or None + Number of targeted elements (typically genomic targets of gRNAs). Provided by model constructor. + n_cont_covariates : int or None + Number of continuous covariates. Provided by model constructor. + n_batches : int or None + Number of batches. Provided by model constructor. + log_gene_mean_init : torch.Tensor or None + Initial values for log gene means. Provided by model constructor. + log_gene_dispersion_init : torch.Tensor or None + Initial values for log gene dispersions. Provided by model constructor. + guide_by_element : torch.Tensor or None Binary array encoding which element(s) are targeted by each guide. - gene_by_element: + gene_by_element : torch.Tensor or None Binary array encoding which element(s) may target each gene *a priori*. - likelihood: + likelihood : Literal["nb", "lnnb"] Observation likelihood, either NegativeBinomial ("nb") or LogNormalNegativeBinomial ("lnnb"). - effect_prior_dist: - Effect size prior, either Cauchy or NormalMixture ("soft" spike & slab) - n_factors: - Number of cell-specific latent factors - n_pert_factors: - Number of perturbation-specific latent factors - use_interactions: - If using cell factors, allow interactions between perturbations and cell factors - efficiency_mode: - Guide efficiency is fraction of cells perturbed ("mixture") or linear scaling of effect size ("scaled"). - "Mixture" mode currently requires (at most) one guide per cell. - dispersion_effects: - Allow for different gene-level dispersion by batch? - prior_params: - dict containing hyperparameter names and tensors to set prior values - merge_guides_mode: - "shared" treats ("partial") across guides targeting same element? + effect_prior_dist : Literal["cauchy", "normal_mixture", "normal", "laplace"] + Effect size prior, either Cauchy, NormalMixture ("soft" spike & slab), Normal, or Laplace. + n_factors : int or None + EXPERIMENTAL: Number of cell-specific latent factors. + n_pert_factors : int or None + EXPERIMENTAL: Number of perturbation-specific latent factors. + efficiency_mode : Literal["mixture", "scaled"] + Guide efficiency is fraction of cells perturbed ("mixture") or fractional scaling of max per-element effect size ("scaled"). + "Mixture" mode currently requires at most one guide observation per cell. + sparse_local_effects : bool + EXPERIMENTAL: If True, use sparse PyTorch matrix for local effects. + fit_guide_efficacy : bool + If True, fit guide efficacy. If False, assume guide efficacy = 1. + prior_param_dict : Mapping[str, torch.Tensor] or None + Dict containing hyperparameter names and tensors to set prior values. + module_kwargs : dict + Additional keyword arguments (unused). """ super().__init__() # set user-defined options for model behavior @@ -87,10 +100,8 @@ def __init__( self.n_factors = n_factors self.n_pert_factors = n_pert_factors self.effect_prior_dist = effect_prior_dist - # self.use_interactions = use_interactions self.efficiency_mode = efficiency_mode self.local_effects = sparse_local_effects - # self.guide_noise = guide_noise # copy data summary stats self.n_cells = n_cells @@ -115,6 +126,7 @@ def __init__( self.n_batches = n_batches + # Sites to approximate with Delta distribution instead of default Normal distribution. self.delta_sites = [] # self.delta_sites = ["cell_factors"] # self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"] @@ -213,7 +225,7 @@ def __init__( self.register_buffer(k, v) @staticmethod - def _get_fn_args_from_batch(tensor_dict): + def _get_fn_args_from_batch(tensor_dict: dict) -> tuple[tuple[torch.Tensor], dict]: fit_size_factor_covariate = False if fit_size_factor_covariate: @@ -231,7 +243,7 @@ def _get_fn_args_from_batch(tensor_dict): # return indices and then the rest of the tensors return (tensor_dict[REGISTRY_KEYS.INDICES_KEY].squeeze(),), tensor_dict - def create_plates(self, idx, **tensor_dict): + def create_plates(self, idx: torch.Tensor, **tensor_dict) -> tuple: # dims = self.infer_data_dims(idx, **tensor_dict) return ( pyro.plate("Cells", self.n_cells, dim=-2, subsample=idx), @@ -246,7 +258,21 @@ def create_plates(self, idx, **tensor_dict): pyro.plate("Pert_factors", self.n_pert_factors, dim=-3), ) - def model(self, idx, **tensor_dict): + def model(self, idx: torch.Tensor, **tensor_dict) -> None: + """ + The probabilistic model definition for perturbo. + + Parameters + ---------- + idx : torch.Tensor + Indices for subsampling cells. + tensor_dict : dict + Dictionary containing all required tensors for the model. + + Returns + ------- + None + """ pyro.module("perturbo", self) ( cell_plate, @@ -255,7 +281,7 @@ def model(self, idx, **tensor_dict): batch_plate, gene_plate, cont_covariate_plate, - element_effects_plate, # sparse mode + element_effects_plate, guide_plate_sparse, cell_factor_plate, pert_factor_plate, @@ -485,5 +511,5 @@ def model(self, idx, **tensor_dict): raise NotImplementedError(f"'{self.likelihood}' likelihood not implemented") @property - def guide(self): + def guide(self) -> AutoGuideList: return self._guide diff --git a/tests/test_basic.py b/tests/test_basic.py index 26d76cb..45f0649 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -39,11 +39,10 @@ def test_model_mdata( pyro.clear_param_store() perturbo.PERTURBO.setup_mudata( mdata, - # size_factor_key="lib_size", + library_size_key="lib_size", batch_key="batch_id", guide_element_uns_key="elements" if use_guide_by_element else None, rna_element_uns_key="elements" if use_gene_by_element else None, - categorical_covariates_keys=["lib_size"], guide_by_element_key=guide_by_element_key if use_guide_by_element else None, gene_by_element_key=gene_by_element_key if use_gene_by_element else None, modalities={