diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -30,7 +30,8 @@ enum { ImpossibleToMatchSentinel = 65535 }; public: - /*implicit*/ PatternBenefit(unsigned benefit); + PatternBenefit() : representation(ImpossibleToMatchSentinel) {} + PatternBenefit(unsigned benefit); PatternBenefit(const PatternBenefit &) = default; PatternBenefit &operator=(const PatternBenefit &) = default; @@ -48,9 +49,11 @@ bool operator<(const PatternBenefit &rhs) const { return representation < rhs.representation; } + bool operator>(const PatternBenefit &rhs) const { return rhs < *this; } + bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); } + bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); } private: - PatternBenefit() : representation(ImpossibleToMatchSentinel) {} unsigned short representation; }; @@ -384,6 +387,9 @@ // Pattern-driven rewriters //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// OwningRewritePatternList + class OwningRewritePatternList { using PatternListT = std::vector>; @@ -401,6 +407,7 @@ PatternListT::iterator end() { return patterns.end(); } PatternListT::const_iterator begin() const { return patterns.begin(); } PatternListT::const_iterator end() const { return patterns.end(); } + PatternListT::size_type size() const { return patterns.size(); } void clear() { patterns.clear(); } //===--------------------------------------------------------------------===// @@ -419,60 +426,100 @@ // types 'Ts'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. // FIXME: In c++17 this can be simplified by using 'fold expressions'. - using dummy = int[]; - (void)dummy{ + (void)std::initializer_list{ 0, (patterns.emplace_back(std::make_unique(arg, args...)), 0)...}; return *this; } + /// Add the given pattern to the pattern list. + void insert(std::unique_ptr pattern) { + patterns.emplace_back(std::move(pattern)); + } + private: PatternListT patterns; }; -/// This class manages optimization and execution of a group of rewrite -/// patterns, providing an API for finding and applying, the best match against -/// a given node. -/// -class RewritePatternMatcher { +//===----------------------------------------------------------------------===// +// PatternApplicator + +/// This class manages the application of a group of rewrite patterns, with a +/// user-provided cost model. +class PatternApplicator { public: - /// Create a RewritePatternMatcher with the specified set of patterns. - explicit RewritePatternMatcher(const OwningRewritePatternList &patterns); + /// The cost model dynamically assigns a PatternBenefit to a particular + /// pattern. Users can query contained patterns and pass analysis results to + /// applyCostModel. Patterns to be discarded should have a benefit of + /// `impossibleToMatch`. + using CostModel = function_ref; + + explicit PatternApplicator(const OwningRewritePatternList &owningPatternList) + : owningPatternList(owningPatternList) {} + + /// Attempt to match and rewrite the given op with any pattern, allowing a + /// predicate to decide if a pattern can be applied or not, and hooks for if + /// the pattern match was a success or failure. + /// + /// canApply: called before each match and rewrite attempt; return false to + /// skip pattern. + /// onFailure: called when a pattern fails to match to perform cleanup. + /// onSuccess: called when a pattern match succeeds; return failure() to + /// invalidate the match and try another pattern. + LogicalResult matchAndRewrite( + Operation *op, PatternRewriter &rewriter, + function_ref canApply = {}, + function_ref onFailure = {}, + function_ref onSuccess = {}); + + /// Apply a cost model to the patterns within this applicator. + void applyCostModel(CostModel model); + + /// Apply the default cost model that solely uses the pattern's static + /// benefit. + void applyDefaultCostModel() { + applyCostModel( + [](const RewritePattern &pattern) { return pattern.getBenefit(); }); + } - /// Try to match the given operation to a pattern and rewrite it. Return - /// true if any pattern matches. - bool matchAndRewrite(Operation *op, PatternRewriter &rewriter); + /// Walk all of the rewrite patterns within the applicator. + void walkAllPatterns(function_ref walk); private: - RewritePatternMatcher(const RewritePatternMatcher &) = delete; - void operator=(const RewritePatternMatcher &) = delete; + /// The list that owns the patterns used within this applicator. + const OwningRewritePatternList &owningPatternList; - /// The group of patterns that are matched for optimization through this - /// matcher. - std::vector patterns; + /// The set of patterns to match for each operation, stable sorted by benefit. + DenseMap> patterns; }; +//===----------------------------------------------------------------------===// +// applyPatternsGreedily +//===----------------------------------------------------------------------===// + /// Rewrite the regions of the specified operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return true if no more patterns can be matched in -/// the result operation regions. -/// Note: This does not apply patterns to the top-level operation itself. -/// Note: These methods also perform folding and simple dead-code elimination +/// work-list driven manner. Return success if no more patterns can be matched +/// in the result operation regions. +/// Note: This does not apply patterns to the top-level operation itself. Note: +/// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// -bool applyPatternsAndFoldGreedily(Operation *op, - const OwningRewritePatternList &patterns); +LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const OwningRewritePatternList &patterns); /// Rewrite the given regions, which must be isolated from above. -bool applyPatternsAndFoldGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns); +LogicalResult +applyPatternsAndFoldGreedily(MutableArrayRef regions, + const OwningRewritePatternList &patterns); /// Applies the specified patterns on `op` alone while also trying to fold it, -/// by selecting the highest benefits patterns in a greedy manner. Returns true -/// if no more patterns can be matched. `erased` is set to true if `op` was -/// folded away or erased as a result of becoming dead. Note: This does not +/// by selecting the highest benefits patterns in a greedy manner. Returns +/// success if no more patterns can be matched. `erased` is set to true if `op` +/// was folded away or erased as a result of becoming dead. Note: This does not /// apply any patterns recursively to the regions of `op`. -bool applyOpPatternsAndFold(Operation *op, - const OwningRewritePatternList &patterns, - bool *erased = nullptr); +LogicalResult applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool *erased = nullptr); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -214,13 +214,13 @@ for (const auto &patterns : stage1Patterns) { LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" << *op); - if (!applyPatternsAndFoldGreedily(op, patterns)) { + if (failed(applyPatternsAndFoldGreedily(op, patterns))) { LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); return failure(); } LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" << *op); - if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) { + if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); return failure(); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -19,8 +19,7 @@ } unsigned short PatternBenefit::getBenefit() const { - assert(representation != ImpossibleToMatchSentinel && - "Pattern doesn't match"); + assert(!isImpossibleToMatch() && "Pattern doesn't match"); return representation; } @@ -171,31 +170,72 @@ // PatternMatcher implementation //===----------------------------------------------------------------------===// -RewritePatternMatcher::RewritePatternMatcher( - const OwningRewritePatternList &patterns) { - for (auto &pattern : patterns) - this->patterns.push_back(pattern.get()); +void PatternApplicator::applyCostModel(CostModel model) { + // Separate patterns by root kind to simplify lookup later on. + patterns.clear(); + for (const auto &pat : owningPatternList) + patterns[pat->getRootKind()].push_back(pat.get()); + + // Sort the patterns using the provided cost model. + llvm::SmallDenseMap benefits; + auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) { + return benefits[lhs] > benefits[rhs]; + }; + for (auto &it : patterns) { + SmallVectorImpl &list = it.second; + + // Special case for one pattern in the list, which is the most common case. + if (list.size() == 1) { + if (model(*list.front()).isImpossibleToMatch()) + list.clear(); + continue; + } + + // Collect the dynamic benefits for the current pattern list. + benefits.clear(); + for (RewritePattern *pat : list) + benefits.try_emplace(pat, model(*pat)); + + // Sort patterns with highest benefit first, and remove those that are + // impossible to match. + std::stable_sort(list.begin(), list.end(), cmp); + while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) + list.pop_back(); + } +} - // Sort the patterns by benefit to simplify the matching logic. - std::stable_sort(this->patterns.begin(), this->patterns.end(), - [](RewritePattern *l, RewritePattern *r) { - return r->getBenefit() < l->getBenefit(); - }); +void PatternApplicator::walkAllPatterns( + function_ref walk) { + for (auto &it : owningPatternList) + walk(*it); } /// Try to match the given operation to a pattern and rewrite it. -bool RewritePatternMatcher::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) { - for (auto *pattern : patterns) { - // Ignore patterns that are for the wrong root or are impossible to match. - if (pattern->getRootKind() != op->getName() || - pattern->getBenefit().isImpossibleToMatch()) +LogicalResult PatternApplicator::matchAndRewrite( + Operation *op, PatternRewriter &rewriter, + function_ref canApply, + function_ref onFailure, + function_ref onSuccess) { + auto patternIt = patterns.find(op->getName()); + if (patternIt == patterns.end()) + return failure(); + + for (auto *pattern : patternIt->second) { + // Check that the pattern can be applied. + if (canApply && !canApply(*pattern)) continue; // Try to match and rewrite this pattern. The patterns are sorted by - // benefit, so if we match we can immediately rewrite and return. - if (succeeded(pattern->matchAndRewrite(op, rewriter))) - return true; + // benefit, so if we match we can immediately rewrite. + rewriter.setInsertionPoint(op); + if (succeeded(pattern->matchAndRewrite(op, rewriter))) { + if (!onSuccess || succeeded(onSuccess(*pattern))) + return success(); + continue; + } + + if (onFailure) + onFailure(*pattern); } - return false; + return failure(); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1122,7 +1122,7 @@ namespace { /// A set of rewrite patterns that can be used to legalize a given operation. -using LegalizationPatterns = SmallVector; +using LegalizationPatterns = SmallVector; /// This class defines a recursive operation legalizer. class OperationLegalizer { @@ -1130,11 +1130,7 @@ using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(ConversionTarget &targetInfo, - const OwningRewritePatternList &patterns) - : target(targetInfo) { - buildLegalizationGraph(patterns); - computeLegalizationGraphBenefit(); - } + const OwningRewritePatternList &patterns); /// Returns if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1151,16 +1147,28 @@ LogicalResult legalizeWithFold(Operation *op, ConversionPatternRewriter &rewriter); - /// Attempt to legalize the given operation by applying the provided pattern. - /// Returns success if the operation was legalized, failure otherwise. - LogicalResult legalizePattern(Operation *op, RewritePattern *pattern, - ConversionPatternRewriter &rewriter); + /// Attempt to legalize the given operation by applying a pattern. Returns + /// success if the operation was legalized, failure otherwise. + LogicalResult legalizeWithPattern(Operation *op, + ConversionPatternRewriter &rewriter); + + /// Return true if the given pattern may be applied to the given operation, + /// false otherwise. + bool canApplyPattern(Operation *op, const RewritePattern &pattern, + ConversionPatternRewriter &rewriter); + + /// Legalize the resultant IR after successfully applying the given pattern. + LogicalResult legalizePatternResult(Operation *op, + const RewritePattern &pattern, + ConversionPatternRewriter &rewriter, + RewriterState &curState); /// Build an optimistic legalization graph given the provided patterns. This /// function populates 'legalizerPatterns' with the operations that are not /// directly legal, but may be transitively legal for the current target given /// the provided patterns. - void buildLegalizationGraph(const OwningRewritePatternList &patterns); + void buildLegalizationGraph( + DenseMap &legalizerPatterns); /// Compute the benefit of each node within the computed legalization graph. /// This orders the patterns within 'legalizerPatterns' based upon two @@ -1170,20 +1178,31 @@ /// 2) When comparing patterns with the same legalization depth, prefer the /// pattern with the highest PatternBenefit. This allows for users to /// prefer specific legalizations over others. - void computeLegalizationGraphBenefit(); + void computeLegalizationGraphBenefit( + DenseMap &legalizerPatterns); /// The current set of patterns that have been applied. - SmallPtrSet appliedPatterns; - - /// The set of legality information for operations transitively supported by - /// the target. - DenseMap legalizerPatterns; + SmallPtrSet appliedPatterns; /// The legalization information provided by the target. ConversionTarget ⌖ + + /// The pattern applicator to use for conversions. + PatternApplicator applicator; }; } // namespace +OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo, + const OwningRewritePatternList &patterns) + : target(targetInfo), applicator(patterns) { + // The set of legality information for operations transitively supported by + // the target. + DenseMap legalizerPatterns; + + buildLegalizationGraph(legalizerPatterns); + computeLegalizationGraphBenefit(legalizerPatterns); +} + bool OperationLegalizer::isIllegal(Operation *op) const { // Check if the target explicitly marked this operation as illegal. return target.getOpAction(op->getName()) == LegalizationAction::Illegal; @@ -1253,24 +1272,12 @@ } // Otherwise, we need to apply a legalization pattern to this operation. - auto it = legalizerPatterns.find(op->getName()); - if (it == legalizerPatterns.end()) { + if (succeeded(legalizeWithPattern(op, rewriter))) { LLVM_DEBUG({ - logFailure(rewriterImpl.logger, "no known legalization path"); + logSuccess(rewriterImpl.logger, ""); rewriterImpl.logger.startLine() << logLineComment; }); - return failure(); - } - - // The patterns are sorted by expected benefit, so try to apply each in-order. - for (auto *pattern : it->second) { - if (succeeded(legalizePattern(op, pattern, rewriter))) { - LLVM_DEBUG({ - logSuccess(rewriterImpl.logger, ""); - rewriterImpl.logger.startLine() << logLineComment; - }); - return success(); - } + return success(); } LLVM_DEBUG({ @@ -1320,46 +1327,70 @@ } LogicalResult -OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, - ConversionPatternRewriter &rewriter) { +OperationLegalizer::legalizeWithPattern(Operation *op, + ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); + + // Functor that returns if the given pattern may be applied. + auto canApply = [&](const RewritePattern &pattern) { + return canApplyPattern(op, pattern, rewriter); + }; + + // Functor that cleans up the rewriter state after a pattern failed to match. + RewriterState curState = rewriterImpl.getCurrentState(); + auto onFailure = [&](const RewritePattern &pattern) { + LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match")); + rewriterImpl.resetState(curState); + appliedPatterns.erase(&pattern); + }; + + // Functor that performs additional legalization when a pattern is + // successfully applied. + auto onSuccess = [&](const RewritePattern &pattern) { + auto result = legalizePatternResult(op, pattern, rewriter, curState); + appliedPatterns.erase(&pattern); + if (failed(result)) + rewriterImpl.resetState(curState); + return result; + }; + + // Try to match and rewrite a pattern on this operation. + return applicator.matchAndRewrite(op, rewriter, canApply, onFailure, + onSuccess); +} + +bool OperationLegalizer::canApplyPattern(Operation *op, + const RewritePattern &pattern, + ConversionPatternRewriter &rewriter) { LLVM_DEBUG({ - auto &os = rewriterImpl.logger; + auto &os = rewriter.getImpl().logger; os.getOStream() << "\n"; - os.startLine() << "* Pattern : '" << pattern->getRootKind() << " -> ("; - llvm::interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); + os.startLine() << "* Pattern : '" << pattern.getRootKind() << " -> ("; + llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs()); os.getOStream() << ")' {\n"; os.indent(); }); // Ensure that we don't cycle by not allowing the same pattern to be // applied twice in the same recursion stack if it is not known to be safe. - if (!pattern->hasBoundedRewriteRecursion() && - !appliedPatterns.insert(pattern).second) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern was already applied")); - return failure(); + if (!pattern.hasBoundedRewriteRecursion() && + !appliedPatterns.insert(&pattern).second) { + LLVM_DEBUG( + logFailure(rewriter.getImpl().logger, "pattern was already applied")); + return false; } + return true; +} - RewriterState curState = rewriterImpl.getCurrentState(); - auto cleanupFailure = [&] { - // Reset the rewriter state and pop this pattern. - rewriterImpl.resetState(curState); - appliedPatterns.erase(pattern); - return failure(); - }; +LogicalResult OperationLegalizer::legalizePatternResult( + Operation *op, const RewritePattern &pattern, + ConversionPatternRewriter &rewriter, RewriterState &curState) { + auto &rewriterImpl = rewriter.getImpl(); - // Try to rewrite with the given pattern. - rewriter.setInsertionPoint(op); - LogicalResult matchedPattern = pattern->matchAndRewrite(op, rewriter); #ifndef NDEBUG assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); #endif - if (failed(matchedPattern)) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match")); - return cleanupFailure(); - } - // If the pattern moved or created any blocks, try to legalize their types. // This ensures that the types of the block arguments are legal for the region // they were moved into. @@ -1376,7 +1407,7 @@ if (failed(rewriterImpl.convertBlockSignature(action.block))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "failed to convert types of moved block")); - return cleanupFailure(); + return failure(); } } @@ -1414,7 +1445,7 @@ LLVM_DEBUG(logFailure(rewriterImpl.logger, "operation updated in-place '{0}' was illegal", op->getName())); - return cleanupFailure(); + return failure(); } } @@ -1426,42 +1457,42 @@ LLVM_DEBUG(logFailure(rewriterImpl.logger, "generated operation '{0}'({1}) was illegal", op->getName(), op)); - return cleanupFailure(); + return failure(); } } LLVM_DEBUG(logSuccess(rewriterImpl.logger, "pattern applied successfully")); - appliedPatterns.erase(pattern); return success(); } void OperationLegalizer::buildLegalizationGraph( - const OwningRewritePatternList &patterns) { + DenseMap &legalizerPatterns) { // A mapping between an operation and a set of operations that can be used to // generate it. DenseMap> parentOps; // A mapping between an operation and any currently invalid patterns it has. - DenseMap> invalidPatterns; + DenseMap> + invalidPatterns; // A worklist of patterns to consider for legality. - llvm::SetVector patternWorklist; + llvm::SetVector patternWorklist; // Build the mapping from operations to the parent ops that may generate them. - for (auto &pattern : patterns) { - auto root = pattern->getRootKind(); + applicator.walkAllPatterns([&](const RewritePattern &pattern) { + OperationName root = pattern.getRootKind(); // Skip operations that are always known to be legal. if (target.getOpAction(root) == LegalizationAction::Legal) - continue; + return; // Add this pattern to the invalid set for the root op and record this root // as a parent for any generated operations. - invalidPatterns[root].insert(pattern.get()); - for (auto op : pattern->getGeneratedOps()) + invalidPatterns[root].insert(&pattern); + for (auto op : pattern.getGeneratedOps()) parentOps[op].insert(root); // Add this pattern to the worklist. - patternWorklist.insert(pattern.get()); - } + patternWorklist.insert(&pattern); + }); while (!patternWorklist.empty()) { auto *pattern = patternWorklist.pop_back_val(); @@ -1486,7 +1517,8 @@ } } -void OperationLegalizer::computeLegalizationGraphBenefit() { +void OperationLegalizer::computeLegalizationGraphBenefit( + DenseMap &legalizerPatterns) { // The smallest pattern depth, when legalizing an operation. DenseMap minPatternDepth; @@ -1511,9 +1543,9 @@ minPatternDepth.try_emplace(op, minDepth); // Compute the depth for each pattern used to legalize this operation. - SmallVector, 4> patternsByDepth; + SmallVector, 4> patternsByDepth; patternsByDepth.reserve(opPatternsIt->second.size()); - for (RewritePattern *pattern : opPatternsIt->second) { + for (const RewritePattern *pattern : opPatternsIt->second) { unsigned depth = 0; for (auto generatedOp : pattern->getGeneratedOps()) depth = std::max(depth, computeDepth(generatedOp) + 1); @@ -1534,8 +1566,8 @@ // Sort the patterns by those likely to be the most beneficial. llvm::array_pod_sort( patternsByDepth.begin(), patternsByDepth.end(), - [](const std::pair *lhs, - const std::pair *rhs) { + [](const std::pair *lhs, + const std::pair *rhs) { // First sort by the smaller pattern legalization depth. if (lhs->second != rhs->second) return llvm::array_pod_sort_comparator(&lhs->second, @@ -1560,6 +1592,21 @@ for (auto &opIt : legalizerPatterns) if (!minPatternDepth.count(opIt.first)) computeDepth(opIt.first); + + // Apply a cost model to the pattern applicator. We order patterns first by + // depth then benefit. `legalizerPatterns` contains per-op patterns by + // decreasing benefit. + applicator.applyCostModel([&](const RewritePattern &p) { + auto &list = legalizerPatterns[p.getRootKind()]; + + // If the pattern is not found, then it was removed and cannot be matched. + LegalizationPatterns::iterator it = llvm::find(list, &p); + if (it == list.end()) + return PatternBenefit::impossibleToMatch(); + + // Patterns found earlier in the list have higher benefit. + return PatternBenefit(std::distance(it, list.end())); + }); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,6 +39,9 @@ const OwningRewritePatternList &patterns) : PatternRewriter(ctx), matcher(patterns), folder(ctx) { worklist.reserve(64); + + // Apply a simple cost model based solely on pattern benefit. + matcher.applyDefaultCostModel(); } bool simplify(MutableArrayRef regions, int maxIterations); @@ -103,8 +106,7 @@ // be re-added to the worklist. This function should be called when an // operation is modified or removed, as it may trigger further // simplifications. - template - void addToWorklist(Operands &&operands) { + template void addToWorklist(Operands &&operands) { for (Value operand : operands) { // If the use count of this operand is now < 2, we re-add the defining // operation to the worklist. @@ -118,8 +120,8 @@ } } - /// The low-level pattern matcher. - RewritePatternMatcher matcher; + /// The low-level pattern applicator. + PatternApplicator matcher; /// The worklist for this transformation keeps track of the operations that /// need to be revisited, plus their index in the worklist. This allows us to @@ -192,12 +194,9 @@ continue; } - // Make sure that any new operations are inserted at this point. - setInsertionPoint(op); - // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do here. - changed |= matcher.matchAndRewrite(op, *this); + changed |= succeeded(matcher.matchAndRewrite(op, *this)); } // After applying patterns, make sure that the CFG of each of the regions is @@ -213,20 +212,21 @@ /// Rewrite the regions of the specified operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return true if no more patterns can be matched in -/// the result operation regions. -/// Note: This does not apply patterns to the top-level operation itself. +/// work-list driven manner. Return success if no more patterns can be matched +/// in the result operation regions. Note: This does not apply patterns to the +/// top-level operation itself. /// -bool mlir::applyPatternsAndFoldGreedily( - Operation *op, const OwningRewritePatternList &patterns) { +LogicalResult +mlir::applyPatternsAndFoldGreedily(Operation *op, + const OwningRewritePatternList &patterns) { return applyPatternsAndFoldGreedily(op->getRegions(), patterns); } - /// Rewrite the given regions, which must be isolated from above. -bool mlir::applyPatternsAndFoldGreedily( - MutableArrayRef regions, const OwningRewritePatternList &patterns) { +LogicalResult +mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, + const OwningRewritePatternList &patterns) { if (regions.empty()) - return true; + return success(); // The top-level operation must be known to be isolated from above to // prevent performing canonicalizations on operations defined at or above @@ -245,7 +245,7 @@ llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " << maxPatternMatchIterations << " times"; }); - return converged; + return success(converged); } //===----------------------------------------------------------------------===// @@ -259,9 +259,17 @@ public: explicit OpPatternRewriteDriver(MLIRContext *ctx, const OwningRewritePatternList &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) {} + : PatternRewriter(ctx), matcher(patterns), folder(ctx) { + // Apply a simple cost model based solely on pattern benefit. + matcher.applyDefaultCostModel(); + } - bool simplifyLocally(Operation *op, int maxIterations, bool &erased); + /// Performs the rewrites and folding only on `op`. The simplification + /// converges if the op is erased as a result of being folded, replaced, or + /// dead, or no more changes happen in an iteration. Returns success if the + /// rewrite converges in `maxIterations`. `erased` is set to true if `op` gets + /// erased. + LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased); // These are hooks implemented for PatternRewriter. protected: @@ -276,8 +284,8 @@ void notifyRootReplaced(Operation *op) override {} private: - /// The low-level pattern matcher. - RewritePatternMatcher matcher; + /// The low-level pattern applicator. + PatternApplicator matcher; /// Non-pattern based folder for operations. OperationFolder folder; @@ -288,12 +296,9 @@ } // anonymous namespace -/// Performs the rewrites and folding only on `op`. The simplification converges -/// if the op is erased as a result of being folded, replaced, or dead, or no -/// more changes happen in an iteration. Returns true if the rewrite converges -/// in `maxIterations`. `erased` is set to true if `op` gets erased. -bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations, - bool &erased) { +LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, + int maxIterations, + bool &erased) { bool changed = false; erased = false; opErasedViaPatternRewrites = false; @@ -305,7 +310,7 @@ if (isOpTriviallyDead(op)) { op->erase(); erased = true; - return true; + return success(); } // Try to fold this op. @@ -316,38 +321,34 @@ changed = true; if (!inPlaceUpdate) { erased = true; - return true; + return success(); } } - // Make sure that any new operations are inserted at this point. - setInsertionPoint(op); - // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do here. - changed |= matcher.matchAndRewrite(op, *this); + changed |= succeeded(matcher.matchAndRewrite(op, *this)); if ((erased = opErasedViaPatternRewrites)) - return true; + return success(); } while (changed && ++i < maxIterations); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return !changed; + return failure(changed); } /// Rewrites only `op` using the supplied canonicalization patterns and /// folding. `erased` is set to true if the op is erased as a result of being /// folded, replaced, or dead. -bool mlir::applyOpPatternsAndFold(Operation *op, - const OwningRewritePatternList &patterns, - bool *erased) { +LogicalResult mlir::applyOpPatternsAndFold( + Operation *op, const OwningRewritePatternList &patterns, bool *erased) { // Start the pattern driver. OpPatternRewriteDriver driver(op->getContext(), patterns); bool opErased; - bool converged = + LogicalResult converged = driver.simplifyLocally(op, maxPatternMatchIterations, opErased); if (erased) *erased = opErased; - LLVM_DEBUG(if (!converged) { + LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " << maxPatternMatchIterations << " times"; });