Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 67 additions & 6 deletions chorus/core/weights_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,53 @@ def _probe_legnet() -> Tuple[bool, List[str]]:
return (True, [])


def _hf_cache_dir() -> Path:
"""Resolve the HuggingFace hub cache directory.

Prefer the canonical ``huggingface_hub.constants.HF_HUB_CACHE`` (which
honours ``HF_HOME`` / ``HF_HUB_CACHE`` env vars); fall back to the
documented default ``~/.cache/huggingface/hub`` if the import fails
(chorus base env always has huggingface_hub, so the fallback is
defensive).
"""
try:
from huggingface_hub.constants import HF_HUB_CACHE
return Path(HF_HUB_CACHE)
except Exception:
return Path.home() / ".cache" / "huggingface" / "hub"


def _probe_chrombpnet() -> Tuple[bool, List[str]]:
# ChromBPNet has no constructor-level default; `chorus setup chrombpnet`
# pre-downloads the canonical DNASE / K562 model.
default = CHORUS_DOWNLOADS_DIR / "chrombpnet" / "DNASE_K562"
if not default.exists() or not any(default.iterdir()):
return (False, [str(default)])
return (True, [])
"""Accept either the 0.3.0+ slim HF mirror or the legacy ENCODE-tarball
layout as proof of install.

Default in 0.3.0+: weights stream from
``lucapinello/chorus-chrombpnet-slim`` and live in the HF hub cache.
The local ``downloads/chrombpnet/DNASE_K562`` directory is only
populated for users who explicitly request ``model_type='chrombpnet'``
(bias-aware) or fold ≠ 0, both of which fall back to ENCODE tarballs.
`chorus setup chrombpnet` succeeds on the slim path, so we must
accept either cache as installed.
"""
# 0.3.0+ default: slim HF mirror.
slim_snapshots = (
_hf_cache_dir() / "models--lucapinello--chorus-chrombpnet-slim" / "snapshots"
)
if slim_snapshots.exists():
for snap in slim_snapshots.iterdir():
if (snap / "manifest.json").exists():
return (True, [])
# Legacy ENCODE-tarball cache.
legacy = CHORUS_DOWNLOADS_DIR / "chrombpnet" / "DNASE_K562"
if legacy.exists() and any(legacy.iterdir()):
return (True, [])
return (
False,
[
f"neither HF slim mirror cache ({slim_snapshots.parent}) "
f"nor legacy {legacy} is populated"
],
)


def _probe_alphagenome() -> Tuple[bool, List[str]]:
Expand Down Expand Up @@ -88,11 +128,32 @@ def _probe_library_cached() -> Tuple[bool, List[str]]:
return (True, [])


def _probe_alphagenome_pt() -> Tuple[bool, List[str]]:
"""Check the HF cache for the upstream PyTorch port's safetensors.

Weights live at ``models--gtca--alphagenome_pytorch/snapshots/<rev>/
model_all_folds.safetensors`` after ``chorus setup --oracle
alphagenome_pt`` (or the first ``load_pretrained_model()`` call).
"""
pt_snapshots = (
_hf_cache_dir() / "models--gtca--alphagenome_pytorch" / "snapshots"
)
if pt_snapshots.exists():
for snap in pt_snapshots.iterdir():
if any(p.suffix == ".safetensors" for p in snap.iterdir()):
return (True, [])
return (
False,
[f"HF cache not populated: {pt_snapshots.parent}"],
)


_ARTIFACT_PROBES: Dict[str, Callable[[], Tuple[bool, List[str]]]] = {
"sei": _probe_sei,
"legnet": _probe_legnet,
"chrombpnet": _probe_chrombpnet,
"alphagenome": _probe_alphagenome,
"alphagenome_pt": _probe_alphagenome_pt,
"enformer": _probe_library_cached,
"borzoi": _probe_library_cached,
}
Expand Down
Loading