Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Implement Split operator #23198

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/split.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/tensor/split.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

namespace {

// Helper function to calculate the output index based on the input index and the sizes of the splits.
void CalculateOutputIndex(std::ostream& os, size_t output_count) {
os << "fn calculate_output_index(index: u32) -> u32 {\n"
<< " for (var i: u32 = 0u; i < " << output_count << "u; i += 1u ) {\n"
<< " if (index < " << GetElementAt("uniforms.sizes_in_split_axis", "i", output_count) << ") {\n"
<< " return i;\n"
<< " }\n"
<< " }\n"
<< " return " << output_count << "u;\n"
<< "}\n";
}

// Helper function to write the buffer data for each output.
void WriteBufferData(std::ostream& os, const ShaderVariableHelper& input,
gsl::span<const ShaderVariableHelper*> outputs) {
os << "fn write_buffer_data(output_number: u32, global_idx: u32, indices: output_0_indices_t) {\n";
for (size_t i = 0; i < outputs.size(); ++i) {
const auto buffer_write = outputs[i]->SetByIndices("indices", input.GetByOffset("global_idx"));
if (outputs.size() == 1) {
os << buffer_write;
} else if (i == 0) {
os << " if (output_number == 0u) {\n"
<< " " << buffer_write << "\n";
} else if (i == outputs.size() - 1) {
os << " } else {\n"
<< " " << buffer_write << "\n";
} else {
os << " } else if (output_number == " << i << "u) {\n"
<< " " << buffer_write << "\n";
}
}
os << " }\n"
<< "}\n";
}

} // namespace

Status SplitProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);

size_t output_count = Outputs().size();
std::vector<const ShaderVariableHelper*> outputs;
outputs.reserve(output_count);
for (size_t i = 0; i < output_count; ++i) {
outputs.push_back(
&shader.AddOutput("output_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias));
}

// Add implementation of fn calculate_output_index.
CalculateOutputIndex(shader.AdditionalImplementation(), output_count);
// Add implementation of fn write_buffer_data.
WriteBufferData(shader.AdditionalImplementation(), input, outputs);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
<< " var indices = " << input.OffsetToIndices("global_idx") << ";\n"
<< " var index = indices[" << axis_ << "];\n"
<< " let output_number = calculate_output_index(index);\n"
<< " if (output_number != 0u) {\n"
<< " index -= uniforms.sizes_in_split_axis[output_number - 1u];\n"
<< " indices[" << axis_ << "] = index;\n"
<< " }\n"
<< " write_buffer_data(output_number, global_idx, indices);\n";

return Status::OK();
}

Status Split::ComputeInternal(ComputeContext& context) const {
const Tensor* input = context.Input<Tensor>(0);
auto& input_shape = input->Shape();
auto num_outputs = context.OutputCount();

int64_t axis = axis_;
std::vector<int64_t> split_sizes;

split_sizes.assign(split_sizes_.begin(), split_sizes_.end());
// Compute split_sizes from the 'split' input tensor.
if (split_sizes_.size() == 0 && context.InputCount() > 1) {
const Tensor* split_tensor = context.Input<Tensor>(1);
// Check if split_tensor is valid.
if (split_tensor != nullptr) {
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "The split tensor must be a vector tensor.");
// Get split_sizes from the input tensor.
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]);
const auto* data = split_tensor->Data<int64_t>();
split_sizes.assign(data, data + nDims);
}
}

// The variables below are not actually used in the current implementation.
int before_dims = 0;
int after_dims_including_split_axis = 0;
int after_dims_excluding_split = 0;
// This handles the case where the axis is negative. It also splits outputs evenly according to num_ouputs if
// split_sizes is empty.
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis,
after_dims_excluding_split, split_sizes));

SplitProgram program{gsl::narrow_cast<uint32_t>(axis)};
program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank});

auto output_dimensions = input_shape.AsShapeVector();
for (int i = 0; i < num_outputs; ++i) {
// Update the size of dimension for axis we're splitting on.
auto split_size = narrow<int>(split_sizes[i]);
output_dimensions[narrow<size_t>(axis)] = split_size;

Tensor* output = context.Output(i, TensorShape{output_dimensions});
program.AddOutput({output, ProgramTensorMetadataDependency::Rank});
}

uint32_t input_size = gsl::narrow<uint32_t>(input_shape.Size());
// Early return if the input tensor is empty.
if (input_size == 0) {
return Status::OK();
}

uint32_t previous_sum = 0;
std::vector<uint32_t> sizes_in_split_axis;
// sizes_in_split_axis are the cumulative sizes of the splits in the split axis.
for (auto split_size : split_sizes) {
previous_sum += gsl::narrow<uint32_t>(split_size);
sizes_in_split_axis.push_back(previous_sum);
}

program
.SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.CacheHint(std::to_string(axis))
.AddUniformVariables(
{input_size, gsl::span<const uint32_t>(sizes_in_split_axis.data(), sizes_in_split_axis.size())});
return context.RunProgram(program);
}

#define WEBGPU_SPLIT_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
KERNEL_CLASS);

#define WEBGPU_SPLIT_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
KERNEL_CLASS);

WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 1, 1, Split_1, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 2, 10, Split_2_10, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 11, 12, Split_11_12, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 13, 17, Split_13_17, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_KERNEL(Split, 18, Split_18, WebGpuSupportedNumberTypes());

} // namespace webgpu
} // namespace onnxruntime
96 changes: 96 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/split.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/common.h"
#include "core/providers/cpu/tensor/split.h"

namespace onnxruntime {
namespace webgpu {

class SplitProgram final : public Program<SplitProgram> {
public:
SplitProgram(const uint32_t axis) : Program{"Split"}, axis_{axis} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32},
{"sizes_in_split_axis", ProgramUniformVariableDataType::Uint32});

private:
uint32_t axis_;
};

class Split : public WebGpuKernel, public SplitBase {
public:
Split(const OpKernelInfo& info, uint32_t opset) : WebGpuKernel(info), SplitBase(info, opset) {
jchen10 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int32_t> split_sizes;
// Check if split_sizes is provided as an attribute.
if (split_sizes_.size() > 0) {
ORT_ENFORCE(split_sizes_.size() == info.node().OutputDefs().size(), "Number of outputs (",
info.node().OutputDefs().size(), ") does not match split_sizes (", split_sizes_.size(), ")");
split_sizes.resize(split_sizes_.size());
for (size_t i = 0; i < split_sizes_.size(); ++i) {
split_sizes[i] = gsl::narrow_cast<int32_t>(split_sizes_[i]);
}
} else if (info.GetInputCount() < 2) {
// No valid split_sizes is providede as an attribute or input tensor. In this case, we try to compute it from input, output shapes and
// num_outputs.

// Handle negative axis.
const auto num_dimensions = gsl::narrow_cast<int64_t>(info.node().InputDefs()[0]->Shape()->dim_size());
const auto axis = HandleNegativeAxis(axis_, num_dimensions);

auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast<int32_t>(axis)).dim_value();
int64_t split_size_sum = 0;
if (num_outputs_ >= 0) {
ORT_ENFORCE(num_outputs_ == gsl::narrow_cast<int64_t>(info.node().OutputDefs().size()),
"Invalid num_outputs value of ", num_outputs_, ". Size of dimension being split is ",
info.node().OutputDefs().size());
}

// Compute split_sizes from the output shapes.
for (auto output : info.node().OutputDefs()) {
auto split_size = output->Shape()->dim(gsl::narrow_cast<int32_t>(axis)).dim_value();
split_sizes.push_back(gsl::narrow_cast<int32_t>(split_size));
split_size_sum += split_size;
}
ORT_ENFORCE(split_size_sum == total_split_size, "Sum of split sizes (", split_size_sum,
") does not match input size (", total_split_size, ")");
}
}

protected:
Status ComputeInternal(ComputeContext& context) const override;
};

class Split_1 final : public Split {
public:
Split_1(const OpKernelInfo& info) : Split(info, 1) {}
};

class Split_2_10 final : public Split {
public:
Split_2_10(const OpKernelInfo& info) : Split(info, 2) {}
};

class Split_11_12 final : public Split {
public:
Split_11_12(const OpKernelInfo& info) : Split(info, 11) {}
};

class Split_13_17 final : public Split {
public:
Split_13_17(const OpKernelInfo& info) : Split(info, 13) {}
};

class Split_18 final : public Split {
public:
Split_18(const OpKernelInfo& info) : Split(info, 18) {}
};

} // namespace webgpu
} // namespace onnxruntime
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,11 +637,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand)>,

Expand Down