diff --git a/CHANGES.md b/CHANGES.md index 07f3b4f5accc..59794ca8f777 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -78,6 +78,7 @@ * Python examples added for CloudSQL enrichment handler on [Beam website](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment-cloudsql/) (Python) ([#35473](https://github.com/apache/beam/issues/36095)). * Support for batch mode execution in WriteToPubSub transform added (Python) ([#35990](https://github.com/apache/beam/issues/35990)). * Added official support for Python 3.13 ([#34869](https://github.com/apache/beam/issues/34869)). +* Added Triton Inference Server ModelHandler for ML inference (Python) ([#36369](https://github.com/apache/beam/issues/36369)). ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/inference/TritonModelHandler.py b/sdks/python/apache_beam/ml/inference/TritonModelHandler.py new file mode 100644 index 000000000000..fc83d0aa8202 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/TritonModelHandler.py @@ -0,0 +1,222 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Apache Beam ModelHandler implementation for Triton Inference Server.""" + +import json +import logging +from typing import Any, Callable, Dict, Iterable, Optional, Sequence + +from apache_beam.ml.inference.base import ModelHandler, PredictionResult + +try: + import tritonserver + from tritonserver import Model, Server +except ImportError: + tritonserver = None + +LOGGER = logging.getLogger(__name__) + + +class TritonModelWrapper: + """Wrapper to manage Triton Server lifecycle with the model.""" + def __init__(self, server: 'Server', model: 'Model'): + self.server = server + self.model = model + self._cleaned_up = False + + def cleanup(self): + """Explicitly cleanup server resources. + + This method should be called when the model is no longer needed. + It's safe to call multiple times. + """ + if self._cleaned_up: + return + + try: + if self.server: + self.server.stop() + self._cleaned_up = True + except Exception as e: + LOGGER.warning("Error stopping Triton server: %s", e) + raise + + def __del__(self): + """Cleanup server when model is garbage collected. + + Note: __del__ is not guaranteed to be called. Prefer using cleanup() + explicitly when possible. + """ + if not self._cleaned_up: + try: + if self.server: + self.server.stop() + self._cleaned_up = True + except Exception as e: + LOGGER.warning("Error stopping Triton server in __del__: %s", e) + + +class TritonModelHandler(ModelHandler[Any, PredictionResult, + TritonModelWrapper]): + """Beam ModelHandler for Triton Inference Server. + + This handler supports loading models from a Triton model repository and + running inference using the Triton Python API. + + Example usage:: + + pcoll | RunInference( + TritonModelHandler( + model_repository="/workspace/models", + model_name="my_model", + input_tensor_name="input", + output_tensor_name="output" + ) + ) + + Args: + model_repository: Path to the Triton model repository directory. + model_name: Name of the model to load from the repository. + input_tensor_name: Name of the input tensor (default: "INPUT"). + output_tensor_name: Name of the output tensor (default: "OUTPUT"). + parse_output_fn: Optional custom function to parse model outputs. + Should take (outputs_dict, output_tensor_name) and return parsed result. + """ + def __init__( + self, + model_repository: str, + model_name: str, + input_tensor_name: str = "INPUT", + output_tensor_name: str = "OUTPUT", + parse_output_fn: Optional[Callable] = None, + ): + if tritonserver is None: + raise ImportError( + "tritonserver is not installed. " + "Install it with: pip install tritonserver") + + self._model_repository = model_repository + self._model_name = model_name + self._input_tensor_name = input_tensor_name + self._output_tensor_name = output_tensor_name + self._parse_output_fn = parse_output_fn + + def load_model(self) -> TritonModelWrapper: + """Loads and initializes a Triton model for processing. + + Returns: + TritonModelWrapper containing the server and model instances. + + Raises: + RuntimeError: If server fails to start or model fails to load. + """ + try: + server = tritonserver.Server(model_repository=self._model_repository) + server.start() + except Exception as e: + raise RuntimeError( + f"Failed to start Triton server with repository " + f"'{self._model_repository}': {e}") from e + + try: + model = server.model(self._model_name) + if model is None: + raise RuntimeError( + f"Model '{self._model_name}' not found in repository") + except Exception as e: + server.stop() + raise RuntimeError( + f"Failed to load model '{self._model_name}': {e}") from e + + return TritonModelWrapper(server, model) + + def run_inference( + self, + batch: Sequence[Any], + model: TritonModelWrapper, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """Runs inferences on a batch of inputs. + + Args: + batch: A sequence of examples (can be strings, arrays, etc.). + model: TritonModelWrapper returned by load_model(). + inference_args: Optional dict with 'input_tensor_name' and/or + 'output_tensor_name' to override defaults for this batch. + + Returns: + An Iterable of PredictionResult objects. + + Raises: + RuntimeError: If inference fails. + """ + # Allow per-batch tensor name overrides + input_name = self._input_tensor_name + output_name = self._output_tensor_name + if inference_args: + input_name = inference_args.get('input_tensor_name', input_name) + output_name = inference_args.get('output_tensor_name', output_name) + + try: + responses = model.model.infer(inputs={input_name: batch}) + except Exception as e: + raise RuntimeError( + f"Triton inference failed for model '{self._model_name}': {e}") from e + + # Parse outputs + predictions = [] + try: + for response in responses: + if output_name not in response.outputs: + raise RuntimeError( + f"Output tensor '{output_name}' not found in response. " + f"Available outputs: {list(response.outputs.keys())}") + + output_tensor = response.outputs[output_name] + + # Use custom parser if provided + if self._parse_output_fn: + parsed = self._parse_output_fn(response.outputs, output_name) + else: + # Default parsing: try string array, fallback to raw + try: + parsed = [ + json.loads(val) + for val in output_tensor.to_string_array().tolist() + ] + except (json.JSONDecodeError, TypeError, AttributeError): + # If JSON parsing fails, return raw output + parsed = output_tensor.to_bytes_array().tolist() + + predictions.extend(parsed if isinstance(parsed, list) else [parsed]) + + except Exception as e: + raise RuntimeError(f"Failed to parse model outputs: {e}") from e + + if len(predictions) != len(batch): + LOGGER.warning( + "Prediction count (%d) doesn't match " + "batch size (%d). Truncating or padding.", + len(predictions), + len(batch)) + + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def get_metrics_namespace(self) -> str: + """Returns namespace for metrics.""" + return "BeamML_Triton" diff --git a/sdks/python/apache_beam/ml/inference/triton_inference_test.py b/sdks/python/apache_beam/ml/inference/triton_inference_test.py new file mode 100644 index 000000000000..bab96f3a70e6 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/triton_inference_test.py @@ -0,0 +1,267 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +import unittest +from unittest.mock import Mock, patch + +# Protect against environments where tritonserver library is not available. +# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports +try: + from apache_beam.ml.inference.base import PredictionResult + from apache_beam.ml.inference.TritonModelHandler import TritonModelHandler + from apache_beam.ml.inference.TritonModelHandler import TritonModelWrapper +except ImportError: + raise unittest.SkipTest('Triton dependencies are not installed') + + +class TritonModelHandlerTest(unittest.TestCase): + def test_handler_initialization(self): + """Test that handler initializes correctly with required parameters.""" + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + self.assertEqual(handler._model_repository, "/workspace/models") + self.assertEqual(handler._model_name, "test_model") + self.assertEqual(handler._input_tensor_name, "INPUT") + self.assertEqual(handler._output_tensor_name, "OUTPUT") + + def test_handler_custom_tensor_names(self): + """Test handler with custom tensor names.""" + handler = TritonModelHandler( + model_repository="/workspace/models", + model_name="test_model", + input_tensor_name="custom_input", + output_tensor_name="custom_output") + self.assertEqual(handler._input_tensor_name, "custom_input") + self.assertEqual(handler._output_tensor_name, "custom_output") + + def test_handler_missing_tritonserver(self): + """Test that handler raises ImportError if tritonserver is not available.""" + with patch('apache_beam.ml.inference.TritonModelHandler.tritonserver', + None): + with self.assertRaises(ImportError): + TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + + @patch('apache_beam.ml.inference.TritonModelHandler.tritonserver') + def test_load_model_success(self, mock_tritonserver): + """Test successful model loading.""" + mock_server = Mock() + mock_model = Mock() + mock_tritonserver.Server.return_value = mock_server + mock_server.model.return_value = mock_model + + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + wrapper = handler.load_model() + + self.assertIsInstance(wrapper, TritonModelWrapper) + self.assertEqual(wrapper.server, mock_server) + self.assertEqual(wrapper.model, mock_model) + mock_server.start.assert_called_once() + mock_server.model.assert_called_once_with("test_model") + + @patch('apache_beam.ml.inference.TritonModelHandler.tritonserver') + def test_load_model_server_start_fails(self, mock_tritonserver): + """Test model loading when server fails to start.""" + mock_tritonserver.Server.side_effect = Exception("Server start failed") + + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + + with self.assertRaises(RuntimeError) as context: + handler.load_model() + self.assertIn("Failed to start Triton server", str(context.exception)) + + @patch('apache_beam.ml.inference.TritonModelHandler.tritonserver') + def test_load_model_model_not_found(self, mock_tritonserver): + """Test model loading when model is not found.""" + mock_server = Mock() + mock_tritonserver.Server.return_value = mock_server + mock_server.model.return_value = None + + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + + with self.assertRaises(RuntimeError) as context: + handler.load_model() + self.assertIn("Model 'test_model' not found", str(context.exception)) + mock_server.stop.assert_called_once() + + @patch('apache_beam.ml.inference.TritonModelHandler.tritonserver') + def test_run_inference_success(self, mock_tritonserver): + """Test successful inference.""" + # Setup mocks + mock_model = Mock() + mock_server = Mock() + wrapper = TritonModelWrapper(server=mock_server, model=mock_model) + + mock_response = Mock() + mock_output = Mock() + mock_output.to_string_array.return_value.tolist.return_value = [ + '{"result": 1}', '{"result": 2}' + ] + mock_response.outputs = {"OUTPUT": mock_output} + mock_model.infer.return_value = [mock_response] + + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + + batch = ["input1", "input2"] + results = list(handler.run_inference(batch, wrapper)) + + self.assertEqual(len(results), 2) + self.assertIsInstance(results[0], PredictionResult) + mock_model.infer.assert_called_once_with(inputs={"INPUT": batch}) + + @patch('apache_beam.ml.inference.TritonModelHandler.tritonserver') + def test_run_inference_with_custom_tensor_names(self, mock_tritonserver): + """Test inference with custom tensor names via inference_args.""" + mock_model = Mock() + mock_server = Mock() + wrapper = TritonModelWrapper(server=mock_server, model=mock_model) + + mock_response = Mock() + mock_output = Mock() + mock_output.to_string_array.return_value.tolist.return_value = [ + '{"result": 1}' + ] + mock_response.outputs = {"custom_out": mock_output} + mock_model.infer.return_value = [mock_response] + + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + + batch = ["input1"] + inference_args = { + 'input_tensor_name': 'custom_in', 'output_tensor_name': 'custom_out' + } + results = list(handler.run_inference(batch, wrapper, inference_args)) + + self.assertEqual(len(results), 1) + mock_model.infer.assert_called_once_with(inputs={"custom_in": batch}) + + @patch('apache_beam.ml.inference.TritonModelHandler.tritonserver') + def test_run_inference_fails(self, mock_tritonserver): + """Test inference failure handling.""" + mock_model = Mock() + mock_server = Mock() + wrapper = TritonModelWrapper(server=mock_server, model=mock_model) + mock_model.infer.side_effect = Exception("Inference error") + + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + + batch = ["input1"] + with self.assertRaises(RuntimeError) as context: + list(handler.run_inference(batch, wrapper)) + self.assertIn("Triton inference failed", str(context.exception)) + + @patch('apache_beam.ml.inference.TritonModelHandler.tritonserver') + def test_run_inference_output_not_found(self, mock_tritonserver): + """Test error when expected output tensor is not in response.""" + mock_model = Mock() + mock_server = Mock() + wrapper = TritonModelWrapper(server=mock_server, model=mock_model) + + mock_response = Mock() + mock_response.outputs = {} # Empty outputs + mock_model.infer.return_value = [mock_response] + + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + + batch = ["input1"] + with self.assertRaises(RuntimeError) as context: + list(handler.run_inference(batch, wrapper)) + self.assertIn("Output tensor 'OUTPUT' not found", str(context.exception)) + + @patch('apache_beam.ml.inference.TritonModelHandler.tritonserver') + def test_custom_parse_function(self, mock_tritonserver): + """Test using a custom output parsing function.""" + def custom_parser(outputs, output_name): + return ["custom_parsed_result"] + + mock_model = Mock() + mock_server = Mock() + wrapper = TritonModelWrapper(server=mock_server, model=mock_model) + + mock_response = Mock() + mock_response.outputs = {"OUTPUT": Mock()} + mock_model.infer.return_value = [mock_response] + + handler = TritonModelHandler( + model_repository="/workspace/models", + model_name="test_model", + parse_output_fn=custom_parser) + + batch = ["input1"] + results = list(handler.run_inference(batch, wrapper)) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0].inference, "custom_parsed_result") + + def test_get_metrics_namespace(self): + """Test metrics namespace.""" + handler = TritonModelHandler( + model_repository="/workspace/models", model_name="test_model") + self.assertEqual(handler.get_metrics_namespace(), "BeamML_Triton") + + def test_wrapper_cleanup(self): + """Test that TritonModelWrapper cleans up server on deletion.""" + mock_server = Mock() + mock_model = Mock() + + wrapper = TritonModelWrapper(server=mock_server, model=mock_model) + wrapper.__del__() + + mock_server.stop.assert_called_once() + + def test_wrapper_explicit_cleanup(self): + """Test explicit cleanup method.""" + mock_server = Mock() + mock_model = Mock() + + wrapper = TritonModelWrapper(server=mock_server, model=mock_model) + wrapper.cleanup() + + mock_server.stop.assert_called_once() + self.assertTrue(wrapper._cleaned_up) + + # Calling cleanup again should not call stop again + wrapper.cleanup() + mock_server.stop.assert_called_once() + + def test_wrapper_cleanup_idempotent(self): + """Test that cleanup after __del__ doesn't cause double cleanup.""" + mock_server = Mock() + mock_model = Mock() + + wrapper = TritonModelWrapper(server=mock_server, model=mock_model) + + # Call cleanup explicitly first + wrapper.cleanup() + mock_server.stop.assert_called_once() + + # __del__ should not call stop again + wrapper.__del__() + mock_server.stop.assert_called_once() + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/inference/triton_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/triton_tests_requirements.txt new file mode 100644 index 000000000000..614b2fac4956 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/triton_tests_requirements.txt @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +tritonserver>=2.41.0 diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 719d188ed266..510eece10d70 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -179,6 +179,7 @@ def cythonize(*args, **kwargs): 'tf2onnx', 'torch', 'transformers', + 'tritonserver', ]