fix: infer fanout stages from process return type#1964
Conversation
Signed-off-by: nightcityblade <nightcityblade@gmail.com>
Greptile SummaryThis PR adds automatic fanout-stage detection to
Confidence Score: 4/5Safe to merge — the new logic only affects stages that do not override The introspection logic is correct for the annotated cases covered by the tests. The one gap is the
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["ray_stage_spec() called"] --> B["_process_returns_list()"]
B --> C["get_type_hints(type(self).process)"]
C -->|Success| D{"return hint?"}
C -->|NameError / TypeError| E["fallback: __annotations__.get('return')"]
E --> F{"return hint?"}
F -->|None| G["return False"]
F -->|string annotation| H["get_origin(str) returns None, return False"]
D -->|None| G
D -->|hint present| I{"get_origin is list?"}
I -->|Yes| J["return True"]
I -->|No| K{"any arg in get_args has origin == list?"}
K -->|Yes| J
K -->|No| G
J --> L["ray_stage_spec returns is_fanout_stage True"]
G --> M["ray_stage_spec returns empty dict"]
Reviews (1): Last reviewed commit: "fix: infer fanout stages from process re..." | Re-trigger Greptile |
| def _process_returns_list(self) -> bool: | ||
| """Return whether the process return annotation can produce a list.""" | ||
| try: | ||
| return_hint = get_type_hints(type(self).process).get("return") | ||
| except (NameError, TypeError): | ||
| return_hint = type(self).process.__annotations__.get("return") | ||
|
|
||
| if return_hint is None: | ||
| return False | ||
|
|
||
| origin = get_origin(return_hint) | ||
| if origin is list: | ||
| return True | ||
|
|
||
| return any(get_origin(arg) is list for arg in get_args(return_hint)) |
There was a problem hiding this comment.
Silent false-negative when
from __future__ import annotations is active
When a module uses from __future__ import annotations (PEP 563), all annotations are stored as strings. get_type_hints() evaluates those strings against the module's namespace, which usually succeeds. However, when it does raise NameError (e.g., a forward reference that can't be resolved at introspection time), the fallback reads raw __annotations__, which returns the un-evaluated string "list[MockTask]". get_origin("list[MockTask]") returns None and get_args(...) returns (), so any such stage silently falls through to returning {} instead of {"is_fanout_stage": True}. Adding a test case that exercises the NameError fallback path (or documents this limitation) would prevent silent misclassification if the pattern appears in production stages.
| class TestProcessingStageRayStageSpec: | ||
| """Test the default Ray stage spec for ProcessingStage.""" | ||
|
|
||
| def test_single_output_process_is_not_fanout(self): | ||
| assert ConcreteProcessingStage().ray_stage_spec() == {} | ||
|
|
||
| @pytest.mark.parametrize("stage_cls", [FanoutProcessingStage, OptionalFanoutProcessingStage]) | ||
| def test_list_output_process_is_fanout(self, stage_cls: type[ProcessingStage]) -> None: | ||
| assert stage_cls().ray_stage_spec() == {"is_fanout_stage": True} |
There was a problem hiding this comment.
Missing edge-case test: unannotated
process() method
The helper _process_returns_list() explicitly handles the case where get_type_hints() returns no "return" key (return_hint is None → False), but there is no test that verifies a stage whose process() has no return annotation is correctly treated as non-fanout. Adding a fixture class with def process(self, task: MockTask) (no return annotation) and asserting ray_stage_spec() == {} would close this gap.
sarahyurick
left a comment
There was a problem hiding this comment.
Thanks @nightcityblade !
| def _process_returns_list(self) -> bool: | ||
| """Return whether the process return annotation can produce a list.""" | ||
| try: | ||
| return_hint = get_type_hints(type(self).process).get("return") | ||
| except (NameError, TypeError): | ||
| return_hint = type(self).process.__annotations__.get("return") | ||
|
|
||
| if return_hint is None: | ||
| return False | ||
|
|
||
| origin = get_origin(return_hint) | ||
| if origin is list: | ||
| return True | ||
|
|
||
| return any(get_origin(arg) is list for arg in get_args(return_hint)) |
|
|
||
| Returns (dict[str, Any]): | ||
| Dictionary containing Ray-specific configuration | ||
| """ | ||
| if self._process_returns_list(): | ||
| return {"is_fanout_stage": True} |
There was a problem hiding this comment.
Should we remove is_fanout_stage from existing stages to make sure it works (And make sure those existing stages have tests for it)?
Description
Closes #1613.
Adds default
ProcessingStage.ray_stage_spec()logic that marks stages as fanout stages when theirprocess()return annotation can produce alist[...]. This covers both always-fanout returns and union returns such asTask | list[Task].Usage
N/A
Checklist
Testing
uv run --group linting ruff check nemo_curator/stages/base.py tests/stages/common/test_base.pyuv run pytest tests/stages/common/test_base.py(blocked locally: NeMo-Curator currently supports Linux only, but this runner is macOS/Darwin)