diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -77,6 +77,20 @@ const FrozenRewritePatternSet &patterns, bool *erased = nullptr); +/// Applies the specified rewrite patterns on `ops` while also trying to fold +/// these ops as well as any other ops that were in turn created due to such +/// rewrites. Furthermore, any pre-existing ops in the IR outside of `ops` +/// remain completely unmodified if `strict` is set to true. If `strict` is +/// false, other operations that use results of rewritten ops or supply operands +/// to such ops are in turn simplified; any other ops still remain unmodified +/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a +/// result of folding, becoming dead, or via pattern rewrites. If more far +/// reaching simplification is desired, applyPatternsAndFoldGreedily should be +/// used. Returns true if at all any IR was rewritten. +bool applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + bool strict); + } // end namespace mlir #endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -231,6 +231,5 @@ AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - for (Operation *op : copyOps) - (void)applyOpPatternsAndFold(op, frozenPatterns); + (void)applyOpPatternsAndFold(copyOps, frozenPatterns, /*strict=*/true); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -80,10 +80,14 @@ auto func = getFunction(); simplifiedAttributes.clear(); RewritePatternSet patterns(func.getContext()); + AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext()); AffineForOp::getCanonicalizationPatterns(patterns, func.getContext()); AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext()); - AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + // The simplification of affine attributes will likely simplify the op. Try to + // fold/apply canonicalization patterns when we have affine dialect ops. + SmallVector opsToSimplify; func.walk([&](Operation *op) { for (auto attr : op->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) @@ -92,9 +96,8 @@ simplifyAndUpdateAttribute(op, attr.first, setAttr); } - // The simplification of the attribute will likely simplify the op. Try to - // fold / apply canonicalization patterns when we have affine dialect ops. if (isa(op)) - (void)applyOpPatternsAndFold(op, frozenPatterns); + opsToSimplify.push_back(op); }); + (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, /*strict=*/true); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -81,6 +81,25 @@ // inserted ops are added to the worklist for processing. void notifyOperationInserted(Operation *op) override { addToWorklist(op); } + // Look over the provided operands for any defining operations that should + // be re-added to the worklist. This function should be called when an + // operation is modified or removed, as it may trigger further + // simplifications. + template + void addToWorklist(Operands &&operands) { + for (Value operand : operands) { + // If the use count of this operand is now < 2, we re-add the defining + // operation to the worklist. + // TODO: This is based on the fact that zero use operations + // may be deleted, and that single use values often have more + // canonicalization opportunities. + if (!operand || (!operand.use_empty() && !operand.hasOneUse())) + continue; + if (auto *defOp = operand.getDefiningOp()) + addToWorklist(defOp); + } + } + // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. void notifyOperationRemoved(Operation *op) override { @@ -100,26 +119,6 @@ addToWorklist(user); } -private: - // Look over the provided operands for any defining operations that should - // be re-added to the worklist. This function should be called when an - // operation is modified or removed, as it may trigger further - // simplifications. - template - void addToWorklist(Operands &&operands) { - for (Value operand : operands) { - // If the use count of this operand is now < 2, we re-add the defining - // operation to the worklist. - // TODO: This is based on the fact that zero use operations - // may be deleted, and that single use values often have more - // canonicalization opportunities. - if (!operand || (!operand.use_empty() && !operand.hasOneUse())) - continue; - if (auto *defInst = operand.getDefiningOp()) - addToWorklist(defInst); - } - } - /// The low-level pattern applicator. PatternApplicator matcher; @@ -133,6 +132,7 @@ /// Non-pattern based folder for operations. OperationFolder folder; +private: /// Configuration information for how to simplify. GreedyRewriteConfig config; }; @@ -277,11 +277,6 @@ matcher.applyDefaultCostModel(); } - /// Performs the rewrites and folding only on `op`. The simplification - /// converges if the op is erased as a result of being folded, replaced, or - /// dead, or no more changes happen in an iteration. Returns success if the - /// rewrite converges in `maxIterations`. `erased` is set to true if `op` gets - /// erased. LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased); // These are hooks implemented for PatternRewriter. @@ -309,13 +304,18 @@ } // anonymous namespace +/// Performs the rewrites and folding only on `op`. The simplification +/// converges if the op is erased as a result of being folded, replaced, or +/// becoming dead, or no more changes happen in an iteration. Returns success if +/// the rewrite converges in `maxIterations`. `erased` is set to true if `op` +/// gets erased. LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations, bool &erased) { bool changed = false; erased = false; opErasedViaPatternRewrites = false; - int i = 0; + int iterations = 0; // Iterate until convergence or until maxIterations. Deletion of the op as // a result of being dead or folded is convergence. do { @@ -345,12 +345,162 @@ changed |= succeeded(matcher.matchAndRewrite(op, *this)); if ((erased = opErasedViaPatternRewrites)) return success(); - } while (changed && ++i < maxIterations); + } while (changed && ++iterations < maxIterations); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return failure(changed); } +//===----------------------------------------------------------------------===// +// MultiOpPatternRewriteDriver +//===----------------------------------------------------------------------===// + +namespace { + +/// This is a specialized GreedyPatternRewriteDriver to apply patterns and +/// perform folding for a supplied set of ops. It repeatedly simplifies while +/// restricting the rewrites to only the provided set of ops or optionally +/// to those directly affected by it (result users or operand providers). +class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { +public: + explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, + const FrozenRewritePatternSet &patterns, + bool strict) + : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), + strictMode(strict) {} + + bool simplifyLocally(ArrayRef op); + +private: + // Look over the provided operands for any defining operations that should + // be re-added to the worklist. This function should be called when an + // operation is modified or removed, as it may trigger further + // simplifications. If `strict` is set to true, only ops in + // `strictModeFilteredOps` are considered. + template + void addOperandsToWorklist(Operands &&operands) { + for (Value operand : operands) { + if (auto *defOp = operand.getDefiningOp()) { + if (!strictMode || strictModeFilteredOps.contains(defOp)) + addToWorklist(defOp); + } + } + } + + void notifyOperationRemoved(Operation *op) override { + GreedyPatternRewriteDriver::notifyOperationRemoved(op); + if (strictMode) + strictModeFilteredOps.erase(op); + } + + /// If `strictMode` is true, any pre-existing ops outside of + /// `strictModeFilteredOps` remain completely untouched by the rewrite driver. + /// If `strictMode` is false, operations that use results of (or supply + /// operands to) any rewritten ops stemming from the simplification of the + /// provided ops are in turn simplified; any other ops still remain untouched + /// (i.e., regardless of `strictMode`). + bool strictMode = false; + + /// The list of ops we are restricting our rewrites to if `strictMode` is on. + /// These include the supplied set of ops as well as new ops created while + /// rewriting those ops. This set is not maintained when strictMode is off. + llvm::SmallDenseSet strictModeFilteredOps; +}; + +} // end anonymous namespace + +/// Performs the specified rewrites on `ops` while also trying to fold these ops +/// as well as any other ops that were in turn created due to these rewrite +/// patterns. Any pre-existing ops outside of `ops` remain completely +/// unmodified if `strictMode` is true. If `strictMode` is false, other +/// operations that use results of rewritten ops or supply operands to such ops +/// are in turn simplified; any other ops still remain unmodified (i.e., +/// regardless of `strictMode`). Note that ops in `ops` could be erased as a +/// result of folding, becoming dead, or via pattern rewrites. Returns true if +/// at all any changes happened. +// Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op +// or GreedyPatternRewriteDriver::simplify, this method just iterates until +// the worklist is empty. As our objective is to keep simplification "local", +// there is no strong rationale to re-add all operations into the worklist and +// rerun until an iteration changes nothing. If more widereaching simplification +// is desired, GreedyPatternRewriteDriver should be used. +bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops) { + if (strictMode) { + strictModeFilteredOps.clear(); + strictModeFilteredOps.insert(ops.begin(), ops.end()); + } + + bool changed = false; + worklist.clear(); + worklistMap.clear(); + for (Operation *op : ops) + addToWorklist(op); + + // These are scratch vectors used in the folding loop below. + SmallVector originalOperands, resultValues; + while (!worklist.empty()) { + Operation *op = popFromWorklist(); + + // Nulls get added to the worklist when operations are removed, ignore + // them. + if (op == nullptr) + continue; + + // If the operation is trivially dead - remove it. + if (isOpTriviallyDead(op)) { + notifyOperationRemoved(op); + op->erase(); + changed = true; + continue; + } + + // Collects all the operands and result uses of the given `op` into work + // list. Also remove `op` and nested ops from worklist. + originalOperands.assign(op->operand_begin(), op->operand_end()); + auto preReplaceAction = [&](Operation *op) { + // Add the operands to the worklist for visitation. + addOperandsToWorklist(originalOperands); + + // Add all the users of the result to the worklist so we make sure + // to revisit them. + for (Value result : op->getResults()) + for (Operation *userOp : result.getUsers()) { + if (!strictMode || strictModeFilteredOps.contains(userOp)) + addToWorklist(userOp); + } + notifyOperationRemoved(op); + }; + + // Add the given operation generated by the folder to the worklist. + auto processGeneratedConstants = [this](Operation *op) { + // Newly created ops are also simplified -- these are also "local". + addToWorklist(op); + // When strict mode is off, we don't need to maintain + // strictModeFilteredOps. + if (strictMode) + strictModeFilteredOps.insert(op); + }; + + // Try to fold this op. + bool inPlaceUpdate; + if (succeeded(folder.tryToFold(op, processGeneratedConstants, + preReplaceAction, &inPlaceUpdate))) { + changed = true; + if (!inPlaceUpdate) { + // Op has been erased. + continue; + } + } + + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do + // here. + changed |= succeeded(matcher.matchAndRewrite(op, *this)); + } + + return changed; +} + /// Rewrites only `op` using the supplied canonicalization patterns and /// folding. `erased` is set to true if the op is erased as a result of being /// folded, replaced, or dead. @@ -370,3 +520,15 @@ }); return converged; } + +bool mlir::applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + bool strict) { + if (ops.empty()) + return false; + + // Start the pattern driver. + MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, + strict); + return driver.simplifyLocally(ops); +} diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -261,12 +261,12 @@ // CHECK-DAG: -> (s0 * 2 + 1) // Test "op local" simplification on affine.apply. DCE on addi will not happen. -func @affine.apply(%N : index) { +func @affine.apply(%N : index) -> index { %v = affine.apply affine_map<(d0, d1) -> (d0 + d1 + 1)>(%N, %N) - addi %v, %v : index + %res = addi %v, %v : index // CHECK: affine.apply #map{{.*}}()[%arg0] // CHECK-NEXT: addi - return + return %res: index } // ----- diff --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir --- a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir +++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir @@ -10,10 +10,9 @@ %c16 = constant 16 : index %c1024 = constant 1024 : index + // CHECK: %[[C2:.*]] = constant 2 : i64 // CHECK: scf.for - // CHECK-NEXT: %[[C2:.*]] = constant 2 : index - // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] - // CHECK-NEXT: memref.store %[[C2I64]], %{{.*}}[] : memref + // CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref scf.for %i = %c0 to %c4 step %c2 { %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4) %2 = index_cast %1: index to i64 @@ -21,9 +20,7 @@ } // CHECK: scf.for - // CHECK-NEXT: %[[C2:.*]] = constant 2 : index - // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] - // CHECK-NEXT: memref.store %[[C2I64]], %{{.*}}[] : memref + // CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref scf.for %i = %c1 to %c7 step %c2 { %1 = affine.min affine_map<(d0)[s0] -> (s0 - d0, 2)> (%i)[%c7] %2 = index_cast %1: index to i64 @@ -93,10 +90,9 @@ %c7 = constant 7 : index %c4 = constant 4 : index + // CHECK: %[[C2:.*]] = constant 2 : i64 // CHECK: scf.parallel - // CHECK-NEXT: %[[C2:.*]] = constant 2 : index - // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] - // CHECK-NEXT: memref.store %[[C2I64]], %{{.*}}[] : memref + // CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref scf.parallel (%i) = (%c0) to (%c4) step (%c2) { %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4) %2 = index_cast %1: index to i64 @@ -104,9 +100,7 @@ } // CHECK: scf.parallel - // CHECK-NEXT: %[[C2:.*]] = constant 2 : index - // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]] - // CHECK-NEXT: memref.store %[[C2I64]], %{{.*}}[] : memref + // CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref scf.parallel (%i) = (%c1) to (%c7) step (%c2) { %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7] %2 = index_cast %1: index to i64 diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -126,8 +126,8 @@ assert(isa(op) && "expected affine store op"); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); } - (void)applyOpPatternsAndFold(op, std::move(patterns)); } + (void)applyOpPatternsAndFold(copyOps, std::move(patterns), /*strict=*/true); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -551,11 +551,11 @@ foldPattern.add(funcOp.getContext()); FrozenRewritePatternSet frozenPatterns(std::move(foldPattern)); - // Explicitly walk and apply the pattern locally to avoid more general folding + // Explicitly apply the pattern on affected ops to avoid more general folding // on the rest of the IR. - funcOp.walk([&frozenPatterns](AffineMinOp minOp) { - (void)applyOpPatternsAndFold(minOp, frozenPatterns); - }); + SmallVector minOps; + funcOp.walk([&](AffineMinOp minOp) { minOps.push_back(minOp); }); + (void)applyOpPatternsAndFold(minOps, frozenPatterns, /*strict=*/false); } // For now, just assume it is the zero of type.