Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/perturbo/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scvi.dataloaders import AnnDataLoader, DeviceBackedDataSplitter
from scvi.model.base import (
BaseModelClass,
PyroJitGuideWarmup,
PyroSampleMixin,
PyroSviTrainMixin,
)
Expand Down Expand Up @@ -85,10 +86,9 @@ def __init__(
sample_var = (X**2).mean(axis=0).squeeze() - sample_mean_squared

theta_hat = torch.tensor(sample_mean_squared / (sample_var - sample_mean)).clamp(min=1e-1)
init_values = {
"log_gene_mean": torch.tensor(sample_mean, dtype=torch.float32).log(),
"log_gene_dispersion": torch.tensor(theta_hat).log(),
}
log_gene_mean_init = torch.tensor(sample_mean, dtype=torch.float32).log()
log_gene_dispersion_init = theta_hat.float().log()

# if control_guides is not None and "n_factors" in model_kwargs and guide_by_element is not None:
# # control_guides, _ = torch.max(guide_by_element[:, control_elements], dim=-1)
# control_mask = self.read_matrix_from_registry(REGISTRY_KEYS.PERTURBATION_KEY)[:, control_guides].sum(dim=-1)
Expand All @@ -104,7 +104,8 @@ def __init__(
n_genes=self.summary_stats.n_vars,
n_cont_covariates=n_extra_continuous_covs,
n_elements=n_elements,
init_values=init_values,
log_gene_mean_init=log_gene_mean_init,
log_gene_dispersion_init=log_gene_dispersion_init,
guide_by_element=guide_by_element,
gene_by_element=gene_by_element,
# n_cats_per_cov=n_cats_per_cov,
Expand Down Expand Up @@ -230,7 +231,6 @@ def setup_mudata(
mod_key=modalities.rna_layer,
)


batch_field = fields.MuDataCategoricalObsField(
REGISTRY_KEYS.BATCH_KEY,
batch_key,
Expand Down Expand Up @@ -390,9 +390,9 @@ def train(
es = "early_stopping"
trainer_kwargs[es] = early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]

# if "callbacks" not in trainer_kwargs.keys():
# trainer_kwargs["callbacks"] = []
# trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())
if "callbacks" not in trainer_kwargs.keys():
trainer_kwargs["callbacks"] = []
trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())

runner = self._train_runner_cls(
self,
Expand Down
38 changes: 27 additions & 11 deletions src/perturbo/models/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pyro.distributions as dist
import torch
from pyro import poutine
from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median, init_to_value
from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median
from scvi.module.base import PyroBaseModuleClass

from ._constants import REGISTRY_KEYS
Expand All @@ -19,6 +19,19 @@ def sample(self, sample_shape=torch.Size()):
return dist.NegativeBinomial(total_count=self.total_count, logits=self.logits + normals).sample()


class LazyInitToValue:
def __init__(self, module, name_map):
self.module = module
self.name_map = name_map # e.g., {"my_param": "my_init"}

def __call__(self, site):
name = site["name"]
if name in self.name_map:
buf = getattr(self.module, self.name_map[name])
return buf.to(site["fn"].support.device)
return None # fallback to Pyro default


class PerTurboPyroModule(PyroBaseModuleClass):
def __init__(
self,
Expand All @@ -28,7 +41,8 @@ def __init__(
n_elements: int | None = None,
n_cont_covariates: int | None = None,
n_batches: int | None = 1,
init_values: dict[torch.Tensor] | None = None,
log_gene_mean_init: torch.Tensor | None = None,
log_gene_dispersion_init: torch.Tensor | None = None,
guide_by_element: torch.Tensor | None = None,
gene_by_element: torch.Tensor | None = None,
# guide_noise: bool = False,
Expand Down Expand Up @@ -112,27 +126,29 @@ def __init__(
# self.delta_sites = ["cell_factors"]
# self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"]

self._guide = AutoGuideList(self.model, create_plates=self.create_plates)
# init_values = init_values or {}
if log_gene_mean_init is None:
log_gene_mean_init = torch.zeros(self.n_genes)

if log_gene_dispersion_init is None:
log_gene_dispersion_init = torch.ones(self.n_genes)

# if control_pcs is not None and n_factors is not None:
# init_values["cell_loadings"] = control_pcs

self._guide = AutoGuideList(self.model, create_plates=self.create_plates)

self._guide.append(
AutoNormal(
poutine.block(self.model, hide=self.delta_sites + self.discrete_sites),
init_loc_fn=init_to_value(values=init_values, fallback=init_to_median),
init_loc_fn=lambda x: init_to_median(x, num_samples=100),
),
)

if self.delta_sites:
self._guide.append(
AutoDelta(
poutine.block(self.model, expose=self.delta_sites),
init_loc_fn=init_to_value(
values=init_values,
fallback=init_to_median,
),
init_loc_fn=lambda x: init_to_median(x, num_samples=100),
)
)

Expand All @@ -153,8 +169,8 @@ def __init__(
self.register_buffer("one", torch.tensor(1.0))

# per-gene hyperparams
self.register_buffer("gene_mean_prior_loc", torch.tensor(0.0))
self.register_buffer("gene_disp_prior_loc", torch.tensor(1.0))
self.register_buffer("gene_mean_prior_loc", log_gene_mean_init)
self.register_buffer("gene_disp_prior_loc", log_gene_dispersion_init)

self.register_buffer("gene_mean_prior_scale", torch.tensor(3.0))
self.register_buffer("gene_disp_prior_scale", torch.tensor(3.0))
Expand Down
Loading