Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Perturbo_reproducibility
Submodule Perturbo_reproducibility updated from 7a20b9 to 34328f
230 changes: 158 additions & 72 deletions src/perturbo/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,30 @@ class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass):
def __init__(
self,
mdata: AnnOrMuData,
control_guides=None,
load_sparse_tensors=False,
dispersion_smoothing="none",
smoothing_factor=0.3,
control_guides: list | None = None,
load_sparse_tensors: bool = False,
dispersion_smoothing: str = "none",
smoothing_factor: float = 0.3,
**model_kwargs,
):
"""
Initialize the PERTURBO model.

Parameters
----------
mdata : AnnOrMuData
MuData or AnnData object containing the data.
control_guides : list or None
List of control guide indices (optional, only used for setting initial values).
load_sparse_tensors : bool
Whether to load sparse tensors.
dispersion_smoothing : str
Smoothing method for dispersion estimation ("none", "linear", "isotonic").
smoothing_factor : float
Smoothing factor for dispersion smoothing.
model_kwargs : dict
Additional keyword arguments for the model.
"""
super().__init__(mdata)

# data fields that will be loaded/mini-batched into the module
Expand Down Expand Up @@ -121,7 +139,20 @@ def __init__(

logger.info("The model has been initialized")

def read_matrix_from_registry(self, registry_key):
def read_matrix_from_registry(self, registry_key: str) -> torch.Tensor:
"""
Reads a matrix from the AnnDataManager registry and converts it to a torch.Tensor.

Parameters
----------
registry_key : str
Key to retrieve the matrix from the registry.

Returns
-------
torch.Tensor
The matrix as a torch tensor.
"""
data = self.adata_manager.get_from_registry(registry_key)
if isinstance(data, DataFrame):
data = data.values
Expand Down Expand Up @@ -151,44 +182,45 @@ def setup_mudata(
guide_by_element_key: str | None = None,
library_size_key: str | None = None,
size_factor_key: str | None = None,
gene_mean_key: str | None = None,
gene_mean_key: str | None = None, # not used, supported for legacy reasons
continuous_covariates_keys: str | None = None,
categorical_covariates_keys: str | None = None,
categorical_covariates_keys: str | None = None, # not used, supported for legacy reasons
modalities: dict[str, str] | None = None,
**kwargs,
):
"""Registers data from a MuData object with the model.
"""
Registers data from a MuData object with the model.

Parameters
----------
mdata
mdata : MuData
A MuData object containing the perturbations and observational data.
rna_layer
The key of the MuData modality containing the RNA counts
perturbation_layer
The key of the MuData modality containing the perturbations
batch_key
Key within the RNA AnnData .obs corresponding to the experimental batch
gene_by_element_key
.varm key within the RNA AnnData object containing a mask of which genes can be affected by which genetic elements
guide_by_element_key
.varm key within the perturbation AnnData object containing which perturbations target which genetic elements
rna_element_uns_key
.uns key within the RNA AnnData object containing names of perturbed elements (if using GENE_BY_ELEMENT_KEY),
otherwise automatically inferred from column names if .varm object is a DataFrame
guide_element_uns_key
.uns key within the perturbation AnnData object containing names of perturbed elements
(if using GUIDE_BY_ELEMENT_KEY), otherwise automatically inferred from column names if .varm object is a DataFrame
library_size_key
.obs key within the RNA AnnData object containing raw (not log-scaled) library size factors for each sample
size_factor_key
.obs key within the RNA AnnData object containing library size factors for each sample (e.g. log-library size)
continuous_covariates_keys
list of .obs keys within the RNA AnnData object containing other continuous covariates to be "regressed out"
modalities
A dict containing these same setup arguments
kwargs
Additional keyword arguments
rna_layer : str or None
The key of the MuData modality containing the RNA counts.
perturbation_layer : str or None
The key of the MuData modality containing the perturbations.
batch_key : str or None
Key within the RNA AnnData .obs corresponding to the experimental batch.
gene_by_element_key : str or None
.varm key within the RNA AnnData object containing a mask of which genes can be affected by which genetic elements.
rna_element_uns_key : str or None
.uns key within the RNA AnnData object containing names of perturbed elements.
guide_element_uns_key : str or None
.uns key within the perturbation AnnData object containing names of perturbed elements.
guide_by_element_key : str or None
.varm key within the perturbation AnnData object containing which perturbations target which genetic elements.
library_size_key : str or None
.obs key within the RNA AnnData object containing raw library size factors for each sample.
size_factor_key : str or None
.obs key within the RNA AnnData object containing library size factors for each sample.
gene_mean_key : str or None
.var key for gene mean (legacy, for simulator).
continuous_covariates_keys : str or None
List of .obs keys within the RNA AnnData object containing other continuous covariates.
modalities : dict[str, str] or None
A dict containing these same setup arguments.
kwargs : dict
Additional keyword arguments.
"""
setup_method_args = cls._get_setup_method_args(**locals())

Expand Down Expand Up @@ -326,36 +358,37 @@ def train(

Parameters
----------
max_epochs
Number of passes through the dataset. If `None`, defaults to
`np.min([round((20000 / n_cells) * 400), 400])`
%(param_use_gpu)s
%(param_accelerator)s
%(param_device)s
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
shuffle_set_split
Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the
sequential order of the data according to `validation_size` and `train_size` percentages.
batch_size
Minibatch size to use during training. If `None`, no minibatching occurs and all
data is copied to device (e.g., GPU).
early_stopping
Perform early stopping. Additional arguments can be passed in `**kwargs`.
See :class:`~scvi.train.Trainer` for further options.
lr
Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`).
Specifying optimiser via plan_kwargs overrides this choice of lr.
training_plan
Training plan :class:`~scvi.train.PyroTrainingPlan`.
plan_kwargs
Keyword args for :class:`~scvi.train.PyroTrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**trainer_kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
max_epochs : int
Number of passes through the dataset.
accelerator : str
Accelerator type ("cpu", "gpu", etc.).
device : int or str
Device identifier.
train_size : float
Size of training set in the range [0.0, 1.0]. All cells by default.
validation_size : float or None
Size of the validation set. Zero cells by default.
shuffle_set_split : bool
Whether to shuffle indices before splitting.
batch_size : int
Minibatch size to use during training.
early_stopping : bool
Perform early stopping.
lr : float or None
Optimizer learning rate.
training_plan : type
Training plan class.
plan_kwargs : dict or None
Keyword args for the training plan.
data_splitter_kwargs : dict or None
Keyword args for the data splitter.
trainer_kwargs : dict
Other keyword args for the Trainer.

Returns
-------
Any
The result of the training runner.
"""
plan_kwargs = plan_kwargs if plan_kwargs is not None else {}
if len(self.module.discrete_sites) > 0:
Expand Down Expand Up @@ -407,15 +440,30 @@ def train(
)
return runner()

def get_element_names(self):
def get_element_names(self) -> list:
"""
Returns the names of the targeted elements.

Returns
-------
list
List of element names.
"""
if REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY in self.adata_manager.data_registry:
element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY).column_names
else:
element_ids = self.adata_manager.get_state_registry(REGISTRY_KEYS.PERTURBATION_KEY).column_names
return element_ids

