diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -602,6 +602,19 @@ /// must already have a corresponding dim/symbol in this constraint system. AffineMap computeAlignedMap(AffineMap map, ValueRange operands) const; + /// Given an affine map that is aligned with this constraint system: + /// * Flatten the map. + /// * Add newly introduced local columns at the beginning of this constraint + /// system (local column pos 0). + /// * Add equalities that define the new local columns to this constraint + /// system. + /// * Return the flattened expressions via `flattenedExprs`. + /// + /// Note: This is a shared helper function of `addLowerOrUpperBound` and + /// `composeMatchingMap`. + LogicalResult flattenAlignedMapAndMergeLocals( + AffineMap map, std::vector> *flattenedExprs); + // Eliminates a single identifier at 'position' from equality and inequality // constraints. Returns 'success' if the identifier was eliminated, and // 'failure' otherwise. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -400,28 +400,10 @@ assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); std::vector> flatExprs; - FlatAffineConstraints localCst; - if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) { - LLVM_DEBUG(llvm::dbgs() - << "composition unimplemented for semi-affine maps\n"); + if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs))) return failure(); - } assert(flatExprs.size() == other.getNumResults()); - // Add localCst information. - if (localCst.getNumLocalIds() > 0) { - unsigned numLocalIds = getNumLocalIds(); - // Insert local dims of localCst at the beginning. - for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; ++l) - addLocalId(0); - // Insert local dims of `this` at the end of localCst. - for (unsigned l = 0; l < numLocalIds; ++l) - localCst.addLocalId(localCst.getNumLocalIds()); - // Dimensions of localCst and this constraint set match. Append localCst to - // this constraint set. - append(localCst); - } - // Add dimensions corresponding to the map's results. for (unsigned t = 0, e = other.getNumResults(); t < e; t++) { addDimId(0); @@ -429,25 +411,24 @@ // We add one equality for each result connecting the result dim of the map to // the other identifiers. - // For eg: if the expression is 16*i0 + i1, and this is the r^th + // E.g.: if the expression is 16*i0 + i1, and this is the r^th // iteration/result of the value map, we are adding the equality: - // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we - // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. + // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we + // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { const auto &flatExpr = flatExprs[r]; assert(flatExpr.size() >= other.getNumInputs() + 1); - // eqToAdd is the equality corresponding to the flattened affine expression. SmallVector eqToAdd(getNumCols(), 0); // Set the coefficient for this result to one. eqToAdd[r] = 1; // Dims and symbols. for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { - // Negate 'eq[r]' since the newly added dimension will be set to this one. + // Negate `eq[r]` since the newly added dimension will be set to this one. eqToAdd[e + i] = -flatExpr[i]; } - // Local vars common to eq and localCst are at the beginning. + // Local columns of `eq` are at the beginning. unsigned j = getNumDimIds() + getNumSymbolIds(); unsigned end = flatExpr.size() - 1; for (unsigned i = other.getNumInputs(); i < end; i++, j++) { @@ -1872,27 +1853,14 @@ } } -LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, - AffineMap boundMap, - bool eq, bool lower) { - assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch"); - assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); - assert(pos < getNumDimAndSymbolIds() && "invalid position"); - - // Equality follows the logic of lower bound except that we add an equality - // instead of an inequality. - assert((!eq || boundMap.getNumResults() == 1) && "single result expected"); - if (eq) - lower = true; - - std::vector> flatExprs; +LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals( + AffineMap map, std::vector> *flattenedExprs) { FlatAffineConstraints localCst; - if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localCst))) { + if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) { LLVM_DEBUG(llvm::dbgs() << "composition unimplemented for semi-affine maps\n"); return failure(); } - assert(flatExprs.size() == boundMap.getNumResults()); // Add localCst information. if (localCst.getNumLocalIds() > 0) { @@ -1908,6 +1876,27 @@ append(localCst); } + return success(); +} + +LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, + AffineMap boundMap, + bool eq, bool lower) { + assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch"); + assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); + assert(pos < getNumDimAndSymbolIds() && "invalid position"); + + // Equality follows the logic of lower bound except that we add an equality + // instead of an inequality. + assert((!eq || boundMap.getNumResults() == 1) && "single result expected"); + if (eq) + lower = true; + + std::vector> flatExprs; + if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs))) + return failure(); + assert(flatExprs.size() == boundMap.getNumResults()); + // Add one (in)equality for each result. for (const auto &flatExpr : flatExprs) { SmallVector ineq(getNumCols(), 0); @@ -1921,7 +1910,7 @@ if (ineq[pos] != 0) continue; ineq[pos] = lower ? 1 : -1; - // Local vars common to eq and localCst are at the beginning. + // Local columns of `ineq` are at the beginning. unsigned j = getNumDimIds() + getNumSymbolIds(); unsigned end = flatExpr.size() - 1; for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {