diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -172,6 +172,10 @@ in which patterns are applied is unspecified; i.e., the ordering of ops in the region of this op is irrelevant. + If `apple_cse` is set, the greedy pattern rewrite is interleaved with + common subexpression elimination (CSE): both are repeated until a fixpoint + is reached. + This transform only reads the target handle and modifies the payload. If a pattern erases or replaces a tracked op, the mapping is updated accordingly. @@ -188,7 +192,7 @@ }]; let arguments = (ins - TransformHandleTypeInterface:$target); + TransformHandleTypeInterface:$target, UnitAttr:$apply_cse); let results = (outs); let regions = (region MaxSizedRegion<1>:$region); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -317,32 +317,52 @@ GreedyRewriteConfig config; config.listener = static_cast(rewriter.getListener()); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE + // was requested, apply the greedy pattern rewrite only once. (The greedy + // pattern rewrite driver already iterates to a fixpoint internally.) + bool cseChanged = false; + // One or two iterations should be sufficient. Stop iterating after a certain + // threshold to make debugging easier. + static const int64_t kNumMaxIterations = 50; + int64_t iteration = 0; + do { + LogicalResult result = failure(); + if (target->hasTrait()) { + // Op is isolated from above. Apply patterns and also perform region + // simplification. + result = applyPatternsAndFoldGreedily(target, frozenPatterns, config); + } else { + // Manually gather list of ops because the other + // GreedyPatternRewriteDriver overloads only accepts ops that are isolated + // from above. This way, patterns can be applied to ops that are not + // isolated from above. Regions are not being simplified. Furthermore, + // only a single greedy rewrite iteration is performed. + SmallVector ops; + target->walk([&](Operation *nestedOp) { + if (target != nestedOp) + ops.push_back(nestedOp); + }); + result = applyOpPatternsAndFold(ops, frozenPatterns, config); + } - LogicalResult result = failure(); - if (target->hasTrait()) { - // Op is isolated from above. Apply patterns and also perform region - // simplification. - result = applyPatternsAndFoldGreedily(target, std::move(patterns), config); - } else { - // Manually gather list of ops because the other GreedyPatternRewriteDriver - // overloads only accepts ops that are isolated from above. This way, - // patterns can be applied to ops that are not isolated from above. Regions - // are not being simplified. Furthermore, only a single greedy rewrite - // iteration is performed. - SmallVector ops; - target->walk([&](Operation *nestedOp) { - if (target != nestedOp) - ops.push_back(nestedOp); - }); - result = applyOpPatternsAndFold(ops, std::move(patterns), config); - } + // A failure typically indicates that the pattern application did not + // converge. + if (failed(result)) { + return emitSilenceableFailure(target) + << "greedy pattern application failed"; + } - // A failure typically indicates that the pattern application did not - // converge. - if (failed(result)) { - return emitSilenceableFailure(target) - << "greedy pattern application failed"; - } + if (getApplyCse()) { + DominanceInfo domInfo; + mlir::eliminateCommonSubExpressions(rewriter, domInfo, target, + &cseChanged); + } + } while (cseChanged && ++iteration < kNumMaxIterations); + + if (iteration == kNumMaxIterations) + return emitDefiniteFailure() << "fixpoint iteration did not converge"; return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -210,3 +210,24 @@ } } } + +// ----- + +// CHECK-LABEL: func @canonicalization_and_cse( +// CHECK-NOT: memref.subview +// CHECK-NOT: memref.copy +func.func @canonicalization_and_cse(%m: memref<5xf32>) { + %c2 = arith.constant 2 : index + %s0 = memref.subview %m[1] [2] [1] : memref<5xf32> to memref<2xf32, strided<[1], offset: 1>> + %s1 = memref.subview %m[1] [%c2] [1] : memref<5xf32> to memref> + memref.copy %s0, %s1 : memref<2xf32, strided<[1], offset: 1>> to memref> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.canonicalization + } {apply_cse} : !transform.any_op +}