diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -318,6 +318,12 @@ /// Returns `*this` if `numResults` >= `this->getNumResults()`. AffineMap getMinorSubMap(unsigned numResults) const; + /// Get the largest known divisor of all map expressions. + /// For eg: for (d0, d1) -> (8*d0 + 4, 4*d1 + 2), the result is 2. + /// In the case of maps with no expressions or all zero constant expressions, + /// the largest known divisor is trivially the max uint64_t value. + uint64_t getLargestKnownDivisorOfMapExprs(); + friend ::llvm::hash_code hash_value(AffineMap arg); /// Methods supporting C API. diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include using namespace mlir; @@ -583,6 +584,163 @@ return result[0]; } +/// Returns the largest known divisor of `e`. Exploits information from the +/// values in `operands`. +static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef operands) { + // This method isn't aware of `operands`. + int64_t div = e.getLargestKnownDivisor(); + + // We now make use of operands for the case `e` is a dim expression. + // TODO: More powerful simplification would have to modify + // getLargestKnownDivisor to take `operands` and exploit that information as + // well for dim/sym expressions, but in that case, getLargestKnownDivisor + // can't be part of the IR library but of the `Analysis` library. The IR + // library can only really depend on simple O(1) checks. + auto dimExpr = e.dyn_cast(); + // If it's not a dim expr, `div` is the best we have. + if (!dimExpr) + return div; + + // We simply exploit information from loop IVs. + // We don't need to use mlir::getLargestKnownDivisorOfValue since the other + // desired simplifications are expected to be part of other + // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the + // LoopAnalysis library. + Value operand = operands[dimExpr.getPosition()]; + int64_t operandDivisor = 1; + // TODO: With the right accessors, this can be extended to + // LoopLikeOpInterface. + if (AffineForOp forOp = getForInductionVarOwner(operand)) { + if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) { + operandDivisor = forOp.getStep(); + } else { + uint64_t lbLargestKnownDivisor = + forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs(); + operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStep()); + } + } + return operandDivisor; +} + +/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e` +/// being an affine dim expression or a constant. +static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef operands, + int64_t k) { + if (auto constExpr = e.dyn_cast()) { + int64_t constVal = constExpr.getValue(); + return constVal >= 0 && constVal < k; + } + auto dimExpr = e.dyn_cast(); + if (!dimExpr) + return false; + Value operand = operands[dimExpr.getPosition()]; + // TODO: With the right accessors, this can be extended to + // LoopLikeOpInterface. + if (AffineForOp forOp = getForInductionVarOwner(operand)) + if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 && + forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) + return true; + // We don't consider other cases like `operand` being defined by a constant or + // an affine.apply op since such cases will already be handled by other + // patterns and propagation of loop IVs or constant would happen. + return false; +} + +/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d. +/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the +/// expression is in that form. +static bool isQTimesDPlusR(AffineExpr e, ArrayRef operands, int64_t &div, + AffineExpr "ientTimesDiv, AffineExpr &rem) { + auto bin = e.dyn_cast(); + if (!bin || bin.getKind() != AffineExprKind::Add) + return false; + + AffineExpr llhs = bin.getLHS(); + AffineExpr rlhs = bin.getRHS(); + div = getLargestKnownDivisor(llhs, operands); + if (isNonNegativeBoundedBy(rlhs, operands, div)) { + quotientTimesDiv = llhs; + rem = rlhs; + return true; + } + div = getLargestKnownDivisor(rlhs, operands); + if (isNonNegativeBoundedBy(llhs, operands, div)) { + quotientTimesDiv = rlhs; + rem = llhs; + return true; + } + return false; +} + +/// Simplify `expr` while exploiting information from the values in `operands`. +static void simplifyExprAndOperands(AffineExpr &expr, + ArrayRef operands) { + // We do this only for certain floordiv/mod expressions. + auto binExpr = expr.dyn_cast(); + if (!binExpr) + return; + + // Simplify the child expressions first. + auto lhs = binExpr.getLHS(); + auto rhs = binExpr.getRHS(); + simplifyExprAndOperands(lhs, operands); + simplifyExprAndOperands(rhs, operands); + expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs); + + binExpr = expr.dyn_cast(); + if (!binExpr || (binExpr.getKind() != AffineExprKind::FloorDiv && + binExpr.getKind() != AffineExprKind::Mod)) + return; + + auto rhsConst = rhs.dyn_cast(); + if (!rhsConst) + return; + + int64_t rhsConstVal = rhsConst.getValue(); + AffineExpr quotientTimesDiv, rem; + int64_t divisor; + + // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2) + // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if + // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c. + // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c. + if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) { + if (rhsConstVal % divisor == 0 && + binExpr.getKind() == AffineExprKind::FloorDiv) { + expr = quotientTimesDiv.floorDiv(rhsConst); + } else if (divisor % rhsConstVal == 0 && + binExpr.getKind() == AffineExprKind::Mod) { + expr = rem % rhsConst; + } + return; + } + + // Handle the simple case when the LHS expression can be either upper + // bounded or is a known multiple of RHS constant. + // lhs floordiv c -> 0 if 0 <= lhs < c, + // lhs mod c -> 0 if lhs % c = 0. + if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) && + binExpr.getKind() == AffineExprKind::FloorDiv) || + (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 && + binExpr.getKind() == AffineExprKind::Mod)) { + expr = getAffineConstantExpr(0, expr.getContext()); + } +} + +/// Simplify the map while exploiting information on the values in `operands`. +static void LLVM_ATTRIBUTE_UNUSED +simplifyMapWithOperands(AffineMap &map, ArrayRef operands) { + assert(map.getNumInputs() == operands.size() && "invalid operands for map"); + SmallVector newResults; + newResults.reserve(map.getNumResults()); + for (AffineExpr expr : map.getResults()) { + simplifyExprAndOperands(expr, operands); + newResults.push_back(expr); + } + map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, + map.getContext()); +} + /// Replace all occurrences of AffineExpr at position `pos` in `map` by the /// defining AffineApplyOp expression and operands. /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. @@ -1100,6 +1258,7 @@ SmallVector resultOperands(oldOperands); composeAffineMapAndOperands(&map, &resultOperands); canonicalizeMapAndOperands(&map, &resultOperands); + simplifyMapWithOperands(map, resultOperands); if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), resultOperands.begin())) return failure(); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" +#include using namespace mlir; @@ -241,6 +242,17 @@ return ::inferFromExprList(exprsList); } +uint64_t AffineMap::getLargestKnownDivisorOfMapExprs() { + uint64_t gcd = 0; + for (AffineExpr resultExpr : getResults()) { + uint64_t thisGcd = resultExpr.getLargestKnownDivisor(); + gcd = std::gcd(gcd, thisGcd); + } + if (gcd == 0) + gcd = std::numeric_limits::max(); + return gcd; +} + AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, MLIRContext *context) { SmallVector dimExprs; diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -98,13 +98,13 @@ %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - affine.for %i0 = 0 to 3 { + affine.for %i0 = 0 to 16 { %x0 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)> (%i0)[%c4] - affine.for %i1 = 0 to 3 { + affine.for %i1 = 0 to 16 { %x1 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)> (%i1)[%c8] - affine.for %i2 = 0 to 3 { + affine.for %i2 = 0 to 16 { %x2 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)> (%i2)[%c4] - affine.for %i3 = 0 to 3 { + affine.for %i3 = 0 to 16 { %x3 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)> (%i3)[%c8] %x40 = affine.apply affine_map<(d0, d1, d2, d3)[s0, s1] -> @@ -1150,3 +1150,41 @@ return %s: memref<32x64xf32> } } + +// ----- + +// Simplification of maps exploiting operand info. + +// CHECK-LABEL: func @simplify_with_operands +func.func @simplify_with_operands(%N: index, %A: memref) { + // CHECK-NEXT: affine.for %[[I:.*]] = 0 to %{{.*}} + affine.for %i = 0 to %N step 32 { + // CHECK-NEXT: affine.for %[[II:.*]] = 0 to 32 + affine.for %ii = 0 to 32 { + // %ii is less than 32 and %i divides 32. + // CHECK: affine.load %{{.*}}[0, 0] + %x = affine.load %A[%ii floordiv 32, %i mod 32] : memref + "test.foo"(%x) : (f32) -> () + + // %i is aligned at 32 boundary and %ii < 32. + // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 32] + %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 32] : memref + "test.foo"(%a) : (f32) -> () + // CHECK: affine.load %{{.*}}[%[[I]] floordiv 64, (%[[I]] + %[[II]]) mod 64] + %b = affine.load %A[(%i + %ii) floordiv 64, (%i + %ii) mod 64] : memref + "test.foo"(%b) : (f32) -> () + // CHECK: affine.load %{{.*}}[(%[[I]] + %[[II]]) floordiv 16, %[[II]] mod 16] + %c = affine.load %A[(%i + %ii) floordiv 16, (%i + %ii) mod 16] : memref + "test.foo"(%c) : (f32) -> () + } + } + + // Should not simplify. + affine.for %i = -1 to 32 { + // CHECK: affine.load %{{.*}}[%{{.*}} floordiv {{.*}}, %{{.*}} mod {{.*}}] : + %x = affine.load %A[%i floordiv 32, %i mod 32] : memref + "test.foo"(%x) : (f32) -> () + } + + return +}