Skip to content

✨[Feature] Support in-place custom plugins (and multiple outputs) #4240

@peterbjorgensen

Description

@peterbjorgensen

Is your feature request related to a problem? Please describe.
I am trying to use this package to put models into production and I am not sure if I am using it wrongly.
The usual route is pytorch->onnx->tensorrt, but now that I have custom triton kernels it becomes a pain.

The first problem is that the generate_plugin_converter only supports single outputs, but I have managed to hack this by adding a number_of_outputs argument to generate_plugin_converter and do return tuple(layer.get_output(i) for i in range(number_of_outputs)) inside the custom_kernel_converter function. Having multiple outputs seems like a basic feature.

The second problem is that I am trying to make in-place custom plugins work. I don't think it is supported at all, because I had to dig into the library and set builder_config.set_preview_feature(trt.PreviewFeature.ALIASED_PLUGIN_IO_10_03, True).
The following code does run after setting the flag, but it does not use aliased in/out:

from pathlib import Path
import triton
import triton.language as tl
import tensorrt.plugin as trtp
from typing import override
import torch
from torch import Tensor
from torch import nn
import torch_tensorrt
import tensorrt as trt

@triton.jit
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    output = x + 1
    tl.store(y_ptr + offsets, output, mask=mask)


@torch.library.custom_op("my::add_one", mutates_args=())  # type: ignore[misc]
def add_one(X: torch.Tensor) -> torch.Tensor:
    assert X.is_cuda
    BLOCK_SIZE = 256
    grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)
    Y = torch.empty_like(X)
    add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE)
    return Y


@torch.library.register_fake("my::add_one")
def _(X: torch.Tensor) -> torch.Tensor:
    return X

@trtp.register("my::add_one")
def add_plugin_desc(X: trtp.TensorDesc) -> tuple[trtp.TensorDesc]:
    # Output has the same shape and dtype as the input.
    return X.aliased(),


@trtp.aot_impl("my::add_one")
def add_plugin_aot_impl(
    X: trtp.TensorDesc, outputs: tuple[trtp.TensorDesc], tactic: int
) -> tuple[
    str | bytes, str | bytes, trtp.KernelLaunchParams, trtp.SymExprs
]:
    # Choose the pointer type based on the input dtype.
    type_str = "fp32" if X.dtype == trt.float32 else "fp16"

    block_size = 256
    # Compile the Triton kernel to PTX now, at registration time.
    # ``ASTSource`` describes the kernel's input types and constexprs without
    # running it — Triton compiles it to architecture-specific PTX/CUBIN.
    src = triton.compiler.ASTSource(
        fn=add_one_kernel,
        signature={
            "x_ptr": f"*{type_str}",
            "n_elements": "i32",
            "y_ptr": f"*{type_str}",
        },
        constexprs={
            "BLOCK_SIZE": block_size,
        },
    )
    compiled_kernel = triton.compile(src)

    # Build symbolic launch parameters.
    # ``X.shape_expr.numel()`` is a symbolic expression for the total number of
    # elements — TRT will evaluate it to a concrete integer at engine runtime.
    N = X.shape_expr.numel()
    launch_params = trtp.KernelLaunchParams()
    launch_params.grid_x = trtp.cdiv(N, block_size)  # number of thread blocks
    launch_params.block_x = compiled_kernel.metadata.num_warps * 32  # threads per block
    launch_params.shared_mem = compiled_kernel.metadata.shared  # bytes of shared mem

    # ``extra_args`` are scalar arguments appended to the kernel's argument list at
    # launch. Here ``n_elements`` is passed as a 32-bit symbolic integer so TRT
    # evaluates it from the actual tensor size at runtime.
    extra_args = trtp.SymIntExprs(1)
    extra_args[0] = trtp.SymInt32(N)
    print(compiled_kernel.asm["ptx"])

    return (
        compiled_kernel.metadata.name,  # kernel function name in PTX
        compiled_kernel.asm["ptx"],  # PTX source — embedded in TRT engine
        launch_params,
        extra_args,
    )


torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
    "my::add_one",
    supports_dynamic_shapes=False,
    requires_output_allocator=False,
    use_aot_if_available=True,
)

