diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -590,15 +590,6 @@ return success(); } -// Turn a dimension into a symbol. -static void turnDimIntoSymbol(FlatAffineValueConstraints *cst, Value id) { - unsigned pos; - if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) { - cst->swapId(pos, cst->getNumDimIds() - 1); - cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1); - } -} - // Turn a symbol into a dimension. static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) { unsigned pos; @@ -2081,13 +2072,6 @@ LogicalResult FlatAffineValueConstraints::addLowerOrUpperBound( unsigned pos, AffineMap boundMap, ValueRange boundOperands, bool eq, bool lower) { - assert(pos < getNumDimAndSymbolIds() && "invalid position"); - // Equality follows the logic of lower bound except that we add an equality - // instead of an inequality. - assert((!eq || boundMap.getNumResults() == 1) && "single result expected"); - if (eq) - lower = true; - // Fully compose map and operands; canonicalize and simplify so that we // transitively get to terminal symbols or loop IVs. auto map = boundMap; @@ -2098,70 +2082,28 @@ for (auto operand : operands) addInductionVarOrTerminalSymbol(operand); - FlatAffineValueConstraints localVarCst; - std::vector> flatExprs; - if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) { - LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); - return failure(); - } - - // Merge and align with localVarCst. - if (localVarCst.getNumLocalIds() > 0) { - // Set values for localVarCst. - localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands); - for (auto operand : operands) { - unsigned pos; - if (findId(operand, &pos)) { - if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) { - // If the local var cst has this as a dim, turn it into its symbol. - turnDimIntoSymbol(&localVarCst, operand); - } else if (pos < getNumDimIds()) { - // Or vice versa. - turnSymbolIntoDim(&localVarCst, operand); - } - } - } - mergeAndAlignIds(/*offset=*/0, this, &localVarCst); - append(localVarCst); - } - - // Record positions of the operands in the constraint system. Need to do - // this here since the constraint system changes after a bound is added. - SmallVector positions; - unsigned numOperands = operands.size(); - for (auto operand : operands) { - unsigned pos; - if (!findId(operand, &pos)) - assert(0 && "expected to be found"); - positions.push_back(pos); - } - - for (const auto &flatExpr : flatExprs) { - // Invalid bound: pos appears among the operands. - if (llvm::find(positions, pos) != positions.end()) - continue; - - SmallVector ineq(getNumCols(), 0); - ineq[pos] = lower ? 1 : -1; - // Dims and symbols. - for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) { - ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; - } - // Copy over the local id coefficients. - unsigned numLocalIds = flatExpr.size() - 1 - numOperands; - for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds; - jj++, j++) { - ineq[j] = - lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj]; - } - // Constant term. - ineq[getNumCols() - 1] = - lower ? -flatExpr[flatExpr.size() - 1] - // Upper bound in flattenedExpr is an exclusive one. - : flatExpr[flatExpr.size() - 1] - 1; - eq ? addEquality(ineq) : addInequality(ineq); - } - return success(); + SmallVector dims, syms; +#ifndef NDEBUG + SmallVector newSyms; + SmallVector *newSymsPtr = &newSyms; +#else + SmallVector *newSymsPtr = nullptr; +#endif // NDEBUG + + dims.reserve(numDims); + syms.reserve(numSymbols); + for (unsigned i = 0; i < numDims; ++i) + dims.push_back(ids[i] ? *ids[i] : Value()); + for (unsigned i = numDims; i < numDims + numSymbols; ++i) + syms.push_back(ids[i] ? *ids[i] : Value()); + + AffineMap alignedMap = + alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr); + // All symbols are already part of this FlatAffineConstraints. + assert(syms.size() == newSymsPtr->size()); + assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin())); + + return addLowerOrUpperBound(pos, alignedMap, eq, lower); } // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper