diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -730,15 +732,25 @@ return success(); } - auto lb = op.getLowerBound().getDefiningOp(); - auto ub = op.getUpperBound().getDefiningOp(); - if (!lb || !ub) + auto alb = op.getLowerBound().getDefiningOp(); + auto aub = op.getUpperBound().getDefiningOp(); + + // ID map: (d0)->d0 + auto id = AffineMap::getMultiDimIdentityMap(1, op.getContext()); + auto lb = + alb ? alb.getAffineValueMap() : AffineValueMap(id, op.getLowerBound()); + auto ub = + aub ? aub.getAffineValueMap() : AffineValueMap(id, op.getUpperBound()); + + AffineValueMap diffAffine; + AffineValueMap::difference(ub, lb, &diffAffine); + AffineExpr res = diffAffine.getResult(0); + if (!res.isa()) return failure(); + int64_t diff = res.cast().getValue(); // If the loop is known to have 0 iterations, remove it. - llvm::APInt lbValue = lb.getValue().cast().getValue(); - llvm::APInt ubValue = ub.getValue().cast().getValue(); - if (lbValue.sge(ubValue)) { + if (diff <= 0) { rewriter.replaceOp(op, op.getIterOperands()); return success(); } @@ -750,7 +762,7 @@ // If the loop is known to have 1 iteration, inline its body and remove the // loop. llvm::APInt stepValue = step.getValue().cast().getValue(); - if ((lbValue + stepValue).sge(ubValue)) { + if (stepValue.sge(diff)) { SmallVector blockArgs; blockArgs.reserve(op.getNumIterOperands() + 1); blockArgs.push_back(op.getLowerBound());