diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -47,6 +47,11 @@ /// Note: Only applicable when simplifying entire regions. bool enableRegionSimplification = true; + /// Allow pattern rewrites for operations that are not isolated from above. + /// + /// Note: This is by default off. + bool enableRewritesForNonIsolatedFromAboveOps = false; + /// This specifies the maximum number of times the rewriter will iterate /// between applying patterns and simplifying regions. Use `kNoLimit` to /// disable this iteration limit. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -717,11 +717,12 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config, bool *changed) { - // The top-level operation must be known to be isolated from above to - // prevent performing canonicalizations on operations defined at or above - // the region containing 'op'. - assert(region.getParentOp()->hasTrait() && - "patterns can only be applied to operations IsolatedFromAbove"); + if (!region.getParentOp()->hasTrait() && + !config.enableRewritesForNonIsolatedFromAboveOps) { + return region.getParentOp()->emitOpError( + "invalid application of patterns to operation not IsolatedFromAbove. " + "To enable set `enableRewritesForNonIsolatedFromAboveOps` to `true`"); + } // Set scope if not specified. if (!config.scope) diff --git a/mlir/test/IR/pattern-rewriter-non-isolated-from-above.mlir b/mlir/test/IR/pattern-rewriter-non-isolated-from-above.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/pattern-rewriter-non-isolated-from-above.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt --test-pattern-non-isolated-from-above %s | FileCheck %s + +func.func @test(%arg0 : i32) { + %0 = "test.cast"(%arg0) : (i32) -> f32 + "test.one_region_op"()({ + %1 = "test.cast"(%0) : (f32) -> i32 + "test.region_yield"(%1) : (i32) -> () + }) : () -> () + return +} +// CHECK-LABEL: func @test( +// CHECK-SAME: %[[ARG0:.+]]: i32) +// CHECK: test.one_region_op +// CHECK: test.region_yield +// CHECK-SAME: %[[ARG0]] diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1696,6 +1696,61 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Test Not isolated from above +//===----------------------------------------------------------------------===// + +namespace { +struct TestFoldCastOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TestCastOp op, + PatternRewriter &rewriter) const final { + if (op.getNumOperands() != 1) + return failure(); + auto definingCastOp = op->getOperand(0).getDefiningOp(); + if (!definingCastOp || definingCastOp->getNumOperands() != 1) + return failure(); + Value source = definingCastOp.getOperand(0); + if (source.getType() != op.getResult().getType()) + return failure(); + rewriter.replaceOp(op, definingCastOp.getOperand(0)); + return success(); + } +}; + +struct TestNonIsolatedFromAbovePatternDriver + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestNonIsolatedFromAbovePatternDriver) + + StringRef getArgument() const final { + return "test-pattern-non-isolated-from-above"; + } + StringRef getDescription() const final { + return "Test pattern application within operation that is not isolated " + "from above"; + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + patterns.insert(context); + SmallVector oneRegionOps; + getOperation()->walk( + [&](OneRegionOp oneRegionOp) { oneRegionOps.push_back(oneRegionOp); }); + GreedyRewriteConfig config; + config.enableRewritesForNonIsolatedFromAboveOps = true; + for (auto oneRegionOp : oneRegionOps) { + if (failed(applyPatternsAndFoldGreedily(oneRegionOp, std::move(patterns), + config))) { + return signalPassFailure(); + } + } + } +}; +} // namespace + //===----------------------------------------------------------------------===// // PassRegistration //===----------------------------------------------------------------------===// @@ -1725,6 +1780,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir