diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -22,6 +22,8 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { +class PatternRewriter; + namespace shape { namespace ShapeTypes { diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -372,7 +372,7 @@ let hasFolder = 1; } -def Shape_YieldOp : Shape_Op<"yield", +def Shape_YieldOp : Shape_Op<"yield", [HasParent<"ReduceOp">, NoSideEffect, ReturnLike, @@ -528,6 +528,14 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; + let extraClassDeclaration = [{ + // Inline the region into the region containing the AssumingOp and delete + // the AssumingOp. + // + // This does no checks on the inputs to the AssumingOp. + static void inlineRegion(AssumingOp op, PatternRewriter &rewriter); + }]; + let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -18,6 +18,9 @@ namespace mlir { +class FunctionPass; +class MLIRContext; +class OwningRewritePatternList; class Pass; /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape @@ -25,6 +28,11 @@ /// transformed to `shape.reduce`, which can be lowered to SCF and Standard. std::unique_ptr createShapeToShapeLowering(); +void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx); + +std::unique_ptr createRemoveShapeConstraintsPass(); + } // end namespace mlir #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -11,6 +11,11 @@ include "mlir/Pass/PassBase.td" +def RemoveShapeConstraints : FunctionPass<"remove-shape-constraints"> { + let summary = "Remove all cstr_ and assuming_ ops"; + let constructor = "mlir::createRemoveShapeConstraintsPass()"; +} + def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> { let summary = "Legalize Shape dialect to be convertible to Standard"; let constructor = "mlir::createShapeToShapeLowering()"; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -168,22 +168,7 @@ if (!witness || !witness.passingAttr()) return failure(); - auto *blockBeforeAssuming = rewriter.getInsertionBlock(); - auto *assumingBlock = op.getBody(); - auto initPosition = rewriter.getInsertionPoint(); - auto *blockAfterAssuming = - rewriter.splitBlock(blockBeforeAssuming, initPosition); - - // Remove the AssumingOp and AssumingYieldOp. - auto &yieldOp = assumingBlock->back(); - rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); - rewriter.replaceOp(op, yieldOp.getOperands()); - rewriter.eraseOp(&yieldOp); - - // Merge blocks together as there was no branching behavior from the - // AssumingOp. - rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); - rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); + AssumingOp::inlineRegion(op, rewriter); return success(); } }; @@ -195,6 +180,25 @@ patterns.insert(context); } +void AssumingOp::inlineRegion(AssumingOp op, PatternRewriter &rewriter) { + auto *blockBeforeAssuming = rewriter.getInsertionBlock(); + auto *assumingBlock = op.getBody(); + auto initPosition = rewriter.getInsertionPoint(); + auto *blockAfterAssuming = + rewriter.splitBlock(blockBeforeAssuming, initPosition); + + // Remove the AssumingOp and AssumingYieldOp. + auto &yieldOp = assumingBlock->back(); + rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); + rewriter.replaceOp(op, yieldOp.getOperands()); + rewriter.eraseOp(&yieldOp); + + // Merge blocks together as there was no branching behavior from the + // AssumingOp. + rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); + rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); +} + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms + RemoveShapeConstraints.cpp ShapeToShapeLowering.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp @@ -0,0 +1,89 @@ +//===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace { + +/// Removal patterns. +class RemoveCstrBroadcastableOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op.getOperation()); + return success(); + } +}; + +class RemoveAssumingOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::AssumingOp op, + PatternRewriter &rewriter) const override { + shape::AssumingOp::inlineRegion(op, rewriter); + return success(); + } +}; + +class RemoveAssumingAllOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::AssumingAllOp op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op.getOperation()); + return success(); + } +}; + +class RemoveCstrEqOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrEqOp op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op.getOperation()); + return success(); + } +}; + +/// Removal pass. +class RemoveShapeConstraintsPass + : public RemoveShapeConstraintsBase { + + void runOnFunction() override { + MLIRContext &ctx = getContext(); + + OwningRewritePatternList patterns; + populateRemoveShapeConstraintsPatterns(patterns, &ctx); + + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace + +void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +std::unique_ptr createRemoveShapeConstraintsPass() { + return std::make_unique(); +} + +} // namespace mlir diff --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints <%s | FileCheck %s --dump-input=fail + +// ----- +// With a non-const value, we cannot fold away the code, but all constraints +// should be removed still. +// +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index { + // CHECK-NEXT: test.source + // CHECK-NEXT: return + %0 = shape.cstr_broadcastable %arg0, %arg1 + %1 = shape.cstr_eq %arg0, %arg1 + %2 = shape.assuming_all %0, %1 + %3 = shape.assuming %0 -> index { + %4 = "test.source"() : () -> (index) + shape.assuming_yield %4 : index + } + return %3 : index +} + +