def get_element_effects(self):
"""Return a DataFrame summary of the effects for targeted elements on each gene."""
def get_element_effects(self) -> pd.DataFrame:
"""
Return a DataFrame summary of the effects for targeted elements on each gene.

Returns
-------
pd.DataFrame
DataFrame with columns for effect location, scale, element, gene, z-value, and q-value.
"""
element_ids = self.get_element_names()
gene_ids = self.adata_manager.get_state_registry("X").column_names
for guide in self.module.guide:
Expand Down Expand Up @@ -461,7 +509,20 @@ def make_long_df(mat, value_name):

return element_effects.sort_values("z_value")

def get_map_labels(self, indices: list | None = None):
def get_map_labels(self, indices: list | None = None) -> np.ndarray:
"""
Get MAP (maximum a posteriori) labels for the discrete latent variable "perturbed".

Parameters
----------
indices : list or None
Indices of the data subset to use.

Returns
-------
np.ndarray
Array of MAP labels.
"""
args, kwargs = self._get_data_subset(indices)
self.module.guide(*args, **kwargs)
guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) # record the globals
Expand All @@ -474,7 +535,20 @@ def get_map_labels(self, indices: list | None = None):
map_labels = serving_model_trace.nodes["perturbed"]["value"].squeeze().cpu().numpy()
return map_labels

def _get_data_subset(self, indices: list | None = None):
def _get_data_subset(self, indices: list | None = None) -> tuple:
"""
Get a data subset for inference.

Parameters
----------
indices : list or None
Indices of the data subset.

Returns
-------
tuple
Tuple of (args, kwargs) for the model.
"""
loader = AnnDataLoader(
adata_manager=self.adata_manager,
indices=indices,
Expand Down Expand Up @@ -512,11 +586,23 @@ def _get_data_subset(self, indices: list | None = None):
# return samples


def estimate_nb_params(X, smoothing="isotonic"):
def estimate_nb_params(
X: np.ndarray | sp.spmatrix, smoothing: str = "isotonic"
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Estimate NB mean and dispersion (MoM) for each gene (column) in count matrix X.

With optional smoothing: 'linear', 'quadratic', or 'isotonic' (monotonic).
Parameters
----------
X : np.ndarray or scipy.sparse.spmatrix
Count matrix (cells x genes).
smoothing : str
Smoothing method: 'linear', 'quadratic', or 'isotonic'.

Returns
-------
tuple of np.ndarray
log_means, log_disp, log_disp_smoothed
"""
if sp.issparse(X):
means = np.array(X.mean(axis=0)).flatten()
Expand Down
Loading
Loading