diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -209,6 +209,15 @@ /// Adds a constant upper bound constraint for the specified identifier. void addConstantUpperBound(unsigned pos, int64_t ub); + /// Adds a lower or an upper bound for the identifier at the specified + /// position with constraints being drawn from the specified bound map. If + /// `eq` is true, add a single equality equal to the bound map's first result + /// expr. + /// Note: The dimensions/symbols of this FlatAffineConstraints must match the + /// dimensions/symbols of the affine map. + LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, bool eq, + bool lower = true); + /// Adds a new local identifier as the floordiv of an affine function of other /// identifiers, the coefficients of which are provided in 'dividend' and with /// respect to a positive constant 'divisor'. Two constraints are added to the @@ -573,6 +582,7 @@ LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, ValueRange operands, bool eq, bool lower = true); + using FlatAffineConstraints::addLowerOrUpperBound; /// Returns the bound for the identifier at `pos` from the inequality at /// `ineqPos` as a 1-d affine value map (affine map + operands). The returned diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -2015,7 +2015,68 @@ } } -// TODO: Add FlatAffineConstraints::addLowerOrUpperBound +LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, + AffineMap boundMap, + bool eq, bool lower) { + assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch"); + assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); + assert(pos < getNumDimAndSymbolIds() && "invalid position"); + // Equality follows the logic of lower bound except that we add an equality + // instead of an inequality. + assert((!eq || boundMap.getNumResults() == 1) && "single result expected"); + if (eq) + lower = true; + + std::vector> flatExprs; + FlatAffineConstraints localCst; + if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localCst))) { + LLVM_DEBUG(llvm::dbgs() + << "composition unimplemented for semi-affine maps\n"); + return failure(); + } + assert(flatExprs.size() == boundMap.getNumResults()); + + // Add localCst information. + if (localCst.getNumLocalIds() > 0) { + unsigned numLocalIds = getNumLocalIds(); + // Insert local dims of localCst at the beginning. + for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; ++l) + addLocalId(0); + // Insert local dims of `this` at the end of localCst. + for (unsigned l = 0; l < numLocalIds; ++l) + localCst.addLocalId(localCst.getNumLocalIds()); + // Dimensions of localCst and this constraint set match. Append localCst to + // this constraint set. + append(localCst); + } + + // Add one (in)equality for each result. + for (const auto &flatExpr : flatExprs) { + SmallVector ineq(getNumCols(), 0); + // Dims and symbols. + for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { + ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; + } + // Invalid bound: pos appears in `boundMap`. + if (ineq[pos] != 0) + continue; + ineq[pos] = lower ? 1 : -1; + // Local vars common to eq and localCst are at the beginning. + unsigned j = getNumDimIds() + getNumSymbolIds(); + unsigned end = flatExpr.size() - 1; + for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { + ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; + } + // Constant term. + ineq[getNumCols() - 1] = + lower ? -flatExpr[flatExpr.size() - 1] + // Upper bound in flattenedExpr is an exclusive one. + : flatExpr[flatExpr.size() - 1] - 1; + eq ? addEquality(ineq) : addInequality(ineq); + } + + return success(); +} LogicalResult FlatAffineValueConstraints::addLowerOrUpperBound( unsigned pos, AffineMap boundMap, ValueRange boundOperands, bool eq,