diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -293,10 +293,12 @@ return U(expr); } -/// Simplify an affine expression by flattening and some amount of -/// simple analysis. This has complexity linear in the number of nodes in -/// 'expr'. Returns the simplified expression, which is the same as the input -/// expression if it can't be simplified. +/// Simplify an affine expression by flattening and some amount of simple +/// analysis. This has complexity linear in the number of nodes in 'expr'. +/// Returns the simplified expression, which is the same as the input expression +/// if it can't be simplified. When `expr` is semi-affine, a simplified +/// semi-affine expression is constructed in the sorted order of dimension and +/// symbol positions. AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols); diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -299,7 +299,26 @@ virtual void addLocalFloorDivId(ArrayRef dividend, int64_t divisor, AffineExpr localExpr); + /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul + /// expr) when the rhs is a symbolic expression. The local identifier added + /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine + /// function of other identifiers, coefficients of which are specified in the + /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a + /// symbolic rhs expression. `localExpr` is the simplified tree expression + /// (AffineExpr) corresponding to the quantifier. + virtual void addLocalIdSemiAffine(AffineExpr localExpr); + private: + /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression + /// representing the affine expression corresponding to the quantifier + /// introduced as the local variable corresponding to `expr`. If the + /// quantifier is already present, we put the coefficient in the proper index + /// of `result`, otherwise we add a new local variable and put the coefficient + /// there. + void addLocalVariableSemiAffine(AffineExpr expr, + SmallVectorImpl &result, + unsigned long resultSize); + // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 // A floordiv is thus flattened by introducing a new local variable q, and // replacing that expression with 'q' while adding the constraints diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -903,21 +903,213 @@ return expr; } +/// Constructs a semi-affine expression from a flat ArrayRef. If there are +/// local identifiers (neither dimensional nor symbolic) that appear in the sum +/// of products expression, `localExprs` is expected to have the AffineExprs for +/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in +/// the format [dims, symbols, locals, constant term]. The semi-affine +/// expression is constructed in the sorted order of dimension and symbol +/// position numbers. Note: local expressions/ids are used for mod, div as well +/// as symbolic RHS terms for terms that are not pure affine. +static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef flatExprs, + unsigned numDims, + unsigned numSymbols, + ArrayRef localExprs, + MLIRContext *context) { + assert(!flatExprs.empty() && "flatExprs cannot be empty"); + + // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. + assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && + "unexpected number of local expressions"); + + AffineExpr expr = getAffineConstantExpr(0, context); + + // We design indices as a pair which help us present the semi-affine map as + // sum of product where terms are sorted based on dimension or symbol + // position: for expressions of the form dimension * symbol, + // where keyA is the position number of the dimension and keyB is the + // position number of the symbol. For dimensional expressions we set the index + // as (position number of the dimension, -1), as we want dimensional + // expressions to appear before symbolic and product of dimensional and + // symbolic expressions having the dimension with the same position number. + // For symbolic expression set the index as (position number of the symbol, + // maximum of last dimension and symbol position) number. For example, we want + // the expression we are constructing to look something like: d0 + d0 * s0 + + // s0 + d1*s1 + s1. + + // Stores the affine expression corresponding to a given index. + DenseMap, AffineExpr> indexToExprMap; + // Stores the constant coefficient value corresponding to a given + // dimension, symbol or a non-pure affine expression stored in `localExprs`. + DenseMap, int64_t> coefficients; + // Stores the indices as defined above, and later sorted to produce + // the semi-affine expression in the desired form. + SmallVector, 8> indices; + + // Example: expression = d0 + d0 * s0 + 2 * s0. + // indices = [{0,-1}, {0, 0}, {0, 1}] + // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}] + // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}] + + // Adds entries to `indexToExprMap`, `coefficients` and `indices`. + auto addEntry = [&](std::pair index, int64_t coefficient, + AffineExpr expr) { + assert(std::find(indices.begin(), indices.end(), index) == indices.end() && + "Key is already present in indices vector and overwriting will " + "happen in `indexToExprMap` and `coefficients`!"); + + indices.push_back(index); + coefficients.insert({index, coefficient}); + indexToExprMap.insert({index, expr}); + }; + + // Design indices for dimensional or symbolic terms, and store the indices, + // constant coefficient corresponding to the indices in `coefficients` map, + // and affine expression corresponding to indices in `indexToExprMap` map. + + for (unsigned j = 0; j < numDims; ++j) { + if (flatExprs[j] == 0) + continue; + // For dimensional expressions we set the index as , as we want dimensional expressions to appear before + // symbolic ones and products of dimensional and symbolic expressions + // having the dimension with the same position number. + std::pair indexEntry(j, -1); + addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context)); + } + for (unsigned j = numDims; j < numDims + numSymbols; ++j) { + if (flatExprs[j] == 0) + continue; + // For symbolic expression set the index as number, + // as we want symbolic expressions with the same positional number to + // appear after dimensional expressions having the same positional number. + std::pair indexEntry(j - numDims, + std::max(numDims, numSymbols)); + addEntry(indexEntry, flatExprs[j], + getAffineSymbolExpr(j - numDims, context)); + } + + // Denotes semi-affine product, modulo or division terms, which has been added + // to the `indexToExpr` map. + SmallVector addedToMap(flatExprs.size() - numDims - numSymbols - 1, + false); + unsigned lhsPos, rhsPos; + // Construct indices for product terms involving dimension, symbol or constant + // as lhs/rhs, and store the indices, constant coefficient corresponding to + // the indices in `coefficients` map, and affine expression corresponding to + // in indices in `indexToExprMap` map. + for (auto it : llvm::enumerate(localExprs)) { + AffineExpr expr = it.value(); + if (flatExprs[numDims + numSymbols + it.index()] == 0) + continue; + AffineExpr lhs = expr.cast().getLHS(); + AffineExpr rhs = expr.cast().getRHS(); + if (!((lhs.isa() || lhs.isa()) && + (rhs.isa() || rhs.isa() || + rhs.isa()))) { + continue; + } + if (rhs.isa()) { + // For product/modulo/division expressions, when rhs of modulo/division + // expression is constant, we put 0 in place of keyB, because we want + // them to appear earlier in the semi-affine expression we are + // constructing. When rhs is constant, we place 0 in place of keyB. + if (lhs.isa()) { + lhsPos = lhs.cast().getPosition(); + std::pair indexEntry(lhsPos, -1); + addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], + expr); + } else { + lhsPos = lhs.cast().getPosition(); + std::pair indexEntry(lhsPos, + std::max(numDims, numSymbols)); + addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], + expr); + } + } else if (lhs.isa()) { + // For product/modulo/division expressions having lhs as dimension and rhs + // as symbol, we order the terms in the semi-affine expression based on + // the pair: for expressions of the form dimension * symbol, + // where keyA is the position number of the dimension and keyB is the + // position number of the symbol. + lhsPos = lhs.cast().getPosition(); + rhsPos = rhs.cast().getPosition(); + std::pair indexEntry(lhsPos, rhsPos); + addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); + } else { + // For product/modulo/division expressions having both lhs and rhs as + // symbol, we design indices as a pair: for expressions + // of the form dimension * symbol, where keyA is the position number of + // the dimension and keyB is the position number of the symbol. + lhsPos = lhs.cast().getPosition(); + rhsPos = rhs.cast().getPosition(); + std::pair indexEntry(lhsPos, rhsPos); + addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); + } + addedToMap[it.index()] = true; + } + + // Constructing the simplified semi-affine sum of product/division/mod + // expression from the flattened form in the desired sorted order of indices + // of the various individual product/division/mod expressions. + std::sort(indices.begin(), indices.end()); + for (const std::pair index : indices) { + assert(indexToExprMap.lookup(index) && + "cannot find key in `indexToExprMap` map"); + expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index); + } + + // Local identifiers. + for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; + j++) { + // If the coefficient of the local expression is 0, continue as we need not + // add it in out final expression. + if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols]) + continue; + auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; + expr = expr + term; + } + + // Constant term. + int64_t constTerm = flatExprs.back(); + if (constTerm != 0) + expr = expr + constTerm; + return expr; +} + SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols) : numDims(numDims), numSymbols(numSymbols), numLocals(0) { operandExprStack.reserve(8); } +// In pure affine t = expr * c, we multiply each coefficient of lhs with c. +// +// In case of semi affine multiplication expressions, t = expr * symbolic_expr, +// introduce a local variable p (= expr * symbolic_expr), and the affine +// expression expr * symbolic_expr is added to `localExprs`. void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); - // This is a pure affine expr; the RHS will be a constant. - assert(expr.getRHS().isa()); - // Get the RHS constant. - auto rhsConst = operandExprStack.back()[getConstantIndex()]; + SmallVector rhs = operandExprStack.back(); operandExprStack.pop_back(); - // Update the LHS in place instead of pop and push. - auto &lhs = operandExprStack.back(); + SmallVector &lhs = operandExprStack.back(); + + // Flatten semi-affine multiplication expressions by introducing a local + // variable in place of the product; the affine expression + // corresponding to the quantifier is added to `localExprs`. + if (!expr.getRHS().isa()) { + MLIRContext *context = expr.getContext(); + AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, + localExprs, context); + AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, + localExprs, context); + addLocalVariableSemiAffine(a * b, lhs, lhs.size()); + return; + } + + // Get the RHS constant. + auto rhsConst = rhs[getConstantIndex()]; for (unsigned i = 0, e = lhs.size(); i < e; i++) { lhs[i] *= rhsConst; } @@ -942,13 +1134,32 @@ // A mod expression "expr mod c" is thus flattened by introducing a new local // variable q (= expr floordiv c), such that expr mod c is replaced with // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. +// +// In case of semi-affine modulo expressions, t = expr mod symbolic_expr, +// introduce a local variable m (= expr mod symbolic_expr), and the affine +// expression expr mod symbolic_expr is added to `localExprs`. void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); - // This is a pure affine expr; the RHS will be a constant. - assert(expr.getRHS().isa()); - auto rhsConst = operandExprStack.back()[getConstantIndex()]; + + SmallVector rhs = operandExprStack.back(); operandExprStack.pop_back(); - auto &lhs = operandExprStack.back(); + SmallVector &lhs = operandExprStack.back(); + MLIRContext *context = expr.getContext(); + + // Flatten semi affine modulo expressions by introducing a local + // variable in place of the modulo value, and the affine expression + // corresponding to the quantifier is added to `localExprs`. + if (!expr.getRHS().isa()) { + AffineExpr dividendExpr = getAffineExprFromFlatForm( + lhs, numDims, numSymbols, localExprs, context); + AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols, + localExprs, context); + AffineExpr modExpr = dividendExpr % divisorExpr; + addLocalVariableSemiAffine(modExpr, lhs, lhs.size()); + return; + } + + int64_t rhsConst = rhs[getConstantIndex()]; // TODO: handle modulo by zero case when this issue is fixed // at the other places in the IR. assert(rhsConst > 0 && "RHS constant has to be positive"); @@ -979,11 +1190,11 @@ int64_t floorDivisor = rhsConst / static_cast(gcd); // Construct the AffineExpr form of the floordiv to store in localExprs. - MLIRContext *context = expr.getContext(); - auto dividendExpr = getAffineExprFromFlatForm( + + AffineExpr dividendExpr = getAffineExprFromFlatForm( floorDividend, numDims, numSymbols, localExprs, context); - auto divisorExpr = getAffineConstantExpr(floorDivisor, context); - auto floorDivExpr = dividendExpr.floorDiv(divisorExpr); + AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context); + AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr); int loc; if ((loc = findLocalId(floorDivExpr)) == -1) { addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); @@ -1022,6 +1233,21 @@ eq[getConstantIndex()] = expr.getValue(); } +void SimpleAffineExprFlattener::addLocalVariableSemiAffine( + AffineExpr expr, SmallVectorImpl &result, + unsigned long resultSize) { + assert(result.size() == resultSize && + "`result` vector passed is not of correct size"); + int loc; + if ((loc = findLocalId(expr)) == -1) + addLocalIdSemiAffine(expr); + std::fill(result.begin(), result.end(), 0); + if (loc == -1) + result[getLocalVarStartIndex() + numLocals - 1] = 1; + else + result[getLocalVarStartIndex() + loc] = 1; +} + // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 // A floordiv is thus flattened by introducing a new local variable q, and // replacing that expression with 'q' while adding the constraints @@ -1030,18 +1256,38 @@ // // A ceildiv is similarly flattened: // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c +// +// In case of semi affine division expressions, t = expr floordiv symbolic_expr +// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr +// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to +// `localExprs`. void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) { assert(operandExprStack.size() >= 2); - assert(expr.getRHS().isa()); + + MLIRContext *context = expr.getContext(); + SmallVector rhs = operandExprStack.back(); + operandExprStack.pop_back(); + SmallVector &lhs = operandExprStack.back(); + + // Flatten semi affine division expressions by introducing a local + // variable in place of the quotient, and the affine expression corresponding + // to the quantifier is added to `localExprs`. + if (!expr.getRHS().isa()) { + AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, + localExprs, context); + AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, + localExprs, context); + AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); + addLocalVariableSemiAffine(divExpr, lhs, lhs.size()); + return; + } // This is a pure affine expr; the RHS is a positive constant. - int64_t rhsConst = operandExprStack.back()[getConstantIndex()]; + int64_t rhsConst = rhs[getConstantIndex()]; // TODO: handle division by zero at the same time the issue is // fixed at other places. assert(rhsConst > 0 && "RHS constant has to be positive"); - operandExprStack.pop_back(); - auto &lhs = operandExprStack.back(); // Simplify the floordiv, ceildiv if possible by canceling out the greatest // common divisors of the numerator and denominator. @@ -1063,13 +1309,12 @@ // the ceil/floor expr (simplified up until here). Add an existential // quantifier to express its result, i.e., expr1 div expr2 is replaced // by a new identifier, q. - MLIRContext *context = expr.getContext(); - auto a = + AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context); - auto b = getAffineConstantExpr(divisor, context); + AffineExpr b = getAffineConstantExpr(divisor, context); int loc; - auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); + AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); if ((loc = findLocalId(divExpr)) == -1) { if (!isCeil) { SmallVector dividend(lhs); @@ -1099,13 +1344,20 @@ int64_t divisor, AffineExpr localExpr) { assert(divisor > 0 && "positive constant divisor expected"); - for (auto &subExpr : operandExprStack) + for (SmallVector &subExpr : operandExprStack) subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); localExprs.push_back(localExpr); numLocals++; // dividend and divisor are not used here; an override of this method uses it. } +void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) { + for (SmallVector &subExpr : operandExprStack) + subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); + localExprs.push_back(localExpr); + ++numLocals; +} + int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) { SmallVectorImpl::iterator it; if ((it = llvm::find(localExprs, localExpr)) == localExprs.end()) @@ -1119,17 +1371,25 @@ // Simplify semi-affine expressions separately. if (!expr.isPureAffine()) expr = simplifySemiAffine(expr); - if (!expr.isPureAffine()) - return expr; SimpleAffineExprFlattener flattener(numDims, numSymbols); flattener.walkPostOrder(expr); ArrayRef flattenedExpr = flattener.operandExprStack.back(); - auto simplifiedExpr = - getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, - flattener.localExprs, expr.getContext()); + if (!expr.isPureAffine() && + expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, + flattener.localExprs, + expr.getContext())) + return expr; + AffineExpr simplifiedExpr = + expr.isPureAffine() + ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, + flattener.localExprs, expr.getContext()) + : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, + flattener.localExprs, + expr.getContext()); + flattener.operandExprStack.pop_back(); assert(flattener.operandExprStack.empty()); - + simplifiedExpr.dump(); return simplifiedExpr; } diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -479,3 +479,54 @@ } return %res#0, %res#1, %res#2 : index, index, index } + +// ----- + +// Test simplification of mod expressions. +// CHECK-DAG: #[[MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)> +// CHECK-DAG: #[[SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))> +// CHECK-DAG: #[[MODULO_AND_PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s3 - (-s0 + s3) mod s2)> +// CHECK: func @semiaffine_simplification_mod(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index) +func @semiaffine_simplification_mod(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index, index) { + %a = affine.apply affine_map<(d0, d1)[s0, s1, s2, s3] -> ((-(d1 * s0 - (s0 - s1) mod s2) + s3) + (d0 * s1 + d1 * s0))>(%arg0, %arg1)[%arg2, %arg3, %arg4, %arg5] + %b = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 mod (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] + %c = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + (d0 + s0) mod s2 + s0 * s1 - (d0 + s0) mod s2 - (d0 - s0) mod s2)>(%arg0)[%arg1, %arg2, %arg3] + return %a, %b, %c : index, index, index +} +// CHECK-NEXT: %[[RESULT0:.*]] = affine.apply #[[MOD]]()[%[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG0]]] +// CHECK-NEXT: %[[RESULT1:.*]] = affine.apply #[[SIMPLIFIED_MOD_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] +// CHECK-NEXT: %[[RESULT2:.*]] = affine.apply #[[MODULO_AND_PRODUCT]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] +// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]], %[[RESULT2]] + +// ----- + +// Test simplification of floordiv and ceildiv expressions. +// CHECK-DAG: #[[SIMPLIFIED_FLOORDIV_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 floordiv (s2 - s0 * s1))> +// CHECK-DAG: #[[FLOORDIV:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 + s3 + (s0 - s1) floordiv s2)> +// CHECK-DAG: #[[SIMPLIFIED_CEILDIV_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 ceildiv (s2 - s0 * s1))> +// CHECK: func @semiaffine_simplification_floordiv_and_ceildiv(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) +func @semiaffine_simplification_floordiv_and_ceildiv(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) { + %a = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 floordiv (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] + %b = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> ((-(d0 * s1 - (s0 - s1) floordiv s2) + s3) + (d0 * s1 + s0))>(%arg0)[%arg1, %arg2, %arg3, %arg4] + %c = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 ceildiv (s0 - s1 * s2 + s3 - s0))>(%arg0)[%arg0, %arg1, %arg2, %arg3] + return %a, %b, %c : index, index, index +} +// CHECK-NEXT: %[[RESULT0:.*]] = affine.apply #[[SIMPLIFIED_FLOORDIV_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] +// CHECK-NEXT: %[[RESULT1:.*]] = affine.apply #[[FLOORDIV]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]] +// CHECK-NEXT: %[[RESULT2:.*]] = affine.apply #[[SIMPLIFIED_CEILDIV_RHS]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG0]]] +// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]], %[[RESULT2]] + +// ----- + +// Test simplification of product expressions. +// CHECK-DAG: #[[PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)> +// CHECK-DAG: #[[SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 * s0 + s2 + s3 * s0 + s3 * s1 + s3 + s4 * s1 + s4)> +// CHECK: func @semiaffine_simplification_product(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index) +func @semiaffine_simplification_product(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) { + %a = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> ((-(s0 - (s0 - s1) * s2) + s3) + (d0 + s0))>(%arg0)[%arg1, %arg2, %arg3, %arg4] + %b = affine.apply affine_map<(d0, d1, d2)[s0, s1] -> (d0 + d1 * s1 + d1 + d0 * s0 + d1 * s0 + d2 * s1 + d2)>(%arg0, %arg1, %arg2)[%arg3, %arg4] + return %a, %b : index, index +} +// CHECK-NEXT: %[[RESULT0:.*]] = affine.apply #[[PRODUCT]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]] +// CHECK-NEXT: %[[RESULT1:.*]] = affine.apply #[[SUM_OF_PRODUCTS]]()[%[[ARG3]], %[[ARG4]], %[[ARG0]], %[[ARG1]], %[[ARG2]]] +// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir --- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir +++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -for-loop-peeling -canonicalize -split-input-file | FileCheck %s // RUN: mlir-opt %s -for-loop-peeling=skip-partial=false -canonicalize -split-input-file | FileCheck %s -check-prefix=CHECK-NO-SKIP -// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)> +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (-s0 + s1) mod s2)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0)> // CHECK: func @fully_dynamic_bounds( // CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index