diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 432a5c4380..4983716251 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -68,6 +68,7 @@ cc_library( "RTDevice.cpp", "TRTEngine.cpp", "TRTEngineProfiler.cpp", + "TRTRuntimeConfig.cpp", "execute_engine.cpp", "register_jit_hooks.cpp", "runtime.cpp", @@ -77,12 +78,34 @@ cc_library( "RTDevice.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TRTRuntimeConfig.h", "runtime.h", ], copts = if_torch_nccl(["-DUSE_C10D_NCCL"]), + defines = select({ + # nvinfer1::IRuntimeConfig (and the matching ICudaEngine::createRuntimeConfig + # / createExecutionContext(IRuntimeConfig*) overloads) was introduced in + # TensorRT 10.11. The TensorRT shipped with the Jetpack l4t-r36.4 toolchain + # (@tensorrt_l4t) predates 10.11 and does not export this type. Every other + # configuration here (RTX, SBSA, Windows, default x86_64 Linux) is on a + # TensorRT >= 10.11 bundle, so it gets the macro. + # + # Gate every IRuntimeConfig-using site in core/runtime with + # `#ifdef TRT_HAS_IRUNTIME_CONFIG`; the Jetpack path falls back to the + # legacy createExecutionContext() no-arg overload. + ":jetpack": [], + "//conditions:default": ["TRT_HAS_IRUNTIME_CONFIG"], + }), linkopts = [ "-lstdc++fs", ], + local_defines = select({ + # TensorRT-RTX builds: opt into feature-gated APIs that the runtime layer + # depends on (e.g. IExecutionContext::isStreamCapturable). + ":rtx_win": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"], + ":rtx_x86_64": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"], + "//conditions:default": [], + }), deps = [ "//core/plugins:torch_tensorrt_plugins", "//core/util:prelude", @@ -110,6 +133,7 @@ filegroup( "RTDevice.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TRTRuntimeConfig.h", "runtime.h", ], visibility = ["//visibility:public"], diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index ae5232bb6f..522fb5976f 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "NvInfer.h" @@ -61,26 +62,28 @@ void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims } TRTEngine::TRTEngine( - const std::string& serialized_engine, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) + std::string serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy, + TRTRuntimeConfig runtime_cfg) : TRTEngine( "deserialized_trt", - serialized_engine, + std::move(serialized_engine), cuda_device, _in_binding_names, _out_binding_names, target_platform, hardware_compatible, requires_output_allocator, - serialized_metadata, - resource_allocation_strategy) {} + std::move(serialized_metadata), + resource_allocation_strategy, + std::move(runtime_cfg)) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -95,7 +98,8 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic)) { + : ResourceAllocationStrategy::kStatic), + make_runtime_config_from_serialized(serialized_info)) { this->requires_native_multidevice = std::stoi(serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]); if (this->requires_native_multidevice) { LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution"); @@ -103,16 +107,18 @@ TRTEngine::TRTEngine(std::vector serialized_info) } TRTEngine::TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, + std::string mod_name, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) { + std::string serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy, + TRTRuntimeConfig runtime_cfg) { + this->runtime_cfg = std::move(runtime_cfg); TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -123,7 +129,7 @@ TRTEngine::TRTEngine( auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); - this->serialized_metadata = serialized_metadata; + this->serialized_metadata = std::move(serialized_metadata); this->requires_output_allocator = requires_output_allocator; device_info = most_compatible_device.value(); multi_gpu_device_check(); @@ -131,7 +137,7 @@ TRTEngine::TRTEngine( rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); - name = slugify(mod_name); + name = slugify(std::move(mod_name)); cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); @@ -146,13 +152,7 @@ TRTEngine::TRTEngine( LOG_DEBUG( "Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); - if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } - TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); + recreate_execution_context(); // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) cudaMalloc(&empty_tensor_placeholder, 1); @@ -288,6 +288,9 @@ TRTEngine::TRTEngine( } TRTEngine::~TRTEngine() { + // Marked noexcept so safe to invoke from a destructor without + // explicit try/catch; any I/O error is logged internally. + runtime_cfg.save_runtime_cache(); trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -301,8 +304,7 @@ void TRTEngine::disable_profiling() { torch::cuda::synchronize(device_info.id); profile_execution = false; trt_engine_profiler.reset(); - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context"); + recreate_execution_context(); } void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) { @@ -399,10 +401,7 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) { trt_engine_profiler.reset(); } bool result = cuda_engine->setWeightStreamingBudgetV2(budget); - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - TORCHTRT_CHECK( - (exec_ctx.get() != nullptr), - "Unable to recreate TensorRT execution context after setting new device memory budget"); + recreate_execution_context(); if (profile_execution) { enable_profiling(); } @@ -459,6 +458,7 @@ std::string TRTEngine::to_str() const { ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; ss << " Multi-Device Engine: " << (requires_native_multidevice) << std::endl; + ss << runtime_cfg.to_str(); // clang-format on return ss.str(); } @@ -495,7 +495,14 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), - std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX])); + std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]) +#ifdef TRT_MAJOR_RTX + , + std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), + std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), + std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]) +#endif + ); } std::vector TRTEngine::serialize() { @@ -522,6 +529,13 @@ std::vector TRTEngine::serialize() { this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0"; // rank/world_size are runtime facts (may differ at load time); not serialized. +#ifdef TRT_MAJOR_RTX + serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; + serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( + static_cast>(runtime_cfg.dynamic_shapes_kernel_strategy)); + serialized_info[CUDA_GRAPH_STRATEGY_IDX] = + std::to_string(static_cast>(runtime_cfg.cuda_graph_strategy)); +#endif return serialized_info; } @@ -533,14 +547,11 @@ void TRTEngine::reset_captured_graph() { void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { if (new_strategy != this->resource_allocation_strategy) { this->resource_allocation_strategy = new_strategy; - if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { - LOG_DEBUG("Setting resource allocation strategy to dynamic"); - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - LOG_DEBUG("Setting resource allocation strategy to static"); - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } + LOG_DEBUG( + "Setting resource allocation strategy to " + << (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic" + : "static")); + recreate_execution_context(); } } @@ -637,19 +648,40 @@ void TRTEngine::release_nccl_comm() { LOG_INFO("Releasing NCCL communicator from engine '" << this->name << "'"); torch::cuda::synchronize(device_info.id); this->exec_ctx.reset(); - if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } - TORCHTRT_CHECK( - (exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context after releasing NCCL comm"); + recreate_execution_context(); this->nccl_initialized = false; LOG_INFO("NCCL communicator released from engine '" << this->name << "'"); } #endif // ENABLE_TRT_NCCL_COLLECTIVES +bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { + return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream); +} + +void TRTEngine::disable_rtx_native_cudagraphs() { + bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled; + runtime_cfg.disable_rtx_native_cudagraphs(name); + if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) { + // The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx + // so the new strategy takes effect for subsequent enqueueV3 calls. + recreate_execution_context(); + } +} + +void TRTEngine::recreate_execution_context() { + // Flush any kernels the previous execution context may have compiled into the + // runtime cache before creating the replacement. The destructor also saves, but + // doing it here guards against losing compiled kernels across profiling toggles, + // allocator changes, or process kills that happen between allocator changes and + // teardown. No-op on standard TensorRT or when no cache path is configured. + runtime_cfg.save_runtime_cache(); + const auto allocation_strategy = resource_allocation_strategy == ResourceAllocationStrategy::kDynamic + ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED + : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC; + exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), allocation_strategy); + TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index d851cda07e..30a1320ad0 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -14,6 +14,7 @@ #include "torch/custom_class.h" #include "core/runtime/TRTEngineProfiler.h" +#include "core/runtime/TRTRuntimeConfig.h" #include "core/util/prelude.h" // TensorRT 10.16+ has native NCCL collective support via IExecutionContext::setCommunicator() @@ -45,7 +46,14 @@ using FlattenedState = std::tuple< std::tuple, // serialized metadata std::tuple, // Platform std::tuple, // Resource Allocation Strategy - std::tuple>; // requires_native_multidevice + std::tuple // requires_native_multidevice +#ifdef TRT_MAJOR_RTX + , + std::tuple, // Runtime Cache Path (TRT-RTX) + std::tuple, // Dynamic Shapes Kernel Strategy (TRT-RTX) + std::tuple // CUDA Graph Strategy (TRT-RTX) +#endif + >; struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -140,31 +148,33 @@ struct TRTEngine : torch::CustomClassHolder { ~TRTEngine(); TRTEngine( - const std::string& serialized_engine, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = "", + std::string serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); TRTEngine(std::vector serialized_info); TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, + std::string mod_name, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = "", + std::string serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); @@ -257,6 +267,23 @@ struct TRTEngine : torch::CustomClassHolder { ResourceAllocationStrategy resource_allocation_strategy = kStatic; void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); + + // Owns the IRuntimeConfig (where supported) and TRT-RTX runtime state. On older TRT + // without IRuntimeConfig (e.g. Jetpack) this just carries strategy values that get + // passed to the legacy createExecutionContext overload. + TRTRuntimeConfig runtime_cfg; + + // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph + // capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true. + bool is_monolithic_capturable(cudaStream_t stream) const; + + // Disable TensorRT-RTX native CUDA graph capture on this engine (one-shot, invoked when + // an outer stream capture is detected around execute_engine). No-op on non-RTX. + void disable_rtx_native_cudagraphs(); + + private: + // Single entry point that (re)creates exec_ctx via runtime_cfg.create_execution_context. + void recreate_execution_context(); }; } // namespace runtime diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp new file mode 100644 index 0000000000..c0f6e8c37e --- /dev/null +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -0,0 +1,254 @@ +#include "core/runtime/TRTRuntimeConfig.h" + +#include +#include +#include +#include +#include + +#include "core/runtime/runtime.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// File-local helpers. Kept out of the header because they are only used by this +// translation unit -- TRTEngine now consumes a TRTRuntimeConfig directly and does not +// need the enum conversion helpers. +namespace { + +[[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s) { + switch (s) { + case DynamicShapesKernelStrategy::kLazy: + return "lazy"; + case DynamicShapesKernelStrategy::kEager: + return "eager"; + case DynamicShapesKernelStrategy::kNone: + return "none"; + } + TORCHTRT_CHECK( + false, + "Unexpected DynamicShapesKernelStrategy value: " + << static_cast>(s)); +} + +[[nodiscard]] std::string to_string(CudaGraphStrategyOption s) { + switch (s) { + case CudaGraphStrategyOption::kDisabled: + return "disabled"; + case CudaGraphStrategyOption::kWholeGraphCapture: + return "whole_graph_capture"; + } + TORCHTRT_CHECK( + false, + "Unexpected CudaGraphStrategyOption value: " << static_cast>(s)); +} + +[[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( + std::underlying_type_t v) { + TORCHTRT_CHECK( + v >= 0 && v <= 2, + "Invalid dynamic shapes kernel strategy value: " << v << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); + return static_cast(v); +} + +[[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v) { + TORCHTRT_CHECK( + v >= 0 && v <= 1, + "Invalid CUDA graph strategy value: " << v << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); + return static_cast(v); +} + +#ifdef TRT_MAJOR_RTX +// Raw cache I/O helpers. Exception-propagating; the caller wraps in try/catch at the +// TRTRuntimeConfig member level. Kept file-local because the IRuntimeCache type is +// itself TensorRT-RTX-only and tests reach this path through the member wrappers. +void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { + TORCHTRT_CHECK(cache != nullptr, "load_runtime_cache requires a non-null IRuntimeCache"); + if (!std::filesystem::exists(path)) { + LOG_DEBUG("No existing runtime cache at " << path); + return; + } + std::ifstream f(path, std::ios::binary); + std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + if (buf.empty()) { + return; + } + TORCHTRT_CHECK(cache->deserialize(buf.data(), buf.size()), "IRuntimeCache::deserialize returned false for " << path); + LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); +} + +void save_runtime_cache_impl(const std::string& path, nvinfer1::IRuntimeCache* cache) { + TORCHTRT_CHECK(cache != nullptr, "save_runtime_cache requires a non-null IRuntimeCache"); + auto host_mem = make_trt(cache->serialize()); + if (!host_mem || host_mem->size() == 0) { + return; + } + std::filesystem::path fs_path(path); + if (fs_path.has_parent_path()) { + std::filesystem::create_directories(fs_path.parent_path()); + } + std::filesystem::path tmp_path = fs_path; + tmp_path += ".tmp"; + { + std::ofstream out(tmp_path, std::ios::binary); + out.write(reinterpret_cast(host_mem->data()), host_mem->size()); + } + std::filesystem::rename(tmp_path, fs_path); + LOG_INFO("Saved runtime cache to " << path << " (" << host_mem->size() << " bytes)"); +} +#endif // TRT_MAJOR_RTX + +} // namespace + +void TRTRuntimeConfig::ensure_initialized(TORCHTRT_UNUSED nvinfer1::ICudaEngine* cuda_engine) { +#ifdef TRT_HAS_IRUNTIME_CONFIG + if (config) { + return; + } + TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); + config = make_trt(cuda_engine->createRuntimeConfig()); + TORCHTRT_CHECK(config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); + +#ifdef TRT_MAJOR_RTX + // Runtime cache -- TRT-RTX only. + if (!runtime_cache_path.empty()) { + runtime_cache = make_trt(config->createRuntimeCache()); + if (runtime_cache.get() == nullptr) { + LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); + } else { + try { + load_runtime_cache(runtime_cache_path, runtime_cache.get()); + } catch (const std::exception& e) { + LOG_WARNING("Failed to load runtime cache from " << runtime_cache_path << ": " << e.what()); + } + if (config->setRuntimeCache(*runtime_cache)) { + LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); + } else { + LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); + runtime_cache.reset(); + } + } + } else { + LOG_DEBUG("Runtime cache disabled (no path configured)."); + } + + // Dynamic shapes kernel specialization strategy -- TRT-RTX only. + config->setDynamicShapesKernelSpecializationStrategy( + static_cast(dynamic_shapes_kernel_strategy)); + LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << to_string(dynamic_shapes_kernel_strategy)); + + // CUDA graph strategy -- TRT-RTX only. + if (!config->setCudaGraphStrategy( + cuda_graph_strategy == CudaGraphStrategyOption::kWholeGraphCapture + ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE + : nvinfer1::CudaGraphStrategy::kDISABLED)) { + LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); + } +#endif +#endif // TRT_HAS_IRUNTIME_CONFIG +} + +std::shared_ptr TRTRuntimeConfig::create_execution_context( + nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::ExecutionContextAllocationStrategy allocation_strategy) { + ensure_initialized(cuda_engine); +#ifdef TRT_HAS_IRUNTIME_CONFIG + config->setExecutionContextAllocationStrategy(allocation_strategy); + return make_trt(cuda_engine->createExecutionContext(config.get())); +#else + // Pre-10.11 TRT (e.g. Jetpack): use the legacy strategy overload directly. + return make_trt(cuda_engine->createExecutionContext(allocation_strategy)); +#endif +} + +bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const { +#ifdef TRT_MAJOR_RTX + // On TRT-RTX the internal runtime handles capture/replay whenever a non-disabled + // strategy is set, or when subgraph cudagraphs are enabled globally. In both cases the + // caller should skip its manual at::cuda::CUDAGraph wrapper because TRT-RTX's internal + // capture would collide with it. + return cuda_graph_strategy != CudaGraphStrategyOption::kDisabled || cudagraphs_enabled; +#else + return false; +#endif +} + +void TRTRuntimeConfig::disable_rtx_native_cudagraphs(TORCHTRT_UNUSED const std::string& engine_name) noexcept { +#ifdef TRT_MAJOR_RTX + if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == CudaGraphStrategyOption::kDisabled) { + return; + } + LOG_WARNING( + "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " + << engine_name << " for the remainder of its lifetime."); + // Persist any kernels the engine-internal capture has compiled so far; the outer + // capture will run without them otherwise, and we want future reloads to reuse them. + save_runtime_cache(); + cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; + if (config && !config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED)) { + LOG_WARNING("Failed to update CUDA graph strategy on IRuntimeConfig after disable."); + } + rtx_native_cudagraphs_disabled = true; +#endif +} + +bool TRTRuntimeConfig::is_monolithic_capturable( + TORCHTRT_UNUSED nvinfer1::IExecutionContext* exec_ctx, + TORCHTRT_UNUSED cudaStream_t stream) const { +#ifdef TRT_MAJOR_RTX + TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); + // "lazy" kernel specialization swaps specialized kernels in mid-run, which invalidates + // captured graphs. Other strategies (eager/none) are safe when the context reports the + // stream capturable. + return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != DynamicShapesKernelStrategy::kLazy; +#else + return true; +#endif +} + +void TRTRuntimeConfig::save_runtime_cache() noexcept { +#ifdef TRT_MAJOR_RTX + if (!runtime_cache || runtime_cache_path.empty()) { + return; + } + try { + save_runtime_cache_impl(runtime_cache_path, runtime_cache.get()); + } catch (const std::exception& e) { + LOG_WARNING("Failed to save runtime cache to " << runtime_cache_path << ": " << e.what()); + } catch (...) { + LOG_WARNING("Failed to save runtime cache (unknown exception)."); + } +#endif +} + +std::string TRTRuntimeConfig::to_str() const { + std::ostringstream os; + os << "Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; + os << "Dynamic Shapes Kernel Strategy: " << to_string(dynamic_shapes_kernel_strategy) << std::endl; + os << "CUDA Graph Strategy: " << to_string(cuda_graph_strategy) << std::endl; + return os.str(); +} + +TRTRuntimeConfig make_runtime_config_from_serialized(TORCHTRT_UNUSED const std::vector& info) { + TRTRuntimeConfig cfg; +#ifdef TRT_MAJOR_RTX + cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; + cfg.dynamic_shapes_kernel_strategy = + to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); + cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); +#endif + return cfg; +} + +std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg) { + os << "Runtime cfg {" << std::endl; + os << cfg.to_str(); + os << "}" << std::endl; + return os; +} + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h new file mode 100644 index 0000000000..489d59fcd0 --- /dev/null +++ b/core/runtime/TRTRuntimeConfig.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// TensorRT-RTX-only configuration for how shape-specialized kernels are compiled. +enum class DynamicShapesKernelStrategy : int32_t { + kLazy = 0, + kEager = 1, + kNone = 2, +}; + +// TensorRT-RTX-only configuration for how CUDA graph capture/replay is handled. +enum class CudaGraphStrategyOption : int32_t { + kDisabled = 0, + kWholeGraphCapture = 1, +}; + +// Encapsulates the IRuntimeConfig and TRT-RTX runtime state for a TRTEngine. +// IRuntimeConfig and runtime-cache `#ifdef`s are confined to this TU; serialization- +// index plumbing keeps its own RTX gates elsewhere. +struct TRTRuntimeConfig { + // Settings - typically populated from engine deserialization before `ensure_initialized`. + std::string runtime_cache_path = ""; + DynamicShapesKernelStrategy dynamic_shapes_kernel_strategy = DynamicShapesKernelStrategy::kLazy; + CudaGraphStrategyOption cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; + + // One-shot: set to true once an outer stream capture has been detected and the + // engine-internal CUDA graph strategy has been disabled for the remainder of the + // owning engine's lifetime. + bool rtx_native_cudagraphs_disabled = false; + + // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized` + // and is unavailable on TensorRT versions older than 10.11 (e.g. Jetpack). +#ifdef TRT_HAS_IRUNTIME_CONFIG + std::shared_ptr config; +#endif +#ifdef TRT_MAJOR_RTX + std::shared_ptr runtime_cache; +#endif + + // Lazily construct the IRuntimeConfig and apply RTX-specific settings. Idempotent. + // No-op on builds without IRuntimeConfig (e.g. Jetpack). + void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine); + + // Lazy-initialize the IRuntimeConfig if needed and create an IExecutionContext that + // honors `allocation_strategy`. Selects the right `createExecutionContext` overload + // (IRuntimeConfig* vs ExecutionContextAllocationStrategy) so callers stay free of + // any TRT_HAS_IRUNTIME_CONFIG branching. + [[nodiscard]] std::shared_ptr create_execution_context( + nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::ExecutionContextAllocationStrategy allocation_strategy); + + // Returns true if the TensorRT-RTX runtime owns capture/replay for this engine so the + // caller should bypass its own at::cuda::CUDAGraph capture around enqueueV3. Always + // false on non-RTX builds. + [[nodiscard]] bool uses_internal_capture(bool cudagraphs_enabled) const; + + // One-shot: disable engine-internal CUDA graph capture. Invoked when an outer stream + // capture is detected around execute_engine, so the outer capture can contain the + // kernel launches directly. Saves the runtime cache before recreating the context so + // compiled kernels from the present run are preserved for future reloads. + void disable_rtx_native_cudagraphs(const std::string& engine_name) noexcept; + + // Whether the execution context is safe to include in an outer monolithic capture. + // Non-RTX builds always return true. + [[nodiscard]] bool is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const; + + // Save the runtime cache to disk. Signature is `noexcept` so this is safe from a + // destructor. The underlying file I/O is performed by free functions declared below + // (non-noexcept, exception-leaky for easier testing); this member wraps them and + // swallows any exceptions. + void save_runtime_cache() noexcept; + + // Returns a human-readable summary of the runtime config. + [[nodiscard]] std::string to_str() const; +}; + +// Construct a TRTRuntimeConfig from a flattened serialization vector. Reads the +// RTX-only indices only on RTX builds; standard TRT builds return a default-initialized +// struct. +[[nodiscard]] TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info); + +std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg); + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index ffefa2c742..d66fa69baa 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -236,11 +236,28 @@ std::vector execute_engine(std::vector inputs, c10::intr auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); + // effective_cudagraphs controls the manual at::cuda::CUDAGraph path below. On TRT-RTX + // builds the engine-internal runtime owns capture/replay inside enqueueV3 whenever the + // engine has a cuda_graph_strategy set or subgraph cudagraphs are enabled; the struct + // reports that via `uses_internal_capture` so the caller skips its manual wrapper. If + // an outer stream capture is already in progress (e.g. the caller wraps this module in + // CudaGraphsTorchTensorRTModule for whole-graph capture), engine-internal capture would + // collide, so we disable it one-shot here. + bool effective_cudagraphs = cudagraphs_enabled; + if (compiled_engine->runtime_cfg.uses_internal_capture(cudagraphs_enabled)) { + effective_cudagraphs = false; + cudaStreamCaptureStatus capture_status; + cudaStreamIsCapturing(compiled_engine->engine_stream.stream(), &capture_status); + if (capture_status != cudaStreamCaptureStatusNone) { + compiled_engine->disable_rtx_native_cudagraphs(); + } + } + bool shape_changed = _validate_shapes(inputs, compiled_engine); // Whether cudagraphs needs to record the graph on this pass auto result = compiled_engine->runtime_states.set_runtime_states( - cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); + effective_cudagraphs, compiled_engine->use_pre_allocated_outputs, shape_changed); bool need_cudagraphs_record = std::get<0>(result); bool can_use_pre_allocated_outputs = std::get<1>(result); @@ -263,7 +280,8 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues); + setup_input_tensors( + inputs, compiled_engine, effective_cudagraphs, need_cudagraphs_record, inputShapeTensorValues); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); @@ -295,7 +313,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); } - if (cudagraphs_enabled) { + if (effective_cudagraphs) { TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress( name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), @@ -335,8 +353,10 @@ std::vector execute_engine(std::vector inputs, c10::intr caller_exec_complete.record(compiled_engine->caller_stream); caller_exec_complete.block(compiled_engine->engine_stream); - if (!cudagraphs_enabled) { - // Direct execution uses the caller buffers directly + if (!effective_cudagraphs) { + // Direct execution uses the caller buffers directly. On TRT-RTX with a + // cuda_graph_strategy set, the engine captures/replays internally during + // this enqueueV3 call. compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); } else { if (need_cudagraphs_record) { @@ -369,7 +389,7 @@ std::vector execute_engine(std::vector inputs, c10::intr trt_exec_complete.record(compiled_engine->engine_stream); trt_exec_complete.block(compiled_engine->caller_stream); - if (cudagraphs_enabled) { + if (effective_cudagraphs) { // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { outputs[o].copy_(compiled_engine->output_buffers[o], false); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index e9ceff2a3e..5bbce8d72f 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -199,6 +199,11 @@ TORCH_LIBRARY(tensorrt, m) { return false; #endif }); +#ifdef TRT_MAJOR_RTX + m.def("RUNTIME_CACHE_PATH_IDX", []() -> int64_t { return RUNTIME_CACHE_PATH_IDX; }); + m.def("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", []() -> int64_t { return DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX; }); + m.def("CUDA_GRAPH_STRATEGY_IDX", []() -> int64_t { return CUDA_GRAPH_STRATEGY_IDX; }); +#endif m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index e3a675cb05..636138c1d5 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -40,6 +40,11 @@ typedef enum { REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, REQUIRES_NATIVE_MULTIDEVICE_IDX, +#ifdef TRT_MAJOR_RTX + RUNTIME_CACHE_PATH_IDX, + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, + CUDA_GRAPH_STRATEGY_IDX, +#endif SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 4f9f61ca5f..416e9ae2a9 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -109,6 +109,7 @@ def cross_compile_for_windows( enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -171,6 +172,7 @@ def cross_compile_for_windows( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): TensorRT-RTX CUDA graph strategy. Options: "disabled", "whole_graph_capture". Default: "disabled". Not used for standard TensorRT. lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -333,6 +335,7 @@ def cross_compile_for_windows( "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } @@ -464,6 +467,7 @@ def compile( cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -528,6 +532,7 @@ def compile( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): TensorRT-RTX CUDA graph strategy. Options: "disabled", "whole_graph_capture". Default: "disabled". Not used for standard TensorRT. lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -731,6 +736,7 @@ def compile( "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } @@ -791,7 +797,7 @@ def _insert_complex_io_adapters( Outputs: insert view_as_complex before the output node for each originally-complex output that comes from a TRT block. - Leverages metadata that was captued when the complex rewriter pass was run + Leverages metadata that was captured when the complex rewriter pass was run """ complex_input_names = gm.meta.get("complex_input_names", []) complex_input_dtypes = gm.meta.get("complex_input_dtypes", {}) @@ -1242,6 +1248,7 @@ def convert_exported_program_to_serialized_trt_engine( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -1303,6 +1310,7 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): TensorRT-RTX CUDA graph strategy. Options: "disabled", "whole_graph_capture". Default: "disabled". Not used for standard TensorRT. lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -1473,6 +1481,7 @@ def convert_exported_program_to_serialized_trt_engine( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 007b07db31..0bea37805a 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -68,9 +68,10 @@ ENABLE_RESOURCE_PARTITIONING = False CPU_MEMORY_BUDGET = None DYNAMICALLY_ALLOCATE_RESOURCES = False +DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" +CUDA_GRAPH_STRATEGY = "disabled" DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True -DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index c7ef3eed9b..4665182d2c 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,6 +17,7 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, CPU_MEMORY_BUDGET, + CUDA_GRAPH_STRATEGY, DECOMPOSE_ATTENTION, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -96,7 +97,7 @@ class CompilationSettings: output to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning). - runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT. + runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Loaded on engine setup, saved on module cleanup. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy". cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -119,6 +120,7 @@ class CompilationSettings: autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines + cuda_graph_strategy (str): TensorRT-RTX CUDA graph strategy: "disabled" (default) or "whole_graph_capture" (let TensorRT-RTX manage CUDA graph capture/replay internally). When set and combined with `torch_tensorrt.runtime.set_cudagraphs_mode(True)` on RTX, overrides manual capture. Not used for standard TensorRT. decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. """ @@ -182,6 +184,7 @@ class CompilationSettings: enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING cpu_memory_budget: Optional[int] = CPU_MEMORY_BUDGET dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES + cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY decompose_attention: bool = DECOMPOSE_ATTENTION attn_bias_is_causal: bool = ATTN_BIAS_IS_CAUSAL diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 833fdee639..8d3027cb8d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -35,6 +35,10 @@ SERIALIZED_METADATA_IDX = -1 # Not implemented TARGET_PLATFORM_IDX = -1 # Not implemented REQUIRES_OUTPUT_ALLOCATOR_IDX = -1 # Not implemented +RESOURCE_ALLOCATION_STRATEGY_IDX = -1 # Not implemented +RUNTIME_CACHE_PATH_IDX = -1 # Not implemented +DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = -1 # Not implemented +CUDA_GRAPH_STRATEGY_IDX = -1 # Not implemented SERIALIZATION_LEN = -1 # Not implemented REQUIRES_NATIVE_MULTIDEVICE_IDX = -1 # Not implemented @@ -57,7 +61,25 @@ REQUIRES_NATIVE_MULTIDEVICE_IDX = ( torch.ops.tensorrt.REQUIRES_NATIVE_MULTIDEVICE_IDX() ) # 11 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 12 + if ENABLED_FEATURES.tensorrt_rtx: + RUNTIME_CACHE_PATH_IDX = torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX() # 12 + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = ( + torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX() + ) # 13 + CUDA_GRAPH_STRATEGY_IDX = torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX() # 14 + SERIALIZATION_LEN = ( + torch.ops.tensorrt.SERIALIZATION_LEN() + ) # 15 (RTX) / 12 (standard) + +_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { + "lazy": 0, + "eager": 1, + "none": 2, +} +_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { + "disabled": 0, + "whole_graph_capture": 1, +} @for_all_methods(needs_torch_tensorrt_runtime) @@ -151,6 +173,28 @@ def __init__( self.engine = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources + # TensorRT-RTX-only runtime config mirror. The engine-info serialization slots + # only exist on RTX builds (see below), but we validate the strategy names on + # every build so typos are caught regardless of backend. + self.runtime_cache_path = settings.runtime_cache_path + self.dynamic_shapes_kernel_specialization_strategy = ( + settings.dynamic_shapes_kernel_specialization_strategy + ) + if ( + self.dynamic_shapes_kernel_specialization_strategy + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + f"Invalid dynamic_shapes_kernel_specialization_strategy " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " + f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" + ) + self.cuda_graph_strategy = settings.cuda_graph_strategy + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " + f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" + ) self.symbolic_shape_expressions = symbolic_shape_expressions self.requires_native_multidevice = requires_native_multidevice @@ -229,6 +273,18 @@ def _pack_engine_info(self) -> List[str | bytes]: int(self.requires_native_multidevice) ) # rank/world_size are runtime facts; queried from ProcessGroup at execution time + # Strategy names are validated at __init__ time so typos fail fast on every + # build; the index slots themselves only exist on RTX. + if ENABLED_FEATURES.tensorrt_rtx: + engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" + engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( + _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ + self.dynamic_shapes_kernel_specialization_strategy + ] + ) + engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( + _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] + ) return engine_info diff --git a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py index 965307e57f..cd824e1746 100644 --- a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py +++ b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py @@ -3,10 +3,40 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +# Combinations of (strategy, runtime_name, use_python_runtime). Tests use parameterized +# so the strategy sweep runs on both runtimes with a single test body. +_STRATEGY_RUNTIMES = [ + ("lazy_python", "lazy", True), + ("eager_python", "eager", True), + ("none_python", "none", True), + ("lazy_cpp", "lazy", False), + ("eager_cpp", "eager", False), + ("none_cpp", "none", False), +] + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + + +def _compile_with_strategy(model, inputs, *, use_python_runtime, strategy): + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + use_python_runtime=use_python_runtime, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, @@ -17,17 +47,18 @@ "torchvision is not installed", ) class TestDynamicShapesKernelStrategyModels(TestCase): - """End-to-end model tests with different kernel specialization strategies.""" + """End-to-end model tests with each strategy across both runtimes.""" - def tearDown(self): - torch._dynamo.reset() + @parameterized.expand(_STRATEGY_RUNTIMES) + def test_resnet18_strategy(self, _name, strategy, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + import torchvision.models as models - def _compile_and_verify(self, model, strategy): + model = models.resnet18(pretrained=True).eval().cuda() input_tensor = torch.randn(4, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile_with_strategy( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 224, 224), opt_shape=(4, 3, 224, 224), @@ -35,9 +66,8 @@ def _compile_and_verify(self, model, strategy): dtype=torch.float32, ) ], - use_python_runtime=True, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, + use_python_runtime=use_python_runtime, + strategy=strategy, ) ref_output = model(input_tensor) trt_output = compiled(input_tensor) @@ -45,39 +75,21 @@ def _compile_and_verify(self, model, strategy): self.assertTrue( cos_sim > COSINE_THRESHOLD, f"Cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD} " - f"with strategy={strategy}", + f"(strategy={strategy}, python_runtime={use_python_runtime})", ) - def test_resnet18_lazy_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "lazy") - - def test_resnet18_eager_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "eager") - - def test_resnet18_none_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "none") - @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", ) class TestDynamicShapesKernelStrategyDynamic(TestCase): - """Tests kernel specialization strategies with dynamic input shapes.""" + """Tests kernel specialization strategies with dynamic input shapes, both runtimes.""" - def tearDown(self): - torch._dynamo.reset() + @parameterized.expand(_STRATEGY_RUNTIMES) + def test_dynamic_batch_with_strategy(self, _name, strategy, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) - def _test_dynamic_batch_with_strategy(self, strategy): class ConvModel(torch.nn.Module): def __init__(self): super().__init__() @@ -89,10 +101,9 @@ def forward(self, x): model = ConvModel().eval().cuda() - compiled = torchtrt.compile( + compiled = _compile_with_strategy( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 32, 32), opt_shape=(4, 3, 32, 32), @@ -100,31 +111,20 @@ def forward(self, x): dtype=torch.float32, ) ], - use_python_runtime=True, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, + use_python_runtime=use_python_runtime, + strategy=strategy, ) for batch_size in (1, 4, 8): - with self.subTest(batch_size=batch_size, strategy=strategy): - input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() - ref_output = model(input_tensor) - trt_output = compiled(input_tensor) - cos_sim = cosine_similarity(ref_output, trt_output) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - f"BS={batch_size}, strategy={strategy}: cosine similarity " - f"{cos_sim} below threshold {COSINE_THRESHOLD}", - ) - - def test_dynamic_batch_lazy(self): - self._test_dynamic_batch_with_strategy("lazy") - - def test_dynamic_batch_eager(self): - self._test_dynamic_batch_with_strategy("eager") - - def test_dynamic_batch_none(self): - self._test_dynamic_batch_with_strategy("none") + input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"BS={batch_size}, strategy={strategy}, python_runtime={use_python_runtime}: " + f"cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", + ) if __name__ == "__main__": diff --git a/tests/py/dynamo/models/test_runtime_cache_models.py b/tests/py/dynamo/models/test_runtime_cache_models.py index c4aeeef547..55b11b623e 100644 --- a/tests/py/dynamo/models/test_runtime_cache_models.py +++ b/tests/py/dynamo/models/test_runtime_cache_models.py @@ -8,10 +8,32 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +# Parameterize end-to-end cache tests over both runtime paths. The C++ variant is +# skipped inside the test body when the C++ runtime is not available. +_RUNTIMES = [("python", True), ("cpp", False)] + + +def _compile(model, inputs, *, use_python_runtime, runtime_cache_path): + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": use_python_runtime, + "min_block_size": 1, + "runtime_cache_path": runtime_cache_path, + } + return torchtrt.compile(model, **kwargs) + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, @@ -22,7 +44,7 @@ "torchvision is not installed", ) class TestRuntimeCacheModels(TestCase): - """End-to-end model tests with runtime cache enabled.""" + """End-to-end model tests with runtime cache enabled — both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -32,18 +54,18 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_resnet18_with_runtime_cache(self): + @parameterized.expand(_RUNTIMES) + def test_resnet18_with_runtime_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() input_tensor = torch.randn(1, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - use_python_runtime=True, - min_block_size=1, + [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) @@ -56,7 +78,6 @@ def test_resnet18_with_runtime_cache(self): f"ResNet18 cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", ) - # Verify runtime cache is saved on cleanup del compiled gc.collect() self.assertTrue( @@ -64,8 +85,10 @@ def test_resnet18_with_runtime_cache(self): "Runtime cache should be saved after ResNet18 inference", ) - def test_resnet18_cache_reuse(self): - """Compile + infer twice with same cache path. Second run should load cached data.""" + @parameterized.expand(_RUNTIMES) + def test_resnet18_cache_reuse(self, _name, use_python_runtime): + """Compile + infer twice with same cache path. Second run loads cached data.""" + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() @@ -73,15 +96,13 @@ def test_resnet18_cache_reuse(self): ref_output = model(input_tensor) compile_kwargs = { - "ir": "dynamo", "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } # First compilation — cold cache - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) _ = compiled1(input_tensor) del compiled1 gc.collect() @@ -90,7 +111,7 @@ def test_resnet18_cache_reuse(self): cache_size_1 = os.path.getsize(self.cache_path) # Second compilation — warm cache - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) output2 = compiled2(input_tensor) cos_sim = cosine_similarity(ref_output, output2) @@ -102,22 +123,21 @@ def test_resnet18_cache_reuse(self): del compiled2 gc.collect() cache_size_2 = os.path.getsize(self.cache_path) - # Cache should exist and be non-empty after both runs self.assertGreater(cache_size_1, 0) self.assertGreater(cache_size_2, 0) - def test_mobilenet_v2_with_runtime_cache(self): + @parameterized.expand(_RUNTIMES) + def test_mobilenet_v2_with_runtime_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.mobilenet_v2(pretrained=True).eval().cuda() input_tensor = torch.randn(1, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - use_python_runtime=True, - min_block_size=1, + [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) @@ -140,7 +160,7 @@ def test_mobilenet_v2_with_runtime_cache(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheDynamicShapes(TestCase): - """Tests runtime cache with dynamic input shapes.""" + """Tests runtime cache with dynamic input shapes, exercised on both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -150,7 +170,10 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_dynamic_batch_with_cache(self): + @parameterized.expand(_RUNTIMES) + def test_dynamic_batch_with_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + class ConvModel(torch.nn.Module): def __init__(self): super().__init__() @@ -162,10 +185,9 @@ def forward(self, x): model = ConvModel().eval().cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 32, 32), opt_shape=(4, 3, 32, 32), @@ -173,38 +195,28 @@ def forward(self, x): dtype=torch.float32, ) ], - use_python_runtime=True, - min_block_size=1, + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) - # Test with batch size 1 - input_bs1 = torch.randn(1, 3, 32, 32).cuda() - ref_bs1 = model(input_bs1) - out_bs1 = compiled(input_bs1) - cos_sim_1 = cosine_similarity(ref_bs1, out_bs1) - self.assertTrue( - cos_sim_1 > COSINE_THRESHOLD, - f"BS=1 cosine similarity {cos_sim_1} below threshold", - ) - - # Test with batch size 4 - input_bs4 = torch.randn(4, 3, 32, 32).cuda() - ref_bs4 = model(input_bs4) - out_bs4 = compiled(input_bs4) - cos_sim_4 = cosine_similarity(ref_bs4, out_bs4) - self.assertTrue( - cos_sim_4 > COSINE_THRESHOLD, - f"BS=4 cosine similarity {cos_sim_4} below threshold", - ) + for batch_size in (1, 4): + input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() + ref_output = model(input_tensor) + out = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, out) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"BS={batch_size} cosine similarity {cos_sim} below threshold", + ) - # Verify cache is saved del compiled gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - def test_cache_valid_across_shapes(self): + @parameterized.expand(_RUNTIMES) + def test_cache_valid_across_shapes(self, _name, use_python_runtime): """Save cache from one shape, load and verify it works with another shape in range.""" + _skip_if_cpp_unavailable(self, use_python_runtime) class SimpleConv(torch.nn.Module): def __init__(self): @@ -217,7 +229,6 @@ def forward(self, x): model = SimpleConv().eval().cuda() compile_kwargs = { - "ir": "dynamo", "inputs": [ torchtrt.Input( min_shape=(1, 3, 16, 16), @@ -226,13 +237,12 @@ def forward(self, x): dtype=torch.float32, ) ], - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } # First run with batch=2 — saves cache - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) input_bs2 = torch.randn(2, 3, 16, 16).cuda() _ = compiled1(input_bs2) del compiled1 @@ -241,7 +251,7 @@ def forward(self, x): self.assertTrue(os.path.isfile(self.cache_path)) # Second run with batch=3 — loads same cache - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) input_bs3 = torch.randn(3, 3, 16, 16).cuda() ref_bs3 = model(input_bs3) out_bs3 = compiled2(input_bs3) @@ -268,8 +278,10 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_warmup_timing(self): - """Measure cold vs warm cache inference time. Informational only — no strict pass/fail.""" + @parameterized.expand(_RUNTIMES) + def test_warmup_timing(self, _name, use_python_runtime): + """Measure cold vs warm cache inference time. Informational — no strict assertion.""" + _skip_if_cpp_unavailable(self, use_python_runtime) class MLP(torch.nn.Module): def __init__(self): @@ -285,15 +297,12 @@ def forward(self, x): input_tensor = torch.randn(16, 256).cuda() compile_kwargs = { - "ir": "dynamo", "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } - # Cold cache compilation + inference - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) torch.cuda.synchronize() start = time.perf_counter() _ = compiled1(input_tensor) @@ -303,19 +312,16 @@ def forward(self, x): gc.collect() torch._dynamo.reset() - # Warm cache compilation + inference - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) torch.cuda.synchronize() start = time.perf_counter() _ = compiled2(input_tensor) torch.cuda.synchronize() warm_time = time.perf_counter() - start - print(f"\n Cold cache first inference: {cold_time*1000:.1f}ms") - print(f" Warm cache first inference: {warm_time*1000:.1f}ms") - print(f" Speedup: {cold_time/warm_time:.2f}x") - - # No strict assertion — just log for visibility + print(f"\n [{_name}] Cold cache first inference: {cold_time*1000:.1f}ms") + print(f" [{_name}] Warm cache first inference: {warm_time*1000:.1f}ms") + print(f" [{_name}] Speedup: {cold_time/warm_time:.2f}x") self.assertTrue(True, "Timing test completed (informational)") diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 8db2b9412c..784a852a36 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -1,5 +1,4 @@ import gc -import logging import os import shutil import tempfile @@ -7,10 +6,11 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH, TIMING_CACHE_PATH -from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity class SimpleModel(torch.nn.Module): @@ -18,30 +18,49 @@ def forward(self, x): return torch.relu(x) + 1.0 -class TwoLayerModel(torch.nn.Module): +class ConvModel(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(8, 8) + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) def forward(self, x): - return torch.relu(self.linear(x)) + return torch.relu(self.conv(x)) -def _compile_simple(runtime_cache_path=None): - """Helper: compile SimpleModel with Python runtime, return (compiled_module, inputs).""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] +def _fresh_conv_model_and_inputs(seed=0): + """Deterministic ConvModel + input pair for end-to-end cache tests on either runtime.""" + torch.manual_seed(seed) + return ConvModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] + + +def _compile(model, inputs, *, use_python_runtime, runtime_cache_path=None): + """Compile `model` through either runtime. Returns the compiled module.""" kwargs = { "ir": "dynamo", "inputs": inputs, - "use_python_runtime": True, + "use_python_runtime": use_python_runtime, "min_block_size": 1, } if runtime_cache_path is not None: kwargs["runtime_cache_path"] = runtime_cache_path compiled = torchtrt.compile(model, **kwargs) torch._dynamo.reset() - return compiled, inputs + return compiled + + +def _compile_simple(runtime_cache_path=None): + """Compile the SimpleModel on the Python runtime (used by Python-only setup tests).""" + model = SimpleModel().eval().cuda() + inputs = [torch.randn(2, 3).cuda()] + return ( + _compile( + model, + inputs, + use_python_runtime=True, + runtime_cache_path=runtime_cache_path, + ), + inputs, + ) def _find_python_trt_module(compiled): @@ -50,18 +69,23 @@ def _find_python_trt_module(compiled): PythonTorchTensorRTModule, ) - for name, mod in compiled.named_modules(): + for _name, mod in compiled.named_modules(): if isinstance(mod, PythonTorchTensorRTModule): return mod return None +# Parameterize end-to-end cache persistence tests over both runtime paths. The C++ +# variant is skipped inside the test body when the C++ runtime is not available. +_RUNTIMES = [("python", True), ("cpp", False)] + + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheSetup(TestCase): - """Tests that runtime config and cache are correctly created for RTX.""" + """Python-runtime-only setup checks: the compiled module exposes a live runtime cache.""" def test_runtime_config_created(self): compiled, _ = _compile_simple() @@ -76,7 +100,6 @@ def test_context_created_successfully(self): compiled, inputs = _compile_simple() mod = _find_python_trt_module(compiled) self.assertIsNotNone(mod.context, "execution context should be created") - # Verify inference works output = compiled(*[inp.clone() for inp in inputs]) self.assertEqual(output.shape, inputs[0].shape) @@ -101,7 +124,7 @@ def test_runtime_cache_path_custom(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCachePersistence(TestCase): - """Tests that runtime cache is correctly saved to and loaded from disk.""" + """Load-on-setup / save-on-destructor contract, exercised on both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -110,9 +133,20 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) - def test_cache_saved_on_del(self): - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) - # Run inference to populate the cache + def _skip_if_cpp_unavailable(self, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + self.skipTest("C++ runtime is not available") + + @parameterized.expand(_RUNTIMES) + def test_cache_saved_on_del(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) self.assertFalse( os.path.isfile(self.cache_path), @@ -125,8 +159,16 @@ def test_cache_saved_on_del(self): "Cache file should be created after module cleanup", ) - def test_cache_file_nonempty(self): - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) + @parameterized.expand(_RUNTIMES) + def test_cache_file_nonempty(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) del compiled gc.collect() @@ -136,30 +178,54 @@ def test_cache_file_nonempty(self): "Cache file should have nonzero size", ) - def test_cache_roundtrip(self): - """Compile, infer, save. Then compile again with same cache path and verify correctness.""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] - ref_output = model(*inputs) - - # First compilation — populates and saves cache - compiled1, _ = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled1(*[inp.clone() for inp in inputs]) + @parameterized.expand(_RUNTIMES) + def test_cache_roundtrip(self, _name, use_python_runtime): + """Populate + save, then recompile and confirm correctness against eager output.""" + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + with torch.no_grad(): + ref_output = model(*inputs) + + compiled1 = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) + out1 = compiled1(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out1), + COSINE_THRESHOLD, + "First compiled output should match eager", + ) del compiled1 gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - # Second compilation — should load cached data - compiled2, _ = _compile_simple(runtime_cache_path=self.cache_path) - output = compiled2(*[inp.clone() for inp in inputs]) - max_diff = float(torch.max(torch.abs(ref_output - output))) - self.assertAlmostEqual( - max_diff, 0, places=3, msg="Output mismatch after cache roundtrip" + compiled2 = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) + out2 = compiled2(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out2), + COSINE_THRESHOLD, + "Second compiled output (warm cache) should still match eager", ) - def test_save_creates_directory(self): + @parameterized.expand(_RUNTIMES) + def test_save_creates_directory(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") - compiled, inputs = _compile_simple(runtime_cache_path=nested_path) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=nested_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) del compiled gc.collect() @@ -174,7 +240,7 @@ def test_save_creates_directory(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheConcurrency(TestCase): - """Tests that file locking works for concurrent access.""" + """Tests that file locking works for concurrent access (Python runtime only).""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -190,7 +256,6 @@ def test_filelock_works(self): del compiled gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - # Verify we can acquire a lock on the same path (no deadlock) from filelock import FileLock lock = FileLock(self.cache_path + ".lock") @@ -200,14 +265,12 @@ def test_filelock_works(self): def test_sequential_save_load(self): """Two modules saving and loading from the same path should not corrupt data.""" - # First module saves compiled1, inputs = _compile_simple(runtime_cache_path=self.cache_path) _ = compiled1(*[inp.clone() for inp in inputs]) del compiled1 gc.collect() size1 = os.path.getsize(self.cache_path) - # Second module saves (overwrites) compiled2, inputs = _compile_simple(runtime_cache_path=self.cache_path) _ = compiled2(*[inp.clone() for inp in inputs]) del compiled2 @@ -226,7 +289,6 @@ class TestTimingCacheSkipped(TestCase): """Tests that timing cache is correctly skipped for RTX builds.""" def setUp(self): - # Clean up any pre-existing timing cache if os.path.isfile(TIMING_CACHE_PATH): os.remove(TIMING_CACHE_PATH) @@ -271,7 +333,6 @@ def test_no_runtime_config_for_standard_trt(self): ) def test_timing_cache_still_created(self): - # Clean up any pre-existing timing cache if os.path.isfile(TIMING_CACHE_PATH): os.remove(TIMING_CACHE_PATH) compiled, inputs = _compile_simple() @@ -282,5 +343,26 @@ def test_timing_cache_still_created(self): ) +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "New serialization indices are registered only on TensorRT-RTX builds", +) +class TestCppSerializationIndices(TestCase): + """Verify the new RTX-only C++ serialization indices are registered by the runtime.""" + + def test_new_indices_registered(self): + self.assertEqual(int(torch.ops.tensorrt.ABI_VERSION()), 9) + self.assertEqual(int(torch.ops.tensorrt.SERIALIZATION_LEN()), 15) + self.assertEqual(int(torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX()), 12) + self.assertEqual( + int(torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX()), 13 + ) + self.assertEqual(int(torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX()), 14) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py new file mode 100644 index 0000000000..f3c5c32ba0 --- /dev/null +++ b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py @@ -0,0 +1,116 @@ +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._defaults import CUDA_GRAPH_STRATEGY +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class CudaGraphModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv(x)) + + +def _compile_cpp(strategy): + model = CudaGraphModel().eval().cuda() + inputs = [torch.randn(2, 3, 16, 16).cuda()] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + cuda_graph_strategy=strategy, + ) + torch._dynamo.reset() + return compiled, inputs + + +class TestCudaGraphStrategySettings(TestCase): + """Setting-level validation that runs on every build (RTX and non-RTX).""" + + def test_default_value(self): + settings = CompilationSettings() + self.assertEqual(settings.cuda_graph_strategy, CUDA_GRAPH_STRATEGY) + + def test_settable_values(self): + for value in ("disabled", "whole_graph_capture"): + settings = CompilationSettings(cuda_graph_strategy=value) + self.assertEqual(settings.cuda_graph_strategy, value) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy is a TensorRT-RTX feature", +) +class TestCudaGraphStrategyCpp(TestCase): + """End-to-end: compile + infer through the C++ runtime with each strategy.""" + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_disabled(self): + compiled, inputs = _compile_cpp("disabled") + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_whole_graph_capture(self): + compiled, inputs = _compile_cpp("whole_graph_capture") + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_whole_graph_capture_with_subgraph_cudagraphs(self): + """Subgraph cudagraph mode + RTX strategy: RTX-native should take over without errors.""" + compiled, inputs = _compile_cpp("whole_graph_capture") + torchtrt.runtime.set_cudagraphs_mode(True) + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_repeated_inference(self): + """Repeated inference exercises the RTX-native capture/replay path.""" + compiled, inputs = _compile_cpp("whole_graph_capture") + ref = compiled(*[inp.clone() for inp in inputs]) + for _ in range(4): + out = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(out.shape, ref.shape) + self.assertTrue(torch.isfinite(out).all().item()) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestCudaGraphStrategyInvalidValue(TestCase): + """Invalid strategy names raise ValueError.""" + + def test_invalid_strategy_raises(self): + model = CudaGraphModel().eval().cuda() + inputs = [torch.randn(2, 3, 16, 16).cuda()] + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + cuda_graph_strategy="not_a_real_strategy", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index ef16284438..66de1f9512 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -2,16 +2,29 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._settings import CompilationSettings +_STRATEGIES = [("lazy",), ("eager",), ("none",)] + class SimpleModel(torch.nn.Module): def forward(self, x): return torch.relu(x) + 1.0 +class DynamicConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1) + self.conv2 = torch.nn.Conv2d(16, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv2(torch.relu(self.conv1(x)))) + + def _compile_simple(**extra_kwargs): """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" model = SimpleModel().eval().cuda() @@ -35,13 +48,34 @@ def _compile_simple(**extra_kwargs): return compiled +def _compile_cpp(strategy): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + + def _find_python_trt_module(compiled): """Walk the compiled graph module to find PythonTorchTensorRTModule instances.""" from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( PythonTorchTensorRTModule, ) - for name, mod in compiled.named_modules(): + for _name, mod in compiled.named_modules(): if isinstance(mod, PythonTorchTensorRTModule): return mod return None @@ -54,6 +88,12 @@ def _find_python_trt_module(compiled): class TestDynamicShapesKernelStrategySetup(TestCase): """Tests that the dynamic shapes kernel specialization strategy is correctly applied.""" + _EXPECTED_ENUM = { + "lazy": "LAZY", + "eager": "EAGER", + "none": "NONE", + } + def test_default_strategy_is_lazy(self): import tensorrt as trt @@ -66,28 +106,21 @@ def test_default_strategy_is_lazy(self): trt.DynamicShapesKernelSpecializationStrategy.LAZY, ) - def test_eager_strategy(self): + @parameterized.expand(_STRATEGIES) + def test_strategy_applied(self, strategy): import tensorrt as trt compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy="eager" - ) - mod = _find_python_trt_module(compiled) - self.assertIsNotNone(mod) - self.assertEqual( - mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.EAGER, + dynamic_shapes_kernel_specialization_strategy=strategy ) - - def test_none_strategy(self): - import tensorrt as trt - - compiled = _compile_simple(dynamic_shapes_kernel_specialization_strategy="none") mod = _find_python_trt_module(compiled) self.assertIsNotNone(mod) self.assertEqual( mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.NONE, + getattr( + trt.DynamicShapesKernelSpecializationStrategy, + self._EXPECTED_ENUM[strategy], + ), ) def test_context_created_with_each_strategy(self): @@ -100,7 +133,6 @@ def test_context_created_with_each_strategy(self): self.assertIsNotNone( mod.context, f"Execution context should be created for {strategy}" ) - # Test inference with multiple dynamic batch sizes for bs in (1, 2, 4): output = compiled(torch.randn(bs, 3).cuda()) self.assertEqual(output.shape, (bs, 3)) @@ -136,10 +168,64 @@ def test_setting_ignored_on_non_rtx(self): mod.runtime_config, "runtime_config should be None for standard TRT", ) - # Inference should still work output = compiled(torch.randn(2, 3).cuda()) self.assertEqual(output.shape, (2, 3)) +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel strategy is a TensorRT-RTX feature", +) +class TestDynamicShapesKernelStrategyCpp(TestCase): + """End-to-end: compile + infer through the C++ runtime with each strategy.""" + + @parameterized.expand(_STRATEGIES) + def test_strategy_inference(self, strategy): + compiled = _compile_cpp(strategy) + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_dynamic_shape_with_eager(self): + """Exercise shape changes under eager kernel specialization.""" + compiled = _compile_cpp("eager") + for batch in (1, 2, 3, 4): + x = torch.randn(batch, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (batch, 8, 16, 16)) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestDynamicShapesKernelStrategyCppInvalidValue(TestCase): + """Invalid strategy names raise ValueError on the C++ runtime path.""" + + def test_invalid_strategy_raises(self): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy="not_a_real_strategy", + ) + + if __name__ == "__main__": run_tests()