diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -80,6 +80,9 @@ /// success if no more patterns can be matched. `erased` is set to true if `op` /// was folded away or erased as a result of becoming dead. Note: This does not /// apply any patterns recursively to the regions of `op`. +/// +/// Returns success if the iterative process converged and no more patterns can +/// be matched. LogicalResult applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns, bool *erased = nullptr); @@ -93,10 +96,14 @@ /// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a /// result of folding, becoming dead, or via pattern rewrites. If more far /// reaching simplification is desired, applyPatternsAndFoldGreedily should be -/// used. Returns true if at all any IR was rewritten. -bool applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict); +/// used. +/// +/// `changed` is set to true if the IR was modified at all. Returns success if +/// Returns success if the iterative process converged and no more patterns can +/// be matched. +LogicalResult applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + bool strict, bool *changed = nullptr); } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -131,7 +131,12 @@ SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); // Apply the simplification pattern to a fixpoint. - (void)applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true); + if (failed( + applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true))) { + auto diag = emitDefiniteFailure() + << "affine.min/max simplification did not converge"; + return diag; + } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -574,7 +574,8 @@ : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), strictMode(strict) {} - bool simplifyLocally(ArrayRef op); + LogicalResult simplifyLocally(ArrayRef op, + bool *changed = nullptr); void addToWorklist(Operation *op) override { if (!strictMode || strictModeFilteredOps.contains(op)) @@ -625,13 +626,16 @@ // there is no strong rationale to re-add all operations into the worklist and // rerun until an iteration changes nothing. If more widereaching simplification // is desired, GreedyPatternRewriteDriver should be used. -bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops) { +LogicalResult +MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, + bool *changed) { if (strictMode) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); } - bool changed = false; + if (changed) + *changed = false; worklist.clear(); worklistMap.clear(); for (Operation *op : ops) @@ -657,7 +661,8 @@ if (isOpTriviallyDead(op)) { notifyOperationRemoved(op); op->erase(); - changed = true; + if (changed) + *changed = true; continue; } @@ -687,7 +692,8 @@ bool inPlaceUpdate; if (succeeded(folder.tryToFold(op, processGeneratedConstants, preReplaceAction, &inPlaceUpdate))) { - changed = true; + if (changed) + *changed = true; if (!inPlaceUpdate) { // Op has been erased. continue; @@ -698,12 +704,13 @@ // notified of any necessary changes, so there is nothing else to do // here. if (succeeded(matcher.matchAndRewrite(op, *this))) { - changed = true; + if (changed) + *changed = true; ++numRewrites; } } - return changed; + return success(worklist.empty()); } /// Rewrites only `op` using the supplied canonicalization patterns and @@ -726,14 +733,18 @@ return converged; } -bool mlir::applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict) { - if (ops.empty()) - return false; +LogicalResult +mlir::applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + bool strict, bool *changed) { + if (ops.empty()) { + if (changed) + *changed = false; + return success(); + } // Start the pattern driver. MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, strict); - return driver.simplifyLocally(ops); + return driver.simplifyLocally(ops, changed); }