From 8c8f38aa9c885893d665a727a71fbb5a502f1dda Mon Sep 17 00:00:00 2001 From: Yanlong Wang Date: Sat, 4 Jan 2025 11:00:09 +0800 Subject: [PATCH] [js/node] allow arenaExtendStrategy and gpuMemLimit for cuda --- js/common/lib/inference-session.ts | 9 +++++++++ js/node/src/session_options_helper.cc | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index e62c6579e8333..7a0f910b0456c 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -223,6 +223,15 @@ export declare namespace InferenceSession { export interface CudaExecutionProviderOption extends ExecutionProviderOption { readonly name: 'cuda'; deviceId?: number; + gpuMemLimit?: number; + + /** + * Arena extend strategy. See + * https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/arena_extend_strategy.h + * + * This setting is available only in ONNXRuntime (Node.js binding) + */ + arenaExtendStrategy?: 0 | 1; } export interface DmlExecutionProviderOption extends ExecutionProviderOption { readonly name: 'dml'; diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 8c1d7ca06b8c3..36b357f1cfeeb 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -41,6 +41,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess Napi::Value epValue = epList[i]; std::string name; int deviceId = 0; +#ifdef USE_CUDA + onnxruntime::ArenaExtendStrategy arenaExtendStrategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; + size_t gpuMemLimit = std::numeric_limits::max(); +#endif #ifdef USE_COREML int coreMlFlags = 0; #endif @@ -59,6 +63,16 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess if (obj.Has("deviceId")) { deviceId = obj.Get("deviceId").As(); } +#ifdef USE_CUDA + if (obj.Has("arenaExtendStrategy")) { + arenaExtendStrategy = static_cast( + obj.Get("arenaExtendStrategy").As().Uint32Value()); + } + if (obj.Has("gpuMemLimit")) { + gpuMemLimit = static_cast( + obj.Get("gpuMemLimit").As().DoubleValue()); + } +#endif #ifdef USE_COREML if (obj.Has("coreMlFlags")) { coreMlFlags = obj.Get("coreMlFlags").As(); @@ -86,6 +100,8 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess OrtCUDAProviderOptionsV2* options; Ort::GetApi().CreateCUDAProviderOptions(&options); options->device_id = deviceId; + options->arena_extend_strategy = arenaExtendStrategy; + options->gpu_mem_limit = gpuMemLimit; sessionOptions.AppendExecutionProvider_CUDA_V2(*options); Ort::GetApi().ReleaseCUDAProviderOptions(options); #endif