Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU] Support PIX Capture for WebGPU EP #23192

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C+
option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF)
option(onnxruntime_CUSTOM_DAWN_SRC_PATH "Path to custom Dawn src dir.")
option(onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY "Build Dawn as a monolithic library" OFF)
option(onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP "Adding frame present for PIX to capture a frame" OFF)
# The following 2 options are only for Windows
option(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN "Enable Vulkan backend for Dawn (on Windows)" OFF)
option(onnxruntime_ENABLE_DAWN_BACKEND_D3D12 "Enable D3D12 backend for Dawn (on Windows)" ON)
Expand Down Expand Up @@ -970,6 +971,14 @@ if (onnxruntime_USE_WEBGPU)
if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1)
endif()
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved
if (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12 OR NOT WIN32)
message(
FATAL_ERROR
"Option onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP can only be set on windows with onnxruntime_ENABLE_DAWN_BACKEND_D3D12 is enabled.")
endif()
add_compile_definitions(ENABLE_PIX_FOR_WEBGPU_EP)
endif()
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)
Expand Down
29 changes: 22 additions & 7 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -653,17 +653,28 @@ if (onnxruntime_USE_WEBGPU)

# disable things we don't use
set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF)
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE)
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE)
set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE)
set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE)
set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE)

if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
set(DAWN_ENABLE_DESKTOP_GL ON CACHE BOOL "" FORCE)
set(DAWN_ENABLE_OPENGLES ON CACHE BOOL "" FORCE)
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING ON CACHE BOOL "" FORCE)
set(DAWN_USE_GLFW ON CACHE BOOL "" FORCE)
set(DAWN_USE_WINDOWS_UI ON CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_WRITER ON CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_VALIDATOR ON CACHE BOOL "" FORCE)
else()
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE)
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE)
set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE)
set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE)
endif()

set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving
set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled.
Expand Down Expand Up @@ -704,6 +715,10 @@ if (onnxruntime_USE_WEBGPU)
endif()
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
endif()

if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES glfw webgpu_glfw)
endif()
endif()

set(onnxruntime_LINK_DIRS)
Expand Down
71 changes: 71 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include "dawn/native/DawnNative.h"
#endif

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved
#include <webgpu/webgpu_glfw.h>
#endif // ENABLE_PIX_FOR_WEBGPU_EP

#include "core/common/common.h"
#include "core/common/path_string.h"
#include "core/platform/env.h"
Expand Down Expand Up @@ -144,6 +148,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
query_type_ = TimestampQueryType::None;
}
});

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// set pix frame generator
pix_frame_generator_ = std::make_unique<WebGpuPIXFrameGenerator>();
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved
pix_frame_generator_->Initialize(this);
#endif // ENABLE_PIX_FOR_WEBGPU_EP
}

Status WebGpuContext::Wait(wgpu::Future f) {
Expand Down Expand Up @@ -641,6 +651,67 @@ void WebGpuContext::Flush() {
num_pending_dispatches_ = 0;
}

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
void WebGpuPIXFrameGenerator::Initialize(WebGpuContext* context) {
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved
// Trivial window size for surface texture creation and provide frame concept for PIX.
static constexpr uint32_t kWidth = 512u;
static constexpr uint32_t kHeight = 512u;

if (!glfwInit()) {
ORT_ENFORCE("Failed to init glfw for PIX capture");
}

glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API);

window_ =
glfwCreateWindow(kWidth, kHeight, "WebGPU window", nullptr, nullptr);

ORT_ENFORCE(window_ != nullptr, "PIX Capture: Failed to create Window for capturing frames.");

surface_ = wgpu::glfw::CreateSurfaceForWindow(context->Instance(), window_);
ORT_ENFORCE(surface_.Get() != nullptr, "PIX Capture: Failed to create surface for capturing frames.");

wgpu::TextureFormat format;
wgpu::SurfaceCapabilities capabilities;
surface_.GetCapabilities(context->Adapter(), &capabilities);
format = capabilities.formats[0];

wgpu::SurfaceConfiguration config;
config.device = context->Device();
config.format = format;
config.width = kWidth;
config.height = kHeight;

surface_.Configure(&config);
}

void WebGpuPIXFrameGenerator::GeneratePIXFrame() {
ORT_ENFORCE(surface_.Get() != nullptr, "PIX Capture: Cannot do present on null surface for capturing frames");
wgpu::SurfaceTexture surfaceTexture;
surface_.GetCurrentTexture(&surfaceTexture);

// Call present to trigger dxgi_swapchain present. PIX
// take this as a frame boundary.
surface_.Present();
}