def load_and_test_engine(
    input_path: Path,
    input_args: tuple[Tensor, ...],
    expected_outputs: tuple[Tensor, ...],
    input_names: list[str],
    output_names: list[str],
) -> None:
    logger = trt.Logger(trt.Logger.VERBOSE)
    runtime = trt.Runtime(logger)
    with open(input_path, "rb") as f:
        model_data = f.read()
    runtime.engine_host_code_allowed = True
    engine = runtime.deserialize_cuda_engine(model_data)
    for name in output_names:
        aliased_input = engine.get_aliased_input_tensor(name)
        if aliased_input is not None:
            print(f"FOUND INPUT {aliased_input} aliased from {name}")
    print("DONE checking for aliased inputs")
    context = engine.create_execution_context()

    output_buffers = [torch.zeros_like(ref_out) for ref_out in expected_outputs]
    for name, buffer in zip(input_names, input_args, strict=True):
        context.set_tensor_address(name, buffer.data_ptr())
    for name, buffer in zip(output_names, output_buffers, strict=True):
        context.set_tensor_address(name, buffer.data_ptr())

    stream = torch.cuda.Stream()
    context.execute_async_v3(stream.cuda_stream)
    stream.synchronize()

    for name, dut, ref in zip(output_names, output_buffers, expected_outputs, strict=True):

        def make_error_msg(name: str):
            def _(msg: str) -> str:
                return f"{name} comparison failed: {msg}"

            return _

        torch.testing.assert_close(dut, ref, rtol=1e-3, atol=1e-5, msg=make_error_msg(name))

class InplaceModel(nn.Module):
    @override
    def forward(self, x: Tensor) -> Tensor:
        y = torch.ops.my.add_one.default(x)
        return y

def main():
    device = torch.device("cuda")
    big_state = torch.ones((100,2), dtype=torch.float32, device=device)

    args = (big_state,)
    input_names = ["x"]
    output_names = ["output0"]

    mdl = InplaceModel()

    expected_out = mdl(*args)
    expected_out = (expected_out, )

    with torch_tensorrt.logging.debug():
        program = torch_tensorrt.dynamo.trace(mdl, arg_inputs=args)

        engine_bytes: bytes = torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine(
            program,
            arg_inputs=args,
            require_full_compilation=True,
            immutable_weights=True,
            tiling_optimization_level="full",
            use_python_runtime=False,
            use_explicit_typing=False,
            optimization_level=5,
            version_compatible=False,
            hardware_compatible=True,
        )
    with open("inplace_model_test.engine", "wb") as f:
        f.write(engine_bytes)

    load_and_test_engine(Path("inplace_model_test.engine"), args, expected_out, input_names, output_names)



if __name__ == "__main__":
    main()

Also if I set mutates_args=("X",) as I should, I get an error:

WARNING:torch_tensorrt.dynamo.conversion._symbolic_shape_capture:When processing symbolic shapes for TensorRT engine, found no metadata in FX Graph ERROR:torch_tensorrt.dynamo._compiler:While interpreting the module got an error: Failed to extract symbolic shape expressions from source FX graph partition

It seems like the convention for torch in-place operators is to return None, but the tensorrt plugin convention is to return the aliased tensor, so I know to expect undefined behaviour from the code above.

Describe the solution you'd like
Support for multiple-output custom plugins should be pretty straight-forward.

Regarding in-place ops: I would like if the plugin converter could convert between the torch operator convention of returning None for in-place operations and the tensorrt convention of marking the outputs as aliased.

Describe alternatives you've considered
It seems that torch-tensorrt is not ready for my use cases, but I am not sure what the standard way of doing this is. Would I need to write ONNX operator implementations as well as tensorrt operations to be able to do the proven torch->onnx->tensorrt path that I have used before, or am I missing something?

The docs do have an example of in-place plugins (https://github.com/NVIDIA/TensorRT/blob/main/samples/python/quickly_deployable_plugins/qdp_runner.py), but I am not sure how to use my pytorch model inside the standard TensorRT networkbuilder interface when the end goal is to get a torch/python-free tensorrt engine to run in C++ runtime.

Additional context
Thanks for your time 👍

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions