From 3e0cdb7cbd34d8c9f2f00cb84d3ce2bb482554d7 Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Fri, 7 Jan 2022 10:51:20 +0100 Subject: [PATCH] Lower polygeist.subindex through memref.reinterpret_cast This should be a (hopefully) foolproof method of performing indexing into a memref. A reintrepret_cast is inserted with a dynamic index calculated from the subindex index operand + the product of the sizes of the target type. This has been added as a separate conversion pass instead of through the canonicalization drivers. When added as a canonicalization, the conversion may preemptively apply, resulting in sub-par IR. Nevertheless, i think it has its merits to have a polygeist op lowering pass which can be used as a fallback to convert the dialect operations, if canonicalization fails. For now, just added support for statically shaped memrefs (enough to fix the regression on my side) but should be possible for dynamically shaped as well. --- include/polygeist/Passes/Passes.h | 1 + include/polygeist/Passes/Passes.td | 6 ++ lib/polygeist/Ops.cpp | 44 +---------- lib/polygeist/Passes/CMakeLists.txt | 1 + lib/polygeist/Passes/LowerPolygeistOps.cpp | 88 +++++++++++++++++++++ test/polygeist-opt/canonicalization.mlir | 29 ------- test/polygeist-opt/lower_polygeist_ops.mlir | 17 ++++ 7 files changed, 114 insertions(+), 72 deletions(-) create mode 100644 lib/polygeist/Passes/LowerPolygeistOps.cpp delete mode 100644 test/polygeist-opt/canonicalization.mlir create mode 100644 test/polygeist-opt/lower_polygeist_ops.mlir diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index 7a4fe27c5be4..76721b4fc638 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -19,6 +19,7 @@ std::unique_ptr createParallelLowerPass(); std::unique_ptr createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options); std::unique_ptr createConvertPolygeistToLLVMPass(); +std::unique_ptr createLowerPolygeistOpsPass(); } // namespace polygeist } // namespace mlir diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 86eed651d7d6..d209bcd0ffee 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -61,6 +61,12 @@ def RemoveTrivialUse : FunctionPass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } +def LowerPolygeistOps : FunctionPass<"lower-polygeist-ops"> { + let summary = "Lower polygeist ops to memref operations"; + let constructor = "mlir::polygeist::createLowerPolygeistOpsPass()"; + let dependentDialects = ["::mlir::memref::MemRefDialect"]; +} + def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { let summary = "Convert scalar and vector operations from the Standard to the " "LLVM dialect"; diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 9d40ad88d610..0ecb833b38b6 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -176,48 +176,6 @@ class SubToCast final : public OpRewritePattern { } }; -// Simplify polygeist.subindex to memref.subview. -class SubToSubView final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SubIndexOp op, - PatternRewriter &rewriter) const override { - auto srcMemRefType = op.source().getType().cast(); - auto resMemRefType = op.result().getType().cast(); - auto dims = srcMemRefType.getShape().size(); - - // For now, restrict subview lowering to statically defined memref's - if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape()) - return failure(); - - // For now, restrict to simple rank-reducing indexing - if (srcMemRefType.getShape().size() <= resMemRefType.getShape().size()) - return failure(); - - // Build offset, sizes and strides - SmallVector sizes(dims, rewriter.getIndexAttr(0)); - sizes[0] = op.index(); - SmallVector offsets(dims); - for (auto dim : llvm::enumerate(srcMemRefType.getShape())) { - if (dim.index() == 0) - offsets[0] = rewriter.getIndexAttr(1); - else - offsets[dim.index()] = rewriter.getIndexAttr(dim.value()); - } - SmallVector strides(dims, rewriter.getIndexAttr(1)); - - // Generate the appropriate return type: - auto subMemRefType = MemRefType::get(srcMemRefType.getShape().drop_front(), - srcMemRefType.getElementType()); - - rewriter.replaceOpWithNewOp( - op, subMemRefType, op.source(), sizes, offsets, strides); - - return success(); - } -}; - // Simplify redundant dynamic subindex patterns which tries to represent // rank-reducing indexing: // %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) -> @@ -678,7 +636,7 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); - // Disabled: SubToSubView + // Disabled: } /// Simplify pointer2memref(memref2pointer(x)) to cast(x) diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index 19f5ec443855..371c5cef2bee 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms ParallelLower.cpp TrivialUse.cpp ConvertPolygeistToLLVM.cpp + LowerPolygeistOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp new file mode 100644 index 000000000000..be3152b0d513 --- /dev/null +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -0,0 +1,88 @@ +//===- TrivialUse.cpp - Remove trivial use instruction ---------------- -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into +// a generic parallel for representation +//===----------------------------------------------------------------------===// +#include "PassDetails.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/DialectConversion.h" +#include "polygeist/Dialect.h" +#include "polygeist/Ops.h" + +using namespace mlir; +using namespace polygeist; +using namespace mlir::arith; + +namespace { + +struct SubIndexToReinterpretCast + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(polygeist::SubIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcMemRefType = op.source().getType().cast(); + auto resMemRefType = op.result().getType().cast(); + auto shape = srcMemRefType.getShape(); + + if (!resMemRefType.hasStaticShape()) + return failure(); + + int64_t innerSize = resMemRefType.getNumElements(); + auto offset = rewriter.create( + op.getLoc(), op.index(), + rewriter.create(op.getLoc(), innerSize)); + + llvm::SmallVector sizes, strides; + for (auto dim : shape.drop_front()) { + sizes.push_back(rewriter.getIndexAttr(dim)); + strides.push_back(rewriter.getIndexAttr(1)); + } + + rewriter.replaceOpWithNewOp( + op, resMemRefType, op.source(), offset.getResult(), sizes, strides); + + return success(); + } +}; + +struct LowerPolygeistOpsPass + : public LowerPolygeistOpsBase { + + void runOnFunction() override { + auto op = getOperation(); + auto ctx = op.getContext(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + + ConversionTarget target(*ctx); + target.addIllegalDialect(); + target.addLegalDialect(); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir { +namespace polygeist { +std::unique_ptr createLowerPolygeistOpsPass() { + return std::make_unique(); +} + +} // namespace polygeist +} // namespace mlir diff --git a/test/polygeist-opt/canonicalization.mlir b/test/polygeist-opt/canonicalization.mlir deleted file mode 100644 index d68b8c40dc34..000000000000 --- a/test/polygeist-opt/canonicalization.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: polygeist-opt --canonicalize --split-input-file %s | FileCheck %s -// XFAIL: * -// CHECK: func @main(%arg0: index) -> memref<30xi32> { -// CHECK: %0 = memref.alloca() : memref<30x30xi32> -// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 30] [1, 1] : memref<30x30xi32> to memref<30xi32> -// CHECK: return %1 : memref<30xi32> -// CHECK: } -module { - func @main(%arg0 : index) -> memref<30xi32> { - %0 = memref.alloca() : memref<30x30xi32> - %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32> - return %1 : memref<30xi32> - } -} - -// ----- - -// CHECK: func @main(%arg0: index) -> memref<1000xi32> { -// CHECK: %0 = memref.alloca() : memref<2x1000xi32> -// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 1000] [1, 1] : memref<2x1000xi32> to memref<1000xi32> -// CHECK: return %1 : memref<1000xi32> -// CHECK: } -func @main(%arg0 : index) -> memref<1000xi32> { - %c0 = arith.constant 0 : index - %1 = memref.alloca() : memref<2x1000xi32> - %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) -> memref - %4 = "polygeist.subindex"(%3, %c0) : (memref, index) -> memref<1000xi32> - return %4 : memref<1000xi32> -} diff --git a/test/polygeist-opt/lower_polygeist_ops.mlir b/test/polygeist-opt/lower_polygeist_ops.mlir new file mode 100644 index 000000000000..cd84039e637b --- /dev/null +++ b/test/polygeist-opt/lower_polygeist_ops.mlir @@ -0,0 +1,17 @@ +// RUN: polygeist-opt --lower-polygeist-ops --split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<30xi32> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 30 : index +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32> +// CHECK: return %[[VAL_4]] : memref<30xi32> +// CHECK: } +module { + func @main(%arg0 : index) -> memref<30xi32> { + %0 = memref.alloca() : memref<30x30xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32> + return %1 : memref<30xi32> + } +}