Skip to content
8 changes: 4 additions & 4 deletions src/perturbo/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
log_disp = np.where(np.isfinite(log_disp), log_disp, log_disp_smoothed)
if dispersion_smoothing != "none":
log_disp_smoothed = smoothing_factor * log_disp_smoothed + (1 - smoothing_factor) * log_disp
log_means = np.clip(log_means, 1 / X.shape[0], None)
log_means = np.clip(log_means, a_min=np.log(1 / X.shape[0]), a_max=None)

# if control_guides is not None and "n_factors" in model_kwargs and guide_by_element is not None:
# # control_guides, _ = torch.max(guide_by_element[:, control_elements], dim=-1)
Expand All @@ -106,8 +106,8 @@ def __init__(
n_genes=self.summary_stats.n_vars,
n_cont_covariates=n_extra_continuous_covs,
n_elements=n_elements,
log_gene_mean_init=torch.tensor(log_means),
log_gene_dispersion_init=torch.tensor(log_disp_smoothed),
log_gene_mean_init=torch.tensor(log_means, dtype=torch.float32),
log_gene_dispersion_init=torch.tensor(log_disp_smoothed, dtype=torch.float32),
guide_by_element=guide_by_element,
gene_by_element=gene_by_element,
# n_cats_per_cov=n_cats_per_cov,
Expand Down Expand Up @@ -425,7 +425,7 @@ def get_element_effects(self):
# scale_values = loc_plus_scale_values - loc_values
# loc_values, scale_values = self.module.guide._get_loc_and_scale("element_effects")

if self.module.local_effects:
if hasattr(self.module, "element_by_gene_idx"):
# loc/scale_values are the nonzero elements of a sparse matrix of elements by genes
i, j = self.module.element_by_gene_idx.detach().cpu().numpy().astype(int)

Expand Down
131 changes: 79 additions & 52 deletions src/perturbo/models/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
# use_interactions: bool = False,
efficiency_mode: Literal["mixture", "scaled"] = "scaled",
# dispersion_effects: bool = False,
sparse_local_effects=False,
fit_guide_efficacy: bool = True,
prior_param_dict: Mapping[str, torch.Tensor] | None = None,
**module_kwargs,
Expand Down Expand Up @@ -77,7 +78,7 @@ def __init__(
super().__init__()
# set user-defined options for model behavior
# self.dispersion_effects = dispersion_effects
for k, v in module_kwargs.items():
for k in module_kwargs:
warnings.warn(f"Unused module_kwargs: {k}", stacklevel=2)

self.likelihood = likelihood
Expand All @@ -88,7 +89,7 @@ def __init__(
self.effect_prior_dist = effect_prior_dist
# self.use_interactions = use_interactions
self.efficiency_mode = efficiency_mode
self.local_effects = gene_by_element is not None
self.local_effects = sparse_local_effects
# self.guide_noise = guide_noise

# copy data summary stats
Expand All @@ -114,7 +115,6 @@ def __init__(

self.n_batches = n_batches

# self.delta_sites = ["log_gene_mean"]
self.delta_sites = []
# self.delta_sites = ["cell_factors"]
# self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"]
Expand Down Expand Up @@ -150,12 +150,17 @@ def __init__(
# guide_by_element encoding
self.register_buffer("guide_by_element", guide_by_element)

if gene_by_element is not None:
self.register_buffer("element_by_gene", gene_by_element.T)
else:
self.element_by_gene = None

if self.local_effects:
assert gene_by_element.shape[1] == self.n_elements
self.register_buffer("element_by_gene", gene_by_element.T)
self.register_buffer("element_by_gene_idx", gene_by_element.T.to_sparse_coo().indices())
# self.register_buffer("guide_by_gene_idx", (guide_by_element @ gene_by_element.T).to_sparse_coo().indices())
self.register_buffer("guide_by_gene_idx", (guide_by_element @ gene_by_element.T).to_sparse_coo().indices())
self.n_element_effects = self.element_by_gene_idx.shape[1] if self.local_effects else 1
self.n_guide_effects = self.guide_by_gene_idx.shape[1] if self.local_effects else 1

# global hyperparams
self.register_buffer("zero", torch.tensor(0.0))
Expand Down Expand Up @@ -236,7 +241,8 @@ def create_plates(self, idx, **tensor_dict):
pyro.plate("Genes", self.n_genes, dim=-1),
pyro.plate("Covariates", self.n_cont_covariates, dim=-2),
pyro.plate("Elements_sparse", self.n_element_effects, dim=-1),
# pyro.plate("Cell_factors", self.n_factors, dim=-3),
pyro.plate("Guides_sparse", self.n_guide_effects, dim=-1),
pyro.plate("Cell_factors", self.n_factors, dim=-3),
pyro.plate("Pert_factors", self.n_pert_factors, dim=-3),
)

Expand All @@ -250,7 +256,8 @@ def model(self, idx, **tensor_dict):
gene_plate,
cont_covariate_plate,
element_effects_plate, # sparse mode
# cell_factor_plate,
guide_plate_sparse,
cell_factor_plate,
pert_factor_plate,
) = self.create_plates(idx)

Expand Down Expand Up @@ -286,9 +293,20 @@ def model(self, idx, **tensor_dict):

# Pool guide information based on user-specified strategy
if not self.fit_guide_efficacy:
guide_efficacy = self.one.expand((self.n_perturbations, 1))
guide_efficacy = self.one.expand((self.n_perturbations, self.n_genes))
# elif self.local_effects:
# with guide_plate_sparse:
# guide_efficacy_sparse = pyro.sample(
# "guide_efficacy",
# dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta),
# )
# guide_efficacy = torch.sparse_coo_tensor(
# self.guide_by_gene_idx,
# guide_efficacy_sparse,
# size=(self.n_guides, self.n_genes),
# )
else:
with guide_plate:
with guide_plate, gene_plate:
guide_efficacy = pyro.sample(
"guide_efficacy",
dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta),
Expand All @@ -309,41 +327,45 @@ def model(self, idx, **tensor_dict):
element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings)

# # Sample cell-specific factors (linear unobserved confounders) if using
# if self.n_factors is not None:
# with cell_factor_plate, cell_plate:
# cell_factors = pyro.sample(
# "cell_factors",
# dist.Laplace(0.0, self.cell_factor_prior_scale),
# )
# with cell_factor_plate, gene_plate:
# cell_loadings = pyro.sample(
# "cell_loadings",
# dist.Laplace(0.0, self.cell_loading_prior_scale),
# )
# # cell_factor_scale_term = pyro.sample(
# # "cell_factor_scale_term",
# # dist.LogNormal(self.zero, self.one),
# # )
# cell_factor_effects = torch.einsum("fci,fjg->cg", cell_factors, cell_loadings)

# # if self.use_interactions and self.n_pert_factors is not None:
# # with cell_factor_plate, element_plate:
# # pert_cell_factors = pyro.sample(
# # "pert_cell_factors",
# # dist.Laplace(0.0, self.cell_factor_prior_scale),
# # )
# # element_factor_effects = (
# # torch.einsum("fei,fjg->eg", pert_cell_factors, cell_loadings) + element_factor_effects
# # )
# else:
# cell_factor_effects = 0
if self.n_factors is not None:
with cell_factor_plate, cell_plate:
cell_factors = pyro.sample(
"cell_factors",
dist.Laplace(0.0, self.cell_factor_prior_scale),
)
with cell_factor_plate, gene_plate:
cell_loadings = pyro.sample(
"cell_loadings",
dist.Laplace(0.0, self.cell_loading_prior_scale),
)
# cell_factor_scale_term = pyro.sample(
# "cell_factor_scale_term",
# dist.LogNormal(self.zero, self.one),
# )
cell_factor_effects = torch.einsum("fci,fjg->cg", cell_factors, cell_loadings)

# if self.use_interactions and self.n_pert_factors is not None:
# with cell_factor_plate, element_plate:
# pert_cell_factors = pyro.sample(
# "pert_cell_factors",
# dist.Laplace(0.0, self.cell_factor_prior_scale),
# )
# element_factor_effects = (
# torch.einsum("fei,fjg->eg", pert_cell_factors, cell_loadings) + element_factor_effects
# )
else:
cell_factor_effects = 0

if self.local_effects:
# override factor effects
element_effects = (1 - self.element_by_gene) * element_factor_effects + element_local_effects
# guide_factor_efects = self.guide_by_element @ ((1 - self.element_by_gene) * element_factor_effects)
# guide_local_effects = (guide_efficacy * self.guide_by_element) @ element_local_effects
# guide_effects = guide_factor_efects + guide_local_effects
elif self.element_by_gene is not None:
element_effects = (
self.element_by_gene * element_local_effects + (1 - self.element_by_gene) * element_factor_effects
)
else:
element_effects = element_factor_effects + element_local_effects

Expand All @@ -360,20 +382,25 @@ def model(self, idx, **tensor_dict):
# guide_effects = pyro.deterministic("guide_effects", guide_effects)

# Account for cell-specific latent "perturbation status" variable(s)
with cell_plate:
if self.efficiency_mode == "scaled":
perturbed = guide_efficacy.T
elif self.efficiency_mode == "mixture":
pert_prob = guides_observed @ guide_efficacy

if self.efficiency_mode == "scaled":
mean_perturbation_effect = guides_observed @ (guide_efficacy * guide_effects)
elif self.efficiency_mode == "mixture":
pert_prob = guides_observed @ guide_efficacy
assert pert_prob.shape[0] == self.n_cells
assert (pert_prob.shape[1] == 1) or (pert_prob.shape[1] == self.n_genes)
with cell_plate, gene_plate:
perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob), infer={"enumerate": "parallel"})
elif self.efficiency_mode == "mixture_high_moi": # for simulation only!
pert_prob = guide_efficacy.expand((self.n_cells, -1, -1)).transpose(-3, -2)
assert pert_prob.shape == (self.n_perturbations, self.n_cells, 1)
mean_perturbation_effect = perturbed * (guides_observed @ guide_effects)
elif self.efficiency_mode == "mixture_high_moi": # for simulation only!
pert_prob = guide_efficacy.expand((self.n_cells, -1, -1)).transpose(-3, -2)
assert pert_prob.shape == (self.n_perturbations, self.n_cells, 1)
with cell_plate:
perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob)).squeeze(-1).T
else:
raise Exception("efficiency_mode must be either 'scaled' or 'mixture'")
mean_perturbation_effect = perturbed * guides_observed @ guide_effects
else:
raise Exception("efficiency_mode must be either 'scaled' or 'mixture'")

