Skip to content

Commit

Permalink
Lower polygeist.subindex through memref.reinterpret_cast
Browse files Browse the repository at this point in the history
This should be a (hopefully) foolproof method of performing indexing into a memref. A reintrepret_cast is inserted with a dynamic index calculated from the subindex index operand + the product of the sizes of the target type.

This has been added as a separate conversion pass instead of through the canonicalization drivers. When added as a canonicalization, the conversion may preemptively apply, resulting in sub-par IR. Nevertheless, i think it has its merits to have a polygeist op lowering pass which can be used as a fallback to convert the dialect operations, if canonicalization fails.

For now, just added support for statically shaped memrefs (enough to fix the regression on my side) but should be possible for dynamically shaped as well.
  • Loading branch information
mortbopet committed Jan 7, 2022
1 parent 4325b34 commit 3e0cdb7
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 72 deletions.
1 change: 1 addition & 0 deletions include/polygeist/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ std::unique_ptr<Pass> createParallelLowerPass();
std::unique_ptr<Pass>
createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options);
std::unique_ptr<Pass> createConvertPolygeistToLLVMPass();
std::unique_ptr<Pass> createLowerPolygeistOpsPass();

} // namespace polygeist
} // namespace mlir
Expand Down
6 changes: 6 additions & 0 deletions include/polygeist/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def RemoveTrivialUse : FunctionPass<"trivialuse"> {
let constructor = "mlir::polygeist::createRemoveTrivialUsePass()";
}

def LowerPolygeistOps : FunctionPass<"lower-polygeist-ops"> {
let summary = "Lower polygeist ops to memref operations";
let constructor = "mlir::polygeist::createLowerPolygeistOpsPass()";
let dependentDialects = ["::mlir::memref::MemRefDialect"];
}

def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert scalar and vector operations from the Standard to the "
"LLVM dialect";
Expand Down
44 changes: 1 addition & 43 deletions lib/polygeist/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,48 +176,6 @@ class SubToCast final : public OpRewritePattern<SubIndexOp> {
}
};

// Simplify polygeist.subindex to memref.subview.
class SubToSubView final : public OpRewritePattern<SubIndexOp> {
public:
using OpRewritePattern<SubIndexOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SubIndexOp op,
PatternRewriter &rewriter) const override {
auto srcMemRefType = op.source().getType().cast<MemRefType>();
auto resMemRefType = op.result().getType().cast<MemRefType>();
auto dims = srcMemRefType.getShape().size();

// For now, restrict subview lowering to statically defined memref's
if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
return failure();

// For now, restrict to simple rank-reducing indexing
if (srcMemRefType.getShape().size() <= resMemRefType.getShape().size())
return failure();

// Build offset, sizes and strides
SmallVector<OpFoldResult> sizes(dims, rewriter.getIndexAttr(0));
sizes[0] = op.index();
SmallVector<OpFoldResult> offsets(dims);
for (auto dim : llvm::enumerate(srcMemRefType.getShape())) {
if (dim.index() == 0)
offsets[0] = rewriter.getIndexAttr(1);
else
offsets[dim.index()] = rewriter.getIndexAttr(dim.value());
}
SmallVector<OpFoldResult> strides(dims, rewriter.getIndexAttr(1));

// Generate the appropriate return type:
auto subMemRefType = MemRefType::get(srcMemRefType.getShape().drop_front(),
srcMemRefType.getElementType());

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
op, subMemRefType, op.source(), sizes, offsets, strides);

return success();
}
};

// Simplify redundant dynamic subindex patterns which tries to represent
// rank-reducing indexing:
// %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) ->
Expand Down Expand Up @@ -678,7 +636,7 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<CastOfSubIndex, SubIndex2, SubToCast, SimplifySubViewUsers,
SimplifySubIndexUsers, SelectOfCast, SelectOfSubIndex,
RedundantDynSubIndex>(context);
// Disabled: SubToSubView
// Disabled:
}

/// Simplify pointer2memref(memref2pointer(x)) to cast(x)
Expand Down
1 change: 1 addition & 0 deletions lib/polygeist/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms
ParallelLower.cpp
TrivialUse.cpp
ConvertPolygeistToLLVM.cpp
LowerPolygeistOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
Expand Down
88 changes: 88 additions & 0 deletions lib/polygeist/Passes/LowerPolygeistOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//===- TrivialUse.cpp - Remove trivial use instruction ---------------- -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into
// a generic parallel for representation
//===----------------------------------------------------------------------===//
#include "PassDetails.h"

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "polygeist/Dialect.h"
#include "polygeist/Ops.h"

using namespace mlir;
using namespace polygeist;
using namespace mlir::arith;

namespace {

struct SubIndexToReinterpretCast
: public OpConversionPattern<polygeist::SubIndexOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(polygeist::SubIndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcMemRefType = op.source().getType().cast<MemRefType>();
auto resMemRefType = op.result().getType().cast<MemRefType>();
auto shape = srcMemRefType.getShape();

if (!resMemRefType.hasStaticShape())
return failure();

int64_t innerSize = resMemRefType.getNumElements();
auto offset = rewriter.create<arith::MulIOp>(
op.getLoc(), op.index(),
rewriter.create<ConstantIndexOp>(op.getLoc(), innerSize));

llvm::SmallVector<OpFoldResult> sizes, strides;
for (auto dim : shape.drop_front()) {
sizes.push_back(rewriter.getIndexAttr(dim));
strides.push_back(rewriter.getIndexAttr(1));
}

rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, resMemRefType, op.source(), offset.getResult(), sizes, strides);

return success();
}
};

struct LowerPolygeistOpsPass
: public LowerPolygeistOpsBase<LowerPolygeistOpsPass> {

void runOnFunction() override {
auto op = getOperation();
auto ctx = op.getContext();
RewritePatternSet patterns(ctx);
patterns.insert<SubIndexToReinterpretCast>(ctx);

ConversionTarget target(*ctx);
target.addIllegalDialect<polygeist::PolygeistDialect>();
target.addLegalDialect<arith::ArithmeticDialect, mlir::StandardOpsDialect,
memref::MemRefDialect>();

if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

namespace mlir {
namespace polygeist {
std::unique_ptr<Pass> createLowerPolygeistOpsPass() {
return std::make_unique<LowerPolygeistOpsPass>();
}

} // namespace polygeist
} // namespace mlir
29 changes: 0 additions & 29 deletions test/polygeist-opt/canonicalization.mlir

This file was deleted.

17 changes: 17 additions & 0 deletions test/polygeist-opt/lower_polygeist_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: polygeist-opt --lower-polygeist-ops --split-input-file %s | FileCheck %s

// CHECK-LABEL: func @main(
// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<30xi32> {
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32>
// CHECK: %[[VAL_2:.*]] = arith.constant 30 : index
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index
// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32>
// CHECK: return %[[VAL_4]] : memref<30xi32>
// CHECK: }
module {
func @main(%arg0 : index) -> memref<30xi32> {
%0 = memref.alloca() : memref<30x30xi32>
%1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32>
return %1 : memref<30xi32>
}
}

0 comments on commit 3e0cdb7

Please sign in to comment.