Skip to content

Triton Inference Server support for RunInference transform#36369

Closed
SaiShashank12 wants to merge 9 commits intoapache:masterfrom
SaiShashank12:triton_handler
Closed

Triton Inference Server support for RunInference transform#36369
SaiShashank12 wants to merge 9 commits intoapache:masterfrom
SaiShashank12:triton_handler

Conversation

@SaiShashank12
Copy link
Copy Markdown

Title: #36368

Add Triton Inference Server support for RunInference transform

Issue Description:

Summary

This PR adds support for Triton Inference Server in Apache Beam’s RunInference transform by implementing a TritonModelHandler class.

What does this PR do?

  • Implements TritonModelHandler that extends ModelHandler[str, PredictionResult, Model]
  • Enables inference on text data using Triton Inference Server models
  • Supports batch processing of text strings through the Beam pipeline
  • Handles model loading, initialization, and inference execution with Triton server

Key Features

  • Model Loading: Initializes Triton server with configurable model repository and model name
  • Batch Inference: Processes sequences of text strings efficiently
  • Result Handling: Parses JSON responses from Triton and returns structured PredictionResult objects
  • Flexible Configuration: Supports custom inference arguments

Use Case

This handler allows users to leverage Triton Inference Server’s optimized inference capabilities within Apache Beam pipelines, particularly useful for:

  • Text classification tasks
  • Document processing pipelines
  • Real-time and batch ML inference workloads

Testing

  • Unit tests added
  • Integration tests with Triton server
  • Documentation updated

@github-actions github-actions Bot added the python label Oct 3, 2025
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @SaiShashank12, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Triton Inference Server capabilities into Apache Beam's RunInference transform by introducing a dedicated TritonModelHandler. This new handler allows Beam pipelines to efficiently perform scalable machine learning inference, particularly for text-based models, by abstracting the complexities of model loading, batch processing, and result handling with a Triton server.

Highlights

  • Triton Integration: Introduces support for Triton Inference Server within Apache Beam's RunInference transform.
  • New Model Handler: Implements a TritonModelHandler class to manage Triton models, extending ModelHandler.
  • Text Inference: Enables batch inference on text data using Triton models, processing sequences of text strings efficiently.
  • Core Functionality: Handles model loading, initialization, and execution of inference requests with the Triton server, including parsing JSON responses into PredictionResult objects.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Oct 3, 2025

Checks are failing. Will not request review until checks are succeeding. If you'd like to override that behavior, comment assign set of reviewers

@SaiShashank12 SaiShashank12 changed the title Intial commit Triton Inference Server support for RunInference transform Oct 3, 2025
- Add TritonModelWrapper for server cleanup on deletion
- Make tensor names configurable via constructor and inference_args
- Add comprehensive error handling with descriptive messages
- Complete all docstrings with usage examples
- Add custom output parsing function support
- Create full unit test suite with mocks
- Apply Apache license header and yapf formatting
@SaiShashank12
Copy link
Copy Markdown
Author

assign set of reviewers

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Oct 5, 2025

Assigning reviewers:

R: @tvalentyn for label python.

Note: If you would like to opt out of this review, comment assign to next reviewer.

Available commands:

  • stop reviewer notifications - opt out of the automated review tooling
  • remind me after tests pass - tag the comment author after tests pass
  • waiting on author - shift the attention set back to the author (any comment or push by the author will return the attention set to the reviewers)

The PR bot will only process comments in the main thread (not review comments).

if len(predictions) != len(batch):
LOGGER.warning(
f"Prediction count ({len(predictions)}) doesn't match "
f"batch size ({len(batch)}). Truncating or padding.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think zip always truncates to the length of the shortest list. Can we be losing some predictions here? Also is the correspondence between examples in batch and predictions still meaningful given that we seem to be flattening inferences that return a list in predictions.extend(parsed if isinstance(parsed, list) else [parsed]) ?

from apache_beam.ml.inference.TritonModelHandler import TritonModelHandler
from apache_beam.ml.inference.TritonModelHandler import TritonModelWrapper
except ImportError:
raise unittest.SkipTest('Triton dependencies are not installed')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make these test run, we can add this dep to ml_test extra. I tried to pip install tritonserver and it worked for me on py310, but didn't work on py311 or py312. I think we could introduce a py310_ml_test extra, similarly to the existing py312_ml_test extra, see: https://github.com/apache/beam/blob/e8b41d7664aee65fdf98b990a205b17b361ed222/sdks/python/tox.ini#L119C10-L119C15 https://github.com/apache/beam/blame/master/sdks/python/setup.py#L552C18-L552C18 to exercise these tests.

@tvalentyn
Copy link
Copy Markdown
Contributor

there are also test and lint errors related to this PR (can be seen in test logs), for example:

___________ TritonModelHandlerTest.test_handler_custom_tensor_names ____________
[gw2] linux -- Python 3.13.3 /runner/_work/beam/beam/sdks/python/test-suites/tox/py313/build/srcs/sdks/python/target/.tox-py313-ml/py313-ml/bin/python

self = <apache_beam.ml.inference.triton_inference_test.TritonModelHandlerTest testMethod=test_handler_custom_tensor_names>

...


apache_beam/ml/inference/TritonModelHandler.py:86: ImportError
---------------------------- Captured stdout setup -----------------------------
Warning: 155 active threads detected before test
______________ TritonModelHandlerTest.test_handler_initialization ______________
...
            "Install it with: pip install tritonserver")
