diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -18,6 +18,17 @@ namespace mlir { +/// This enum controls which ops are put on the worklist during a greedy +/// pattern rewrite. +enum class GreedyRewriteStrictness { + /// No restrictions wrt. which ops are processed. + AnyOp, + /// Only pre-existing and newly created ops are processed. + ExistingAndNewOps, + /// Only pre-existing ops are processed. + ExistingOps +}; + /// This class allows control over how the GreedyPatternRewriteDriver works. class GreedyRewriteConfig { public: @@ -88,22 +99,30 @@ bool *erased = nullptr); /// Applies the specified rewrite patterns on `ops` while also trying to fold -/// these ops as well as any other ops that were in turn created due to such -/// rewrites. Furthermore, any pre-existing ops in the IR outside of `ops` -/// remain completely unmodified if `strict` is set to true. If `strict` is -/// false, other operations that use results of rewritten ops or supply operands -/// to such ops are in turn simplified; any other ops still remain unmodified -/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a -/// result of folding, becoming dead, or via pattern rewrites. If more far -/// reaching simplification is desired, applyPatternsAndFoldGreedily should be -/// used. +/// these ops. +/// +/// Newly created ops and other pre-existing ops that use results of rewritten +/// ops or supply operands to such ops are simplified, unless such ops are +/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless +/// of `strictMode`). +/// +/// * GreedyRewriteStrictness::AnyOp: No ops are excluded. +/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly +/// created ops are simplified. All other ops are excluded. +/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are +/// simplified. All other ops are excluded. +/// +/// Note that ops in `ops` could be erased as result of folding, becoming dead, +/// or via pattern rewrites. If more far reaching simplification is desired, +/// applyPatternsAndFoldGreedily should be used. /// /// `changed` is set to true if the IR was modified at all. Returns success if /// Returns success if the iterative process converged and no more patterns can /// be matched. LogicalResult applyOpPatternsAndFold(ArrayRef ops, const FrozenRewritePatternSet &patterns, - bool strict, bool *changed = nullptr); + GreedyRewriteStrictness strictMode, + bool *changed = nullptr); } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -132,7 +132,8 @@ FrozenRewritePatternSet frozenPatterns(std::move(patterns)); // Apply the simplification pattern to a fixpoint. if (failed( - applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true))) { + applyOpPatternsAndFold(targets, frozenPatterns, + GreedyRewriteStrictness::ExistingAndNewOps))) { auto diag = emitDefiniteFailure() << "affine.min/max simplification did not converge"; return diag; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -238,5 +238,6 @@ AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(copyOps, frozenPatterns, /*strict=*/true); + (void)applyOpPatternsAndFold(copyOps, frozenPatterns, + GreedyRewriteStrictness::ExistingAndNewOps); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -105,5 +105,6 @@ if (isa(op)) opsToSimplify.push_back(op); }); - (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, /*strict=*/true); + (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, + GreedyRewriteStrictness::ExistingAndNewOps); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -570,66 +570,54 @@ public: explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - bool strict) + GreedyRewriteStrictness strictMode) : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), - strictMode(strict) {} + strictMode(strictMode) {} + /// Performs the specified rewrites on `ops` while also trying to fold these + /// ops. `strictMode` controls which other ops are simplified. + /// + /// Note that ops in `ops` could be erased as a result of folding, becoming + /// dead, or via pattern rewrites. The return value indicates convergence. LogicalResult simplifyLocally(ArrayRef op, bool *changed = nullptr); void addToWorklist(Operation *op) override { - if (!strictMode || strictModeFilteredOps.contains(op)) + if (strictMode == GreedyRewriteStrictness::AnyOp || + strictModeFilteredOps.contains(op)) GreedyPatternRewriteDriver::addSingleOpToWorklist(op); } private: void notifyOperationInserted(Operation *op) override { - if (strictMode) + if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps) strictModeFilteredOps.insert(op); GreedyPatternRewriteDriver::notifyOperationInserted(op); } void notifyOperationRemoved(Operation *op) override { GreedyPatternRewriteDriver::notifyOperationRemoved(op); - if (strictMode) + if (strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); } - /// If `strictMode` is true, any pre-existing ops outside of - /// `strictModeFilteredOps` remain completely untouched by the rewrite driver. - /// If `strictMode` is false, operations that use results of (or supply - /// operands to) any rewritten ops stemming from the simplification of the - /// provided ops are in turn simplified; any other ops still remain untouched - /// (i.e., regardless of `strictMode`). - bool strictMode = false; - - /// The list of ops we are restricting our rewrites to if `strictMode` is on. - /// These include the supplied set of ops as well as new ops created while - /// rewriting those ops. This set is not maintained when strictMode is off. + /// `strictMode` control which ops are added to the worklist during + /// simplificiation. + GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; + + /// The list of ops we are restricting our rewrites to. These include the + /// supplied set of ops as well as new ops created while rewriting those ops + /// depending on `strictMode`. This set is not maintained when `strictMode` + /// is GreedyRewriteStrictness::AnyOp. llvm::SmallDenseSet strictModeFilteredOps; }; } // namespace -/// Performs the specified rewrites on `ops` while also trying to fold these ops -/// as well as any other ops that were in turn created due to these rewrite -/// patterns. Any pre-existing ops outside of `ops` remain completely -/// unmodified if `strictMode` is true. If `strictMode` is false, other -/// operations that use results of rewritten ops or supply operands to such ops -/// are in turn simplified; any other ops still remain unmodified (i.e., -/// regardless of `strictMode`). Note that ops in `ops` could be erased as a -/// result of folding, becoming dead, or via pattern rewrites. Returns true if -/// at all any changes happened. -// Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op -// or GreedyPatternRewriteDriver::simplify, this method just iterates until -// the worklist is empty. As our objective is to keep simplification "local", -// there is no strong rationale to re-add all operations into the worklist and -// rerun until an iteration changes nothing. If more widereaching simplification -// is desired, GreedyPatternRewriteDriver should be used. LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, bool *changed) { - if (strictMode) { + if (strictMode != GreedyRewriteStrictness::AnyOp) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); } @@ -654,7 +642,8 @@ if (op == nullptr) continue; - assert((!strictMode || strictModeFilteredOps.contains(op)) && + assert((strictMode == GreedyRewriteStrictness::AnyOp || + strictModeFilteredOps.contains(op)) && "unexpected op was inserted under strict mode"); // If the operation is trivially dead - remove it. @@ -713,9 +702,6 @@ return success(worklist.empty()); } -/// Rewrites only `op` using the supplied canonicalization patterns and -/// folding. `erased` is set to true if the op is erased as a result of being -/// folded, replaced, or dead. LogicalResult mlir::applyOpPatternsAndFold( Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { // Start the pattern driver. @@ -733,10 +719,9 @@ return converged; } -LogicalResult -mlir::applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict, bool *changed) { +LogicalResult mlir::applyOpPatternsAndFold( + ArrayRef ops, const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, bool *changed) { if (ops.empty()) { if (changed) *changed = false; @@ -745,6 +730,6 @@ // Start the pattern driver. MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strict); + strictMode); return driver.simplifyLocally(ops, changed); } diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -1,9 +1,15 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect \ +// RUN: -test-strict-pattern-driver="strictness=ExistingAndNewOps" \ +// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EN -// CHECK-LABEL: func @test_erase -// CHECK: test.arg0 -// CHECK: test.arg1 -// CHECK-NOT: test.erase_op +// RUN: mlir-opt -allow-unregistered-dialect \ +// RUN: -test-strict-pattern-driver="strictness=ExistingOps" \ +// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX + +// CHECK-EN-LABEL: func @test_erase +// CHECK-EN: test.arg0 +// CHECK-EN: test.arg1 +// CHECK-EN-NOT: test.erase_op func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) @@ -11,18 +17,22 @@ return } -// CHECK-LABEL: func @test_insert_same_op -// CHECK: "test.insert_same_op"() {skip = true} -// CHECK: "test.insert_same_op"() {skip = true} +// ----- + +// CHECK-EN-LABEL: func @test_insert_same_op +// CHECK-EN: "test.insert_same_op"() {skip = true} +// CHECK-EN: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { %0 = "test.insert_same_op"() : () -> (i32) return } -// CHECK-LABEL: func @test_replace_with_new_op -// CHECK: %[[n:.*]] = "test.new_op" -// CHECK: "test.dummy_user"(%[[n]]) -// CHECK: "test.dummy_user"(%[[n]]) +// ----- + +// CHECK-EN-LABEL: func @test_replace_with_new_op +// CHECK-EN: %[[n:.*]] = "test.new_op" +// CHECK-EN: "test.dummy_user"(%[[n]]) +// CHECK-EN: "test.dummy_user"(%[[n]]) func.func @test_replace_with_new_op() { %0 = "test.replace_with_new_op"() : () -> (i32) %1 = "test.dummy_user"(%0) : (i32) -> (i32) @@ -30,9 +40,15 @@ return } -// CHECK-LABEL: func @test_replace_with_erase_op -// CHECK-NOT: test.replace_with_new_op -// CHECK-NOT: test.erase_op +// ----- + +// CHECK-EN-LABEL: func @test_replace_with_erase_op +// CHECK-EN-NOT: test.replace_with_new_op +// CHECK-EN-NOT: test.erase_op + +// CHECK-EX-LABEL: func @test_replace_with_erase_op +// CHECK-EX-NOT: test.replace_with_new_op +// CHECK-EX: test.erase_op func.func @test_replace_with_erase_op() { "test.replace_with_new_op"() {create_erase_op} : () -> () return diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -132,7 +132,8 @@ AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); } } - (void)applyOpPatternsAndFold(copyOps, std::move(patterns), /*strict=*/true); + (void)applyOpPatternsAndFold(copyOps, std::move(patterns), + GreedyRewriteStrictness::ExistingAndNewOps); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -244,11 +244,13 @@ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) TestStrictPatternDriver() = default; - TestStrictPatternDriver(const TestStrictPatternDriver &other) = default; + TestStrictPatternDriver(const TestStrictPatternDriver &other) { + strictMode = other.strictMode; + } StringRef getArgument() const final { return "test-strict-pattern-driver"; } StringRef getDescription() const final { - return "Run strict mode of pattern driver"; + return "Test strict mode of pattern driver"; } void runOnOperation() override { @@ -263,13 +265,28 @@ } }); + GreedyRewriteStrictness mode; + if (strictMode == "AnyOp") { + mode = GreedyRewriteStrictness::AnyOp; + } else if (strictMode == "ExistingAndNewOps") { + mode = GreedyRewriteStrictness::ExistingAndNewOps; + } else if (strictMode == "ExistingOps") { + mode = GreedyRewriteStrictness::ExistingOps; + } else { + llvm_unreachable("invalid strictness option"); + } + // Check if these transformations introduce visiting of operations that // are not in the `ops` set (The new created ops are valid). An invalid // operation will trigger the assertion while processing. - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), - /*strict=*/true); + (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode); } + Option strictMode{ + *this, "strictness", + llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), + llvm::cl::init("AnyOp")}; + private: // New inserted operation is valid for further transformation. class InsertSameOp : public RewritePattern {