diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -1452,6 +1452,7 @@ template Optional IntegerPolyhedron::computeConstantLowerOrUpperBound(unsigned pos) { + llvm::errs() << "pos = " << pos << ", numIds = " << getNumIds() << "\n"; assert(pos < getNumIds() && "invalid position"); // Project to 'pos'. projectOut(0, pos); @@ -1509,11 +1510,21 @@ Optional IntegerPolyhedron::getConstantBound(BoundType type, unsigned pos) const { - assert(type != BoundType::EQ && "EQ not implemented"); - IntegerPolyhedron tmpCst(*this); if (type == BoundType::LB) - return tmpCst.computeConstantLowerOrUpperBound(pos); - return tmpCst.computeConstantLowerOrUpperBound(pos); + return IntegerPolyhedron(*this) + .computeConstantLowerOrUpperBound(pos); + if (type == BoundType::UB) + return IntegerPolyhedron(*this) + .computeConstantLowerOrUpperBound(pos); + + assert(type == BoundType::EQ && "expected EQ"); + Optional lb = + IntegerPolyhedron(*this) + .computeConstantLowerOrUpperBound(pos); + Optional ub = + IntegerPolyhedron(*this) + .computeConstantLowerOrUpperBound(pos); + return (lb && ub && *lb == *ub) ? Optional(*ub) : None; } // A simple (naive and conservative) check for hyper-rectangularity. diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp --- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp @@ -183,6 +183,16 @@ AffineMap newMap = alignedBoundMap; SmallVector newOperands; unpackOptionalValues(constraints.getMaybeDimAndSymbolValues(), newOperands); + // If dims/symbols have known constant values, use those in order to simplify + // the affine map further. + for (int64_t i = 0, e = constraints.getNumIds(); i < e; ++i) { + // Skip unused operands and operands that are already constants. + if (!newOperands[i] || getConstantIntValue(newOperands[i])) + continue; + if (auto bound = constraints.getConstantBound(FlatAffineConstraints::EQ, i)) + newOperands[i] = + rewriter.create(op->getLoc(), *bound); + } mlir::canonicalizeMapAndOperands(&newMap, &newOperands); rewriter.setInsertionPoint(op); rewriter.replaceOpWithNewOp(op, newMap, newOperands); @@ -211,19 +221,29 @@ if (ubInt) constraints.addBound(FlatAffineConstraints::EQ, dimUb, *ubInt); - // iv >= lb (equiv.: iv - lb >= 0) + // Lower bound: iv >= lb (equiv.: iv - lb >= 0) SmallVector ineqLb(constraints.getNumCols(), 0); ineqLb[dimIv] = 1; ineqLb[dimLb] = -1; constraints.addInequality(ineqLb); - // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 - AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt) - : rewriter.getAffineDimExpr(dimLb); - AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt) - : rewriter.getAffineDimExpr(dimUb); - AffineExpr ivUb = - exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt))); + // Upper bound + AffineExpr ivUb; + if (lbInt && ubInt && (*lbInt + *stepInt >= *ubInt)) { + // The loop has at most one iteration. + // iv < lb + 1 + // TODO: Try to derive this constraint by simplifying the expression in + // the else-branch. + ivUb = rewriter.getAffineDimExpr(dimLb) + 1; + } else { + // The loop may have more than one iteration. + // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 + AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt) + : rewriter.getAffineDimExpr(dimLb); + AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt) + : rewriter.getAffineDimExpr(dimUb); + ivUb = exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt))); + } auto map = AffineMap::get( /*dimCount=*/constraints.getNumDimIds(), /*symbolCount=*/constraints.getNumSymbolIds(), /*result=*/ivUb); diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir --- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir +++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir @@ -349,3 +349,22 @@ %dim = tensor.dim %1, %c0 : tensor return %dim : index } + +// ----- + +// CHECK-LABEL: func @one_trip_scf_for_canonicalize_min +// CHECK: %[[C4:.*]] = arith.constant 4 : i64 +// CHECK: scf.for +// CHECK: memref.store %[[C4]], %{{.*}}[] : memref +func @one_trip_scf_for_canonicalize_min(%A : memref) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + scf.for %i = %c0 to %c4 step %c4 { + %1 = affine.min affine_map<(d0, d1)[] -> (4, d1 - d0)> (%i, %c4) + %2 = arith.index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +}