diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -367,8 +367,6 @@ getNumCols()); } - AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); - /// Adds constraints (lower and upper bounds) for the specified 'affine.for' /// operation's Value using IR information stored in its bound maps. The /// right identifier is first looked up using forOp's Value. Asserts if the diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -214,9 +214,10 @@ /// products expression, 'localExprs' is expected to have the AffineExpr /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the /// format [dims, symbols, locals, constant term]. -AffineExpr toAffineExpr(ArrayRef eq, unsigned numDims, - unsigned numSymbols, ArrayRef localExprs, - MLIRContext *context); +AffineExpr getAffineExprFromFlatForm(ArrayRef flatExprs, + unsigned numDims, unsigned numSymbols, + ArrayRef localExprs, + MLIRContext *context); raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1573,7 +1573,8 @@ // - 1. addCoeffs(ineq, lb); std::transform(lb.begin(), lb.end(), lb.begin(), std::negate()); - auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context); + auto expr = + getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context); // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor int64_t divisor = std::abs(ineq[pos + offset]); expr = (expr + divisor - 1).floorDiv(divisor); @@ -1587,7 +1588,8 @@ auto ineq = getInequality(idx); // Extract the upper bound (in terms of other coeff's + const). addCoeffs(ineq, ub); - auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context); + auto expr = + getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context); expr = expr.floorDiv(std::abs(ineq[pos + offset])); // Upper bound is exclusive. ubExprs.push_back(expr + 1); @@ -1602,12 +1604,14 @@ std::transform(b.begin(), b.end(), b.begin(), std::negate()); // Extract the upper bound (in terms of other coeff's + const). - auto expr = mlir::toAffineExpr(b, dimCount, symCount, localExprs, context); + auto expr = + getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); expr = expr.floorDiv(std::abs(eq[pos + offset])); // Upper bound is exclusive. ubExprs.push_back(expr + 1); // Lower bound. - expr = mlir::toAffineExpr(b, dimCount, symCount, localExprs, context); + expr = + getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context); expr = expr.ceilDiv(std::abs(eq[pos + offset])); lbExprs.push_back(expr); } diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -588,39 +588,39 @@ /// Constructs an affine expression from a flat ArrayRef. If there are local /// identifiers (neither dimensional nor symbolic) that appear in the sum of -/// products expression, 'localExprs' is expected to have the AffineExpr -/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the -/// format [dims, symbols, locals, constant term]. -AffineExpr mlir::toAffineExpr(ArrayRef eq, unsigned numDims, - unsigned numSymbols, - ArrayRef localExprs, - MLIRContext *context) { - // Assert expected numLocals = eq.size() - numDims - numSymbols - 1 - assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() && +/// products expression, `localExprs` is expected to have the AffineExpr +/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be +/// in the format [dims, symbols, locals, constant term]. +AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef flatExprs, + unsigned numDims, + unsigned numSymbols, + ArrayRef localExprs, + MLIRContext *context) { + // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. + assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && "unexpected number of local expressions"); auto expr = getAffineConstantExpr(0, context); // Dimensions and symbols. for (unsigned j = 0; j < numDims + numSymbols; j++) { - if (eq[j] == 0) { + if (flatExprs[j] == 0) continue; - } auto id = j < numDims ? getAffineDimExpr(j, context) : getAffineSymbolExpr(j - numDims, context); - expr = expr + id * eq[j]; + expr = expr + id * flatExprs[j]; } // Local identifiers. - for (unsigned j = numDims + numSymbols, e = eq.size() - 1; j < e; j++) { - if (eq[j] == 0) { + for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; + j++) { + if (flatExprs[j] == 0) continue; - } - auto term = localExprs[j - numDims - numSymbols] * eq[j]; + auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; expr = expr + term; } // Constant term. - int64_t constTerm = eq[eq.size() - 1]; + int64_t constTerm = flatExprs[flatExprs.size() - 1]; if (constTerm != 0) expr = expr + constTerm; return expr; @@ -703,8 +703,8 @@ // Construct the AffineExpr form of the floordiv to store in localExprs. MLIRContext *context = expr.getContext(); - auto dividendExpr = - toAffineExpr(floorDividend, numDims, numSymbols, localExprs, context); + auto dividendExpr = getAffineExprFromFlatForm( + floorDividend, numDims, numSymbols, localExprs, context); auto divisorExpr = getAffineConstantExpr(floorDivisor, context); auto floorDivExpr = dividendExpr.floorDiv(divisorExpr); int loc; @@ -787,7 +787,8 @@ // quantifier to express its result, i.e., expr1 div expr2 is replaced // by a new identifier, q. MLIRContext *context = expr.getContext(); - auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context); + auto a = + getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context); auto b = getAffineConstantExpr(divisor, context); int loc; @@ -846,8 +847,9 @@ SimpleAffineExprFlattener flattener(numDims, numSymbols); flattener.walkPostOrder(expr); ArrayRef flattenedExpr = flattener.operandExprStack.back(); - auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols, - flattener.localExprs, expr.getContext()); + auto simplifiedExpr = + getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, + flattener.localExprs, expr.getContext()); flattener.operandExprStack.pop_back(); assert(flattener.operandExprStack.empty());