diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -715,6 +717,26 @@ } }; +/// Util function that tries to compute a constant diff between u and l. +/// Returns llvm::None when the difference between two AffineValueMap is +/// dynamic. +static Optional computeConstDiff(Value l, Value u) { + auto alb = l.getDefiningOp(); + auto aub = u.getDefiningOp(); + // ID map: (d0)->d0 + auto id = AffineMap::getMultiDimIdentityMap(1, l.getContext()); + auto lb = alb ? alb.getAffineValueMap() : AffineValueMap(id, l); + auto ub = aub ? aub.getAffineValueMap() : AffineValueMap(id, u); + + AffineValueMap diffMap; + AffineValueMap::difference(ub, lb, &diffMap); + AffineExpr diff = diffMap.getResult(0); + if (!diff.isa()) + return llvm::None; + + return diff.cast().getValue(); +} + /// Rewriting pattern that erases loops that are known not to iterate, replaces /// single-iteration loops with their bodies, and removes empty loops that /// iterate at least once and only return values defined outside of the loop. @@ -730,15 +752,13 @@ return success(); } - auto lb = op.getLowerBound().getDefiningOp(); - auto ub = op.getUpperBound().getDefiningOp(); - if (!lb || !ub) + Optional diff = + computeConstDiff(op.getLowerBound(), op.getUpperBound()); + if (!diff) return failure(); // If the loop is known to have 0 iterations, remove it. - llvm::APInt lbValue = lb.getValue().cast().getValue(); - llvm::APInt ubValue = ub.getValue().cast().getValue(); - if (lbValue.sge(ubValue)) { + if (*diff <= 0) { rewriter.replaceOp(op, op.getIterOperands()); return success(); } @@ -750,7 +770,7 @@ // If the loop is known to have 1 iteration, inline its body and remove the // loop. llvm::APInt stepValue = step.getValue().cast().getValue(); - if ((lbValue + stepValue).sge(ubValue)) { + if (stepValue.sge(*diff)) { SmallVector blockArgs; blockArgs.reserve(op.getNumIterOperands() + 1); blockArgs.push_back(op.getLowerBound()); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2678,6 +2678,7 @@ ), includes = ["include"], deps = [ + ":AffineDialect", ":ArithmeticDialect", ":ArithmeticUtils", ":BufferizationDialect",