WebGpuPIXFrameGenerator::~WebGpuPIXFrameGenerator() {
if (surface_.Get()) {
surface_.Unconfigure();
}

if (window_) {
glfwDestroyWindow(window_);
window_ = nullptr;
}
}

void WebGpuContext::GeneratePIXFrame() {
pix_frame_generator_->GeneratePIXFrame();
}

#endif // ENABLE_PIX_FOR_WEBGPU_EP

std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
std::mutex WebGpuContextFactory::mutex_;
std::once_flag WebGpuContextFactory::init_default_flag_;
Expand Down
48 changes: 48 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#include <emscripten/emscripten.h>
#endif

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
#include <GLFW/glfw3.h>
#endif // ENABLE_PIX_FOR_WEBGPU_EP

#include <memory>
#include <mutex>

Expand Down Expand Up @@ -69,6 +73,39 @@ class WebGpuContextFactory {
static wgpu::Instance default_instance_;
};

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// PIX(https://devblogs.microsoft.com/pix/introduction/) is a profiling tool
// provides by Microsoft. It has ability to do GPU capture to profile gpu
// behavior among different GPU vendors. It works on Windows only.
//
// GPU capture(present-to-present) provided by PIX uses present as a frame boundary to
// capture and generate a valid frame infos. But ORT WebGPU EP doesn't have any present logic
// and hangs PIX GPU Capture forever.
//
// To make PIX works with ORT WebGPU EP on Windows, WebGpuPIXFrameGenerator class includes codes
// to create a trivial window through glfw, config surface with Dawn device and call present in
// proper place to trigger frame boundary for PIX GPU Capture.
//
// WebGpuPIXFrameGenerator is an friend class because:
// - It should only be used in WebGpuContext class implementation.
// - It requires instance and device from WebGpuContext.
//
// The lifecycle of WebGpuPIXFrameGenerator instance should be nested into WebGpuContext lifecycle.
// WebGpuPIXFrameGenerator instance should be created during WebGpuContext creation and be destroyed during
// WebGpuContext destruction.
class WebGpuPIXFrameGenerator {
public:
WebGpuPIXFrameGenerator() = default;
~WebGpuPIXFrameGenerator();
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved
void Initialize(WebGpuContext* context);
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved
void GeneratePIXFrame();

private:
wgpu::Surface surface_;
GLFWwindow* window_;
};
#endif // ENABLE_PIX_FOR_WEBGPU_EP

// Class WebGpuContext includes all necessary resources for the context.
class WebGpuContext final {
public:
Expand Down Expand Up @@ -127,6 +164,10 @@ class WebGpuContext final {

Status Run(ComputeContext& context, const ProgramBase& program);

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
void GeneratePIXFrame();
#endif // ENABLE_PIX_FOR_WEBGPU_EP

private:
enum class TimestampQueryType {
None = 0,
Expand All @@ -144,6 +185,7 @@ class WebGpuContext final {
std::vector<wgpu::FeatureName> GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const;
wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) const;
void WriteTimestamp(uint32_t query_index);
const wgpu::Instance& Instance() const { return instance_; }
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved

struct PendingKernelInfo {
PendingKernelInfo(std::string_view kernel_name,
Expand Down Expand Up @@ -211,6 +253,12 @@ class WebGpuContext final {

uint64_t gpu_timestamp_offset_ = 0;
bool is_profiling_ = false;

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// Friend class to access WebGpuContext private members.
friend class WebGpuPIXFrameGenerator;
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;
#endif // ENABLE_PIX_FOR_WEBGPU_EP
};

} // namespace webgpu
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,13 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
preferred_data_layout_{config.data_layout},
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
enable_graph_capture_{config.enable_graph_capture} {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
enable_pix_capture_ = config.enable_pix_capture;
#else
if (config.enable_pix_capture) {
ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)");
}
#endif // ENABLE_PIX_FOR_WEBGPU_EP
}

std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
Expand Down Expand Up @@ -860,6 +867,10 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
context_.CollectProfilingData(profiler_->Events());
}

if (IsPIXCaptureEnabled()) {
context_.GeneratePIXFrame();
}
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved

return Status::OK();
}

Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ class WebGpuProfiler;
} // namespace webgpu

