feat: TorchTRT Cuda generated kernels plugin support#4199
Open
feat: TorchTRT Cuda generated kernels plugin support#4199
Conversation
d3a0651 to
abaaf96
Compare
b41c684 to
3b4dc2b
Compare
BowenFu
reviewed
Apr 23, 2026
BowenFu
reviewed
Apr 23, 2026
narendasan
reviewed
May 1, 2026
Collaborator
narendasan
left a comment
There was a problem hiding this comment.
Cool I think this is getting really close. I think we just have a few naming things to make this more user friendly and I think we should let users provide PTX directly in addition to the cuda apis. Also did you add nvrtc as an optional dependency in the pyproject.toml) (maybe under an a extras called kernels)?
| # Numel("x") pass x.numel() to the kernel as an int extra. | ||
| # Elementwise(flat) 1-D launch over the flattened output; any input rank works. | ||
|
|
||
| tta.auto_cuda_kernel_plugin( |
Collaborator
There was a problem hiding this comment.
maybe we can call this something like torch_tensorrt.kernels.cuda_kernel_op
5f11163 to
23a3bde
Compare
23a3bde to
93e3c76
Compare
93e3c76 to
ae7b9af
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py 2026-05-11 16:39:32.423025+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py 2026-05-11 16:39:55.856400+00:00
@@ -36,11 +36,13 @@
def _coerce_plugin_attr_for_qdp(value: Any, attr_annotation: Any) -> Any:
"""Convert Python scalars to the serialized type expected by QDP."""
if _is_numpy_attr_annotation(attr_annotation):
- return np.asarray(_unwrap_scalar_attr(value), dtype=_numpy_attr_dtype(attr_annotation))
+ return np.asarray(
+ _unwrap_scalar_attr(value), dtype=_numpy_attr_dtype(attr_annotation)
+ )
return value
def _is_numpy_attr_annotation(annotation: Any) -> bool:
return annotation is np.ndarray or typing.get_origin(annotation) is np.ndarray
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/kernels/_kernel_plugin.py 2026-05-11 16:39:32.436980+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/kernels/_kernel_plugin.py 2026-05-11 16:39:57.457501+00:00
@@ -371,16 +371,14 @@
annotations[d.name] = d.py_type
else:
sig_pieces.append(f"{d.name}: 'torch.Tensor'")
annotations[d.name] = torch.Tensor
sig_src = ", ".join(sig_pieces)
- body = textwrap.dedent(
- f"""
+ body = textwrap.dedent(f"""
def _wrapper({sig_src}) -> 'torch.Tensor':
return _fn({", ".join(param_names)})
- """
- )
+ """)
ns: Dict[str, Any] = {"_fn": fn, "torch": torch}
exec(compile(body, "<cuda_kernel_op>", "exec"), ns)
wrapper: Callable[..., Any] = ns["_wrapper"]
wrapper.__annotations__ = dict(annotations)
wrapper.__annotations__["return"] = torch.Tensor
@@ -713,8 +711,6 @@
precompiled_ptx=ptx,
use_aot_if_available=not any(
isinstance(input_spec, ScalarInput) for input_spec in spec.inputs
),
)
- _LOGGER.info(
- "cuda_kernel_op '%s' registered (schema: %s)", op_name, schema
- )
+ _LOGGER.info("cuda_kernel_op '%s' registered (schema: %s)", op_name, schema)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR introduces torch_tensorrt.annotation, an experimental module for registering hand-written CUDA C++ kernels as both PyTorch custom ops (for eager execution) and TensorRT Quick Deployable Plugins with AOT support (for torch_tensorrt.compile).
Usage
After this call, torch.ops.ann_ex.sigmoid is available in eager and is embedded as a TensorRT plugin during torch_tensorrt.compile. The meta function, eager
launch, AOT implementation, and PyTorch schema are all derived from the KernelSpec.
API Surface
The module exposes two primary entry points, layered by declarativeness:
auto_cuda_kernel_plugin is the recommended default. The caller supplies a KernelSpec dataclass describing the kernel's inputs, outputs (with a shape relation such
as SameAs or ReduceDims), scalar extras (Numel, DimSize), and launch geometry (Elementwise or Reduction). The framework derives the meta function, eager CUDA
launch, TensorRT AOT implementation, and PyTorch schema. This path covers pointwise kernels (1-D flat or N-D grid launches), reductions (with optional keepdim),
multi-input kernels, and scalar (non-tensor) kernel arguments via ScalarInput.
manual_cuda_kernel_plugin is the lower-level alternative for kernels outside the declarative DSL — shape-changing outputs, multi-output kernels, or non-standard
launch geometries. The caller provides eager_fn and aot_fn directly; the decorator still registers the PyTorch op, TRT plugin, AOT implementation, and converter
in a single call.
A Custom(fn=...) geometry is also available for callers who want the declarative path's schema/meta derivation but need to hand-write the TRT KernelLaunchParams.
Type of change
Checklist: