From 8b7939d262c711259c0d8cef7d8da8ffc3725060 Mon Sep 17 00:00:00 2001 From: Peter Robicheaux Date: Wed, 6 May 2026 17:21:13 -0400 Subject: [PATCH 1/5] nas: 4 new CLI commands + --group on model list (G4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wraps the new public API routes from the agentic-surface-area onsite plan: roboflow train cancel / [--continue-if-no-refund] roboflow train stop / roboflow train results / roboflow model star [--unstar] Plus extends `roboflow model list -p ` with `-g/--group `, the canonical "list NAS models per run" path. When --group is set, the list command hits the public /models endpoint (full enriched projection: hardware, latency, map5095, paretoOptimalFor, recommended ★) instead of walking versions via the SDK. Adapter additions in roboflow/adapters/rfapi.py: cancel_version_training, stop_version_training, get_training_results, list_project_models (with optional group), get_model_by_url, favorite_nas_model Backend companions: - roboflow#11603 (G1, validator) - roboflow#11605 (G6, projection + ?group=) - roboflow#11610 (G2, public train cancel/stop + favorite) - roboflow#11612 (G3, training results) Tests: +13 cases across test_train_handler.py and test_model_handler.py covering register, success paths, 409 + MODEL_NOT_NAS hint surfacing, unstar flow, and the --group endpoint switch. All 298 CLI tests pass locally; ruff check + ruff format clean. CLI-COMMANDS.md updated with two new sections (train lifecycle + NAS list/star/deploy). E2E: driven against staging (api.roboflow.one) on peter-robicheaux/beer-can-hackathon: - `train results .../410` returned full NAS bundle (52 models, recommendedByHardware, modelGroup) - `model list -p ... -g ` rendered 53-row leaderboard table with HARDWARE / LATENCY / MAP50 / MAP5095 / REC columns - `model star 14CwSGmGetWh6rB0EnjL` → success, favorites reflected - `model star --unstar` flips state - `train cancel .../318` (finished version) → 409 surfaces hint "Cancel only applies to in-flight runs." Co-Authored-By: Claude Opus 4.7 (1M context) --- CLI-COMMANDS.md | 44 ++++++++ roboflow/adapters/rfapi.py | 86 +++++++++++++++ roboflow/cli/handlers/model.py | 140 +++++++++++++++++++++++- roboflow/cli/handlers/train.py | 187 ++++++++++++++++++++++++++++++++ tests/cli/test_model_handler.py | 129 ++++++++++++++++++++++ tests/cli/test_train_handler.py | 104 ++++++++++++++++++ 6 files changed, 688 insertions(+), 2 deletions(-) diff --git a/CLI-COMMANDS.md b/CLI-COMMANDS.md index 34539368..aa961485 100644 --- a/CLI-COMMANDS.md +++ b/CLI-COMMANDS.md @@ -47,6 +47,50 @@ roboflow download my-workspace/my-project/3 -f coco # alias roboflow infer photo.jpg -m my-project/3 ``` +### Train, monitor, cancel, stop + +```bash +# Start training (any architecture). For NAS sweeps, use a NAS parent modelType: +roboflow train start -p my-project -v 3 --type rfdetr-base +roboflow train start -p my-project -v 3 --type rfdetr-nas-parent # NAS sweep +roboflow train start -p my-project -v 3 --type rfdetr-nas-base-parent # NAS Base sweep +roboflow train start -p my-project -v 3 --type rfdetr-nas-seg-parent # NAS instance-segmentation + +# Cancel an in-flight training (any architecture; NAS-aware): +roboflow train cancel my-project/3 +# Pass --continue-if-no-refund to cancel even past the refund window: +roboflow train cancel my-project/3 --continue-if-no-refund + +# Graceful early-stop: +roboflow train stop my-project/3 + +# Run-level training results bundle (NAS leaderboard for NAS runs, +# minimal bundle for non-NAS): +roboflow train results my-project/3 +``` + +NAS sweeps require the version's validation split to have at least 15 images; +the server returns `code: "insufficient_validation_images_for_nas"` otherwise. + +### NAS models — list, star, deploy + +```bash +# Get a NAS run's modelGroup from training results: +roboflow --json train results my-project/3 | jq -r .modelGroup +# → rfdetrNasGroup-3 + +# List every model from one NAS run, with hardware/latency/mAP columns: +roboflow model list -p my-project --group rfdetrNasGroup-3 + +# Star a NAS-trained model (triggers TRT compile for its recommended hardware): +# --json train results … gives you the modelId per row. +roboflow model star +roboflow model star --unstar +``` + +`model star` is NAS-only by server-side design; non-NAS modelTypes return +`code: "MODEL_NOT_NAS"`. + ### Search and export ```bash diff --git a/roboflow/adapters/rfapi.py b/roboflow/adapters/rfapi.py index c0b00f39..717eee64 100644 --- a/roboflow/adapters/rfapi.py +++ b/roboflow/adapters/rfapi.py @@ -84,6 +84,92 @@ def start_version_training( return True +def cancel_version_training( + api_key: str, + workspace_url: str, + project_url: str, + version: str, + *, + continue_if_no_refund: bool = False, +): + """Cancel an in-flight training run. + + Backend handler is canonical for both vanilla and NAS trainings — it + accepts ``mining`` status, so this works for NAS sweeps too. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train/cancel?api_key={api_key}" + body: Dict[str, Union[str, int, bool]] = {} + if continue_if_no_refund: + body["continueIfNoRefund"] = True + response = requests.post(url, json=body) + if not response.ok: + raise RoboflowError(response.text) + return response.json() if response.content else {"success": True} + + +def stop_version_training(api_key: str, workspace_url: str, project_url: str, version: str): + """Request an early stop on an in-flight training run. + + The backend flips ``train.requestedStop``; the run finishes the current + phase gracefully (mining or training). + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train/stop?api_key={api_key}" + response = requests.post(url, json={}) + if not response.ok: + raise RoboflowError(response.text) + return response.json() if response.content else {"success": True} + + +def get_training_results(api_key: str, workspace_url: str, project_url: str, version: str): + """Run-level training results bundle. + + For NAS runs returns ``{ trainingId, status, modelGroup, modelCount, + recommendedByHardware, mining?, models: [...] }``. For non-NAS runs + returns a minimal bundle with the produced model(s). + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/training/results?api_key={api_key}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + +def list_project_models( + api_key: str, + workspace_url: str, + project_url: str, + *, + group: Optional[str] = None, +): + """List models for a project; pass ``group`` to scope to one NAS run.""" + url = f"{API_URL}/{workspace_url}/{project_url}/models?api_key={api_key}" + if group: + url += f"&group={urllib.parse.quote(group, safe='')}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + +def get_model_by_url(api_key: str, workspace_url: str, model_url: str): + """Fetch a single model by its URL slug.""" + encoded = urllib.parse.quote(model_url, safe="/") + url = f"{API_URL}/models/{workspace_url}/{encoded}?api_key={api_key}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + +def favorite_nas_model(api_key: str, workspace_url: str, model_id: str, *, starred: bool = True): + """Star or unstar a NAS-trained model. NAS-only on the server side.""" + url = f"{API_URL}/{workspace_url}/models/{model_id}/favorite?api_key={api_key}" + response = requests.post(url, json={"starred": bool(starred)}) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + def get_version(api_key: str, workspace_url: str, project_url: str, version: str, nocache: bool = False): """ Fetch detailed information about a specific dataset version. diff --git a/roboflow/cli/handlers/model.py b/roboflow/cli/handlers/model.py index 9cf04617..e2fedfd3 100644 --- a/roboflow/cli/handlers/model.py +++ b/roboflow/cli/handlers/model.py @@ -15,9 +15,23 @@ def list_models( ctx: typer.Context, project: Annotated[str, typer.Option("-p", "--project", help="Project ID or shorthand (e.g. my-ws/my-project)")], + group: Annotated[ + Optional[str], + typer.Option( + "-g", + "--group", + help=( + "NAS modelGroup to scope the list to a single NAS run. " + "Get the value from 'roboflow train results /'." + ), + ), + ] = None, ) -> None: - """List trained models for a project.""" - args = ctx_to_args(ctx, project=project) + """List trained models for a project. + + Pass --group to filter to a single NAS run. + """ + args = ctx_to_args(ctx, project=project, group=group) _list_models(args) @@ -31,6 +45,30 @@ def get_model( _get_model(args) +@model_app.command("star") +def star_model( + ctx: typer.Context, + model_id: Annotated[ + str, + typer.Argument( + help=( + "NAS-trained model id (Firestore document id). " + "Get it from 'roboflow train results /' (models[].modelId)." + ), + ), + ], + unstar: Annotated[bool, typer.Option("--unstar", help="Unstar instead of starring")] = False, +) -> None: + """Star or unstar a NAS-trained model. + + NAS-only by design — the server rejects non-NAS modelTypes with a + MODEL_NOT_NAS error. Starring triggers TRT compilation for the model's + recommended hardware so the model becomes deployable as an edge target. + """ + args = ctx_to_args(ctx, model_id=model_id, starred=not unstar) + _star_model(args) + + @model_app.command("infer") def model_infer( ctx: typer.Context, @@ -86,9 +124,11 @@ def upload_model( def _list_models(args): # noqa: ANN001 import roboflow + from roboflow.adapters import rfapi from roboflow.cli._output import output, output_error, suppress_sdk_output from roboflow.cli._resolver import resolve_resource from roboflow.cli._table import format_table + from roboflow.config import load_roboflow_api_key try: workspace_url, project_slug, _version = resolve_resource(args.project, workspace_override=args.workspace) @@ -96,6 +136,52 @@ def _list_models(args): # noqa: ANN001 output_error(args, str(exc)) return + group = getattr(args, "group", None) + + if group: + # NAS path — hit the public /models endpoint with ?group= filter. + # Surfaces full per-row NAS metadata (nasFamily, group, + # train.results.{hardware,latency,map5095,paretoOptimalFor}, + # favorites, recommended). + api_key = args.api_key or load_roboflow_api_key(workspace_url) + if not api_key: + output_error( + args, + "No API key found.", + hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", + exit_code=2, + ) + return + try: + rows = rfapi.list_project_models(api_key, workspace_url, project_slug, group=group) + except rfapi.RoboflowError as exc: + output_error(args, str(exc), exit_code=3) + return + if not isinstance(rows, list): + rows = [] + # Project a leaderboard view for the text table; full row stays in JSON. + table_rows = [] + for r in rows: + metrics = r.get("metrics") or {} + table_rows.append( + { + "url": r.get("url", ""), + "type": r.get("modelType", ""), + "hardware": metrics.get("hardware", ""), + "latency": metrics.get("latency", ""), + "map50": metrics.get("map50", ""), + "map5095": metrics.get("map5095", ""), + "recommended": "★" if r.get("recommended") else "", + } + ) + table = format_table( + table_rows, + columns=["url", "type", "hardware", "latency", "map50", "map5095", "recommended"], + headers=["URL", "TYPE", "HARDWARE", "LATENCY", "MAP50", "MAP5095", "REC"], + ) + output(args, rows, text=table) + return + api_key = args.api_key or None try: @@ -130,6 +216,56 @@ def _list_models(args): # noqa: ANN001 output(args, models, text=table) +def _star_model(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + from roboflow.config import load_roboflow_api_key + + workspace_url = args.workspace + if not workspace_url: + # Need a workspace; fall back to whatever the api key points at. + from roboflow.cli._resolver import resolve_default_workspace + + workspace_url = resolve_default_workspace(args.api_key) + if not workspace_url: + output_error( + args, + "Could not determine workspace.", + hint="Pass -w/--workspace or run 'roboflow auth set-workspace '.", + exit_code=2, + ) + return + + api_key = args.api_key or load_roboflow_api_key(workspace_url) + if not api_key: + output_error( + args, + "No API key found.", + hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", + exit_code=2, + ) + return + + try: + result = rfapi.favorite_nas_model(api_key, workspace_url, args.model_id, starred=args.starred) + except rfapi.RoboflowError as exc: + msg = str(exc) + hint = None + if "MODEL_NOT_NAS" in msg or "non-NAS" in msg: + hint = "Star is NAS-only. Use 'roboflow train results' to find NAS model ids (models[].modelId)." + elif "MODEL_NOT_IN_WORKSPACE" in msg: + hint = "Verify the model id and workspace; ids are Firestore doc ids, not URL slugs." + output_error(args, msg, hint=hint, exit_code=3) + return + + verb = "starred" if args.starred else "unstarred" + output( + args, + result, + text=f"Model {args.model_id} {verb}.", + ) + + def _get_model(args): # noqa: ANN001 import json diff --git a/roboflow/cli/handlers/train.py b/roboflow/cli/handlers/train.py index 6c66c9c3..f5cbd449 100644 --- a/roboflow/cli/handlers/train.py +++ b/roboflow/cli/handlers/train.py @@ -76,6 +76,75 @@ def start_training( _start(args) +@train_app.command("cancel") +def cancel_training( + ctx: typer.Context, + target: Annotated[ + str, + typer.Argument( + help="Training to cancel as 'project/version' (e.g. 'my-project/3' or 'workspace/my-project/3')" + ), + ], + continue_if_no_refund: Annotated[ + bool, + typer.Option( + "--continue-if-no-refund", + help=( + "Cancel even if the run is past the refund window. " + "Default: false (server replies refund:false without cancelling)." + ), + ), + ] = False, +) -> None: + """Cancel an in-flight training run. + + Works for any architecture, including NAS sweeps in the mining or + training phase. Server-side gate: only valid while the run is in-flight; + a finished/failed run returns 409 CANNOT_CANCEL. + """ + args = ctx_to_args(ctx, target=target, continue_if_no_refund=continue_if_no_refund) + _cancel(args) + + +@train_app.command("stop") +def stop_training( + ctx: typer.Context, + target: Annotated[ + str, + typer.Argument(help="Training to stop as 'project/version'"), + ], +) -> None: + """Request a graceful early-stop on an in-flight training run. + + Distinct from cancel: the run finishes the current phase (mining or + training) instead of terminating immediately. Idempotent — calling + stop on an already-stopped run is a no-op. + """ + args = ctx_to_args(ctx, target=target) + _stop(args) + + +@train_app.command("results") +def training_results( + ctx: typer.Context, + target: Annotated[ + str, + typer.Argument(help="Training to inspect as 'project/version'"), + ], +) -> None: + """Run-level training results bundle. + + For NAS sweeps returns { trainingId, status, modelGroup, modelCount, + recommendedByHardware, mining?, models: [...] }. For non-NAS trainings + returns a minimal bundle with the produced model. + + Pass the returned `modelGroup` to `roboflow model list --group ...` to + list every NAS model from that run with full metadata. + """ + args = ctx_to_args(ctx, target=target) + _results(args) + + # --------------------------------------------------------------------------- # Business logic (unchanged from argparse version) # --------------------------------------------------------------------------- @@ -202,3 +271,121 @@ def _ensure_export(args, api_key, workspace_url, project_slug, version_str, mode return except rfapi.RoboflowError: pass + + +def _resolve_train_target(args): + """Parse '/' (or full 'workspace//') and resolve api key. + + Returns (api_key, workspace_url, project_slug, version_str) or None if validation fails. + """ + from roboflow.cli._output import output_error + from roboflow.cli._resolver import resolve_resource + from roboflow.config import load_roboflow_api_key + + try: + workspace_url, project_slug, version = resolve_resource(args.target, workspace_override=args.workspace) + except ValueError as exc: + output_error(args, str(exc)) + return None + if version is None: + output_error( + args, + "Version is required.", + hint="Pass it as 'project/version' or 'workspace/project/version'.", + ) + return None + api_key = args.api_key or load_roboflow_api_key(workspace_url) + if not api_key: + output_error( + args, + "No API key found.", + hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", + exit_code=2, + ) + return None + return api_key, workspace_url, project_slug, str(version) + + +def _cancel(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + + resolved = _resolve_train_target(args) + if resolved is None: + return + api_key, workspace_url, project_slug, version_str = resolved + + try: + result = rfapi.cancel_version_training( + api_key, + workspace_url, + project_slug, + version_str, + continue_if_no_refund=getattr(args, "continue_if_no_refund", False), + ) + except rfapi.RoboflowError as exc: + msg = str(exc) + # 409 from server lands here as a RoboflowError carrying the JSON + # body; surface it with code "CANNOT_CANCEL" if present. + hint = None + if "non-running" in msg or "Cannot cancel" in msg: + hint = ( + "Cancel only applies to in-flight runs. Check status with 'roboflow train results /'." + ) + output_error(args, msg, hint=hint, exit_code=3) + return + + output( + args, + {"status": "cancelled", "project": project_slug, "version": version_str, **(result or {})}, + text=f"Training cancelled for {project_slug} version {version_str}.", + ) + + +def _stop(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + + resolved = _resolve_train_target(args) + if resolved is None: + return + api_key, workspace_url, project_slug, version_str = resolved + + try: + result = rfapi.stop_version_training(api_key, workspace_url, project_slug, version_str) + except rfapi.RoboflowError as exc: + output_error(args, str(exc), exit_code=3) + return + + output( + args, + {"status": "stop_requested", "project": project_slug, "version": version_str, **(result or {})}, + text=f"Early-stop requested for {project_slug} version {version_str}.", + ) + + +def _results(args): # noqa: ANN001 + + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + + resolved = _resolve_train_target(args) + if resolved is None: + return + api_key, workspace_url, project_slug, version_str = resolved + + try: + result = rfapi.get_training_results(api_key, workspace_url, project_slug, version_str) + except rfapi.RoboflowError as exc: + output_error(args, str(exc), exit_code=3) + return + + job_type = result.get("jobType", "unknown") + model_count = result.get("modelCount", 0) + model_group = result.get("modelGroup") + text_summary = ( + f"{job_type} run for {project_slug} v{version_str}: status={result.get('status')}, models={model_count}" + ) + if model_group: + text_summary += f", group={model_group}" + output(args, result, text=text_summary) diff --git a/tests/cli/test_model_handler.py b/tests/cli/test_model_handler.py index 0bbe9a7d..2511cded 100644 --- a/tests/cli/test_model_handler.py +++ b/tests/cli/test_model_handler.py @@ -242,5 +242,134 @@ def test_list_models_project_not_found(self, mock_rf_cls: MagicMock) -> None: self.assertIn("error", result) +class TestModelStarRegister(unittest.TestCase): + """model star subcommand registers.""" + + def test_star_help(self) -> None: + result = runner.invoke(app, ["model", "star", "--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("nas", result.output.lower()) + + def test_list_help_mentions_group(self) -> None: + result = runner.invoke(app, ["model", "list", "--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("group", result.output.lower()) + + +class TestModelStar(unittest.TestCase): + """_star_model business logic.""" + + def _args(self, **kwargs: object) -> types.SimpleNamespace: + defaults = { + "json": True, + "api_key": "test-key", + "workspace": "test-ws", + "model_id": "abc-firestore-id", + "starred": True, + "quiet": True, + } + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + def _capture_stdout(self, fn, args): + buf = io.StringIO() + old = sys.stdout + sys.stdout = buf + try: + fn(args) + finally: + sys.stdout = old + return buf.getvalue() + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_success(self, mock_fav: MagicMock) -> None: + from roboflow.cli.handlers.model import _star_model + + mock_fav.return_value = {"success": True, "model": {"id": "abc-firestore-id"}} + out = self._capture_stdout(_star_model, self._args()) + + mock_fav.assert_called_once_with("test-key", "test-ws", "abc-firestore-id", starred=True) + result = json.loads(out) + self.assertTrue(result.get("success")) + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_unstar_path(self, mock_fav: MagicMock) -> None: + from roboflow.cli.handlers.model import _star_model + + mock_fav.return_value = {"success": True, "model": {"id": "abc-firestore-id"}} + self._capture_stdout(_star_model, self._args(starred=False)) + + mock_fav.assert_called_once_with("test-key", "test-ws", "abc-firestore-id", starred=False) + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_non_nas_surfaces_hint(self, mock_fav: MagicMock) -> None: + from roboflow.adapters import rfapi + from roboflow.cli.handlers.model import _star_model + + mock_fav.side_effect = rfapi.RoboflowError( + '{"code":"MODEL_NOT_NAS","message":"Starring is only supported for NAS-trained models."}' + ) + buf = io.StringIO() + old = sys.stderr + sys.stderr = buf + try: + with self.assertRaises(SystemExit) as cm: + _star_model(self._args()) + finally: + sys.stderr = old + self.assertEqual(cm.exception.code, 3) + err = json.loads(buf.getvalue()) + # output_error parses the JSON body; the code surfaces alongside the message. + self.assertEqual(err["error"].get("code"), "MODEL_NOT_NAS") + self.assertIn("NAS-only", err["error"].get("hint", "")) + + +class TestModelListGroupFilter(unittest.TestCase): + """_list_models with --group hits the public /models endpoint.""" + + def _args(self, **kwargs: object) -> types.SimpleNamespace: + defaults = { + "json": True, + "api_key": "test-key", + "workspace": "test-ws", + "project": "my-project", + "group": None, + "quiet": True, + } + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + @patch("roboflow.adapters.rfapi.list_project_models") + def test_list_with_group_uses_public_endpoint(self, mock_list: MagicMock) -> None: + from roboflow.cli.handlers.model import _list_models + + mock_list.return_value = [ + { + "url": "my-ws/my-proj-3-nas-gpu-abc", + "modelType": "rfdetr-nas", + "metrics": { + "map50": 87.3, + "map5095": 57.6, + "hardware": "gpu", + "latency": 8.7, + }, + "recommended": True, + } + ] + + buf = io.StringIO() + old = sys.stdout + sys.stdout = buf + try: + _list_models(self._args(group="rfdetrNasGroup-3")) + finally: + sys.stdout = old + + mock_list.assert_called_once_with("test-key", "test-ws", "my-project", group="rfdetrNasGroup-3") + rows = json.loads(buf.getvalue()) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["metrics"]["hardware"], "gpu") + + if __name__ == "__main__": unittest.main() diff --git a/tests/cli/test_train_handler.py b/tests/cli/test_train_handler.py index 63fef691..3f50f57c 100644 --- a/tests/cli/test_train_handler.py +++ b/tests/cli/test_train_handler.py @@ -143,5 +143,109 @@ def test_start_json_error_not_double_encoded(self, mock_train: MagicMock) -> Non self.assertEqual(result["error"]["message"], "Unsupported request") +class TestTrainSubcommandsRegister(unittest.TestCase): + """train cancel/stop/results subcommands register correctly.""" + + def test_cancel_help(self) -> None: + result = runner.invoke(app, ["train", "cancel", "--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("cancel", result.output.lower()) + + def test_stop_help(self) -> None: + result = runner.invoke(app, ["train", "stop", "--help"]) + self.assertEqual(result.exit_code, 0) + + def test_results_help(self) -> None: + result = runner.invoke(app, ["train", "results", "--help"]) + self.assertEqual(result.exit_code, 0) + + +class TestTrainCancelStopResults(unittest.TestCase): + """_cancel / _stop / _results business logic.""" + + def _args(self, **kwargs: object) -> types.SimpleNamespace: + defaults = { + "json": True, + "api_key": "test-key", + "workspace": "test-ws", + "target": "my-project/3", + "continue_if_no_refund": False, + "quiet": True, + } + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + def _capture_stdout(self, fn, args): + buf = io.StringIO() + old = sys.stdout + sys.stdout = buf + try: + fn(args) + finally: + sys.stdout = old + return buf.getvalue() + + @patch("roboflow.adapters.rfapi.cancel_version_training") + def test_cancel_success(self, mock_cancel: MagicMock) -> None: + from roboflow.cli.handlers.train import _cancel + + mock_cancel.return_value = {"refund": True} + out = self._capture_stdout(_cancel, self._args()) + + mock_cancel.assert_called_once_with("test-key", "test-ws", "my-project", "3", continue_if_no_refund=False) + result = json.loads(out) + self.assertEqual(result["status"], "cancelled") + self.assertEqual(result["project"], "my-project") + self.assertEqual(result["version"], "3") + self.assertTrue(result.get("refund")) + + @patch("roboflow.adapters.rfapi.cancel_version_training") + def test_cancel_409_surfaces_hint(self, mock_cancel: MagicMock) -> None: + from roboflow.adapters import rfapi + from roboflow.cli.handlers.train import _cancel + + mock_cancel.side_effect = rfapi.RoboflowError("Cannot cancel non-running train job.") + buf = io.StringIO() + old = sys.stderr + sys.stderr = buf + try: + with self.assertRaises(SystemExit) as cm: + _cancel(self._args()) + finally: + sys.stderr = old + self.assertEqual(cm.exception.code, 3) + err = json.loads(buf.getvalue()) + self.assertIn("Cannot cancel", err["error"]["message"]) + self.assertIn("in-flight", err["error"].get("hint", "")) + + @patch("roboflow.adapters.rfapi.stop_version_training") + def test_stop_success(self, mock_stop: MagicMock) -> None: + from roboflow.cli.handlers.train import _stop + + mock_stop.return_value = {"success": True} + out = self._capture_stdout(_stop, self._args()) + + mock_stop.assert_called_once_with("test-key", "test-ws", "my-project", "3") + result = json.loads(out) + self.assertEqual(result["status"], "stop_requested") + + @patch("roboflow.adapters.rfapi.get_training_results") + def test_results_nas_run(self, mock_get: MagicMock) -> None: + from roboflow.cli.handlers.train import _results + + mock_get.return_value = { + "trainingId": "ds/3", + "status": "finished", + "jobType": "nas", + "modelGroup": "rfdetrNasGroup-3", + "modelCount": 5, + "models": [{"modelId": "m1"}], + } + out = self._capture_stdout(_results, self._args()) + result = json.loads(out) + self.assertEqual(result["jobType"], "nas") + self.assertEqual(result["modelCount"], 5) + + if __name__ == "__main__": unittest.main() From 0731c1b9429ca4d67f96bb2f43684649c4cea59e Mon Sep 17 00:00:00 2001 From: Peter Robicheaux Date: Thu, 7 May 2026 10:56:03 -0400 Subject: [PATCH 2/5] nas: 'model star' takes URL slug instead of Firestore doc id MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The public favorite endpoint now accepts the model URL slug (roboflow#11646), so the CLI can drop the Firestore-doc-id wart. Changes: - star_model argument is now `model_url`, accepting either the bare slug (when -w is set / a default workspace exists) or the workspace-prefixed form `/` — same shape as `model get`. - rfapi.favorite_nas_model parameter renamed `model_id` → `model_url` with urllib.parse.quote() for safety, since the slug is now what appears in the path. - Hints updated to point at models[].modelUrl instead of modelId, and the workspace fallback hint mentions the prefix form. Tests: +2 cases for the new parsing (workspace-prefixed URL vs bare slug + -w fallback). 22/22 model handler tests pass; 36/36 across model + train. Co-Authored-By: Claude Opus 4.7 (1M context) --- roboflow/adapters/rfapi.py | 13 +++++++++--- roboflow/cli/handlers/model.py | 34 +++++++++++++++++++++----------- tests/cli/test_model_handler.py | 35 ++++++++++++++++++++++++++++----- 3 files changed, 63 insertions(+), 19 deletions(-) diff --git a/roboflow/adapters/rfapi.py b/roboflow/adapters/rfapi.py index 717eee64..3e3c9a75 100644 --- a/roboflow/adapters/rfapi.py +++ b/roboflow/adapters/rfapi.py @@ -161,9 +161,16 @@ def get_model_by_url(api_key: str, workspace_url: str, model_url: str): return response.json() -def favorite_nas_model(api_key: str, workspace_url: str, model_id: str, *, starred: bool = True): - """Star or unstar a NAS-trained model. NAS-only on the server side.""" - url = f"{API_URL}/{workspace_url}/models/{model_id}/favorite?api_key={api_key}" +def favorite_nas_model(api_key: str, workspace_url: str, model_url: str, *, starred: bool = True): + """Star or unstar a NAS-trained model. + + ``model_url`` is the public model URL slug (e.g. ``my-project-3-nas-gpu-b``), + the same value the public API returns as ``models[].modelUrl`` on + ``GET /:workspace/:project/:version/training/results``. NAS-only on the + server side. + """ + encoded = urllib.parse.quote(model_url, safe="") + url = f"{API_URL}/{workspace_url}/models/{encoded}/favorite?api_key={api_key}" response = requests.post(url, json={"starred": bool(starred)}) if not response.ok: raise RoboflowError(response.text) diff --git a/roboflow/cli/handlers/model.py b/roboflow/cli/handlers/model.py index e2fedfd3..0b9fa414 100644 --- a/roboflow/cli/handlers/model.py +++ b/roboflow/cli/handlers/model.py @@ -48,12 +48,12 @@ def get_model( @model_app.command("star") def star_model( ctx: typer.Context, - model_id: Annotated[ + model_url: Annotated[ str, typer.Argument( help=( - "NAS-trained model id (Firestore document id). " - "Get it from 'roboflow train results /' (models[].modelId)." + "Model URL (e.g. workspace/model-slug, or just the slug if -w is set). " + "Get it from 'roboflow train results /' (models[].modelUrl)." ), ), ], @@ -65,7 +65,7 @@ def star_model( MODEL_NOT_NAS error. Starring triggers TRT compilation for the model's recommended hardware so the model becomes deployable as an edge target. """ - args = ctx_to_args(ctx, model_id=model_id, starred=not unstar) + args = ctx_to_args(ctx, model_url=model_url, starred=not unstar) _star_model(args) @@ -221,9 +221,16 @@ def _star_model(args): # noqa: ANN001 from roboflow.cli._output import output, output_error from roboflow.config import load_roboflow_api_key - workspace_url = args.workspace + # Accept either "workspace/model-slug" or just "model-slug" (when -w is + # set). Mirrors the parsing pattern used by `roboflow model get`. + raw = args.model_url.strip("/") + if "/" in raw: + ws_from_arg, _sep, model_slug = raw.partition("/") + else: + ws_from_arg, model_slug = None, raw + + workspace_url = args.workspace or ws_from_arg if not workspace_url: - # Need a workspace; fall back to whatever the api key points at. from roboflow.cli._resolver import resolve_default_workspace workspace_url = resolve_default_workspace(args.api_key) @@ -231,7 +238,9 @@ def _star_model(args): # noqa: ANN001 output_error( args, "Could not determine workspace.", - hint="Pass -w/--workspace or run 'roboflow auth set-workspace '.", + hint=( + "Pass -w/--workspace, prefix the model URL (workspace/slug), or run 'roboflow auth set-workspace '." + ), exit_code=2, ) return @@ -247,14 +256,17 @@ def _star_model(args): # noqa: ANN001 return try: - result = rfapi.favorite_nas_model(api_key, workspace_url, args.model_id, starred=args.starred) + result = rfapi.favorite_nas_model(api_key, workspace_url, model_slug, starred=args.starred) except rfapi.RoboflowError as exc: msg = str(exc) hint = None if "MODEL_NOT_NAS" in msg or "non-NAS" in msg: - hint = "Star is NAS-only. Use 'roboflow train results' to find NAS model ids (models[].modelId)." + hint = "Star is NAS-only. Use 'roboflow train results' to find NAS model URLs (models[].modelUrl)." elif "MODEL_NOT_IN_WORKSPACE" in msg: - hint = "Verify the model id and workspace; ids are Firestore doc ids, not URL slugs." + hint = ( + "Verify the model URL and workspace. The slug is the same value " + "'roboflow train results' returns as models[].modelUrl." + ) output_error(args, msg, hint=hint, exit_code=3) return @@ -262,7 +274,7 @@ def _star_model(args): # noqa: ANN001 output( args, result, - text=f"Model {args.model_id} {verb}.", + text=f"Model {workspace_url}/{model_slug} {verb}.", ) diff --git a/tests/cli/test_model_handler.py b/tests/cli/test_model_handler.py index 2511cded..58e65e2f 100644 --- a/tests/cli/test_model_handler.py +++ b/tests/cli/test_model_handler.py @@ -264,7 +264,7 @@ def _args(self, **kwargs: object) -> types.SimpleNamespace: "json": True, "api_key": "test-key", "workspace": "test-ws", - "model_id": "abc-firestore-id", + "model_url": "my-proj-3-nas-gpu-b", "starred": True, "quiet": True, } @@ -285,21 +285,46 @@ def _capture_stdout(self, fn, args): def test_star_success(self, mock_fav: MagicMock) -> None: from roboflow.cli.handlers.model import _star_model - mock_fav.return_value = {"success": True, "model": {"id": "abc-firestore-id"}} + mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} out = self._capture_stdout(_star_model, self._args()) - mock_fav.assert_called_once_with("test-key", "test-ws", "abc-firestore-id", starred=True) + # Bare slug + -w, so workspace comes from args.workspace. + mock_fav.assert_called_once_with("test-key", "test-ws", "my-proj-3-nas-gpu-b", starred=True) result = json.loads(out) self.assertTrue(result.get("success")) + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_workspace_prefixed_url(self, mock_fav: MagicMock) -> None: + """When the URL is `/`, the workspace flag overrides anyway.""" + from roboflow.cli.handlers.model import _star_model + + mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} + self._capture_stdout(_star_model, self._args(model_url="some-ws/my-proj-3-nas-gpu-b")) + + # -w wins over the prefix, slug is stripped of the workspace segment. + mock_fav.assert_called_once_with("test-key", "test-ws", "my-proj-3-nas-gpu-b", starred=True) + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_workspace_inferred_from_prefix(self, mock_fav: MagicMock) -> None: + """No -w but `/` argument: workspace comes from the prefix.""" + from roboflow.cli.handlers.model import _star_model + + mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} + self._capture_stdout( + _star_model, + self._args(workspace=None, model_url="some-ws/my-proj-3-nas-gpu-b"), + ) + + mock_fav.assert_called_once_with("test-key", "some-ws", "my-proj-3-nas-gpu-b", starred=True) + @patch("roboflow.adapters.rfapi.favorite_nas_model") def test_star_unstar_path(self, mock_fav: MagicMock) -> None: from roboflow.cli.handlers.model import _star_model - mock_fav.return_value = {"success": True, "model": {"id": "abc-firestore-id"}} + mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} self._capture_stdout(_star_model, self._args(starred=False)) - mock_fav.assert_called_once_with("test-key", "test-ws", "abc-firestore-id", starred=False) + mock_fav.assert_called_once_with("test-key", "test-ws", "my-proj-3-nas-gpu-b", starred=False) @patch("roboflow.adapters.rfapi.favorite_nas_model") def test_star_non_nas_surfaces_hint(self, mock_fav: MagicMock) -> None: From d9431b9cd5b5242e831aa48bf3ba65c05d954021 Mon Sep 17 00:00:00 2001 From: Peter Robicheaux Date: Thu, 7 May 2026 11:45:12 -0400 Subject: [PATCH 3/5] test: train_results fixture uses URL slugs, not doc ids MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors the backend cleanup in roboflow#11646. The training-results fixture now uses the public shape (trainingId is workspace/project/ version, models[].modelUrl, recommendedByHardware values are URL slugs). No behavior change in the CLI handler — it passes the response through. --- tests/cli/test_train_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cli/test_train_handler.py b/tests/cli/test_train_handler.py index 3f50f57c..7aafa30e 100644 --- a/tests/cli/test_train_handler.py +++ b/tests/cli/test_train_handler.py @@ -234,12 +234,13 @@ def test_results_nas_run(self, mock_get: MagicMock) -> None: from roboflow.cli.handlers.train import _results mock_get.return_value = { - "trainingId": "ds/3", + "trainingId": "test-ws/my-project/3", "status": "finished", "jobType": "nas", "modelGroup": "rfdetrNasGroup-3", "modelCount": 5, - "models": [{"modelId": "m1"}], + "recommendedByHardware": {"gpu": "my-project-3-nas-gpu-a"}, + "models": [{"modelUrl": "my-project-3-nas-gpu-a"}], } out = self._capture_stdout(_results, self._args()) result = json.loads(out) From 7e8784eb1a3d3af961652da623161a55a1ab4256 Mon Sep 17 00:00:00 2001 From: Peter Robicheaux Date: Thu, 7 May 2026 12:20:41 -0400 Subject: [PATCH 4/5] =?UTF-8?q?nas:=20'model=20star'=20arg=20+=20favorite?= =?UTF-8?q?=5Fnas=5Fmodel=20param:=20model=5Furl=20=E2=86=92=20model=5Fid?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors the wire rename in roboflow#11646. The public API field for the opaque model identifier is now `modelId` (the value is still the URL slug; that's an implementation detail callers shouldn't have to reason about). Changes: - `roboflow model star` argument: `model_url` → `model_id`. Help text and error hints updated to point at `models[].modelId`. - `rfapi.favorite_nas_model(model_url=...)` → `favorite_nas_model( model_id=...)`. Internal local var becomes `public_model_id` to keep the call-site readable. - Test fixtures: `model_url` arg → `model_id`, `models[].modelUrl` → `models[].modelId`. 36/36 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- roboflow/adapters/rfapi.py | 8 ++++---- roboflow/cli/handlers/model.py | 30 ++++++++++++++++-------------- tests/cli/test_model_handler.py | 14 +++++++------- tests/cli/test_train_handler.py | 2 +- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/roboflow/adapters/rfapi.py b/roboflow/adapters/rfapi.py index 3e3c9a75..1a38697e 100644 --- a/roboflow/adapters/rfapi.py +++ b/roboflow/adapters/rfapi.py @@ -161,15 +161,15 @@ def get_model_by_url(api_key: str, workspace_url: str, model_url: str): return response.json() -def favorite_nas_model(api_key: str, workspace_url: str, model_url: str, *, starred: bool = True): +def favorite_nas_model(api_key: str, workspace_url: str, model_id: str, *, starred: bool = True): """Star or unstar a NAS-trained model. - ``model_url`` is the public model URL slug (e.g. ``my-project-3-nas-gpu-b``), - the same value the public API returns as ``models[].modelUrl`` on + ``model_id`` is the opaque public model id (e.g. ``my-project-3-nas-gpu-b``), + the same value the public API returns as ``models[].modelId`` on ``GET /:workspace/:project/:version/training/results``. NAS-only on the server side. """ - encoded = urllib.parse.quote(model_url, safe="") + encoded = urllib.parse.quote(model_id, safe="") url = f"{API_URL}/{workspace_url}/models/{encoded}/favorite?api_key={api_key}" response = requests.post(url, json={"starred": bool(starred)}) if not response.ok: diff --git a/roboflow/cli/handlers/model.py b/roboflow/cli/handlers/model.py index 0b9fa414..8dd4d439 100644 --- a/roboflow/cli/handlers/model.py +++ b/roboflow/cli/handlers/model.py @@ -48,12 +48,12 @@ def get_model( @model_app.command("star") def star_model( ctx: typer.Context, - model_url: Annotated[ + model_id: Annotated[ str, typer.Argument( help=( - "Model URL (e.g. workspace/model-slug, or just the slug if -w is set). " - "Get it from 'roboflow train results /' (models[].modelUrl)." + "Model id (e.g. workspace/model-id, or just the bare id if -w is set). " + "Get it from 'roboflow train results /' (models[].modelId)." ), ), ], @@ -65,7 +65,7 @@ def star_model( MODEL_NOT_NAS error. Starring triggers TRT compilation for the model's recommended hardware so the model becomes deployable as an edge target. """ - args = ctx_to_args(ctx, model_url=model_url, starred=not unstar) + args = ctx_to_args(ctx, model_id=model_id, starred=not unstar) _star_model(args) @@ -221,13 +221,13 @@ def _star_model(args): # noqa: ANN001 from roboflow.cli._output import output, output_error from roboflow.config import load_roboflow_api_key - # Accept either "workspace/model-slug" or just "model-slug" (when -w is + # Accept either "workspace/model-id" or just "model-id" (when -w is # set). Mirrors the parsing pattern used by `roboflow model get`. - raw = args.model_url.strip("/") + raw = args.model_id.strip("/") if "/" in raw: - ws_from_arg, _sep, model_slug = raw.partition("/") + ws_from_arg, _sep, public_model_id = raw.partition("/") else: - ws_from_arg, model_slug = None, raw + ws_from_arg, public_model_id = None, raw workspace_url = args.workspace or ws_from_arg if not workspace_url: @@ -239,7 +239,7 @@ def _star_model(args): # noqa: ANN001 args, "Could not determine workspace.", hint=( - "Pass -w/--workspace, prefix the model URL (workspace/slug), or run 'roboflow auth set-workspace '." + "Pass -w/--workspace, prefix the model id (workspace/id), or run 'roboflow auth set-workspace '." ), exit_code=2, ) @@ -256,16 +256,18 @@ def _star_model(args): # noqa: ANN001 return try: - result = rfapi.favorite_nas_model(api_key, workspace_url, model_slug, starred=args.starred) + result = rfapi.favorite_nas_model( + api_key, workspace_url, public_model_id, starred=args.starred + ) except rfapi.RoboflowError as exc: msg = str(exc) hint = None if "MODEL_NOT_NAS" in msg or "non-NAS" in msg: - hint = "Star is NAS-only. Use 'roboflow train results' to find NAS model URLs (models[].modelUrl)." + hint = "Star is NAS-only. Use 'roboflow train results' to find NAS model ids (models[].modelId)." elif "MODEL_NOT_IN_WORKSPACE" in msg: hint = ( - "Verify the model URL and workspace. The slug is the same value " - "'roboflow train results' returns as models[].modelUrl." + "Verify the model id and workspace. The id is the same value " + "'roboflow train results' returns as models[].modelId." ) output_error(args, msg, hint=hint, exit_code=3) return @@ -274,7 +276,7 @@ def _star_model(args): # noqa: ANN001 output( args, result, - text=f"Model {workspace_url}/{model_slug} {verb}.", + text=f"Model {workspace_url}/{public_model_id} {verb}.", ) diff --git a/tests/cli/test_model_handler.py b/tests/cli/test_model_handler.py index 58e65e2f..f4ca2f2d 100644 --- a/tests/cli/test_model_handler.py +++ b/tests/cli/test_model_handler.py @@ -264,7 +264,7 @@ def _args(self, **kwargs: object) -> types.SimpleNamespace: "json": True, "api_key": "test-key", "workspace": "test-ws", - "model_url": "my-proj-3-nas-gpu-b", + "model_id": "my-proj-3-nas-gpu-b", "starred": True, "quiet": True, } @@ -294,25 +294,25 @@ def test_star_success(self, mock_fav: MagicMock) -> None: self.assertTrue(result.get("success")) @patch("roboflow.adapters.rfapi.favorite_nas_model") - def test_star_workspace_prefixed_url(self, mock_fav: MagicMock) -> None: - """When the URL is `/`, the workspace flag overrides anyway.""" + def test_star_workspace_prefixed_id(self, mock_fav: MagicMock) -> None: + """When the id is `/`, the workspace flag overrides anyway.""" from roboflow.cli.handlers.model import _star_model mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} - self._capture_stdout(_star_model, self._args(model_url="some-ws/my-proj-3-nas-gpu-b")) + self._capture_stdout(_star_model, self._args(model_id="some-ws/my-proj-3-nas-gpu-b")) - # -w wins over the prefix, slug is stripped of the workspace segment. + # -w wins over the prefix, id is stripped of the workspace segment. mock_fav.assert_called_once_with("test-key", "test-ws", "my-proj-3-nas-gpu-b", starred=True) @patch("roboflow.adapters.rfapi.favorite_nas_model") def test_star_workspace_inferred_from_prefix(self, mock_fav: MagicMock) -> None: - """No -w but `/` argument: workspace comes from the prefix.""" + """No -w but `/` argument: workspace comes from the prefix.""" from roboflow.cli.handlers.model import _star_model mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} self._capture_stdout( _star_model, - self._args(workspace=None, model_url="some-ws/my-proj-3-nas-gpu-b"), + self._args(workspace=None, model_id="some-ws/my-proj-3-nas-gpu-b"), ) mock_fav.assert_called_once_with("test-key", "some-ws", "my-proj-3-nas-gpu-b", starred=True) diff --git a/tests/cli/test_train_handler.py b/tests/cli/test_train_handler.py index 7aafa30e..7d826e6c 100644 --- a/tests/cli/test_train_handler.py +++ b/tests/cli/test_train_handler.py @@ -240,7 +240,7 @@ def test_results_nas_run(self, mock_get: MagicMock) -> None: "modelGroup": "rfdetrNasGroup-3", "modelCount": 5, "recommendedByHardware": {"gpu": "my-project-3-nas-gpu-a"}, - "models": [{"modelUrl": "my-project-3-nas-gpu-a"}], + "models": [{"modelId": "my-project-3-nas-gpu-a"}], } out = self._capture_stdout(_results, self._args()) result = json.loads(out) From 7169478915580d3b04f2cb334be463e1ae927bbe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 May 2026 16:20:52 +0000 Subject: [PATCH 5/5] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboflow/cli/handlers/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/roboflow/cli/handlers/model.py b/roboflow/cli/handlers/model.py index 8dd4d439..033d52dd 100644 --- a/roboflow/cli/handlers/model.py +++ b/roboflow/cli/handlers/model.py @@ -256,9 +256,7 @@ def _star_model(args): # noqa: ANN001 return try: - result = rfapi.favorite_nas_model( - api_key, workspace_url, public_model_id, starred=args.starred - ) + result = rfapi.favorite_nas_model(api_key, workspace_url, public_model_id, starred=args.starred) except rfapi.RoboflowError as exc: msg = str(exc) hint = None