Triton Inference Server support for RunInference transform#36369
Triton Inference Server support for RunInference transform#36369SaiShashank12 wants to merge 9 commits intoapache:masterfrom
Conversation
Summary of ChangesHello @SaiShashank12, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Triton Inference Server capabilities into Apache Beam's Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Checks are failing. Will not request review until checks are succeeding. If you'd like to override that behavior, comment |
- Add TritonModelWrapper for server cleanup on deletion - Make tensor names configurable via constructor and inference_args - Add comprehensive error handling with descriptive messages - Complete all docstrings with usage examples - Add custom output parsing function support - Create full unit test suite with mocks - Apply Apache license header and yapf formatting
|
assign set of reviewers |
|
Assigning reviewers: R: @tvalentyn for label python. Note: If you would like to opt out of this review, comment Available commands:
The PR bot will only process comments in the main thread (not review comments). |
| if len(predictions) != len(batch): | ||
| LOGGER.warning( | ||
| f"Prediction count ({len(predictions)}) doesn't match " | ||
| f"batch size ({len(batch)}). Truncating or padding.") |
There was a problem hiding this comment.
I think zip always truncates to the length of the shortest list. Can we be losing some predictions here? Also is the correspondence between examples in batch and predictions still meaningful given that we seem to be flattening inferences that return a list in predictions.extend(parsed if isinstance(parsed, list) else [parsed]) ?
| from apache_beam.ml.inference.TritonModelHandler import TritonModelHandler | ||
| from apache_beam.ml.inference.TritonModelHandler import TritonModelWrapper | ||
| except ImportError: | ||
| raise unittest.SkipTest('Triton dependencies are not installed') |
There was a problem hiding this comment.
to make these test run, we can add this dep to ml_test extra. I tried to pip install tritonserver and it worked for me on py310, but didn't work on py311 or py312. I think we could introduce a py310_ml_test extra, similarly to the existing py312_ml_test extra, see: https://github.com/apache/beam/blob/e8b41d7664aee65fdf98b990a205b17b361ed222/sdks/python/tox.ini#L119C10-L119C15 https://github.com/apache/beam/blame/master/sdks/python/setup.py#L552C18-L552C18 to exercise these tests.
|
there are also test and lint errors related to this PR (can be seen in test logs), for example: |
|
cc: @damccorm as well |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Triton Inference Server in Apache Beam's RunInference transform by adding a TritonModelHandler. The implementation includes model loading, batch inference, and result parsing. The changes are well-structured and accompanied by a comprehensive set of unit tests. My review includes suggestions to improve resource management reliability, refine exception handling for more precise error reporting, and clean up unused imports in both the new handler and its test file.
| def __del__(self): | ||
| """Cleanup server when model is garbage collected.""" | ||
| try: | ||
| if self.server: | ||
| self.server.stop() | ||
| except Exception as e: | ||
| LOGGER.warning(f"Error stopping Triton server: {e}") |
There was a problem hiding this comment.
Using __del__ for resource cleanup, such as stopping the Triton server, is generally not recommended as its execution is not guaranteed. For instance, it may not be called when the interpreter exits, potentially leaving the server process running. While this might be a limitation of the ModelHandler interface, it's a potential source of resource leaks. A more robust approach would be to use an explicit cleanup mechanism if the framework allows for it, for example, through the teardown() method of a DoFn.
| except Exception: | ||
| # If JSON parsing fails, return raw output | ||
| parsed = output_tensor.to_bytes_array().tolist() |
There was a problem hiding this comment.
Catching a broad Exception can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions. Since you are handling potential errors from json.loads, you should catch json.JSONDecodeError specifically.
| except Exception: | |
| # If JSON parsing fails, return raw output | |
| parsed = output_tensor.to_bytes_array().tolist() | |
| except json.JSONDecodeError: | |
| # If JSON parsing fails, return raw output | |
| parsed = output_tensor.to_bytes_array().tolist() |
| from typing import Sequence, Dict, Any, Iterable, Optional | ||
| import logging | ||
| import json | ||
| import atexit |
| # pytype: skip-file | ||
|
|
||
| import unittest | ||
| from unittest.mock import Mock, MagicMock, patch |
| import apache_beam as beam | ||
| from apache_beam.testing.test_pipeline import TestPipeline | ||
| from apache_beam.testing.util import assert_that | ||
| from apache_beam.testing.util import equal_to |
- Remove unused atexit import - Fix logging f-string usage (use lazy % formatting) - Add explicit cleanup() method to TritonModelWrapper for reliable resource cleanup - Improve __del__ to prevent double cleanup - Remove unused imports from test file (MagicMock, beam, TestPipeline, assert_that, equal_to, RunInference) - Add tests for explicit cleanup and idempotent behavior All files now pass pylint 10.00/10 and yapf formatting checks.
- Add tritonserver to ml_test and p312_ml_test extras in setup.py - Create triton_tests_requirements.txt for test-specific dependencies This ensures CI environments have tritonserver installed when running tests.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #36369 +/- ##
=============================================
- Coverage 56.84% 40.16% -16.69%
Complexity 3386 3386
=============================================
Files 1220 1222 +2
Lines 185898 186316 +418
Branches 3523 3523
=============================================
- Hits 105672 74828 -30844
- Misses 76885 108147 +31262
Partials 3341 3341
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@neilbhutada please have a look into it |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Triton Inference Server in Apache Beam's RunInference transform by adding a TritonModelHandler. The implementation is well-structured, covering model loading, batch inference, and lifecycle management of the Triton server. The accompanying unit tests are comprehensive.
My review includes a few suggestions for improvement:
- Making the handling of mismatched batch and prediction sizes safer by raising an error instead of silently truncating data.
- Improving the robustness of resource cleanup in the
TritonModelWrapper. - Refining exception handling to be more specific.
- Adding a test case to ensure cleanup idempotency.
Overall, this is a great addition to Beam's ML capabilities.
| if len(predictions) != len(batch): | ||
| LOGGER.warning( | ||
| "Prediction count (%d) doesn't match " | ||
| "batch size (%d). Truncating or padding.", | ||
| len(predictions), | ||
| len(batch)) |
There was a problem hiding this comment.
Silently truncating data when the number of predictions does not match the batch size can lead to data loss and hard-to-debug issues. It's safer to raise an exception in this case. The log message is also slightly misleading as it mentions "padding", which is not implemented.
if len(predictions) != len(batch):
raise RuntimeError(
f"Prediction count ({len(predictions)}) doesn't match "
f"batch size ({len(batch)}).")| if self.server: | ||
| self.server.stop() |
There was a problem hiding this comment.
The __del__ method should set self._cleaned_up = True after stopping the server. This ensures that if cleanup() is called after __del__ has been invoked (e.g., during a complex shutdown sequence), it remains idempotent and does not attempt to stop an already-stopped server, which could lead to errors.
| if self.server: | |
| self.server.stop() | |
| if self.server: | |
| self.server.stop() | |
| self._cleaned_up = True |
| json.loads(val) | ||
| for val in output_tensor.to_string_array().tolist() | ||
| ] | ||
| except Exception: |
There was a problem hiding this comment.
The except Exception: clause is too broad and may catch and hide unexpected errors beyond JSON parsing issues. It's better to be more specific and catch only the expected exceptions, such as json.JSONDecodeError and TypeError.
| except Exception: | |
| except (json.JSONDecodeError, TypeError): |
|
|
||
| # __del__ should not call stop again | ||
| wrapper.__del__() | ||
| mock_server.stop.assert_called_once() |
There was a problem hiding this comment.
To ensure the idempotency of the cleanup logic, it would be beneficial to add a test case that verifies cleanup() does not re-attempt to stop the server if __del__() has already been called. This complements test_wrapper_cleanup_idempotent and validates the fix for making __del__ set the _cleaned_up flag.
| mock_server.stop.assert_called_once() | |
| mock_server.stop.assert_called_once() | |
| def test_wrapper_cleanup_after_del(self): | |
| """Test that cleanup after __del__ doesn't cause double cleanup.""" | |
| mock_server = Mock() | |
| mock_model = Mock() | |
| wrapper = TritonModelWrapper(server=mock_server, model=mock_model) | |
| # Call __del__ first | |
| wrapper.__del__() | |
| mock_server.stop.assert_called_once() | |
| # cleanup() should not call stop again | |
| wrapper.cleanup() | |
| mock_server.stop.assert_called_once() |
- Reorder imports: stdlib (json, logging) before typing before apache_beam - Separate each typing import on its own line - Separate base imports on individual lines - Remove unused type: ignore comment on line 30 - Apply isort formatting for consistency
- Set _cleaned_up flag in __del__ to ensure idempotent cleanup - Use specific exception types (JSONDecodeError, TypeError, AttributeError) instead of broad Exception catch in JSON parsing - Improves resource management reliability and error handling precision
- Add tritonserver to ml_base list for consistency with upstream refactoring - ml_base is now shared across ml_test, p312_ml_test, and p313_ml_test - Resolves conflict from upstream introducing ml_base and p313_ml_test
- Change Optional[callable] to Optional[Callable] in TritonModelHandler - Add Triton Inference Server entry to CHANGES.md
|
/gemini review |
|
could you also please take a look at the failing precommits? |
There was a problem hiding this comment.
Code Review
This pull request introduces Triton Inference Server support for the RunInference transform in Apache Beam. It includes the implementation of a TritonModelHandler class, enabling inference on text data using Triton Inference Server models within Beam pipelines. The changes also add necessary testing files and modify the setup.py file to include tritonserver as a dependency. I have identified some areas for improvement, mainly focusing on error handling and code clarity.
| except Exception as e: | ||
| LOGGER.warning("Error stopping Triton server: %s", e) | ||
| raise |
There was a problem hiding this comment.
High: Instead of raising the exception, consider logging the error at the error level and then re-raising. This provides more detailed information for debugging purposes.
Consider adding exception chaining to preserve the original traceback.
| except Exception as e: | |
| LOGGER.warning("Error stopping Triton server: %s", e) | |
| raise | |
| except Exception as e: | |
| LOGGER.error("Error stopping Triton server: %s", e, exc_info=True) # Log the error with exc_info for traceback | |
| raise RuntimeError("Error stopping Triton server") from e # Re-raise with exception chaining |
| except Exception as e: | ||
| raise RuntimeError( | ||
| f"Failed to start Triton server with repository " | ||
| f"'{self._model_repository}': {e}") from e |
There was a problem hiding this comment.
High: Consider logging the exception with exc_info=True to include the traceback in the logs, which can be helpful for debugging.
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to start Triton server with repository " | |
| f"'{self._model_repository}': {e}") from e | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to start Triton server with repository " | |
| f"'{self._model_repository}': {e}") from e | |
| LOGGER.error("Failed to start Triton server", exc_info=True) |
| except Exception as e: | ||
| server.stop() | ||
| raise RuntimeError( | ||
| f"Failed to load model '{self._model_name}': {e}") from e |
There was a problem hiding this comment.
High: Consider logging the exception with exc_info=True to include the traceback in the logs, which can be helpful for debugging.
| except Exception as e: | |
| server.stop() | |
| raise RuntimeError( | |
| f"Failed to load model '{self._model_name}': {e}") from e | |
| except Exception as e: | |
| server.stop() | |
| LOGGER.error("Failed to load model", exc_info=True) | |
| raise RuntimeError( | |
| f"Failed to load model '{self._model_name}': {e}") from e |
| except Exception as e: | ||
| raise RuntimeError( | ||
| f"Triton inference failed for model '{self._model_name}': {e}") from e |
There was a problem hiding this comment.
High: Consider logging the exception with exc_info=True to include the traceback in the logs, which can be helpful for debugging.
| except Exception as e: | |
| raise RuntimeError( | |
| f"Triton inference failed for model '{self._model_name}': {e}") from e | |
| except Exception as e: | |
| LOGGER.error("Triton inference failed", exc_info=True) | |
| raise RuntimeError( | |
| f"Triton inference failed for model '{self._model_name}': {e}") from e |
| except Exception as e: | ||
| raise RuntimeError(f"Failed to parse model outputs: {e}") from e |
There was a problem hiding this comment.
High: Consider logging the exception with exc_info=True to include the traceback in the logs, which can be helpful for debugging.
| except Exception as e: | |
| raise RuntimeError(f"Failed to parse model outputs: {e}") from e | |
| except Exception as e: | |
| LOGGER.error("Failed to parse model outputs", exc_info=True) | |
| raise RuntimeError(f"Failed to parse model outputs: {e}") from e |
| if len(predictions) != len(batch): | ||
| LOGGER.warning( | ||
| "Prediction count (%d) doesn't match " | ||
| "batch size (%d). Truncating or padding.", | ||
| len(predictions), | ||
| len(batch)) |
There was a problem hiding this comment.
Medium: It might be beneficial to include the actual prediction and batch content in the warning message for debugging purposes. Consider truncating the content if it's too large.
| if len(predictions) != len(batch): | |
| LOGGER.warning( | |
| "Prediction count (%d) doesn't match " | |
| "batch size (%d). Truncating or padding.", | |
| len(predictions), | |
| len(batch)) | |
| LOGGER.warning( | |
| "Prediction count (%d) doesn't match " | |
| "batch size (%d). Truncating or padding. " | |
| "Predictions: %s, Batch: %s", | |
| len(predictions), | |
| len(batch), | |
| str(predictions[:100]), # Truncate for large content | |
| str(batch[:100])) # Truncate for large content |
|
Reminder, please take a look at this pr: @tvalentyn |
|
waiting on author |
|
This pull request has been marked as stale due to 60 days of inactivity. It will be closed in 1 week if no further activity occurs. If you think that’s incorrect or this pull request requires a review, please simply write any comment. If closed, you can revive the PR at any time and @mention a reviewer or discuss it on the dev@beam.apache.org list. Thank you for your contributions. |
|
This pull request has been closed due to lack of activity. If you think that is incorrect, or the pull request requires review, you can revive the PR at any time. |
Title: #36368
Add Triton Inference Server support for RunInference transform
Issue Description:
Summary
This PR adds support for Triton Inference Server in Apache Beam’s RunInference transform by implementing a
TritonModelHandlerclass.What does this PR do?
TritonModelHandlerthat extendsModelHandler[str, PredictionResult, Model]Key Features
PredictionResultobjectsUse Case
This handler allows users to leverage Triton Inference Server’s optimized inference capabilities within Apache Beam pipelines, particularly useful for:
Testing