Skip to content

fix: infer fanout stages from process return type#1964

Open
nightcityblade wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
nightcityblade:fix/issue-1613
Open

fix: infer fanout stages from process return type#1964
nightcityblade wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
nightcityblade:fix/issue-1613

Conversation

@nightcityblade
Copy link
Copy Markdown
Contributor

Description

Closes #1613.

Adds default ProcessingStage.ray_stage_spec() logic that marks stages as fanout stages when their process() return annotation can produce a list[...]. This covers both always-fanout returns and union returns such as Task | list[Task].

Usage

N/A

Checklist

  • I am familiar with the Contributing Guide.
  • New or Existing tests cover these changes.
  • The documentation is up to date with these changes.

Testing

  • uv run --group linting ruff check nemo_curator/stages/base.py tests/stages/common/test_base.py
  • uv run pytest tests/stages/common/test_base.py (blocked locally: NeMo-Curator currently supports Linux only, but this runner is macOS/Darwin)

Signed-off-by: nightcityblade <nightcityblade@gmail.com>
@nightcityblade nightcityblade requested a review from a team as a code owner May 11, 2026 03:08
@nightcityblade nightcityblade requested review from ayushdg and removed request for a team May 11, 2026 03:08
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR adds automatic fanout-stage detection to ProcessingStage by inspecting the process() return annotation at runtime and setting is_fanout_stage: True in the default ray_stage_spec() when the annotation is list[Y] or a union containing list[Y], eliminating the need to manually override ray_stage_spec() for every fanout stage.

  • nemo_curator/stages/base.py: New _process_returns_list() helper uses get_type_hints with a __annotations__ fallback to detect list-returning process() methods; the default ray_stage_spec() now returns {\"is_fanout_stage\": True} when the helper fires, and {} otherwise.
  • tests/stages/common/test_base.py: Adds FanoutProcessingStage and OptionalFanoutProcessingStage fixtures plus a TestProcessingStageRayStageSpec suite covering the direct-list and union-list detection paths; existing stages that already override ray_stage_spec() are unaffected because their overrides bypass the new base-class logic entirely.

Confidence Score: 4/5

Safe to merge — the new logic only affects stages that do not override ray_stage_spec(), and the introspection is narrowly scoped to the concrete class's process() annotation.

The introspection logic is correct for the annotated cases covered by the tests. The one gap is the NameError fallback path: when get_type_hints() fails to resolve a string annotation it silently falls back to raw __annotations__, which returns a plain string that get_origin cannot parse, so the stage is quietly treated as non-fanout. This is a false-negative rather than a false-positive, meaning no stage will be incorrectly promoted to fanout.

nemo_curator/stages/base.py — specifically the _process_returns_list fallback branch; tests/stages/common/test_base.py for the missing unannotated-process edge case.

Important Files Changed

Filename Overview
nemo_curator/stages/base.py Adds _process_returns_list() introspection helper and wires it into the default ray_stage_spec() to automatically set is_fanout_stage: True when process() return annotation is list[Y] or a union containing list[Y]; fallback to raw __annotations__ on NameError/TypeError silently returns False for unresolvable string annotations.
tests/stages/common/test_base.py Adds FanoutProcessingStage, OptionalFanoutProcessingStage, and TestProcessingStageRayStageSpec covering the two main fanout detection paths; missing coverage for unannotated process() and the NameError fallback path in _process_returns_list().

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["ray_stage_spec() called"] --> B["_process_returns_list()"]
    B --> C["get_type_hints(type(self).process)"]
    C -->|Success| D{"return hint?"}
    C -->|NameError / TypeError| E["fallback: __annotations__.get('return')"]
    E --> F{"return hint?"}
    F -->|None| G["return False"]
    F -->|string annotation| H["get_origin(str) returns None, return False"]
    D -->|None| G
    D -->|hint present| I{"get_origin is list?"}
    I -->|Yes| J["return True"]
    I -->|No| K{"any arg in get_args has origin == list?"}
    K -->|Yes| J
    K -->|No| G
    J --> L["ray_stage_spec returns is_fanout_stage True"]
    G --> M["ray_stage_spec returns empty dict"]
Loading

Reviews (1): Last reviewed commit: "fix: infer fanout stages from process re..." | Re-trigger Greptile

Comment on lines +297 to +311
def _process_returns_list(self) -> bool:
"""Return whether the process return annotation can produce a list."""
try:
return_hint = get_type_hints(type(self).process).get("return")
except (NameError, TypeError):
return_hint = type(self).process.__annotations__.get("return")

if return_hint is None:
return False

origin = get_origin(return_hint)
if origin is list:
return True

return any(get_origin(arg) is list for arg in get_args(return_hint))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent false-negative when from __future__ import annotations is active

When a module uses from __future__ import annotations (PEP 563), all annotations are stored as strings. get_type_hints() evaluates those strings against the module's namespace, which usually succeeds. However, when it does raise NameError (e.g., a forward reference that can't be resolved at introspection time), the fallback reads raw __annotations__, which returns the un-evaluated string "list[MockTask]". get_origin("list[MockTask]") returns None and get_args(...) returns (), so any such stage silently falls through to returning {} instead of {"is_fanout_stage": True}. Adding a test case that exercises the NameError fallback path (or documents this limitation) would prevent silent misclassification if the pattern appears in production stages.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nightcityblade WDYT?

Comment on lines +72 to +80
class TestProcessingStageRayStageSpec:
"""Test the default Ray stage spec for ProcessingStage."""

def test_single_output_process_is_not_fanout(self):
assert ConcreteProcessingStage().ray_stage_spec() == {}

@pytest.mark.parametrize("stage_cls", [FanoutProcessingStage, OptionalFanoutProcessingStage])
def test_list_output_process_is_fanout(self, stage_cls: type[ProcessingStage]) -> None:
assert stage_cls().ray_stage_spec() == {"is_fanout_stage": True}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing edge-case test: unannotated process() method

The helper _process_returns_list() explicitly handles the case where get_type_hints() returns no "return" key (return_hint is None → False), but there is no test that verifies a stage whose process() has no return annotation is correctly treated as non-fanout. Adding a fixture class with def process(self, task: MockTask) (no return annotation) and asserting ray_stage_spec() == {} would close this gap.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label May 13, 2026
Copy link
Copy Markdown
Contributor

@sarahyurick sarahyurick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @nightcityblade !

Comment on lines +297 to +311
def _process_returns_list(self) -> bool:
"""Return whether the process return annotation can produce a list."""
try:
return_hint = get_type_hints(type(self).process).get("return")
except (NameError, TypeError):
return_hint = type(self).process.__annotations__.get("return")

if return_hint is None:
return False

origin = get_origin(return_hint)
if origin is list:
return True

return any(get_origin(arg) is list for arg in get_args(return_hint))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nightcityblade WDYT?


Returns (dict[str, Any]):
Dictionary containing Ray-specific configuration
"""
if self._process_returns_list():
return {"is_fanout_stage": True}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove is_fanout_stage from existing stages to make sure it works (And make sure those existing stages have tests for it)?

@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond and removed waiting-on-maintainers Waiting on maintainers to respond labels May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Automatically detect when IS_FANOUT_STAGE should be set to True

3 participants