diff --git a/chorus/core/weights_probe.py b/chorus/core/weights_probe.py index 79e36e8..23f81a9 100644 --- a/chorus/core/weights_probe.py +++ b/chorus/core/weights_probe.py @@ -49,13 +49,53 @@ def _probe_legnet() -> Tuple[bool, List[str]]: return (True, []) +def _hf_cache_dir() -> Path: + """Resolve the HuggingFace hub cache directory. + + Prefer the canonical ``huggingface_hub.constants.HF_HUB_CACHE`` (which + honours ``HF_HOME`` / ``HF_HUB_CACHE`` env vars); fall back to the + documented default ``~/.cache/huggingface/hub`` if the import fails + (chorus base env always has huggingface_hub, so the fallback is + defensive). + """ + try: + from huggingface_hub.constants import HF_HUB_CACHE + return Path(HF_HUB_CACHE) + except Exception: + return Path.home() / ".cache" / "huggingface" / "hub" + + def _probe_chrombpnet() -> Tuple[bool, List[str]]: - # ChromBPNet has no constructor-level default; `chorus setup chrombpnet` - # pre-downloads the canonical DNASE / K562 model. - default = CHORUS_DOWNLOADS_DIR / "chrombpnet" / "DNASE_K562" - if not default.exists() or not any(default.iterdir()): - return (False, [str(default)]) - return (True, []) + """Accept either the 0.3.0+ slim HF mirror or the legacy ENCODE-tarball + layout as proof of install. + + Default in 0.3.0+: weights stream from + ``lucapinello/chorus-chrombpnet-slim`` and live in the HF hub cache. + The local ``downloads/chrombpnet/DNASE_K562`` directory is only + populated for users who explicitly request ``model_type='chrombpnet'`` + (bias-aware) or fold ≠ 0, both of which fall back to ENCODE tarballs. + `chorus setup chrombpnet` succeeds on the slim path, so we must + accept either cache as installed. + """ + # 0.3.0+ default: slim HF mirror. + slim_snapshots = ( + _hf_cache_dir() / "models--lucapinello--chorus-chrombpnet-slim" / "snapshots" + ) + if slim_snapshots.exists(): + for snap in slim_snapshots.iterdir(): + if (snap / "manifest.json").exists(): + return (True, []) + # Legacy ENCODE-tarball cache. + legacy = CHORUS_DOWNLOADS_DIR / "chrombpnet" / "DNASE_K562" + if legacy.exists() and any(legacy.iterdir()): + return (True, []) + return ( + False, + [ + f"neither HF slim mirror cache ({slim_snapshots.parent}) " + f"nor legacy {legacy} is populated" + ], + ) def _probe_alphagenome() -> Tuple[bool, List[str]]: @@ -88,11 +128,32 @@ def _probe_library_cached() -> Tuple[bool, List[str]]: return (True, []) +def _probe_alphagenome_pt() -> Tuple[bool, List[str]]: + """Check the HF cache for the upstream PyTorch port's safetensors. + + Weights live at ``models--gtca--alphagenome_pytorch/snapshots// + model_all_folds.safetensors`` after ``chorus setup --oracle + alphagenome_pt`` (or the first ``load_pretrained_model()`` call). + """ + pt_snapshots = ( + _hf_cache_dir() / "models--gtca--alphagenome_pytorch" / "snapshots" + ) + if pt_snapshots.exists(): + for snap in pt_snapshots.iterdir(): + if any(p.suffix == ".safetensors" for p in snap.iterdir()): + return (True, []) + return ( + False, + [f"HF cache not populated: {pt_snapshots.parent}"], + ) + + _ARTIFACT_PROBES: Dict[str, Callable[[], Tuple[bool, List[str]]]] = { "sei": _probe_sei, "legnet": _probe_legnet, "chrombpnet": _probe_chrombpnet, "alphagenome": _probe_alphagenome, + "alphagenome_pt": _probe_alphagenome_pt, "enformer": _probe_library_cached, "borzoi": _probe_library_cached, }