diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallBitVector.h" @@ -673,8 +674,168 @@ return false; } +/// Gets the constant lower bound on an `iv`. +static std::optional getLowerBound(Value iv) { + AffineForOp forOp = getForInductionVarOwner(iv); + if (forOp && forOp.hasConstantLowerBound()) + return forOp.getConstantLowerBound(); + return std::nullopt; +} + +/// Gets the constant upper bound on an affine.for `iv`. +static Optional getUpperBound(Value iv) { + AffineForOp forOp = getForInductionVarOwner(iv); + if (!forOp || !forOp.hasConstantUpperBound()) + return std::nullopt; + + // If its lower bound is also known, we can get a more precise bound + // whenever the step is not one. + if (forOp.hasConstantLowerBound()) { + return forOp.getConstantUpperBound() - 1 - + (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) % + forOp.getStep(); + } + return forOp.getConstantUpperBound() - 1; +} + +/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using +/// the constant lower and upper bounds for its inputs provided in +/// `constLowerBounds` and `constUpperBounds`. Return None if such a bound can't +/// be computed. This method only handles simple sum of product expressions +/// (w.r.t constant coefficients) so as to not depend on anything heavyweight in +/// `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 + ... + c_n are +/// handled. Expressions involving floordiv, ceildiv, mod or semi-affine ones +/// will lead a none being returned. +static std::optional +getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, + ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, bool isUpper) { + // Handle divs and mods. + if (auto binOpExpr = expr.dyn_cast()) { + // If the LHS of a floor or ceil is bounded and the RHS is a constant, we + // can compute an upper bound. + if (binOpExpr.getKind() == AffineExprKind::FloorDiv) { + auto rhsConst = binOpExpr.getRHS().dyn_cast(); + if (!rhsConst || rhsConst.getValue() < 1) + return std::nullopt; + auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols, + constLowerBounds, constUpperBounds, isUpper); + if (!bound) + return std::nullopt; + return mlir::floorDiv(*bound, rhsConst.getValue()); + } + if (binOpExpr.getKind() == AffineExprKind::CeilDiv) { + auto rhsConst = binOpExpr.getRHS().dyn_cast(); + if (rhsConst && rhsConst.getValue() >= 1) { + auto bound = + getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols, + constLowerBounds, constUpperBounds, isUpper); + if (!bound) + return std::nullopt; + return mlir::ceilDiv(*bound, rhsConst.getValue()); + } + return std::nullopt; + } + if (binOpExpr.getKind() == AffineExprKind::Mod) { + // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is + // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c + // (same "interval"), then lb mod c <= lhs mod c <= ub mod c. + auto rhsConst = binOpExpr.getRHS().dyn_cast(); + if (rhsConst && rhsConst.getValue() >= 1) { + int64_t rhsConstVal = rhsConst.getValue(); + auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols, + constLowerBounds, constUpperBounds, + /*isUpper=*/false); + auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols, + constLowerBounds, constUpperBounds, isUpper); + if (ub && lb && + floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal)) + return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal); + return isUpper ? rhsConstVal - 1 : 0; + } + } + } + // Flatten the expression. + SimpleAffineExprFlattener flattener(numDims, numSymbols); + flattener.walkPostOrder(expr); + ArrayRef flattenedExpr = flattener.operandExprStack.back(); + // TODO: Handle local variables. We can get hold of flattener.localExprs and + // get bound on the local expr recursively. + if (flattener.numLocals > 0) + return std::nullopt; + int64_t bound = 0; + // Substitute the constant lower or upper bound for the dimensional or + // symbolic input depending on `isUpper` to determine the bound. + for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) { + if (flattenedExpr[i] > 0) { + auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i]; + if (!constBound) + return std::nullopt; + bound += *constBound * flattenedExpr[i]; + } else if (flattenedExpr[i] < 0) { + auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i]; + if (!constBound) + return std::nullopt; + bound += *constBound * flattenedExpr[i]; + } + } + // Constant term. + bound += flattenedExpr.back(); + return bound; +} + +/// Determine a constant upper bound for `expr` if one exists while exploiting +/// values in `operands`. Note that the upper bound is an inclusive one. `expr` +/// is guaranteed to be less than or equal to it. +static Optional getUpperBound(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + ArrayRef operands) { + // Get the constant lower or upper bounds on the operands. + SmallVector> constLowerBounds, constUpperBounds; + constLowerBounds.reserve(operands.size()); + constUpperBounds.reserve(operands.size()); + for (Value operand : operands) { + constLowerBounds.push_back(getLowerBound(operand)); + constUpperBounds.push_back(getUpperBound(operand)); + } + + if (auto constExpr = expr.dyn_cast()) + return constExpr.getValue(); + + return getBoundForExpr(expr, numDims, numSymbols, constLowerBounds, + constUpperBounds, + /*isUpper=*/true); +} + +/// Determine a constant lower bound for `expr` if one exists while exploiting +/// values in `operands`. Note that the upper bound is an inclusive one. `expr` +/// is guaranteed to be less than or equal to it. +static Optional getLowerBound(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + ArrayRef operands) { + // Get the constant lower or upper bounds on the operands. + SmallVector> constLowerBounds, constUpperBounds; + constLowerBounds.reserve(operands.size()); + constUpperBounds.reserve(operands.size()); + for (Value operand : operands) { + constLowerBounds.push_back(getLowerBound(operand)); + constUpperBounds.push_back(getUpperBound(operand)); + } + + Optional lowerBound; + if (auto constExpr = expr.dyn_cast()) { + lowerBound = constExpr.getValue(); + } else { + lowerBound = getBoundForExpr(expr, numDims, numSymbols, constLowerBounds, + constUpperBounds, + /*isUpper=*/false); + } + return lowerBound; +} + /// Simplify `expr` while exploiting information from the values in `operands`. -static void simplifyExprAndOperands(AffineExpr &expr, +static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, + unsigned numSymbols, ArrayRef operands) { // We do this only for certain floordiv/mod expressions. auto binExpr = expr.dyn_cast(); @@ -684,13 +845,14 @@ // Simplify the child expressions first. AffineExpr lhs = binExpr.getLHS(); AffineExpr rhs = binExpr.getRHS(); - simplifyExprAndOperands(lhs, operands); - simplifyExprAndOperands(rhs, operands); + simplifyExprAndOperands(lhs, numDims, numSymbols, operands); + simplifyExprAndOperands(rhs, numDims, numSymbols, operands); expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs); binExpr = expr.dyn_cast(); - if (!binExpr || (binExpr.getKind() != AffineExprKind::FloorDiv && - binExpr.getKind() != AffineExprKind::Mod)) { + if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv && + expr.getKind() != AffineExprKind::CeilDiv && + expr.getKind() != AffineExprKind::Mod)) { return; } @@ -703,16 +865,50 @@ int64_t rhsConstVal = rhsConst.getValue(); // Undefined exprsessions aren't touched; IR can still be valid with them. - if (rhsConstVal == 0) + if (rhsConstVal <= 0) return; - AffineExpr quotientTimesDiv, rem; - int64_t divisor; + // Exploit constant lower/upper bounds to simplify a floordiv or mod. + MLIRContext *context = expr.getContext(); + std::optional lhsLbConst = + getLowerBound(lhs, numDims, numSymbols, operands); + std::optional lhsUbConst = + getUpperBound(lhs, numDims, numSymbols, operands); + if (lhsLbConst && lhsUbConst) { + int64_t lhsLbConstVal = *lhsLbConst; + int64_t lhsUbConstVal = *lhsUbConst; + // lhs floordiv c is a single value lhs is bounded in a range `c` that has + // the same quotient. + if (binExpr.getKind() == AffineExprKind::FloorDiv && + floorDiv(lhsLbConstVal, rhsConstVal) == + floorDiv(lhsUbConstVal, rhsConstVal)) { + expr = + getAffineConstantExpr(floorDiv(lhsLbConstVal, rhsConstVal), context); + return; + } + // lhs ceildiv c is a single value if the entire range has the same ceil + // quotient. + if (binExpr.getKind() == AffineExprKind::CeilDiv && + ceilDiv(lhsLbConstVal, rhsConstVal) == + ceilDiv(lhsUbConstVal, rhsConstVal)) { + expr = + getAffineConstantExpr(ceilDiv(lhsLbConstVal, rhsConstVal), context); + return; + } + // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs. + if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 && + lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) { + expr = lhs; + return; + } + } // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2) // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c. // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c. + AffineExpr quotientTimesDiv, rem; + int64_t divisor; if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) { if (rhsConstVal % divisor == 0 && binExpr.getKind() == AffineExprKind::FloorDiv) { @@ -745,7 +941,8 @@ SmallVector newResults; newResults.reserve(map.getNumResults()); for (AffineExpr expr : map.getResults()) { - simplifyExprAndOperands(expr, operands); + simplifyExprAndOperands(expr, map.getNumDims(), map.getNumSymbols(), + operands); newResults.push_back(expr); } map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1170,8 +1170,8 @@ "test.foo"(%x) : (f32) -> () // %i is aligned at 32 boundary and %ii < 32. - // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 32] - %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 32] : memref + // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 16] + %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 16] : memref "test.foo"(%a) : (f32) -> () // CHECK: affine.load %{{.*}}[%[[I]] floordiv 64, (%[[I]] + %[[II]]) mod 64] %b = affine.load %A[(%i + %ii) floordiv 64, (%i + %ii) mod 64] : memref @@ -1202,6 +1202,66 @@ return } +// CHECK-LABEL: func @simplify_div_mod_with_operands +func.func @simplify_div_mod_with_operands(%N: index, %A: memref<64xf32>, %unknown: index) { + // CHECK: affine.for %[[I:.*]] = 0 to 32 + %cst = arith.constant 1.0 : f32 + affine.for %i = 0 to 32 { + // CHECK: affine.store %{{.*}}, %{{.*}}[0] + affine.store %cst, %A[%i floordiv 32] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[1] + affine.store %cst, %A[(%i + 1) ceildiv 32] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[%[[I]]] + affine.store %cst, %A[%i mod 32] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[0] + affine.store %cst, %A[2 * %i floordiv 64] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[0] + affine.store %cst, %A[(%i mod 16) floordiv 16] : memref<64xf32> + + // The ones below can't be simplified. + affine.store %cst, %A[%i floordiv 16] : memref<64xf32> + affine.store %cst, %A[%i mod 16] : memref<64xf32> + affine.store %cst, %A[(%i mod 16) floordiv 15] : memref<64xf32> + affine.store %cst, %A[%i mod 31] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} floordiv 16] : memref<64xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 16] : memref<64xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[(%{{.*}} mod 16) floordiv 15] : memref<64xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 31] : memref<64xf32> + } + + affine.for %i = -8 to 32 { + // Can't be simplified. + // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} floordiv 32] : memref<64xf32> + affine.store %cst, %A[%i floordiv 32] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 32] : memref<64xf32> + affine.store %cst, %A[%i mod 32] : memref<64xf32> + // floordiv rounds toward -inf; (%i - 96) floordiv 64 will be -2. + // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32> + affine.store %cst, %A[2 + (%i - 96) floordiv 64] : memref<64xf32> + } + + // CHECK: affine.for %[[II:.*]] = 8 to 16 + affine.for %i = 8 to 16 { + // CHECK: affine.store %{{.*}}, %{{.*}}[1] : memref<64xf32> + affine.store %cst, %A[%i floordiv 8] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[2] : memref<64xf32> + affine.store %cst, %A[(%i + 1) ceildiv 8] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[%[[II]] mod 8] : memref<64xf32> + affine.store %cst, %A[%i mod 8] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[%[[II]]] : memref<64xf32> + affine.store %cst, %A[%i mod 32] : memref<64xf32> + // Upper bound on the mod 32 expression will be 15. + // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32> + affine.store %cst, %A[(%i mod 32) floordiv 16] : memref<64xf32> + // Lower bound on the mod 16 expression will be 8. + // CHECK: affine.store %{{.*}}, %{{.*}}[1] : memref<64xf32> + affine.store %cst, %A[(%i mod 16) floordiv 8] : memref<64xf32> + // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32> + affine.store %cst, %A[(%unknown mod 16) floordiv 16] : memref<64xf32> + } + return +} + // ----- // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>