diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index 5faef88..85c82a9 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -32,6 +32,7 @@ def __init__( self, mdata: AnnOrMuData, control_guides=None, + load_sparse_tensors=False, **model_kwargs, ): super().__init__(mdata) @@ -45,6 +46,8 @@ def __init__( REGISTRY_KEYS.INDICES_KEY: np.int64, } + self.load_sparse_tensors = load_sparse_tensors + n_extra_continuous_covs = 0 if "n_extra_continuous_covs" in self.summary_stats: n_extra_continuous_covs = self.summary_stats.n_extra_continuous_covs @@ -478,6 +481,7 @@ def _get_data_subset(self, indices: list | None = None): indices=indices, batch_size=len(indices) if indices is not None else len(self.adata), data_and_attributes=self.data_and_attrs, + load_sparse_tensor=self.load_sparse_tensors, ) return self.module._get_fn_args_from_batch(next(iter(loader))) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index a043a0b..1936c45 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Mapping from typing import Literal @@ -19,19 +20,6 @@ def sample(self, sample_shape=torch.Size()): return dist.NegativeBinomial(total_count=self.total_count, logits=self.logits + normals).sample() -class LazyInitToValue: - def __init__(self, module, name_map): - self.module = module - self.name_map = name_map # e.g., {"my_param": "my_init"} - - def __call__(self, site): - name = site["name"] - if name in self.name_map: - buf = getattr(self.module, self.name_map[name]) - return buf.to(site["fn"].support.device) - return None # fallback to Pyro default - - class PerTurboPyroModule(PyroBaseModuleClass): def __init__( self, @@ -55,6 +43,7 @@ def __init__( # dispersion_effects: bool = False, fit_guide_efficacy: bool = True, prior_param_dict: Mapping[str, torch.Tensor] | None = None, + **module_kwargs, ) -> None: """ Pyro module underlying perturbo. @@ -88,6 +77,9 @@ def __init__( super().__init__() # set user-defined options for model behavior # self.dispersion_effects = dispersion_effects + for k, v in module_kwargs.items(): + warnings.warn(f"Unused module_kwargs: {k}", stacklevel=2) + self.likelihood = likelihood self.fit_guide_efficacy = fit_guide_efficacy self.lnnb_quad_points = 8 @@ -257,6 +249,7 @@ def model(self, idx, **tensor_dict): gene_plate, cont_covariate_plate, element_effects_plate, # sparse mode + # cell_factor_plate, pert_factor_plate, ) = self.create_plates(idx) @@ -346,7 +339,9 @@ def model(self, idx, **tensor_dict): if self.local_effects: # override factor effects - element_effects = (1 - self.element_by_gene) * element_factor_effects + element_local_effects + element_effects = ( + torch.ones(self.element_by_gene.shape) - self.element_by_gene + ) * element_factor_effects + element_local_effects # guide_factor_efects = self.guide_by_element @ ((1 - self.element_by_gene) * element_factor_effects) # guide_local_effects = (guide_efficacy * self.guide_by_element) @ element_local_effects # guide_effects = guide_factor_efects + guide_local_effects diff --git a/src/perturbo/simulation/_support_functions.py b/src/perturbo/simulation/_support_functions.py index 7af75d0..90dbfdf 100644 --- a/src/perturbo/simulation/_support_functions.py +++ b/src/perturbo/simulation/_support_functions.py @@ -37,7 +37,7 @@ def mudata_filtering( print("Please indicate the correct guide_by_element_key.") return - rna = mdata[rna_modality].X.toarray() + rna = mdata[rna_modality].X # grna = mdata[grna_modality].X.toarray() guide_by_element = mdata[grna_modality].varm[guide_by_element_key] if isinstance(guide_by_element, pd.DataFrame): @@ -63,7 +63,10 @@ def mudata_filtering( for col in cols: # Extract the column from rna and convert to dense format - rna_col = rna[:, col].flatten() + rna_col = rna[:, col] + if issparse(rna_col): + rna_col = rna_col.toarray() + rna_col = rna_col.flatten() # print(rna_col.shape) # Condition 1: Number of non-zero values in 'rna' where 'element' has entry 1 should be >= n_nonzero_trt_thresh