E       ImportError: tritonserver is not installed. Install it with: pip install tritonserver

@tvalentyn
Copy link
Copy Markdown
Contributor

cc: @damccorm as well

@tvalentyn
Copy link
Copy Markdown
Contributor

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Triton Inference Server in Apache Beam's RunInference transform by adding a TritonModelHandler. The implementation includes model loading, batch inference, and result parsing. The changes are well-structured and accompanied by a comprehensive set of unit tests. My review includes suggestions to improve resource management reliability, refine exception handling for more precise error reporting, and clean up unused imports in both the new handler and its test file.

Comment on lines +42 to +48
def __del__(self):
"""Cleanup server when model is garbage collected."""
try:
if self.server:
self.server.stop()
except Exception as e:
LOGGER.warning(f"Error stopping Triton server: {e}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using __del__ for resource cleanup, such as stopping the Triton server, is generally not recommended as its execution is not guaranteed. For instance, it may not be called when the interpreter exits, potentially leaving the server process running. While this might be a limitation of the ModelHandler interface, it's a potential source of resource leaks. A more robust approach would be to use an explicit cleanup mechanism if the framework allows for it, for example, through the teardown() method of a DoFn.

Comment on lines +179 to +181
except Exception:
# If JSON parsing fails, return raw output
parsed = output_tensor.to_bytes_array().tolist()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching a broad Exception can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions. Since you are handling potential errors from json.loads, you should catch json.JSONDecodeError specifically.

Suggested change
except Exception:
# If JSON parsing fails, return raw output
parsed = output_tensor.to_bytes_array().tolist()
except json.JSONDecodeError:
# If JSON parsing fails, return raw output
parsed = output_tensor.to_bytes_array().tolist()

from typing import Sequence, Dict, Any, Iterable, Optional
import logging
import json
import atexit
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The atexit module is imported but never used in the file. It should be removed to improve code clarity and cleanliness.

# pytype: skip-file

import unittest
from unittest.mock import Mock, MagicMock, patch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The MagicMock class is imported from unittest.mock but is not used in this file. It can be removed to keep the imports clean.

Suggested change
from unittest.mock import Mock, MagicMock, patch
from unittest.mock import Mock, patch

Comment on lines +23 to +26
import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These imports (apache_beam, TestPipeline, assert_that, equal_to) are not used. Since the tests in this file are standard unit tests that don't construct or run a Beam pipeline, these imports can be removed.

- Remove unused atexit import
- Fix logging f-string usage (use lazy % formatting)
- Add explicit cleanup() method to TritonModelWrapper for reliable resource cleanup
- Improve __del__ to prevent double cleanup
- Remove unused imports from test file (MagicMock, beam, TestPipeline, assert_that, equal_to, RunInference)
- Add tests for explicit cleanup and idempotent behavior

All files now pass pylint 10.00/10 and yapf formatting checks.
- Add tritonserver to ml_test and p312_ml_test extras in setup.py
- Create triton_tests_requirements.txt for test-specific dependencies

This ensures CI environments have tritonserver installed when running tests.
@codecov
Copy link
Copy Markdown

codecov Bot commented Oct 11, 2025

Codecov Report

❌ Patch coverage is 21.17647% with 67 lines in your changes missing coverage. Please review.
✅ Project coverage is 40.16%. Comparing base (9cd994e) to head (cf06fa6).
⚠️ Report is 122 commits behind head on master.

Files with missing lines Patch % Lines
...hon/apache_beam/ml/inference/TritonModelHandler.py 21.17% 67 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff              @@
##             master   #36369       +/-   ##
=============================================
- Coverage     56.84%   40.16%   -16.69%     
  Complexity     3386     3386               
=============================================
  Files          1220     1222        +2     
  Lines        185898   186316      +418     
  Branches       3523     3523               
=============================================
- Hits         105672    74828    -30844     
- Misses        76885   108147    +31262     
  Partials       3341     3341               
Flag Coverage Δ
python 40.50% <21.17%> (-40.44%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@SaiShashank12
Copy link
Copy Markdown
Author

@neilbhutada please have a look into it

@SaiShashank12
Copy link
Copy Markdown
Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Triton Inference Server in Apache Beam's RunInference transform by adding a TritonModelHandler. The implementation is well-structured, covering model loading, batch inference, and lifecycle management of the Triton server. The accompanying unit tests are comprehensive.

My review includes a few suggestions for improvement:

  • Making the handling of mismatched batch and prediction sizes safer by raising an error instead of silently truncating data.
  • Improving the robustness of resource cleanup in the TritonModelWrapper.
  • Refining exception handling to be more specific.
  • Adding a test case to ensure cleanup idempotency.

Overall, this is a great addition to Beam's ML capabilities.

Comment on lines +210 to +215
if len(predictions) != len(batch):
LOGGER.warning(
"Prediction count (%d) doesn't match "
"batch size (%d). Truncating or padding.",
len(predictions),
len(batch))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Silently truncating data when the number of predictions does not match the batch size can lead to data loss and hard-to-debug issues. It's safer to raise an exception in this case. The log message is also slightly misleading as it mentions "padding", which is not implemented.

    if len(predictions) != len(batch):
      raise RuntimeError(
          f"Prediction count ({len(predictions)}) doesn't match "
          f"batch size ({len(batch)}).")

Comment on lines +67 to +68
if self.server:
self.server.stop()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __del__ method should set self._cleaned_up = True after stopping the server. This ensures that if cleanup() is called after __del__ has been invoked (e.g., during a complex shutdown sequence), it remains idempotent and does not attempt to stop an already-stopped server, which could lead to errors.

Suggested change
if self.server:
self.server.stop()
if self.server:
self.server.stop()
self._cleaned_up = True

json.loads(val)
for val in output_tensor.to_string_array().tolist()
]
except Exception:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The except Exception: clause is too broad and may catch and hide unexpected errors beyond JSON parsing issues. It's better to be more specific and catch only the expected exceptions, such as json.JSONDecodeError and TypeError.

Suggested change
except Exception:
except (json.JSONDecodeError, TypeError):


# __del__ should not call stop again
wrapper.__del__()
mock_server.stop.assert_called_once()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure the idempotency of the cleanup logic, it would be beneficial to add a test case that verifies cleanup() does not re-attempt to stop the server if __del__() has already been called. This complements test_wrapper_cleanup_idempotent and validates the fix for making __del__ set the _cleaned_up flag.

Suggested change
mock_server.stop.assert_called_once()
mock_server.stop.assert_called_once()
def test_wrapper_cleanup_after_del(self):
"""Test that cleanup after __del__ doesn't cause double cleanup."""
mock_server = Mock()
mock_model = Mock()
wrapper = TritonModelWrapper(server=mock_server, model=mock_model)
# Call __del__ first
wrapper.__del__()
mock_server.stop.assert_called_once()
# cleanup() should not call stop again
wrapper.cleanup()
mock_server.stop.assert_called_once()

- Reorder imports: stdlib (json, logging) before typing before apache_beam
- Separate each typing import on its own line
- Separate base imports on individual lines
- Remove unused type: ignore comment on line 30
- Apply isort formatting for consistency
- Set _cleaned_up flag in __del__ to ensure idempotent cleanup
- Use specific exception types (JSONDecodeError, TypeError, AttributeError)
  instead of broad Exception catch in JSON parsing
- Improves resource management reliability and error handling precision
- Add tritonserver to ml_base list for consistency with upstream refactoring
- ml_base is now shared across ml_test, p312_ml_test, and p313_ml_test
- Resolves conflict from upstream introducing ml_base and p313_ml_test
- Change Optional[callable] to Optional[Callable] in TritonModelHandler
- Add Triton Inference Server entry to CHANGES.md
@damccorm
Copy link
Copy Markdown
Contributor

/gemini review

@damccorm
Copy link
Copy Markdown
Contributor

could you also please take a look at the failing precommits?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Triton Inference Server support for the RunInference transform in Apache Beam. It includes the implementation of a TritonModelHandler class, enabling inference on text data using Triton Inference Server models within Beam pipelines. The changes also add necessary testing files and modify the setup.py file to include tritonserver as a dependency. I have identified some areas for improvement, mainly focusing on error handling and code clarity.

Comment on lines +55 to +57
except Exception as e:
LOGGER.warning("Error stopping Triton server: %s", e)
raise
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High: Instead of raising the exception, consider logging the error at the error level and then re-raising. This provides more detailed information for debugging purposes.

Consider adding exception chaining to preserve the original traceback.

Suggested change
except Exception as e:
LOGGER.warning("Error stopping Triton server: %s", e)
raise
except Exception as e:
LOGGER.error("Error stopping Triton server: %s", e, exc_info=True) # Log the error with exc_info for traceback
raise RuntimeError("Error stopping Triton server") from e # Re-raise with exception chaining

Comment on lines +131 to +134
except Exception as e:
raise RuntimeError(
f"Failed to start Triton server with repository "
f"'{self._model_repository}': {e}") from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High: Consider logging the exception with exc_info=True to include the traceback in the logs, which can be helpful for debugging.

Suggested change
except Exception as e:
raise RuntimeError(
f"Failed to start Triton server with repository "
f"'{self._model_repository}': {e}") from e
except Exception as e:
raise RuntimeError(
f"Failed to start Triton server with repository "
f"'{self._model_repository}': {e}") from e
LOGGER.error("Failed to start Triton server", exc_info=True)

Comment on lines +141 to +144
except Exception as e:
server.stop()
raise RuntimeError(
f"Failed to load model '{self._model_name}': {e}") from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High: Consider logging the exception with exc_info=True to include the traceback in the logs, which can be helpful for debugging.

Suggested change
except Exception as e:
server.stop()
raise RuntimeError(
f"Failed to load model '{self._model_name}': {e}") from e
except Exception as e:
server.stop()
LOGGER.error("Failed to load model", exc_info=True)
raise RuntimeError(
f"Failed to load model '{self._model_name}': {e}") from e

Comment on lines +177 to +179
except Exception as e:
raise RuntimeError(
f"Triton inference failed for model '{self._model_name}': {e}") from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High: Consider logging the exception with exc_info=True to include the traceback in the logs, which can be helpful for debugging.

Suggested change
except Exception as e:
raise RuntimeError(
f"Triton inference failed for model '{self._model_name}': {e}") from e
except Exception as e:
LOGGER.error("Triton inference failed", exc_info=True)
raise RuntimeError(
f"Triton inference failed for model '{self._model_name}': {e}") from e

Comment on lines +208 to +209
except Exception as e:
raise RuntimeError(f"Failed to parse model outputs: {e}") from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High: Consider logging the exception with exc_info=True to include the traceback in the logs, which can be helpful for debugging.

Suggested change
except Exception as e:
raise RuntimeError(f"Failed to parse model outputs: {e}") from e
except Exception as e:
LOGGER.error("Failed to parse model outputs", exc_info=True)
raise RuntimeError(f"Failed to parse model outputs: {e}") from e

Comment on lines +211 to +216
if len(predictions) != len(batch):
LOGGER.warning(
"Prediction count (%d) doesn't match "
"batch size (%d). Truncating or padding.",
len(predictions),
len(batch))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Medium: It might be beneficial to include the actual prediction and batch content in the warning message for debugging purposes. Consider truncating the content if it's too large.

Suggested change
if len(predictions) != len(batch):
LOGGER.warning(
"Prediction count (%d) doesn't match "
"batch size (%d). Truncating or padding.",
len(predictions),
len(batch))
LOGGER.warning(
"Prediction count (%d) doesn't match "
"batch size (%d). Truncating or padding. "
"Predictions: %s, Batch: %s",
len(predictions),
len(batch),
str(predictions[:100]), # Truncate for large content
str(batch[:100])) # Truncate for large content

@github-actions
Copy link
Copy Markdown
Contributor

Reminder, please take a look at this pr: @tvalentyn

@damccorm
Copy link
Copy Markdown
Contributor

waiting on author

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has been marked as stale due to 60 days of inactivity. It will be closed in 1 week if no further activity occurs. If you think that’s incorrect or this pull request requires a review, please simply write any comment. If closed, you can revive the PR at any time and @mention a reviewer or discuss it on the dev@beam.apache.org list. Thank you for your contributions.

@github-actions github-actions Bot added the stale label Dec 21, 2025
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has been closed due to lack of activity. If you think that is incorrect, or the pull request requires review, you can revive the PR at any time.

@github-actions github-actions Bot closed this Dec 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants