diff --git a/CLI-COMMANDS.md b/CLI-COMMANDS.md index 34539368..aa961485 100644 --- a/CLI-COMMANDS.md +++ b/CLI-COMMANDS.md @@ -47,6 +47,50 @@ roboflow download my-workspace/my-project/3 -f coco # alias roboflow infer photo.jpg -m my-project/3 ``` +### Train, monitor, cancel, stop + +```bash +# Start training (any architecture). For NAS sweeps, use a NAS parent modelType: +roboflow train start -p my-project -v 3 --type rfdetr-base +roboflow train start -p my-project -v 3 --type rfdetr-nas-parent # NAS sweep +roboflow train start -p my-project -v 3 --type rfdetr-nas-base-parent # NAS Base sweep +roboflow train start -p my-project -v 3 --type rfdetr-nas-seg-parent # NAS instance-segmentation + +# Cancel an in-flight training (any architecture; NAS-aware): +roboflow train cancel my-project/3 +# Pass --continue-if-no-refund to cancel even past the refund window: +roboflow train cancel my-project/3 --continue-if-no-refund + +# Graceful early-stop: +roboflow train stop my-project/3 + +# Run-level training results bundle (NAS leaderboard for NAS runs, +# minimal bundle for non-NAS): +roboflow train results my-project/3 +``` + +NAS sweeps require the version's validation split to have at least 15 images; +the server returns `code: "insufficient_validation_images_for_nas"` otherwise. + +### NAS models — list, star, deploy + +```bash +# Get a NAS run's modelGroup from training results: +roboflow --json train results my-project/3 | jq -r .modelGroup +# → rfdetrNasGroup-3 + +# List every model from one NAS run, with hardware/latency/mAP columns: +roboflow model list -p my-project --group rfdetrNasGroup-3 + +# Star a NAS-trained model (triggers TRT compile for its recommended hardware): +# --json train results … gives you the modelId per row. +roboflow model star +roboflow model star --unstar +``` + +`model star` is NAS-only by server-side design; non-NAS modelTypes return +`code: "MODEL_NOT_NAS"`. + ### Search and export ```bash diff --git a/roboflow/adapters/rfapi.py b/roboflow/adapters/rfapi.py index c0b00f39..1a38697e 100644 --- a/roboflow/adapters/rfapi.py +++ b/roboflow/adapters/rfapi.py @@ -84,6 +84,99 @@ def start_version_training( return True +def cancel_version_training( + api_key: str, + workspace_url: str, + project_url: str, + version: str, + *, + continue_if_no_refund: bool = False, +): + """Cancel an in-flight training run. + + Backend handler is canonical for both vanilla and NAS trainings — it + accepts ``mining`` status, so this works for NAS sweeps too. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train/cancel?api_key={api_key}" + body: Dict[str, Union[str, int, bool]] = {} + if continue_if_no_refund: + body["continueIfNoRefund"] = True + response = requests.post(url, json=body) + if not response.ok: + raise RoboflowError(response.text) + return response.json() if response.content else {"success": True} + + +def stop_version_training(api_key: str, workspace_url: str, project_url: str, version: str): + """Request an early stop on an in-flight training run. + + The backend flips ``train.requestedStop``; the run finishes the current + phase gracefully (mining or training). + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train/stop?api_key={api_key}" + response = requests.post(url, json={}) + if not response.ok: + raise RoboflowError(response.text) + return response.json() if response.content else {"success": True} + + +def get_training_results(api_key: str, workspace_url: str, project_url: str, version: str): + """Run-level training results bundle. + + For NAS runs returns ``{ trainingId, status, modelGroup, modelCount, + recommendedByHardware, mining?, models: [...] }``. For non-NAS runs + returns a minimal bundle with the produced model(s). + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/training/results?api_key={api_key}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + +def list_project_models( + api_key: str, + workspace_url: str, + project_url: str, + *, + group: Optional[str] = None, +): + """List models for a project; pass ``group`` to scope to one NAS run.""" + url = f"{API_URL}/{workspace_url}/{project_url}/models?api_key={api_key}" + if group: + url += f"&group={urllib.parse.quote(group, safe='')}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + +def get_model_by_url(api_key: str, workspace_url: str, model_url: str): + """Fetch a single model by its URL slug.""" + encoded = urllib.parse.quote(model_url, safe="/") + url = f"{API_URL}/models/{workspace_url}/{encoded}?api_key={api_key}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + +def favorite_nas_model(api_key: str, workspace_url: str, model_id: str, *, starred: bool = True): + """Star or unstar a NAS-trained model. + + ``model_id`` is the opaque public model id (e.g. ``my-project-3-nas-gpu-b``), + the same value the public API returns as ``models[].modelId`` on + ``GET /:workspace/:project/:version/training/results``. NAS-only on the + server side. + """ + encoded = urllib.parse.quote(model_id, safe="") + url = f"{API_URL}/{workspace_url}/models/{encoded}/favorite?api_key={api_key}" + response = requests.post(url, json={"starred": bool(starred)}) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + def get_version(api_key: str, workspace_url: str, project_url: str, version: str, nocache: bool = False): """ Fetch detailed information about a specific dataset version. diff --git a/roboflow/cli/handlers/model.py b/roboflow/cli/handlers/model.py index 9cf04617..033d52dd 100644 --- a/roboflow/cli/handlers/model.py +++ b/roboflow/cli/handlers/model.py @@ -15,9 +15,23 @@ def list_models( ctx: typer.Context, project: Annotated[str, typer.Option("-p", "--project", help="Project ID or shorthand (e.g. my-ws/my-project)")], + group: Annotated[ + Optional[str], + typer.Option( + "-g", + "--group", + help=( + "NAS modelGroup to scope the list to a single NAS run. " + "Get the value from 'roboflow train results /'." + ), + ), + ] = None, ) -> None: - """List trained models for a project.""" - args = ctx_to_args(ctx, project=project) + """List trained models for a project. + + Pass --group to filter to a single NAS run. + """ + args = ctx_to_args(ctx, project=project, group=group) _list_models(args) @@ -31,6 +45,30 @@ def get_model( _get_model(args) +@model_app.command("star") +def star_model( + ctx: typer.Context, + model_id: Annotated[ + str, + typer.Argument( + help=( + "Model id (e.g. workspace/model-id, or just the bare id if -w is set). " + "Get it from 'roboflow train results /' (models[].modelId)." + ), + ), + ], + unstar: Annotated[bool, typer.Option("--unstar", help="Unstar instead of starring")] = False, +) -> None: + """Star or unstar a NAS-trained model. + + NAS-only by design — the server rejects non-NAS modelTypes with a + MODEL_NOT_NAS error. Starring triggers TRT compilation for the model's + recommended hardware so the model becomes deployable as an edge target. + """ + args = ctx_to_args(ctx, model_id=model_id, starred=not unstar) + _star_model(args) + + @model_app.command("infer") def model_infer( ctx: typer.Context, @@ -86,9 +124,11 @@ def upload_model( def _list_models(args): # noqa: ANN001 import roboflow + from roboflow.adapters import rfapi from roboflow.cli._output import output, output_error, suppress_sdk_output from roboflow.cli._resolver import resolve_resource from roboflow.cli._table import format_table + from roboflow.config import load_roboflow_api_key try: workspace_url, project_slug, _version = resolve_resource(args.project, workspace_override=args.workspace) @@ -96,6 +136,52 @@ def _list_models(args): # noqa: ANN001 output_error(args, str(exc)) return + group = getattr(args, "group", None) + + if group: + # NAS path — hit the public /models endpoint with ?group= filter. + # Surfaces full per-row NAS metadata (nasFamily, group, + # train.results.{hardware,latency,map5095,paretoOptimalFor}, + # favorites, recommended). + api_key = args.api_key or load_roboflow_api_key(workspace_url) + if not api_key: + output_error( + args, + "No API key found.", + hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", + exit_code=2, + ) + return + try: + rows = rfapi.list_project_models(api_key, workspace_url, project_slug, group=group) + except rfapi.RoboflowError as exc: + output_error(args, str(exc), exit_code=3) + return + if not isinstance(rows, list): + rows = [] + # Project a leaderboard view for the text table; full row stays in JSON. + table_rows = [] + for r in rows: + metrics = r.get("metrics") or {} + table_rows.append( + { + "url": r.get("url", ""), + "type": r.get("modelType", ""), + "hardware": metrics.get("hardware", ""), + "latency": metrics.get("latency", ""), + "map50": metrics.get("map50", ""), + "map5095": metrics.get("map5095", ""), + "recommended": "★" if r.get("recommended") else "", + } + ) + table = format_table( + table_rows, + columns=["url", "type", "hardware", "latency", "map50", "map5095", "recommended"], + headers=["URL", "TYPE", "HARDWARE", "LATENCY", "MAP50", "MAP5095", "REC"], + ) + output(args, rows, text=table) + return + api_key = args.api_key or None try: @@ -130,6 +216,68 @@ def _list_models(args): # noqa: ANN001 output(args, models, text=table) +def _star_model(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + from roboflow.config import load_roboflow_api_key + + # Accept either "workspace/model-id" or just "model-id" (when -w is + # set). Mirrors the parsing pattern used by `roboflow model get`. + raw = args.model_id.strip("/") + if "/" in raw: + ws_from_arg, _sep, public_model_id = raw.partition("/") + else: + ws_from_arg, public_model_id = None, raw + + workspace_url = args.workspace or ws_from_arg + if not workspace_url: + from roboflow.cli._resolver import resolve_default_workspace + + workspace_url = resolve_default_workspace(args.api_key) + if not workspace_url: + output_error( + args, + "Could not determine workspace.", + hint=( + "Pass -w/--workspace, prefix the model id (workspace/id), or run 'roboflow auth set-workspace '." + ), + exit_code=2, + ) + return + + api_key = args.api_key or load_roboflow_api_key(workspace_url) + if not api_key: + output_error( + args, + "No API key found.", + hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", + exit_code=2, + ) + return + + try: + result = rfapi.favorite_nas_model(api_key, workspace_url, public_model_id, starred=args.starred) + except rfapi.RoboflowError as exc: + msg = str(exc) + hint = None + if "MODEL_NOT_NAS" in msg or "non-NAS" in msg: + hint = "Star is NAS-only. Use 'roboflow train results' to find NAS model ids (models[].modelId)." + elif "MODEL_NOT_IN_WORKSPACE" in msg: + hint = ( + "Verify the model id and workspace. The id is the same value " + "'roboflow train results' returns as models[].modelId." + ) + output_error(args, msg, hint=hint, exit_code=3) + return + + verb = "starred" if args.starred else "unstarred" + output( + args, + result, + text=f"Model {workspace_url}/{public_model_id} {verb}.", + ) + + def _get_model(args): # noqa: ANN001 import json diff --git a/roboflow/cli/handlers/train.py b/roboflow/cli/handlers/train.py index 6c66c9c3..f5cbd449 100644 --- a/roboflow/cli/handlers/train.py +++ b/roboflow/cli/handlers/train.py @@ -76,6 +76,75 @@ def start_training( _start(args) +@train_app.command("cancel") +def cancel_training( + ctx: typer.Context, + target: Annotated[ + str, + typer.Argument( + help="Training to cancel as 'project/version' (e.g. 'my-project/3' or 'workspace/my-project/3')" + ), + ], + continue_if_no_refund: Annotated[ + bool, + typer.Option( + "--continue-if-no-refund", + help=( + "Cancel even if the run is past the refund window. " + "Default: false (server replies refund:false without cancelling)." + ), + ), + ] = False, +) -> None: + """Cancel an in-flight training run. + + Works for any architecture, including NAS sweeps in the mining or + training phase. Server-side gate: only valid while the run is in-flight; + a finished/failed run returns 409 CANNOT_CANCEL. + """ + args = ctx_to_args(ctx, target=target, continue_if_no_refund=continue_if_no_refund) + _cancel(args) + + +@train_app.command("stop") +def stop_training( + ctx: typer.Context, + target: Annotated[ + str, + typer.Argument(help="Training to stop as 'project/version'"), + ], +) -> None: + """Request a graceful early-stop on an in-flight training run. + + Distinct from cancel: the run finishes the current phase (mining or + training) instead of terminating immediately. Idempotent — calling + stop on an already-stopped run is a no-op. + """ + args = ctx_to_args(ctx, target=target) + _stop(args) + + +@train_app.command("results") +def training_results( + ctx: typer.Context, + target: Annotated[ + str, + typer.Argument(help="Training to inspect as 'project/version'"), + ], +) -> None: + """Run-level training results bundle. + + For NAS sweeps returns { trainingId, status, modelGroup, modelCount, + recommendedByHardware, mining?, models: [...] }. For non-NAS trainings + returns a minimal bundle with the produced model. + + Pass the returned `modelGroup` to `roboflow model list --group ...` to + list every NAS model from that run with full metadata. + """ + args = ctx_to_args(ctx, target=target) + _results(args) + + # --------------------------------------------------------------------------- # Business logic (unchanged from argparse version) # --------------------------------------------------------------------------- @@ -202,3 +271,121 @@ def _ensure_export(args, api_key, workspace_url, project_slug, version_str, mode return except rfapi.RoboflowError: pass + + +def _resolve_train_target(args): + """Parse '/' (or full 'workspace//') and resolve api key. + + Returns (api_key, workspace_url, project_slug, version_str) or None if validation fails. + """ + from roboflow.cli._output import output_error + from roboflow.cli._resolver import resolve_resource + from roboflow.config import load_roboflow_api_key + + try: + workspace_url, project_slug, version = resolve_resource(args.target, workspace_override=args.workspace) + except ValueError as exc: + output_error(args, str(exc)) + return None + if version is None: + output_error( + args, + "Version is required.", + hint="Pass it as 'project/version' or 'workspace/project/version'.", + ) + return None + api_key = args.api_key or load_roboflow_api_key(workspace_url) + if not api_key: + output_error( + args, + "No API key found.", + hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", + exit_code=2, + ) + return None + return api_key, workspace_url, project_slug, str(version) + + +def _cancel(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + + resolved = _resolve_train_target(args) + if resolved is None: + return + api_key, workspace_url, project_slug, version_str = resolved + + try: + result = rfapi.cancel_version_training( + api_key, + workspace_url, + project_slug, + version_str, + continue_if_no_refund=getattr(args, "continue_if_no_refund", False), + ) + except rfapi.RoboflowError as exc: + msg = str(exc) + # 409 from server lands here as a RoboflowError carrying the JSON + # body; surface it with code "CANNOT_CANCEL" if present. + hint = None + if "non-running" in msg or "Cannot cancel" in msg: + hint = ( + "Cancel only applies to in-flight runs. Check status with 'roboflow train results /'." + ) + output_error(args, msg, hint=hint, exit_code=3) + return + + output( + args, + {"status": "cancelled", "project": project_slug, "version": version_str, **(result or {})}, + text=f"Training cancelled for {project_slug} version {version_str}.", + ) + + +def _stop(args): # noqa: ANN001 + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + + resolved = _resolve_train_target(args) + if resolved is None: + return + api_key, workspace_url, project_slug, version_str = resolved + + try: + result = rfapi.stop_version_training(api_key, workspace_url, project_slug, version_str) + except rfapi.RoboflowError as exc: + output_error(args, str(exc), exit_code=3) + return + + output( + args, + {"status": "stop_requested", "project": project_slug, "version": version_str, **(result or {})}, + text=f"Early-stop requested for {project_slug} version {version_str}.", + ) + + +def _results(args): # noqa: ANN001 + + from roboflow.adapters import rfapi + from roboflow.cli._output import output, output_error + + resolved = _resolve_train_target(args) + if resolved is None: + return + api_key, workspace_url, project_slug, version_str = resolved + + try: + result = rfapi.get_training_results(api_key, workspace_url, project_slug, version_str) + except rfapi.RoboflowError as exc: + output_error(args, str(exc), exit_code=3) + return + + job_type = result.get("jobType", "unknown") + model_count = result.get("modelCount", 0) + model_group = result.get("modelGroup") + text_summary = ( + f"{job_type} run for {project_slug} v{version_str}: status={result.get('status')}, models={model_count}" + ) + if model_group: + text_summary += f", group={model_group}" + output(args, result, text=text_summary) diff --git a/tests/cli/test_model_handler.py b/tests/cli/test_model_handler.py index 0bbe9a7d..f4ca2f2d 100644 --- a/tests/cli/test_model_handler.py +++ b/tests/cli/test_model_handler.py @@ -242,5 +242,159 @@ def test_list_models_project_not_found(self, mock_rf_cls: MagicMock) -> None: self.assertIn("error", result) +class TestModelStarRegister(unittest.TestCase): + """model star subcommand registers.""" + + def test_star_help(self) -> None: + result = runner.invoke(app, ["model", "star", "--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("nas", result.output.lower()) + + def test_list_help_mentions_group(self) -> None: + result = runner.invoke(app, ["model", "list", "--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("group", result.output.lower()) + + +class TestModelStar(unittest.TestCase): + """_star_model business logic.""" + + def _args(self, **kwargs: object) -> types.SimpleNamespace: + defaults = { + "json": True, + "api_key": "test-key", + "workspace": "test-ws", + "model_id": "my-proj-3-nas-gpu-b", + "starred": True, + "quiet": True, + } + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + def _capture_stdout(self, fn, args): + buf = io.StringIO() + old = sys.stdout + sys.stdout = buf + try: + fn(args) + finally: + sys.stdout = old + return buf.getvalue() + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_success(self, mock_fav: MagicMock) -> None: + from roboflow.cli.handlers.model import _star_model + + mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} + out = self._capture_stdout(_star_model, self._args()) + + # Bare slug + -w, so workspace comes from args.workspace. + mock_fav.assert_called_once_with("test-key", "test-ws", "my-proj-3-nas-gpu-b", starred=True) + result = json.loads(out) + self.assertTrue(result.get("success")) + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_workspace_prefixed_id(self, mock_fav: MagicMock) -> None: + """When the id is `/`, the workspace flag overrides anyway.""" + from roboflow.cli.handlers.model import _star_model + + mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} + self._capture_stdout(_star_model, self._args(model_id="some-ws/my-proj-3-nas-gpu-b")) + + # -w wins over the prefix, id is stripped of the workspace segment. + mock_fav.assert_called_once_with("test-key", "test-ws", "my-proj-3-nas-gpu-b", starred=True) + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_workspace_inferred_from_prefix(self, mock_fav: MagicMock) -> None: + """No -w but `/` argument: workspace comes from the prefix.""" + from roboflow.cli.handlers.model import _star_model + + mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} + self._capture_stdout( + _star_model, + self._args(workspace=None, model_id="some-ws/my-proj-3-nas-gpu-b"), + ) + + mock_fav.assert_called_once_with("test-key", "some-ws", "my-proj-3-nas-gpu-b", starred=True) + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_unstar_path(self, mock_fav: MagicMock) -> None: + from roboflow.cli.handlers.model import _star_model + + mock_fav.return_value = {"success": True, "model": {"url": "my-proj-3-nas-gpu-b"}} + self._capture_stdout(_star_model, self._args(starred=False)) + + mock_fav.assert_called_once_with("test-key", "test-ws", "my-proj-3-nas-gpu-b", starred=False) + + @patch("roboflow.adapters.rfapi.favorite_nas_model") + def test_star_non_nas_surfaces_hint(self, mock_fav: MagicMock) -> None: + from roboflow.adapters import rfapi + from roboflow.cli.handlers.model import _star_model + + mock_fav.side_effect = rfapi.RoboflowError( + '{"code":"MODEL_NOT_NAS","message":"Starring is only supported for NAS-trained models."}' + ) + buf = io.StringIO() + old = sys.stderr + sys.stderr = buf + try: + with self.assertRaises(SystemExit) as cm: + _star_model(self._args()) + finally: + sys.stderr = old + self.assertEqual(cm.exception.code, 3) + err = json.loads(buf.getvalue()) + # output_error parses the JSON body; the code surfaces alongside the message. + self.assertEqual(err["error"].get("code"), "MODEL_NOT_NAS") + self.assertIn("NAS-only", err["error"].get("hint", "")) + + +class TestModelListGroupFilter(unittest.TestCase): + """_list_models with --group hits the public /models endpoint.""" + + def _args(self, **kwargs: object) -> types.SimpleNamespace: + defaults = { + "json": True, + "api_key": "test-key", + "workspace": "test-ws", + "project": "my-project", + "group": None, + "quiet": True, + } + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + @patch("roboflow.adapters.rfapi.list_project_models") + def test_list_with_group_uses_public_endpoint(self, mock_list: MagicMock) -> None: + from roboflow.cli.handlers.model import _list_models + + mock_list.return_value = [ + { + "url": "my-ws/my-proj-3-nas-gpu-abc", + "modelType": "rfdetr-nas", + "metrics": { + "map50": 87.3, + "map5095": 57.6, + "hardware": "gpu", + "latency": 8.7, + }, + "recommended": True, + } + ] + + buf = io.StringIO() + old = sys.stdout + sys.stdout = buf + try: + _list_models(self._args(group="rfdetrNasGroup-3")) + finally: + sys.stdout = old + + mock_list.assert_called_once_with("test-key", "test-ws", "my-project", group="rfdetrNasGroup-3") + rows = json.loads(buf.getvalue()) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["metrics"]["hardware"], "gpu") + + if __name__ == "__main__": unittest.main() diff --git a/tests/cli/test_train_handler.py b/tests/cli/test_train_handler.py index 63fef691..7d826e6c 100644 --- a/tests/cli/test_train_handler.py +++ b/tests/cli/test_train_handler.py @@ -143,5 +143,110 @@ def test_start_json_error_not_double_encoded(self, mock_train: MagicMock) -> Non self.assertEqual(result["error"]["message"], "Unsupported request") +class TestTrainSubcommandsRegister(unittest.TestCase): + """train cancel/stop/results subcommands register correctly.""" + + def test_cancel_help(self) -> None: + result = runner.invoke(app, ["train", "cancel", "--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("cancel", result.output.lower()) + + def test_stop_help(self) -> None: + result = runner.invoke(app, ["train", "stop", "--help"]) + self.assertEqual(result.exit_code, 0) + + def test_results_help(self) -> None: + result = runner.invoke(app, ["train", "results", "--help"]) + self.assertEqual(result.exit_code, 0) + + +class TestTrainCancelStopResults(unittest.TestCase): + """_cancel / _stop / _results business logic.""" + + def _args(self, **kwargs: object) -> types.SimpleNamespace: + defaults = { + "json": True, + "api_key": "test-key", + "workspace": "test-ws", + "target": "my-project/3", + "continue_if_no_refund": False, + "quiet": True, + } + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + def _capture_stdout(self, fn, args): + buf = io.StringIO() + old = sys.stdout + sys.stdout = buf + try: + fn(args) + finally: + sys.stdout = old + return buf.getvalue() + + @patch("roboflow.adapters.rfapi.cancel_version_training") + def test_cancel_success(self, mock_cancel: MagicMock) -> None: + from roboflow.cli.handlers.train import _cancel + + mock_cancel.return_value = {"refund": True} + out = self._capture_stdout(_cancel, self._args()) + + mock_cancel.assert_called_once_with("test-key", "test-ws", "my-project", "3", continue_if_no_refund=False) + result = json.loads(out) + self.assertEqual(result["status"], "cancelled") + self.assertEqual(result["project"], "my-project") + self.assertEqual(result["version"], "3") + self.assertTrue(result.get("refund")) + + @patch("roboflow.adapters.rfapi.cancel_version_training") + def test_cancel_409_surfaces_hint(self, mock_cancel: MagicMock) -> None: + from roboflow.adapters import rfapi + from roboflow.cli.handlers.train import _cancel + + mock_cancel.side_effect = rfapi.RoboflowError("Cannot cancel non-running train job.") + buf = io.StringIO() + old = sys.stderr + sys.stderr = buf + try: + with self.assertRaises(SystemExit) as cm: + _cancel(self._args()) + finally: + sys.stderr = old + self.assertEqual(cm.exception.code, 3) + err = json.loads(buf.getvalue()) + self.assertIn("Cannot cancel", err["error"]["message"]) + self.assertIn("in-flight", err["error"].get("hint", "")) + + @patch("roboflow.adapters.rfapi.stop_version_training") + def test_stop_success(self, mock_stop: MagicMock) -> None: + from roboflow.cli.handlers.train import _stop + + mock_stop.return_value = {"success": True} + out = self._capture_stdout(_stop, self._args()) + + mock_stop.assert_called_once_with("test-key", "test-ws", "my-project", "3") + result = json.loads(out) + self.assertEqual(result["status"], "stop_requested") + + @patch("roboflow.adapters.rfapi.get_training_results") + def test_results_nas_run(self, mock_get: MagicMock) -> None: + from roboflow.cli.handlers.train import _results + + mock_get.return_value = { + "trainingId": "test-ws/my-project/3", + "status": "finished", + "jobType": "nas", + "modelGroup": "rfdetrNasGroup-3", + "modelCount": 5, + "recommendedByHardware": {"gpu": "my-project-3-nas-gpu-a"}, + "models": [{"modelId": "my-project-3-nas-gpu-a"}], + } + out = self._capture_stdout(_results, self._args()) + result = json.loads(out) + self.assertEqual(result["jobType"], "nas") + self.assertEqual(result["modelCount"], 5) + + if __name__ == "__main__": unittest.main()