mean_perturbation_effect = perturbed * guides_observed @ guide_effects
# if self.use_crispr_factor:
# log_guide_counts = torch.log2(1 + guides_observed.sum(dim=-1, keepdim=True))
# with gene_plate:
Expand Down Expand Up @@ -419,10 +446,10 @@ def model(self, idx, **tensor_dict):
)
covariate_effects = cont_covariates @ cont_covariate_effect_size

nb_log_mean_ctrl = gene_base_log_mean + size_factor + batch_effects + covariate_effects
# nb_log_mean_ctrl = (
# gene_base_log_mean + size_factor + batch_effects + covariate_effects + cell_factor_effects
# )
# nb_log_mean_ctrl = gene_base_log_mean + size_factor + batch_effects + covariate_effects
nb_log_mean_ctrl = (
gene_base_log_mean + size_factor + batch_effects + covariate_effects + cell_factor_effects
)

# if not self.dispersion_effects:
# nb_log_dispersion = gene_log_dispersion
Expand Down
2 changes: 1 addition & 1 deletion src/perturbo/simulation/_pyro_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def simulate_data_from_trained_model(
module_kwargs = {
"n_batches": model.module.n_batches,
"n_cont_covariates": model.module.n_cont_covariates - 1, # size factor auto included
# "n_factors": model.module.n_factors,
"n_factors": model.module.n_factors,
# "dispersion_effects": model.module.dispersion_effects,
"likelihood": model.module.likelihood,
"effect_prior_dist": model.module.effect_prior_dist,
Expand Down
Loading