diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -455,6 +455,15 @@ /// Rewrite the given regions, which must be isolated from above. bool applyPatternsAndFoldGreedily(MutableArrayRef regions, const OwningRewritePatternList &patterns); + +/// Applies the specified patterns on `op` alone while also trying to fold it, +/// by selecting the highest benefits patterns in a greedy manner. Returns true +/// if no more patterns can be matched. `erased` is set to true if `op` was +/// folded away or erased as a result of becoming dead. Note: This does not +/// apply any patterns recursively to the regions of `op`. +bool applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool *erased = nullptr); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -211,20 +211,25 @@ for (auto &block : f) runOnBlock(&block, copyNests); - // Promote any single iteration loops in the copy nests. + // Promote any single iteration loops in the copy nests and collect + // load/stores to simplify. + SmallVector copyOps; for (auto nest : copyNests) - nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); + // With a post order walk, the erasure of loops does not affect + // continuation of the walk or the collection of load/store ops. + nest->walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) + promoteIfSingleIteration(forOp); + else if (isa(op) || isa(op)) + copyOps.push_back(op); + }); // Promoting single iteration loops could lead to simplification of - // load's/store's. We will run canonicalization patterns on load/stores. - // TODO: this whole function load/store canonicalization should be replaced by - // canonicalization that is limited to only the load/store ops - // introduced/touched by this pass (those inside 'copyNests'). This would be - // possible once the necessary support is available in the pattern rewriter. - if (!copyNests.empty()) { - OwningRewritePatternList patterns; - AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); - AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(f, std::move(patterns)); - } + // contained load's/store's, and the latter could anyway also be + // canonicalized. + OwningRewritePatternList patterns; + AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); + AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); + for (auto op : copyOps) + applyOpPatternsAndFold(op, std::move(patterns)); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -6,14 +6,16 @@ // //===----------------------------------------------------------------------===// // -// This file implements a pass to simplify affine structures. +// This file implements a pass to simplify affine structures in operations. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/Utils.h" #define DEBUG_TYPE "simplify-affine-structure" @@ -77,13 +79,22 @@ void SimplifyAffineStructures::runOnFunction() { auto func = getFunction(); simplifiedAttributes.clear(); - func.walk([&](Operation *opInst) { - for (auto attr : opInst->getAttrs()) { + OwningRewritePatternList patterns; + AffineForOp::getCanonicalizationPatterns(patterns, func.getContext()); + AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext()); + AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext()); + func.walk([&](Operation *op) { + for (auto attr : op->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) - simplifyAndUpdateAttribute(opInst, attr.first, mapAttr); + simplifyAndUpdateAttribute(op, attr.first, mapAttr); else if (auto setAttr = attr.second.dyn_cast()) - simplifyAndUpdateAttribute(opInst, attr.first, setAttr); + simplifyAndUpdateAttribute(op, attr.first, setAttr); } + + // The simplification of the attribute will likely simplify the op. Try to + // fold / apply canonicalization patterns when we have affine dialect ops. + if (isa(op) || isa(op) || isa(op)) + applyOpPatternsAndFold(op, patterns); }); // Turn memrefs' non-identity layouts maps into ones with identity. Collect diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -26,6 +26,10 @@ /// The max number of iterations scanning for pattern match. static unsigned maxPatternMatchIterations = 10; +//===----------------------------------------------------------------------===// +// GreedyPatternRewriteDriver +//===----------------------------------------------------------------------===// + namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly /// applies the locally optimal patterns in a roughly "bottom up" way. @@ -37,8 +41,6 @@ worklist.reserve(64); } - /// Perform the rewrites while folding and erasing any dead ops. Return true - /// if the rewrite converges in `maxIterations`. bool simplify(MutableArrayRef regions, int maxIterations); void addToWorklist(Operation *op) { @@ -248,3 +250,112 @@ }); return converged; } + +//===----------------------------------------------------------------------===// +// OpPatternRewriteDriver +//===----------------------------------------------------------------------===// + +namespace { +/// This is a simple driver for the PatternMatcher to apply patterns and perform +/// folding on a single op. It repeatedly applies locally optimal patterns. +class OpPatternRewriteDriver : public PatternRewriter { +public: + explicit OpPatternRewriteDriver(MLIRContext *ctx, + const OwningRewritePatternList &patterns) + : PatternRewriter(ctx), matcher(patterns), folder(ctx) {} + + bool simplifyLocally(Operation *op, int maxIterations, bool &erased); + + /// No additional action needed other than inserting the op. + Operation *insert(Operation *op) override { return OpBuilder::insert(op); } + + // These are hooks implemented for PatternRewriter. +protected: + /// If an operation is about to be removed, mark it so that we can let clients + /// know. + void notifyOperationRemoved(Operation *op) override { + opErasedViaPatternRewrites = true; + } + + // When a root is going to be replaced, its removal will be notified as well. + // So there is nothing to do here. + void notifyRootReplaced(Operation *op) override {} + +private: + /// The low-level pattern matcher. + RewritePatternMatcher matcher; + + /// Non-pattern based folder for operations. + OperationFolder folder; + + /// Set to true if the operation has been erased via pattern rewrites. + bool opErasedViaPatternRewrites = false; +}; + +} // anonymous namespace + +/// Performs the rewrites and folding only on `op`. The simplification converges +/// if the op is erased as a result of being folded, replaced, or dead, or no +/// more changes happen in an iteration. Returns true if the rewrite converges +/// in `maxIterations`. `erased` is set to true if `op` gets erased. +bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations, + bool &erased) { + bool changed = false; + erased = false; + opErasedViaPatternRewrites = false; + int i = 0; + // Iterate until convergence or until maxIterations. Deletion of the op as + // a result of being dead or folded is convergence. + do { + // If the operation is trivially dead - remove it. + if (isOpTriviallyDead(op)) { + op->erase(); + erased = true; + return true; + } + + // Try to fold this op. + bool inPlaceUpdate; + if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, + /*preReplaceAction=*/nullptr, + &inPlaceUpdate))) { + changed = true; + if (!inPlaceUpdate) { + erased = true; + return true; + } + } + + // Make sure that any new operations are inserted at this point. + setInsertionPoint(op); + + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do here. + changed |= matcher.matchAndRewrite(op, *this); + if ((erased = opErasedViaPatternRewrites)) + return true; + } while (changed && ++i < maxIterations); + + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. + return !changed; +} + +/// Rewrites only `op` using the supplied canonicalization patterns and +/// folding. `erased` is set to true if the op is erased as a result of being +/// folded, replaced, or dead. +bool mlir::applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool *erased) { + // Start the pattern driver. + OpPatternRewriteDriver driver(op->getContext(), patterns); + bool opErased; + bool converged = + driver.simplifyLocally(op, maxPatternMatchIterations, opErased); + if (erased) + *erased = opErased; + LLVM_DEBUG(if (!converged) { + llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " + << maxPatternMatchIterations << " times"; + }); + return converged; +} diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/RegionUtils.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" @@ -312,9 +313,19 @@ opGroupQueue, /*offset=*/0, forOp, b); lbShift = d * step; } - if (!prologue && res) - prologue = res; - epilogue = res; + + if (res) { + // Simplify/canonicalize the affine.for. + OwningRewritePatternList patterns; + AffineForOp::getCanonicalizationPatterns(patterns, res.getContext()); + bool erased; + applyOpPatternsAndFold(res, std::move(patterns), &erased); + + if (!erased && !prologue) + prologue = res; + if (!erased) + epilogue = res; + } } else { // Start of first interval. lbShift = d * step; @@ -694,7 +705,8 @@ } /// Return true if `loops` is a perfect nest. -static bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef loops) { +static bool LLVM_ATTRIBUTE_UNUSED +isPerfectlyNested(ArrayRef loops) { auto outerLoop = loops.front(); for (auto loop : loops.drop_front()) { auto parentForOp = dyn_cast(loop.getParentOp()); diff --git a/mlir/test/Dialect/Affine/affine-data-copy.mlir b/mlir/test/Dialect/Affine/affine-data-copy.mlir --- a/mlir/test/Dialect/Affine/affine-data-copy.mlir +++ b/mlir/test/Dialect/Affine/affine-data-copy.mlir @@ -216,7 +216,7 @@ return %A : memref<4096xf32> } // CHECK: affine.for %[[IV1:.*]] = 0 to 4096 step 100 -// CHECK-NEXT: %[[BUF:.*]] = alloc() : memref<100xf32> +// CHECK: %[[BUF:.*]] = alloc() : memref<100xf32> // CHECK-NEXT: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) { // CHECK-NEXT: affine.load %{{.*}}[%[[IV2]]] : memref<4096xf32> // CHECK-NEXT: affine.store %{{.*}}, %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32> @@ -226,7 +226,7 @@ // CHECK-NEXT: mulf // CHECK-NEXT: affine.store %{{.*}}, %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32> // CHECK-NEXT: } -// CHECK-NEXT: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) { +// CHECK: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) { // CHECK-NEXT: affine.load %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[IV2]]] : memref<4096xf32> // CHECK-NEXT: } @@ -239,8 +239,8 @@ // with multi-level tiling when the tile sizes used don't divide loop trip // counts. -#lb = affine_map<(d0, d1) -> (d0 * 512, d1 * 6)> -#ub = affine_map<(d0, d1) -> (d0 * 512 + 512, d1 * 6 + 6)> +#lb = affine_map<()[s0, s1] -> (s0 * 512, s1 * 6)> +#ub = affine_map<()[s0, s1] -> (s0 * 512 + 512, s1 * 6 + 6)> // CHECK-DAG: #[[LB:.*]] = affine_map<()[s0, s1] -> (s0 * 512, s1 * 6)> // CHECK-DAG: #[[UB:.*]] = affine_map<()[s0, s1] -> (s0 * 512 + 512, s1 * 6 + 6)> @@ -250,7 +250,7 @@ // CHECK-SAME: [[j:arg[0-9]+]] func @max_lower_bound(%M: memref<2048x516xf64>, %i : index, %j : index) { affine.for %ii = 0 to 2048 { - affine.for %jj = max #lb(%i, %j) to min #ub(%i, %j) { + affine.for %jj = max #lb()[%i, %j] to min #ub()[%i, %j] { affine.load %M[%ii, %jj] : memref<2048x516xf64> } } diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -1,19 +1,19 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -simplify-affine-structures | FileCheck %s -// CHECK-DAG: #[[SET_EMPTY_2D:.*]] = affine_set<(d0, d1) : (1 == 0)> +// CHECK-DAG: #[[SET_EMPTY:.*]] = affine_set<() : (1 == 0)> // CHECK-DAG: #[[SET_2D:.*]] = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0)> -// CHECK-DAG: #[[SET_EMPTY_2D_2S:.*]] = affine_set<(d0, d1)[s0, s1] : (1 == 0)> -// CHECK-DAG: #[[SET_2D_2S:.*]] = affine_set<(d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)> -// CHECK-DAG: #[[SET_EMPTY_1D:.*]] = affine_set<(d0) : (1 == 0)> -// CHECK-DAG: #[[SET_EMPTY_1D_2S:.*]] = affine_set<(d0)[s0, s1] : (1 == 0)> -// CHECK-DAG: #[[SET_EMPTY_3D:.*]] = affine_set<(d0, d1, d2) : (1 == 0)> +// CHECK-DAG: #[[SET_7_11:.*]] = affine_set<(d0, d1) : (d0 * 7 + d1 * 5 + 88 == 0, d0 * 5 - d1 * 11 + 60 == 0, d0 * 11 + d1 * 7 - 24 == 0, d0 * 7 + d1 * 5 + 88 == 0)> + +// An external function that we will use in bodies to avoid DCE. +func @external() -> () // CHECK-LABEL: func @test_gaussian_elimination_empty_set0() { func @test_gaussian_elimination_empty_set0() { affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: [[SET_EMPTY_2D]](%arg0, %arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (2 == 0)>(%arg0, %arg1) { + call @external() : () -> () } } } @@ -24,8 +24,9 @@ func @test_gaussian_elimination_empty_set1() { affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: [[SET_EMPTY_2D]](%arg0, %arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (1 >= 0, -1 >= 0)> (%arg0, %arg1) { + call @external() : () -> () } } } @@ -38,6 +39,7 @@ affine.for %arg1 = 1 to 100 { // CHECK: #[[SET_2D]](%arg0, %arg1) affine.if affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)>(%arg0, %arg1) { + call @external() : () -> () } } } @@ -50,8 +52,9 @@ %c11 = constant 11 : index affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: #[[SET_EMPTY_2D_2S]](%arg0, %arg1)[%c7, %c11] + // CHECK: #[[SET_EMPTY]]() affine.if affine_set<(d0, d1)[s0, s1] : (d0 - s0 == 0, d0 + s0 == 0, s0 - 1 == 0)>(%arg0, %arg1)[%c7, %c11] { + call @external() : () -> () } } } @@ -70,8 +73,9 @@ %c11 = constant 11 : index affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: #[[SET_2D_2S]](%arg0, %arg1)[%c7, %c11] + // CHECK: #[[SET_7_11]](%arg0, %arg1) affine.if #set_2d_non_empty(%arg0, %arg1)[%c7, %c11] { + call @external() : () -> () } } } @@ -79,7 +83,6 @@ } // Add invalid constraints to previous non-empty set to make it empty. -// Set for test case: test_gaussian_elimination_empty_set5 #set_2d_empty = affine_set<(d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, @@ -92,8 +95,9 @@ %c11 = constant 11 : index affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: #[[SET_EMPTY_2D_2S]](%arg0, %arg1)[%c7, %c11] + // CHECK: #[[SET_EMPTY]]() affine.if #set_2d_empty(%arg0, %arg1)[%c7, %c11] { + call @external() : () -> () } } } @@ -147,6 +151,7 @@ affine.for %arg4 = 1 to 10 { affine.for %arg5 = 1 to 100 { affine.if #set_fuzz_virus(%arg4, %arg5, %arg0, %arg1, %arg2, %arg3) { + call @external() : () -> () } } } @@ -157,33 +162,33 @@ func @test_empty_set(%N : index) { affine.for %i = 0 to 10 { affine.for %j = 0 to 10 { - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)>(%i, %j) { "foo"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) { "bar"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) { "foo"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_1D_2S]](%arg1)[%arg0, %arg0] + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, -s0 >= 0)>(%i)[%N, %N] { "bar"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_3D]](%arg1, %arg2, %arg0) + // CHECK: affine.if #[[SET_EMPTY]]() // The set below implies d0 = d1; so d1 >= d0, but d0 >= d1 + 1. affine.if affine_set<(d0, d1, d2) : (d0 - d1 == 0, d2 - d0 >= 0, d0 - d1 - 1 >= 0)>(%i, %j, %N) { "foo"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() // The set below has rational solutions but no integer solutions; GCD test catches it. affine.if affine_set<(d0, d1) : (d0*2 -d1*2 - 1 == 0, d0 >= 0, -d0 + 100 >= 0, d1 >= 0, -d1 + 100 >= 0)>(%i, %j) { "foo"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (d1 == 0, d0 - 1 >= 0, - d0 - 1 >= 0)>(%i, %j) { "foo"() : () -> () } @@ -193,12 +198,12 @@ affine.for %k = 0 to 10 { affine.for %l = 0 to 10 { // Empty because no multiple of 8 lies between 4 and 7. - // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0) : (8*d0 - 4 >= 0, -8*d0 + 7 >= 0)>(%k) { "foo"() : () -> () } // Same as above but with equalities and inequalities. - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (d0 - 4*d1 == 0, 4*d1 - 5 >= 0, -4*d1 + 7 >= 0)>(%k, %l) { "foo"() : () -> () } @@ -206,12 +211,12 @@ // 8*d1 here is a multiple of 4, and so can't lie between 9 and 11. GCD // tightening will tighten constraints to 4*d0 + 8*d1 >= 12 and 4*d0 + // 8*d1 <= 8; hence infeasible. - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (4*d0 + 8*d1 - 9 >= 0, -4*d0 - 8*d1 + 11 >= 0)>(%k, %l) { "foo"() : () -> () } // Same as above but with equalities added into the mix. - // CHECK: affine.if #[[SET_EMPTY_3D]](%arg1, %arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1, d2) : (d0 - 4*d2 == 0, d0 + 8*d1 - 9 >= 0, -d0 - 8*d1 + 11 >= 0)>(%k, %k, %l) { "foo"() : () -> () } @@ -219,7 +224,7 @@ } affine.for %m = 0 to 10 { - // CHECK: affine.if #[[SET_EMPTY_1D]](%arg{{[0-9]+}}) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0) : (d0 mod 2 - 3 == 0)> (%m) { "foo"() : () -> () } @@ -230,20 +235,39 @@ // ----- -// CHECK-DAG: #[[SET_2D:.*]] = affine_set<(d0, d1) : (d0 >= 0, -d0 + 50 >= 0) -// CHECK-DAG: #[[SET_EMPTY:.*]] = affine_set<(d0, d1) : (1 == 0) -// CHECK-DAG: #[[SET_UNIV:.*]] = affine_set<(d0, d1) : (0 == 0) +// An external function that we will use in bodies to avoid DCE. +func @external() -> () + +// CHECK-DAG: #[[SET:.*]] = affine_set<()[s0] : (s0 >= 0, -s0 + 50 >= 0) +// CHECK-DAG: #[[EMPTY_SET:.*]] = affine_set<() : (1 == 0) +// CHECK-DAG: #[[UNIV_SET:.*]] = affine_set<() : (0 == 0) // CHECK-LABEL: func @simplify_set func @simplify_set(%a : index, %b : index) { - // CHECK: affine.if #[[SET_2D]] + // CHECK: affine.if #[[SET]] affine.if affine_set<(d0, d1) : (d0 - d1 + d1 + d0 >= 0, 2 >= 0, d0 >= 0, -d0 + 50 >= 0, -d0 + 100 >= 0)>(%a, %b) { + call @external() : () -> () } - // CHECK: affine.if #[[SET_EMPTY]] + // CHECK: affine.if #[[EMPTY_SET]] affine.if affine_set<(d0, d1) : (d0 mod 2 - 1 == 0, d0 - 2 * (d0 floordiv 2) == 0)>(%a, %b) { + call @external() : () -> () } - // CHECK: affine.if #[[SET_UNIV]] + // CHECK: affine.if #[[UNIV_SET]] affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%a, %b) { + call @external() : () -> () } return } + +// ----- + +// CHECK-DAG: -> (s0 * 2 + 1) + +// Test "op local" simplification on affine.apply. DCE on addi will not happen. +func @affine.apply(%N : index) { + %v = affine.apply affine_map<(d0, d1) -> (d0 + d1 + 1)>(%N, %N) + addi %v, %v : index + // CHECK: affine.apply #map{{.*}}()[%arg0] + // CHECK-NEXT: addi + return +} diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -88,16 +88,35 @@ generateCopyForMemRegion(region, loopNest, copyOptions, result); } - // Promote any single iteration loops in the copy nests. + // Promote any single iteration loops in the copy nests and simplify + // load/stores. + SmallVector copyOps; for (auto nest : copyNests) - nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); - - // Promoting single iteration loops could lead to simplification - // of load's/store's. We will run the canonicalization patterns again. + // With a post order walk, the erasure of loops does not affect + // continuation of the walk or the collection of load/store ops. + nest->walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) + promoteIfSingleIteration(forOp); + else if (auto loadOp = dyn_cast(op)) + copyOps.push_back(loadOp); + else if (auto storeOp = dyn_cast(op)) + copyOps.push_back(storeOp); + }); + + // Promoting single iteration loops could lead to simplification of + // generated load's/store's, and the latter could anyway also be + // canonicalized. OwningRewritePatternList patterns; - AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); - AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + for (auto op : copyOps) { + patterns.clear(); + if (isa(op)) { + AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); + } else { + assert(isa(op) && "expected affine store op"); + AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); + } + applyOpPatternsAndFold(op, std::move(patterns)); + } } namespace mlir {