struct WebGpuExecutionProviderConfig {
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture)
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture, bool enable_pix_capture)
: data_layout{data_layout},
enable_graph_capture{enable_graph_capture} {}
enable_graph_capture{enable_graph_capture},
enable_pix_capture{enable_pix_capture} {}
WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default;
WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig);

DataLayout data_layout;
bool enable_graph_capture;
bool enable_pix_capture;
std::vector<std::string> force_cpu_node_names;
};

Expand Down Expand Up @@ -69,6 +71,8 @@ class WebGpuExecutionProvider : public IExecutionProvider {
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;

bool IsPIXCaptureEnabled() const { return enable_pix_capture_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems unnecessary getter. the only usage is in OnRunEnd() and it's totally OK to just use enable_pix_capture_ there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this one from webgpu ep to webgpu context so that we could check pix enable in onRunEnd to call GeneratePixFrame.


private:
bool IsGraphCaptureAllowed() const;
void IncrementRegularRunCountBeforeGraphCapture();
Expand All @@ -78,6 +82,7 @@ class WebGpuExecutionProvider : public IExecutionProvider {
DataLayout preferred_data_layout_;
std::vector<std::string> force_cpu_node_names_;
bool enable_graph_capture_ = false;
bool enable_pix_capture_ = false;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
DataLayout::NHWC,
// graph capture feature is disabled by default
false,
// enable pix capture feature is diabled by default
false,
};

std::string preferred_layout_str;
Expand Down Expand Up @@ -67,6 +69,18 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_config.enable_graph_capture;

std::string enable_pix_capture_str;
if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) {
if (enable_pix_capture_str == kEnablePIXCapture_ON) {
webgpu_ep_config.enable_pix_capture = true;
} else if (enable_pix_capture_str == kEnablePIXCapture_OFF) {
webgpu_ep_config.enable_pix_capture = false;
} else {
ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str);
}
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << webgpu_ep_config.enable_pix_capture;

// parse force CPU node names
// The force CPU node names are separated by EOL (\n or \r\n) in the config entry.
// each line is a node name that will be forced to run on CPU.
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ constexpr const char* kDefaultBufferCacheMode = "WebGPU:defaultBufferCacheMode";
constexpr const char* kValidationMode = "WebGPU:validationMode";

constexpr const char* kForceCpuNodeNames = "WebGPU:forceCpuNodeNames";
constexpr const char* kEnablePIXCapture = "WebGPU:enablePIXCapture";

// The following are the possible values for the provider options.

Expand All @@ -41,6 +42,9 @@ constexpr const char* kPreferredLayout_NHWC = "NHWC";
constexpr const char* kEnableGraphCapture_ON = "1";
constexpr const char* kEnableGraphCapture_OFF = "0";

constexpr const char* kEnablePIXCapture_ON = "1";
constexpr const char* kEnablePIXCapture_OFF = "0";

constexpr const char* kBufferCacheMode_Disabled = "disabled";
constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease";
constexpr const char* kBufferCacheMode_Simple = "simple";
Expand Down
3 changes: 3 additions & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def convert_arg_line_to_args(self, arg_line):
parser.add_argument("--use_migraphx", action="store_true", help="Build with MIGraphX")
parser.add_argument("--migraphx_home", help="Path to MIGraphX installation dir")
parser.add_argument("--use_full_protobuf", action="store_true", help="Use the full protobuf library")
parser.add_argument("--enable_pix_capture", action="store_true", help="Enable Pix Support.")

parser.add_argument(
"--skip_onnx_tests",
Expand Down Expand Up @@ -1054,6 +1055,8 @@ def generate_build_tree(
"-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"),
"-Donnxruntime_USE_JSEP=" + ("ON" if args.use_jsep else "OFF"),
"-Donnxruntime_USE_WEBGPU=" + ("ON" if args.use_webgpu else "OFF"),
"-Donnxruntime_ENABLE_PIX_FOR_WEBGPU_EP="
+ ("ON" if args.enable_pix_capture and args.use_webgpu and is_windows() else "OFF"),
shaoboyan091 marked this conversation as resolved.
Show resolved Hide resolved
"-Donnxruntime_USE_EXTERNAL_DAWN=" + ("ON" if args.use_external_dawn else "OFF"),
# Training related flags
"-Donnxruntime_ENABLE_NVTX_PROFILE=" + ("ON" if args.enable_nvtx_profile else "OFF"),
Expand Down