Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
42 changes: 29 additions & 13 deletions src/apm_cli/commands/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
list = builtins.list
dict = builtins.dict

from ..core.auth import AuthResolver

# APM Dependencies (conditional import for graceful degradation)
APM_DEPS_AVAILABLE = False
_APM_IMPORT_ERROR = None
Expand All @@ -56,7 +58,7 @@
# ---------------------------------------------------------------------------


def _validate_and_add_packages_to_apm_yml(packages, dry_run=False, dev=False, logger=None):
def _validate_and_add_packages_to_apm_yml(packages, dry_run=False, dev=False, logger=None, auth_resolver=None):
"""Validate packages exist and can be accessed, then add to apm.yml dependencies section.

Implements normalize-on-write: any input form (HTTPS URL, SSH URL, FQDN, shorthand)
Expand All @@ -68,6 +70,7 @@ def _validate_and_add_packages_to_apm_yml(packages, dry_run=False, dev=False, lo
dry_run: If True, only show what would be added.
dev: If True, write to devDependencies instead of dependencies.
logger: InstallLogger for structured output.
auth_resolver: Shared auth resolver for caching credentials.

Returns:
Tuple of (validated_packages list, _ValidationOutcome).
Expand Down Expand Up @@ -148,7 +151,7 @@ def _validate_and_add_packages_to_apm_yml(packages, dry_run=False, dev=False, lo

# Validate package exists and is accessible
verbose = bool(logger and logger.verbose)
if _validate_package_exists(package, verbose=verbose):
if _validate_package_exists(package, verbose=verbose, auth_resolver=auth_resolver):
valid_outcomes.append((canonical, already_in_deps))
if logger:
logger.validation_pass(canonical, already_present=already_in_deps)
Expand Down Expand Up @@ -214,15 +217,17 @@ def _validate_and_add_packages_to_apm_yml(packages, dry_run=False, dev=False, lo
return validated_packages, outcome


def _validate_package_exists(package, verbose=False):
def _validate_package_exists(package, verbose=False, auth_resolver=None):
"""Validate that a package exists and is accessible on GitHub, Azure DevOps, or locally."""
import os
import subprocess
import tempfile
from apm_cli.core.auth import AuthResolver

verbose_log = (lambda msg: _rich_echo(f" {msg}", color="dim")) if verbose else None
auth_resolver = AuthResolver()
# Use provided resolver or create new one if not in a CLI session context
if auth_resolver is None:
auth_resolver = AuthResolver()

try:
# Parse the package to check if it's a virtual package or ADO
Expand Down Expand Up @@ -252,8 +257,8 @@ def _validate_package_exists(package, verbose=False):
org = dep_ref.repo_url.split('/')[0] if dep_ref.repo_url and '/' in dep_ref.repo_url else None
if verbose_log:
verbose_log(f"Auth resolved: host={host}, org={org}, source={ctx.source}, type={ctx.token_type}")
downloader = GitHubPackageDownloader(auth_resolver=auth_resolver)
result = downloader.validate_virtual_package_exists(dep_ref)
virtual_downloader = GitHubPackageDownloader(auth_resolver=auth_resolver)
result = virtual_downloader.validate_virtual_package_exists(dep_ref)
if not result and verbose_log:
try:
err_ctx = auth_resolver.build_error_context(host, f"accessing {package}", org=org)
Expand All @@ -268,13 +273,13 @@ def _validate_package_exists(package, verbose=False):
if dep_ref.is_azure_devops() or (dep_ref.host and dep_ref.host != "github.com"):
from apm_cli.utils.github_host import is_github_hostname, is_azure_devops_hostname

downloader = GitHubPackageDownloader()
ado_downloader = GitHubPackageDownloader(auth_resolver=auth_resolver)
# Set the host
if dep_ref.host:
downloader.github_host = dep_ref.host
ado_downloader.github_host = dep_ref.host

# Build authenticated URL using downloader's auth
package_url = downloader._build_repo_url(
package_url = ado_downloader._build_repo_url(
dep_ref.repo_url, use_ssh=False, dep_ref=dep_ref
)

Expand All @@ -283,11 +288,11 @@ def _validate_package_exists(package, verbose=False):
# This mirrors _clone_with_fallback() which does the same relaxation.
is_generic = not is_github_hostname(dep_ref.host) and not is_azure_devops_hostname(dep_ref.host)
if is_generic:
validate_env = {k: v for k, v in downloader.git_env.items()
validate_env = {k: v for k, v in ado_downloader.git_env.items()
if k not in ('GIT_ASKPASS', 'GIT_CONFIG_GLOBAL', 'GIT_CONFIG_NOSYSTEM')}
validate_env['GIT_TERMINAL_PROMPT'] = '0'
else:
validate_env = {**os.environ, **downloader.git_env}
validate_env = {**os.environ, **ado_downloader.git_env}

if verbose_log:
verbose_log(f"Trying git ls-remote for {dep_ref.host}")
Expand Down Expand Up @@ -491,6 +496,10 @@ def install(ctx, packages, runtime, exclude, only, update, dry_run, force, verbo
is_partial = bool(packages)
logger = InstallLogger(verbose=verbose, dry_run=dry_run, partial=is_partial)

# Create shared auth resolver for all downloads in this CLI invocation
# to ensure credentials are cached and reused (prevents duplicate auth popups)
auth_resolver = AuthResolver()

# Check if apm.yml exists
apm_yml_exists = Path(APM_YML_FILENAME).exists()

Expand All @@ -512,7 +521,7 @@ def install(ctx, packages, runtime, exclude, only, update, dry_run, force, verbo
# If packages are specified, validate and add them to apm.yml first
if packages:
validated_packages, outcome = _validate_and_add_packages_to_apm_yml(
packages, dry_run, dev=dev, logger=logger,
packages, dry_run, dev=dev, logger=logger, auth_resolver=auth_resolver,
)
# Short-circuit: all packages failed validation — nothing to install
if outcome.all_failed:
Expand Down Expand Up @@ -613,6 +622,7 @@ def install(ctx, packages, runtime, exclude, only, update, dry_run, force, verbo
apm_package, update, verbose, only_pkgs, force=force,
parallel_downloads=parallel_downloads,
logger=logger,
auth_resolver=auth_resolver,
)
apm_count = install_result.installed_count
prompt_count = install_result.prompts_integrated
Expand Down Expand Up @@ -1023,6 +1033,7 @@ def _install_apm_dependencies(
force: bool = False,
parallel_downloads: int = 4,
logger: "InstallLogger" = None,
auth_resolver: "AuthResolver" = None,
):
"""Install APM package dependencies.

Expand All @@ -1034,6 +1045,7 @@ def _install_apm_dependencies(
force: Whether to overwrite locally-authored files on collision
parallel_downloads: Max concurrent downloads (0 disables parallelism)
logger: InstallLogger for structured output
auth_resolver: Shared auth resolver for caching credentials
"""
if not APM_DEPS_AVAILABLE:
raise RuntimeError("APM dependency system not available")
Expand Down Expand Up @@ -1066,8 +1078,12 @@ def _install_apm_dependencies(
apm_modules_dir = project_root / APM_MODULES_DIR
apm_modules_dir.mkdir(exist_ok=True)

# Use provided resolver or create new one if not in a CLI session context
if auth_resolver is None:
auth_resolver = AuthResolver()

# Create downloader early so it can be used for transitive dependency resolution
downloader = GitHubPackageDownloader()
downloader = GitHubPackageDownloader(auth_resolver=auth_resolver)

# Track direct dependency keys so the download callback can distinguish them from transitive
direct_dep_keys = builtins.set(dep.get_unique_key() for dep in apm_deps)
Expand Down
35 changes: 18 additions & 17 deletions src/apm_cli/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,25 +182,26 @@ def resolve(self, host: str, org: Optional[str] = None) -> AuthContext:
"""Resolve auth for *(host, org)*. Cached & thread-safe."""
key = (host.lower() if host else host, org.lower() if org else org)
with self._lock:
if key in self._cache:
return self._cache[key]

host_info = self.classify_host(host)
token, source = self._resolve_token(host_info, org)
token_type = self.detect_token_type(token) if token else "unknown"
git_env = self._build_git_env(token)

ctx = AuthContext(
token=token,
source=source,
token_type=token_type,
host_info=host_info,
git_env=git_env,
)
cached = self._cache.get(key)
if cached is not None:
return cached

with self._lock:
# Keep cache fill inside the lock to avoid concurrent duplicate
# credential-helper lookups for the same host/org.
host_info = self.classify_host(host)
token, source = self._resolve_token(host_info, org)
token_type = self.detect_token_type(token) if token else "unknown"
git_env = self._build_git_env(token)

ctx = AuthContext(
token=token,
source=source,
token_type=token_type,
host_info=host_info,
git_env=git_env,
)
self._cache[key] = ctx
return ctx
return ctx

def resolve_for_dep(self, dep_ref: "DependencyReference") -> AuthContext:
"""Resolve auth from a ``DependencyReference``."""
Expand Down
27 changes: 13 additions & 14 deletions src/apm_cli/deps/github_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,19 @@ def _setup_git_environment(self) -> Dict[str, Any]:
else:
env['GIT_CONFIG_GLOBAL'] = '/dev/null'

# Resolve default host tokens via AuthResolver (backward compat properties)
default_ctx = self.auth_resolver.resolve(default_host())
self._default_github_ctx = default_ctx
self.github_token = default_ctx.token
self.has_github_token = default_ctx.token is not None
self._github_token_from_credential_fill = (
self.has_github_token
and self.token_manager.get_token_for_purpose('modules', env) is None
)

# Azure DevOps
ado_ctx = self.auth_resolver.resolve("dev.azure.com")
self.ado_token = ado_ctx.token
self.has_ado_token = ado_ctx.token is not None
# IMPORTANT: Do not resolve credentials via helpers at construction time.
# AuthResolver.resolve(...) can trigger OS credential helper UI. If we do
# this eagerly (host-only key) and later resolve per-dependency (host+org),
# users can see duplicate auth prompts. Keep constructor token state env-only
# and resolve lazily per dependency during clone/validate flows.
self._default_github_ctx = None
self.github_token = self.token_manager.get_token_for_purpose('modules', env)
self.has_github_token = self.github_token is not None
self._github_token_from_credential_fill = False

# Azure DevOps (env-only at init; lazy auth resolution happens per dep)
self.ado_token = self.token_manager.get_token_for_purpose('ado_modules', env)
self.has_ado_token = self.ado_token is not None

# JFrog Artifactory (not host-based, uses dedicated env var)
self.artifactory_token = self.token_manager.get_token_for_purpose('artifactory_modules', env)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_github_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def test_setup_git_environment_no_token(self):
# Should not have GitHub tokens in environment
assert 'GITHUB_TOKEN' not in env or not env['GITHUB_TOKEN']
assert 'GH_TOKEN' not in env or not env['GH_TOKEN']

def test_setup_git_environment_does_not_eagerly_call_credential_helper(self):
"""Constructor should not invoke git credential helper (lazy per-dep auth)."""
with patch.dict(os.environ, {}, clear=True):
with patch(
'apm_cli.core.token_manager.GitHubTokenManager.resolve_credential_from_git'
) as mock_cred:
GitHubPackageDownloader()
mock_cred.assert_not_called()

@patch('apm_cli.deps.github_downloader.Repo')
@patch('tempfile.mkdtemp')
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Unit tests for AuthResolver, HostInfo, and AuthContext."""

import os
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -134,6 +136,25 @@ def test_caching(self):
ctx2 = resolver.resolve("github.com", org="microsoft")
assert ctx1 is ctx2

def test_caching_is_singleflight_under_concurrency(self):
"""Concurrent resolve() calls for the same key should populate cache once."""
resolver = AuthResolver()

def _slow_resolve_token(host_info, org):
time.sleep(0.05)
return ("cred-token", "git-credential-fill")

with patch.object(AuthResolver, "_resolve_token", side_effect=_slow_resolve_token) as mock_resolve:
with ThreadPoolExecutor(max_workers=8) as pool:
futures = [
pool.submit(resolver.resolve, "github.com", "microsoft")
for _ in range(8)
]
contexts = [f.result() for f in futures]

assert mock_resolve.call_count == 1
assert all(ctx is contexts[0] for ctx in contexts)

def test_different_orgs_different_cache(self):
"""Different orgs get different cache entries."""
with patch.dict(os.environ, {
Expand Down