diff --git a/src/perturbo/models/_model.py b/src/perturbo/models/_model.py index f03a95b..5faef88 100644 --- a/src/perturbo/models/_model.py +++ b/src/perturbo/models/_model.py @@ -14,6 +14,7 @@ from scvi.dataloaders import AnnDataLoader, DeviceBackedDataSplitter from scvi.model.base import ( BaseModelClass, + PyroJitGuideWarmup, PyroSampleMixin, PyroSviTrainMixin, ) @@ -85,10 +86,9 @@ def __init__( sample_var = (X**2).mean(axis=0).squeeze() - sample_mean_squared theta_hat = torch.tensor(sample_mean_squared / (sample_var - sample_mean)).clamp(min=1e-1) - init_values = { - "log_gene_mean": torch.tensor(sample_mean, dtype=torch.float32).log(), - "log_gene_dispersion": torch.tensor(theta_hat).log(), - } + log_gene_mean_init = torch.tensor(sample_mean, dtype=torch.float32).log() + log_gene_dispersion_init = theta_hat.float().log() + # if control_guides is not None and "n_factors" in model_kwargs and guide_by_element is not None: # # control_guides, _ = torch.max(guide_by_element[:, control_elements], dim=-1) # control_mask = self.read_matrix_from_registry(REGISTRY_KEYS.PERTURBATION_KEY)[:, control_guides].sum(dim=-1) @@ -104,7 +104,8 @@ def __init__( n_genes=self.summary_stats.n_vars, n_cont_covariates=n_extra_continuous_covs, n_elements=n_elements, - init_values=init_values, + log_gene_mean_init=log_gene_mean_init, + log_gene_dispersion_init=log_gene_dispersion_init, guide_by_element=guide_by_element, gene_by_element=gene_by_element, # n_cats_per_cov=n_cats_per_cov, @@ -230,7 +231,6 @@ def setup_mudata( mod_key=modalities.rna_layer, ) - batch_field = fields.MuDataCategoricalObsField( REGISTRY_KEYS.BATCH_KEY, batch_key, @@ -390,9 +390,9 @@ def train( es = "early_stopping" trainer_kwargs[es] = early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] - # if "callbacks" not in trainer_kwargs.keys(): - # trainer_kwargs["callbacks"] = [] - # trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) + if "callbacks" not in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] = [] + trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) runner = self._train_runner_cls( self, diff --git a/src/perturbo/models/_module.py b/src/perturbo/models/_module.py index 6bdccd9..a043a0b 100644 --- a/src/perturbo/models/_module.py +++ b/src/perturbo/models/_module.py @@ -5,7 +5,7 @@ import pyro.distributions as dist import torch from pyro import poutine -from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median, init_to_value +from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median from scvi.module.base import PyroBaseModuleClass from ._constants import REGISTRY_KEYS @@ -19,6 +19,19 @@ def sample(self, sample_shape=torch.Size()): return dist.NegativeBinomial(total_count=self.total_count, logits=self.logits + normals).sample() +class LazyInitToValue: + def __init__(self, module, name_map): + self.module = module + self.name_map = name_map # e.g., {"my_param": "my_init"} + + def __call__(self, site): + name = site["name"] + if name in self.name_map: + buf = getattr(self.module, self.name_map[name]) + return buf.to(site["fn"].support.device) + return None # fallback to Pyro default + + class PerTurboPyroModule(PyroBaseModuleClass): def __init__( self, @@ -28,7 +41,8 @@ def __init__( n_elements: int | None = None, n_cont_covariates: int | None = None, n_batches: int | None = 1, - init_values: dict[torch.Tensor] | None = None, + log_gene_mean_init: torch.Tensor | None = None, + log_gene_dispersion_init: torch.Tensor | None = None, guide_by_element: torch.Tensor | None = None, gene_by_element: torch.Tensor | None = None, # guide_noise: bool = False, @@ -112,16 +126,21 @@ def __init__( # self.delta_sites = ["cell_factors"] # self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"] - self._guide = AutoGuideList(self.model, create_plates=self.create_plates) - # init_values = init_values or {} + if log_gene_mean_init is None: + log_gene_mean_init = torch.zeros(self.n_genes) + + if log_gene_dispersion_init is None: + log_gene_dispersion_init = torch.ones(self.n_genes) # if control_pcs is not None and n_factors is not None: # init_values["cell_loadings"] = control_pcs + self._guide = AutoGuideList(self.model, create_plates=self.create_plates) + self._guide.append( AutoNormal( poutine.block(self.model, hide=self.delta_sites + self.discrete_sites), - init_loc_fn=init_to_value(values=init_values, fallback=init_to_median), + init_loc_fn=lambda x: init_to_median(x, num_samples=100), ), ) @@ -129,10 +148,7 @@ def __init__( self._guide.append( AutoDelta( poutine.block(self.model, expose=self.delta_sites), - init_loc_fn=init_to_value( - values=init_values, - fallback=init_to_median, - ), + init_loc_fn=lambda x: init_to_median(x, num_samples=100), ) ) @@ -153,8 +169,8 @@ def __init__( self.register_buffer("one", torch.tensor(1.0)) # per-gene hyperparams - self.register_buffer("gene_mean_prior_loc", torch.tensor(0.0)) - self.register_buffer("gene_disp_prior_loc", torch.tensor(1.0)) + self.register_buffer("gene_mean_prior_loc", log_gene_mean_init) + self.register_buffer("gene_disp_prior_loc", log_gene_dispersion_init) self.register_buffer("gene_mean_prior_scale", torch.tensor(3.0)) self.register_buffer("gene_disp_prior_scale", torch.tensor(3.0))