diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -374,7 +374,6 @@ /// Returns the constant bound for the pos^th identifier if there is one; /// None otherwise. - // TODO: Support EQ bounds. Optional getConstantBound(BoundType type, unsigned pos) const; /// Gets the lower and upper bound of the `offset` + `pos`th identifier diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -2828,11 +2828,22 @@ Optional FlatAffineConstraints::getConstantBound(BoundType type, unsigned pos) const { - assert(type != BoundType::EQ && "EQ not implemented"); FlatAffineConstraints tmpCst(*this); if (type == BoundType::LB) - return tmpCst.computeConstantLowerOrUpperBound(pos); - return tmpCst.computeConstantLowerOrUpperBound(pos); + return FlatAffineConstraints(*this) + .computeConstantLowerOrUpperBound(pos); + if (type == BoundType::UB) + return FlatAffineConstraints(*this) + .computeConstantLowerOrUpperBound(pos); + + assert(type == BoundType::EQ && "expected EQ"); + Optional lb = + FlatAffineConstraints(*this) + .computeConstantLowerOrUpperBound(pos); + Optional ub = + FlatAffineConstraints(*this) + .computeConstantLowerOrUpperBound(pos); + return (lb && ub && *lb == *ub) ? Optional(*ub) : None; } // A simple (naive and conservative) check for hyper-rectangularity. diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -305,6 +305,16 @@ AffineMap newMap = alignedBoundMap; SmallVector newOperands; unpackOptionalValues(constraints.getMaybeDimAndSymbolValues(), newOperands); + // If dims/symbols have known constant values, use those in order to simplify + // the affine map further. + for (int64_t i = 0; i < constraints.getNumDimAndSymbolIds(); ++i) { + // Skip unused operands and operands that are already constants. + if (!newOperands[i] || getConstantIntValue(newOperands[i])) + continue; + if (auto bound = constraints.getConstantBound(FlatAffineConstraints::EQ, i)) + newOperands[i] = + rewriter.create(op->getLoc(), *bound); + } mlir::canonicalizeMapAndOperands(&newMap, &newOperands); rewriter.setInsertionPoint(op); rewriter.replaceOpWithNewOp(op, newMap, newOperands); @@ -457,19 +467,30 @@ if (ubInt) constraints.addBound(FlatAffineConstraints::EQ, dimUb, *ubInt); - // iv >= lb (equiv.: iv - lb >= 0) + // Lower bound: iv >= lb (equiv.: iv - lb >= 0) SmallVector ineqLb(constraints.getNumCols(), 0); ineqLb[dimIv] = 1; ineqLb[dimLb] = -1; constraints.addInequality(ineqLb); - // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 - AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt) - : rewriter.getAffineDimExpr(dimLb); - AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt) - : rewriter.getAffineDimExpr(dimUb); - AffineExpr ivUb = - exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt))); + // Upper bound + AffineExpr ivUb; + if (lbInt && ubInt && (*lbInt + *stepInt >= *ubInt)) { + // The loop has at most one iteration. + // iv < lb + 1 + // TODO: Try to derive this constraint by simplifying the expression in + // the else-branch. + ivUb = rewriter.getAffineDimExpr(dimLb) + 1; + } else { + // The loop may have more than one iteration. + // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 + AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt) + : rewriter.getAffineDimExpr(dimLb); + AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt) + : rewriter.getAffineDimExpr(dimUb); + ivUb = + exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt))); + } auto map = AffineMap::get( /*dimCount=*/constraints.getNumDimIds(), /*symbolCount=*/constraints.getNumSymbolIds(), /*result=*/ivUb); diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir --- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir +++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir @@ -348,3 +348,22 @@ %dim = tensor.dim %1, %c0 : tensor return %dim : index } + +// ----- + +// CHECK-LABEL: func @one_trip_scf_for_canonicalize_min +// CHECK: %[[C4:.*]] = arith.constant 4 : i64 +// CHECK: scf.for +// CHECK: memref.store %[[C4]], %{{.*}}[] : memref +func @one_trip_scf_for_canonicalize_min(%A : memref) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + scf.for %i = %c0 to %c4 step %c4 { + %1 = affine.min affine_map<(d0, d1)[] -> (4, d1 - d0)> (%i, %c4) + %2 = arith.index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +}