diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1430,88 +1430,123 @@ return posLimit - posStart; } -// Detect the identifier at 'pos' (say id_r) as modulo of another identifier -// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q) -// could be detected as the floordiv of n. For eg: -// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=> -// id_r = id_n mod 4, id_q = id_n floordiv 4. -// lbConst and ubConst are the constant lower and upper bounds for 'pos' - -// pre-detected at the caller. +// Determine whether the identifier at 'pos' (say id_r) can be expressed as +// modulo of another known identifier (say id_n) w.r.t a constant. For example, +// if the following constraints hold true: +// ``` +// 0 <= id_r <= divisor - 1 +// id_n - (divisor * q_expr) = id_r +// ``` +// where `id_n` is a known identifier (called dividend), and `q_expr` is an +// `AffineExpr` (called the quotient expression), `id_r` can be written as: +// +// `id_r = id_n mod divisor`. +// +// Additionally, in a special case of the above constaints where `q_expr` is an +// identifier itself that is not yet known (say `id_q`), it can be written as a +// floordiv in the following way: +// +// `id_q = id_n floordiv divisor`. +// +// Returns true if the above mod or floordiv are detected, updating 'memo' with +// these new expressions. Returns false otherwise. static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, int64_t lbConst, int64_t ubConst, - SmallVectorImpl *memo) { + SmallVectorImpl &memo, + MLIRContext *context) { assert(pos < cst.getNumIds() && "invalid position"); - // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to - // id_n - divisor * id_q. If these are true, then id_n becomes the dividend - // and id_q the quotient when dividing id_n by the divisor. - + // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can + // be determined. if (lbConst != 0 || ubConst < 1) return false; - int64_t divisor = ubConst + 1; - // Now check for: id_r = id_n - divisor * id_q. As an example, we - // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0. - unsigned seenQuotient = 0, seenDividend = 0; - int quotientPos = -1, dividendPos = -1; - for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { - // id_n should have coeff 1 or -1. - if (std::abs(cst.atEq(r, pos)) != 1) + // Check for the aforementioned conditions in each equality. + for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); + curEquality < numEqualities; curEquality++) { + int64_t coefficientAtPos = cst.atEq(curEquality, pos); + // If current equality does not involve `id_r`, continue to the next + // equality. + if (coefficientAtPos == 0) continue; - // constant term should be 0. - if (cst.atEq(r, cst.getNumCols() - 1) != 0) + + // Constant term should be 0 in this equality. + if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0) continue; - unsigned c, f; - int quotientSign = 1, dividendSign = 1; - for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) { - if (c == pos) + + // Traverse through the equality and construct the dividend expression + // `dividendExpr`, to contain all the identifiers which are known and are + // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the + // `dividendExpr` gets simplified into a single identifier `id_n` discussed + // above. + auto dividendExpr = getAffineConstantExpr(0, context); + + // Track the terms that go into quotient expression, later used to detect + // additional floordiv. + unsigned quotientCount = 0; + int quotientPosition = -1; + int quotientSign = 1; + + // Consider each term in the current equality. + unsigned curId, e; + for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) { + // Ignore id_r. + if (curId == pos) + continue; + int64_t coefficientOfCurId = cst.atEq(curEquality, curId); + // Ignore ids that do not contribute to the current equality. + if (coefficientOfCurId == 0) + continue; + // Check if the current id goes into the quotient expression. + if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) { + quotientCount++; + quotientPosition = curId; + quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1; continue; - // The coefficient of the quotient should be +/-divisor. - // TODO: could be extended to detect an affine function for the quotient - // (i.e., the coeff could be a non-zero multiple of divisor). - int64_t v = cst.atEq(r, c) * cst.atEq(r, pos); - if (v == divisor || v == -divisor) { - seenQuotient++; - quotientPos = c; - quotientSign = v > 0 ? 1 : -1; } - // The coefficient of the dividend should be +/-1. - // TODO: could be extended to detect an affine function of the other - // identifiers as the dividend. - else if (v == -1 || v == 1) { - seenDividend++; - dividendPos = c; - dividendSign = v < 0 ? 1 : -1; - } else if (cst.atEq(r, c) != 0) { - // Cannot be inferred as a mod since the constraint has a coefficient - // for an identifier that's neither a unit nor the divisor (see TODOs - // above). + // Identifiers that are part of dividendExpr should be known. + if (!memo[curId]) break; - } + // Append the current identifier to the dividend expression. + dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId; } - if (c < f) - // Cannot be inferred as a mod since the constraint has a coefficient for - // an identifier that's neither a unit nor the divisor (see TODOs above). + + // Can't construct expression as it depends on a yet uncomputed id. + if (curId < e) continue; - // We are looking for exactly one identifier as the dividend. - if (seenDividend == 1 && seenQuotient >= 1) { - if (!(*memo)[dividendPos]) - return false; - // Successfully detected a mod. - (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign; - auto ub = cst.getConstantUpperBound(dividendPos); + // Express `id_r` in terms of the other ids collected so far. + if (coefficientAtPos > 0) + dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos); + else + dividendExpr = dividendExpr.floorDiv(-coefficientAtPos); + + // Simplify the expression. + dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(), + cst.getNumSymbolIds()); + // Only if the final dividend expression is just a single id (which we call + // `id_n`), we can proceed. + // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it + // to dims themselves. + auto dimExpr = dividendExpr.dyn_cast(); + if (!dimExpr) + continue; + + // Express `id_r` as `id_n % divisor` and store the expression in `memo`. + if (quotientCount >= 1) { + auto ub = cst.getConstantUpperBound(dimExpr.getPosition()); + // If `id_n` has an upperbound that is less than the divisor, mod can be + // eliminated altogether. if (ub.hasValue() && ub.getValue() < divisor) - // The mod can be optimized away. - (*memo)[pos] = (*memo)[dividendPos] * dividendSign; + memo[pos] = dimExpr; else - (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign; + memo[pos] = dimExpr % divisor; + // If a unique quotient `id_q` was seen, it can be expressed as + // `id_n floordiv divisor`. + if (quotientCount == 1 && !memo[quotientPosition]) + memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign; - if (seenQuotient == 1 && !(*memo)[quotientPos]) - // Successfully detected a floordiv as well. - (*memo)[quotientPos] = - (*memo)[dividendPos].floorDiv(divisor) * quotientSign; return true; } } @@ -1885,7 +1920,7 @@ // Detect an identifier as modulo of another identifier w.r.t a // constant. if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), - &memo)) { + memo, context)) { changed = true; continue; } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -3330,3 +3330,69 @@ // CHECK: affine.for // CHECK: affine.for // CHECK-NOT: affine.for + +// ----- + +// Expects fusion of producer into consumer at depth 4 and subsequent removal of +// source loop. +// CHECK-LABEL: func @unflatten4d +func @unflatten4d(%arg1: memref<7x8x9x10xf32>) { + %m = memref.alloc() : memref<5040xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 7 { + affine.for %i1 = 0 to 8 { + affine.for %i2 = 0 to 9 { + affine.for %i3 = 0 to 10 { + affine.store %cf7, %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32> + } + } + } + } + affine.for %i0 = 0 to 7 { + affine.for %i1 = 0 to 8 { + affine.for %i2 = 0 to 9 { + affine.for %i3 = 0 to 10 { + %v0 = affine.load %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32> + affine.store %v0, %arg1[%i0, %i1, %i2, %i3] : memref<7x8x9x10xf32> + } + } + } + } + return +} + +// CHECK: affine.for +// CHECK-NEXT: affine.for +// CHECK-NEXT: affine.for +// CHECK-NEXT: affine.for +// CHECK-NOT: affine.for +// CHECK: return + +// ----- + +// Expects fusion of producer into consumer at depth 2 and subsequent removal of +// source loop. +// CHECK-LABEL: func @unflatten2d_with_transpose +func @unflatten2d_with_transpose(%arg1: memref<8x7xf32>) { + %m = memref.alloc() : memref<56xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 7 { + affine.for %i1 = 0 to 8 { + affine.store %cf7, %m[8 * %i0 + %i1] : memref<56xf32> + } + } + affine.for %i0 = 0 to 8 { + affine.for %i1 = 0 to 7 { + %v0 = affine.load %m[%i0 + 8 * %i1] : memref<56xf32> + affine.store %v0, %arg1[%i0, %i1] : memref<8x7xf32> + } + } + return +} + +// CHECK: affine.for +// CHECK-NEXT: affine.for +// CHECK-NOT: affine.for +// CHECK: return