diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index d4f5594..43b6866 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -89,7 +89,7 @@ def __init__( log_disp = np.where(np.isfinite(log_disp), log_disp, log_disp_smoothed) if dispersion_smoothing != "none": log_disp_smoothed = smoothing_factor * log_disp_smoothed + (1 - smoothing_factor) * log_disp - log_means = np.clip(log_means, 1 / X.shape[0], None) + log_means = np.clip(log_means, a_min=np.log(1 / X.shape[0]), a_max=None) # if control_guides is not None and "n_factors" in model_kwargs and guide_by_element is not None: # # control_guides, _ = torch.max(guide_by_element[:, control_elements], dim=-1) @@ -106,8 +106,8 @@ def __init__( n_genes=self.summary_stats.n_vars, n_cont_covariates=n_extra_continuous_covs, n_elements=n_elements, - log_gene_mean_init=torch.tensor(log_means), - log_gene_dispersion_init=torch.tensor(log_disp_smoothed), + log_gene_mean_init=torch.tensor(log_means, dtype=torch.float32), + log_gene_dispersion_init=torch.tensor(log_disp_smoothed, dtype=torch.float32), guide_by_element=guide_by_element, gene_by_element=gene_by_element, # n_cats_per_cov=n_cats_per_cov, @@ -425,7 +425,7 @@ def get_element_effects(self): # scale_values = loc_plus_scale_values - loc_values # loc_values, scale_values = self.module.guide._get_loc_and_scale("element_effects") - if self.module.local_effects: + if hasattr(self.module, "element_by_gene_idx"): # loc/scale_values are the nonzero elements of a sparse matrix of elements by genes i, j = self.module.element_by_gene_idx.detach().cpu().numpy().astype(int) diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index cab1542..d3fe940 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -41,6 +41,7 @@ def __init__( # use_interactions: bool = False, efficiency_mode: Literal["mixture", "scaled"] = "scaled", # dispersion_effects: bool = False, + sparse_local_effects=False, fit_guide_efficacy: bool = True, prior_param_dict: Mapping[str, torch.Tensor] | None = None, **module_kwargs, @@ -77,7 +78,7 @@ def __init__( super().__init__() # set user-defined options for model behavior # self.dispersion_effects = dispersion_effects - for k, v in module_kwargs.items(): + for k in module_kwargs: warnings.warn(f"Unused module_kwargs: {k}", stacklevel=2) self.likelihood = likelihood @@ -88,7 +89,7 @@ def __init__( self.effect_prior_dist = effect_prior_dist # self.use_interactions = use_interactions self.efficiency_mode = efficiency_mode - self.local_effects = gene_by_element is not None + self.local_effects = sparse_local_effects # self.guide_noise = guide_noise # copy data summary stats @@ -114,7 +115,6 @@ def __init__( self.n_batches = n_batches - # self.delta_sites = ["log_gene_mean"] self.delta_sites = [] # self.delta_sites = ["cell_factors"] # self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"] @@ -150,12 +150,17 @@ def __init__( # guide_by_element encoding self.register_buffer("guide_by_element", guide_by_element) + if gene_by_element is not None: + self.register_buffer("element_by_gene", gene_by_element.T) + else: + self.element_by_gene = None + if self.local_effects: assert gene_by_element.shape[1] == self.n_elements - self.register_buffer("element_by_gene", gene_by_element.T) self.register_buffer("element_by_gene_idx", gene_by_element.T.to_sparse_coo().indices()) - # self.register_buffer("guide_by_gene_idx", (guide_by_element @ gene_by_element.T).to_sparse_coo().indices()) + self.register_buffer("guide_by_gene_idx", (guide_by_element @ gene_by_element.T).to_sparse_coo().indices()) self.n_element_effects = self.element_by_gene_idx.shape[1] if self.local_effects else 1 + self.n_guide_effects = self.guide_by_gene_idx.shape[1] if self.local_effects else 1 # global hyperparams self.register_buffer("zero", torch.tensor(0.0)) @@ -236,7 +241,8 @@ def create_plates(self, idx, **tensor_dict): pyro.plate("Genes", self.n_genes, dim=-1), pyro.plate("Covariates", self.n_cont_covariates, dim=-2), pyro.plate("Elements_sparse", self.n_element_effects, dim=-1), - # pyro.plate("Cell_factors", self.n_factors, dim=-3), + pyro.plate("Guides_sparse", self.n_guide_effects, dim=-1), + pyro.plate("Cell_factors", self.n_factors, dim=-3), pyro.plate("Pert_factors", self.n_pert_factors, dim=-3), ) @@ -250,7 +256,8 @@ def model(self, idx, **tensor_dict): gene_plate, cont_covariate_plate, element_effects_plate, # sparse mode - # cell_factor_plate, + guide_plate_sparse, + cell_factor_plate, pert_factor_plate, ) = self.create_plates(idx) @@ -286,9 +293,20 @@ def model(self, idx, **tensor_dict): # Pool guide information based on user-specified strategy if not self.fit_guide_efficacy: - guide_efficacy = self.one.expand((self.n_perturbations, 1)) + guide_efficacy = self.one.expand((self.n_perturbations, self.n_genes)) + # elif self.local_effects: + # with guide_plate_sparse: + # guide_efficacy_sparse = pyro.sample( + # "guide_efficacy", + # dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta), + # ) + # guide_efficacy = torch.sparse_coo_tensor( + # self.guide_by_gene_idx, + # guide_efficacy_sparse, + # size=(self.n_guides, self.n_genes), + # ) else: - with guide_plate: + with guide_plate, gene_plate: guide_efficacy = pyro.sample( "guide_efficacy", dist.Beta(self.logit_efficacy_alpha, self.logit_efficacy_beta), @@ -309,34 +327,34 @@ def model(self, idx, **tensor_dict): element_factor_effects = torch.einsum("fei,fjg->eg", pert_factors, pert_loadings) # # Sample cell-specific factors (linear unobserved confounders) if using - # if self.n_factors is not None: - # with cell_factor_plate, cell_plate: - # cell_factors = pyro.sample( - # "cell_factors", - # dist.Laplace(0.0, self.cell_factor_prior_scale), - # ) - # with cell_factor_plate, gene_plate: - # cell_loadings = pyro.sample( - # "cell_loadings", - # dist.Laplace(0.0, self.cell_loading_prior_scale), - # ) - # # cell_factor_scale_term = pyro.sample( - # # "cell_factor_scale_term", - # # dist.LogNormal(self.zero, self.one), - # # ) - # cell_factor_effects = torch.einsum("fci,fjg->cg", cell_factors, cell_loadings) - - # # if self.use_interactions and self.n_pert_factors is not None: - # # with cell_factor_plate, element_plate: - # # pert_cell_factors = pyro.sample( - # # "pert_cell_factors", - # # dist.Laplace(0.0, self.cell_factor_prior_scale), - # # ) - # # element_factor_effects = ( - # # torch.einsum("fei,fjg->eg", pert_cell_factors, cell_loadings) + element_factor_effects - # # ) - # else: - # cell_factor_effects = 0 + if self.n_factors is not None: + with cell_factor_plate, cell_plate: + cell_factors = pyro.sample( + "cell_factors", + dist.Laplace(0.0, self.cell_factor_prior_scale), + ) + with cell_factor_plate, gene_plate: + cell_loadings = pyro.sample( + "cell_loadings", + dist.Laplace(0.0, self.cell_loading_prior_scale), + ) + # cell_factor_scale_term = pyro.sample( + # "cell_factor_scale_term", + # dist.LogNormal(self.zero, self.one), + # ) + cell_factor_effects = torch.einsum("fci,fjg->cg", cell_factors, cell_loadings) + + # if self.use_interactions and self.n_pert_factors is not None: + # with cell_factor_plate, element_plate: + # pert_cell_factors = pyro.sample( + # "pert_cell_factors", + # dist.Laplace(0.0, self.cell_factor_prior_scale), + # ) + # element_factor_effects = ( + # torch.einsum("fei,fjg->eg", pert_cell_factors, cell_loadings) + element_factor_effects + # ) + else: + cell_factor_effects = 0 if self.local_effects: # override factor effects @@ -344,6 +362,10 @@ def model(self, idx, **tensor_dict): # guide_factor_efects = self.guide_by_element @ ((1 - self.element_by_gene) * element_factor_effects) # guide_local_effects = (guide_efficacy * self.guide_by_element) @ element_local_effects # guide_effects = guide_factor_efects + guide_local_effects + elif self.element_by_gene is not None: + element_effects = ( + self.element_by_gene * element_local_effects + (1 - self.element_by_gene) * element_factor_effects + ) else: element_effects = element_factor_effects + element_local_effects @@ -360,20 +382,25 @@ def model(self, idx, **tensor_dict): # guide_effects = pyro.deterministic("guide_effects", guide_effects) # Account for cell-specific latent "perturbation status" variable(s) - with cell_plate: - if self.efficiency_mode == "scaled": - perturbed = guide_efficacy.T - elif self.efficiency_mode == "mixture": - pert_prob = guides_observed @ guide_efficacy + + if self.efficiency_mode == "scaled": + mean_perturbation_effect = guides_observed @ (guide_efficacy * guide_effects) + elif self.efficiency_mode == "mixture": + pert_prob = guides_observed @ guide_efficacy + assert pert_prob.shape[0] == self.n_cells + assert (pert_prob.shape[1] == 1) or (pert_prob.shape[1] == self.n_genes) + with cell_plate, gene_plate: perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob), infer={"enumerate": "parallel"}) - elif self.efficiency_mode == "mixture_high_moi": # for simulation only! - pert_prob = guide_efficacy.expand((self.n_cells, -1, -1)).transpose(-3, -2) - assert pert_prob.shape == (self.n_perturbations, self.n_cells, 1) + mean_perturbation_effect = perturbed * (guides_observed @ guide_effects) + elif self.efficiency_mode == "mixture_high_moi": # for simulation only! + pert_prob = guide_efficacy.expand((self.n_cells, -1, -1)).transpose(-3, -2) + assert pert_prob.shape == (self.n_perturbations, self.n_cells, 1) + with cell_plate: perturbed = pyro.sample("perturbed", dist.Bernoulli(pert_prob)).squeeze(-1).T - else: - raise Exception("efficiency_mode must be either 'scaled' or 'mixture'") + mean_perturbation_effect = perturbed * guides_observed @ guide_effects + else: + raise Exception("efficiency_mode must be either 'scaled' or 'mixture'") - mean_perturbation_effect = perturbed * guides_observed @ guide_effects # if self.use_crispr_factor: # log_guide_counts = torch.log2(1 + guides_observed.sum(dim=-1, keepdim=True)) # with gene_plate: @@ -419,10 +446,10 @@ def model(self, idx, **tensor_dict): ) covariate_effects = cont_covariates @ cont_covariate_effect_size - nb_log_mean_ctrl = gene_base_log_mean + size_factor + batch_effects + covariate_effects - # nb_log_mean_ctrl = ( - # gene_base_log_mean + size_factor + batch_effects + covariate_effects + cell_factor_effects - # ) + # nb_log_mean_ctrl = gene_base_log_mean + size_factor + batch_effects + covariate_effects + nb_log_mean_ctrl = ( + gene_base_log_mean + size_factor + batch_effects + covariate_effects + cell_factor_effects + ) # if not self.dispersion_effects: # nb_log_dispersion = gene_log_dispersion diff --git a/src/perturbo/simulation/_pyro_simulator.py b/src/perturbo/simulation/_pyro_simulator.py index 1a92a63..5039e08 100644 --- a/src/perturbo/simulation/_pyro_simulator.py +++ b/src/perturbo/simulation/_pyro_simulator.py @@ -139,7 +139,7 @@ def simulate_data_from_trained_model( module_kwargs = { "n_batches": model.module.n_batches, "n_cont_covariates": model.module.n_cont_covariates - 1, # size factor auto included - # "n_factors": model.module.n_factors, + "n_factors": model.module.n_factors, # "dispersion_effects": model.module.dispersion_effects, "likelihood": model.module.likelihood, "effect_prior_dist": model.module.effect_prior_dist,