Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/perturbo/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self,
mdata: AnnOrMuData,
control_guides=None,
load_sparse_tensors=False,
**model_kwargs,
):
super().__init__(mdata)
Expand All @@ -45,6 +46,8 @@ def __init__(
REGISTRY_KEYS.INDICES_KEY: np.int64,
}

self.load_sparse_tensors = load_sparse_tensors

n_extra_continuous_covs = 0
if "n_extra_continuous_covs" in self.summary_stats:
n_extra_continuous_covs = self.summary_stats.n_extra_continuous_covs
Expand Down Expand Up @@ -478,6 +481,7 @@ def _get_data_subset(self, indices: list | None = None):
indices=indices,
batch_size=len(indices) if indices is not None else len(self.adata),
data_and_attributes=self.data_and_attrs,
load_sparse_tensor=self.load_sparse_tensors,
)
return self.module._get_fn_args_from_batch(next(iter(loader)))

Expand Down
23 changes: 9 additions & 14 deletions src/perturbo/models/_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Mapping
from typing import Literal

Expand All @@ -19,19 +20,6 @@ def sample(self, sample_shape=torch.Size()):
return dist.NegativeBinomial(total_count=self.total_count, logits=self.logits + normals).sample()


class LazyInitToValue:
def __init__(self, module, name_map):
self.module = module
self.name_map = name_map # e.g., {"my_param": "my_init"}

def __call__(self, site):
name = site["name"]
if name in self.name_map:
buf = getattr(self.module, self.name_map[name])
return buf.to(site["fn"].support.device)
return None # fallback to Pyro default


class PerTurboPyroModule(PyroBaseModuleClass):
def __init__(
self,
Expand All @@ -55,6 +43,7 @@ def __init__(
# dispersion_effects: bool = False,
fit_guide_efficacy: bool = True,
prior_param_dict: Mapping[str, torch.Tensor] | None = None,
**module_kwargs,
) -> None:
"""
Pyro module underlying perturbo.
Expand Down Expand Up @@ -88,6 +77,9 @@ def __init__(
super().__init__()
# set user-defined options for model behavior
# self.dispersion_effects = dispersion_effects
for k, v in module_kwargs.items():
warnings.warn(f"Unused module_kwargs: {k}", stacklevel=2)

self.likelihood = likelihood
self.fit_guide_efficacy = fit_guide_efficacy
self.lnnb_quad_points = 8
Expand Down Expand Up @@ -257,6 +249,7 @@ def model(self, idx, **tensor_dict):
gene_plate,
cont_covariate_plate,
element_effects_plate, # sparse mode
# cell_factor_plate,
pert_factor_plate,
) = self.create_plates(idx)

Expand Down Expand Up @@ -346,7 +339,9 @@ def model(self, idx, **tensor_dict):

if self.local_effects:
# override factor effects
element_effects = (1 - self.element_by_gene) * element_factor_effects + element_local_effects
element_effects = (
torch.ones(self.element_by_gene.shape) - self.element_by_gene
) * element_factor_effects + element_local_effects
# guide_factor_efects = self.guide_by_element @ ((1 - self.element_by_gene) * element_factor_effects)
# guide_local_effects = (guide_efficacy * self.guide_by_element) @ element_local_effects
# guide_effects = guide_factor_efects + guide_local_effects
Expand Down
7 changes: 5 additions & 2 deletions src/perturbo/simulation/_support_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def mudata_filtering(
print("Please indicate the correct guide_by_element_key.")
return

rna = mdata[rna_modality].X.toarray()
rna = mdata[rna_modality].X
# grna = mdata[grna_modality].X.toarray()
guide_by_element = mdata[grna_modality].varm[guide_by_element_key]
if isinstance(guide_by_element, pd.DataFrame):
Expand All @@ -63,7 +63,10 @@ def mudata_filtering(

for col in cols:
# Extract the column from rna and convert to dense format
rna_col = rna[:, col].flatten()
rna_col = rna[:, col]
if issparse(rna_col):
rna_col = rna_col.toarray()
rna_col = rna_col.flatten()
# print(rna_col.shape)

# Condition 1: Number of non-zero values in 'rna' where 'element' has entry 1 should be >= n_nonzero_trt_thresh
Expand Down
Loading