diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -194,14 +194,23 @@ return inequalities.getRow(idx); } - /// Adds a lower or an upper bound for the identifier at the specified - /// position with constraints being drawn from the specified bound map. If - /// `eq` is true, add a single equality equal to the bound map's first result - /// expr. + /// The type of bound: equal, lower bound or upper bound. + enum BoundType { EQ, LB, UB }; + + /// Adds a bound for the identifier at the specified position with constraints + /// being drawn from the specified bound map. In case of an EQ bound, the + /// bound map is expected to have exactly one result. In case of a LB/UB, the + /// bound map may have more than one result, for each of which an inequality + /// is added. /// Note: The dimensions/symbols of this FlatAffineConstraints must match the /// dimensions/symbols of the affine map. - LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, bool eq, - bool lower = true); + LogicalResult addBound(BoundType type, unsigned pos, AffineMap boundMap); + + /// Adds a constant bound for the specified identifier. + void addBound(BoundType type, unsigned pos, int64_t value); + + /// Adds a constant bound for the specified expression. + void addBound(BoundType type, ArrayRef expr, int64_t value); /// Returns the constraint system as an integer set. Returns a null integer /// set if the system has no constraints, or if an integer set couldn't be @@ -224,11 +233,6 @@ /// Adds an equality from the coefficients specified in `eq`. void addEquality(ArrayRef eq); - /// Adds a constant lower bound constraint for the specified identifier. - void addConstantLowerBound(unsigned pos, int64_t lb); - /// Adds a constant upper bound constraint for the specified identifier. - void addConstantUpperBound(unsigned pos, int64_t ub); - /// Adds a new local identifier as the floordiv of an affine function of other /// identifiers, the coefficients of which are provided in `dividend` and with /// respect to a positive constant `divisor`. Two constraints are added to the @@ -236,14 +240,6 @@ /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); - /// Adds a constant lower bound constraint for the specified expression. - void addConstantLowerBound(ArrayRef expr, int64_t lb); - /// Adds a constant upper bound constraint for the specified expression. - void addConstantUpperBound(ArrayRef expr, int64_t ub); - - /// Sets the identifier at the specified position to a constant. - void setIdToConstant(unsigned pos, int64_t val); - /// Swap the posA^th identifier with the posB^th identifier. virtual void swapId(unsigned posA, unsigned posB); @@ -349,13 +345,10 @@ SmallVectorImpl *ub = nullptr, unsigned *minLbPos = nullptr, unsigned *minUbPos = nullptr) const; - /// Returns the constant lower bound for the pos^th identifier if there is - /// one; None otherwise. - Optional getConstantLowerBound(unsigned pos) const; - - /// Returns the constant upper bound for the pos^th identifier if there is - /// one; None otherwise. - Optional getConstantUpperBound(unsigned pos) const; + /// Returns the constant bound for the pos^th identifier if there is one; + /// None otherwise. + // TODO: Support EQ bounds. + Optional getConstantBound(BoundType type, unsigned pos) const; /// Gets the lower and upper bound of the `offset` + `pos`th identifier /// treating [0, offset) U [offset + num, symStartPos) as dimensions and @@ -611,14 +604,18 @@ /// the columns in the current one regarding numbers and values. void addAffineIfOpDomain(AffineIfOp ifOp); - /// Adds a lower or an upper bound for the identifier at the specified - /// position with constraints being drawn from the specified bound map and - /// operands. If `eq` is true, add a single equality equal to the bound map's - /// first result expr. - LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, - ValueRange operands, bool eq, - bool lower = true); - using FlatAffineConstraints::addLowerOrUpperBound; + /// Adds a bound for the identifier at the specified position with constraints + /// being drawn from the specified bound map and operands. In case of an + /// EQ bound, the bound map is expected to have exactly one result. In case + /// of a LB/UB, the bound map may have more than one result, for each of which + /// an inequality is added. + LogicalResult addBound(BoundType type, unsigned pos, AffineMap boundMap, + ValueRange operands); + + /// Adds a constant bound for the identifier associated with the given Value. + void addBound(BoundType type, Value val, int64_t value); + + using FlatAffineConstraints::addBound; /// Returns the bound for the identifier at `pos` from the inequality at /// `ineqPos` as a 1-d affine value map (affine map + operands). The returned @@ -640,11 +637,6 @@ ArrayRef ubMaps, ArrayRef operands); - /// Sets the identifier corresponding to the specified Value `value` to a - /// constant. Asserts if the `value` is not found. - void setIdToConstant(Value value, int64_t val); - using FlatAffineConstraints::setIdToConstant; - /// Looks up the position of the identifier with the specified Value. Returns /// true if found (false otherwise). `pos` is set to the (column) position of /// the identifier. diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -664,8 +664,9 @@ assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto cOp = symbol.getDefiningOp()) - dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), - cOp.getValue()); + dependenceDomain->addBound(FlatAffineConstraints::EQ, + valuePosMap.getSymPos(symbol), + cOp.getValue()); } }; @@ -885,10 +886,12 @@ dependenceComponents->resize(numCommonLoops); for (unsigned j = 0; j < numCommonLoops; ++j) { (*dependenceComponents)[j].op = commonLoops[j].getOperation(); - auto lbConst = dependenceDomain->getConstantLowerBound(j); + auto lbConst = + dependenceDomain->getConstantBound(FlatAffineConstraints::LB, j); (*dependenceComponents)[j].lb = lbConst.getValueOr(std::numeric_limits::min()); - auto ubConst = dependenceDomain->getConstantUpperBound(j); + auto ubConst = + dependenceDomain->getConstantBound(FlatAffineConstraints::UB, j); (*dependenceComponents)[j].ub = ubConst.getValueOr(std::numeric_limits::max()); } diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -550,7 +550,7 @@ addSymbolId(getNumSymbolIds(), val); // Check if the symbol is a constant. if (auto constOp = val.getDefiningOp()) - setIdToConstant(val, constOp.getValue()); + addBound(BoundType::EQ, val, constOp.getValue()); } LogicalResult @@ -588,23 +588,21 @@ } if (forOp.hasConstantLowerBound()) { - addConstantLowerBound(pos, forOp.getConstantLowerBound()); + addBound(BoundType::LB, pos, forOp.getConstantLowerBound()); } else { // Non-constant lower bound case. - if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), - forOp.getLowerBoundOperands(), - /*eq=*/false, /*lower=*/true))) + if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(), + forOp.getLowerBoundOperands()))) return failure(); } if (forOp.hasConstantUpperBound()) { - addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1); + addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1); return success(); } // Non-constant upper bound case. - return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), - forOp.getUpperBoundOperands(), - /*eq=*/false, /*lower=*/false); + return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(), + forOp.getUpperBoundOperands()); } LogicalResult @@ -649,12 +647,9 @@ // This slice refers to a loop that doesn't exist in the IR yet. Add its // bounds to the system assuming its dimension identifier position is the // same as the position of the loop in the loop nest. - if (lbMap && failed(addLowerOrUpperBound(i, lbMap, operands, /*eq=*/false, - /*lower=*/true))) + if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands))) return failure(); - - if (ubMap && failed(addLowerOrUpperBound(i, ubMap, operands, /*eq=*/false, - /*lower=*/false))) + if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands))) return failure(); } return success(); @@ -1393,7 +1388,8 @@ // Express `id_r` as `id_n % divisor` and store the expression in `memo`. if (quotientCount >= 1) { - auto ub = cst.getConstantUpperBound(dimExpr.getPosition()); + auto ub = cst.getConstantBound(FlatAffineConstraints::BoundType::UB, + dimExpr.getPosition()); // If `id_n` has an upperbound that is less than the divisor, mod can be // eliminated altogether. if (ub.hasValue() && ub.getValue() < divisor) @@ -1768,8 +1764,8 @@ if (memo[pos]) continue; - auto lbConst = getConstantLowerBound(pos); - auto ubConst = getConstantUpperBound(pos); + auto lbConst = getConstantBound(BoundType::LB, pos); + auto ubConst = getConstantBound(BoundType::UB, pos); if (lbConst.hasValue() && ubConst.hasValue()) { // Detect equality to a constant. if (lbConst.getValue() == ubConst.getValue()) { @@ -1878,7 +1874,7 @@ if (!lbMap || lbMap.getNumResults() > 1) { LLVM_DEBUG(llvm::dbgs() << "WARNING: Potentially over-approximating slice lb\n"); - auto lbConst = getConstantLowerBound(pos + offset); + auto lbConst = getConstantBound(BoundType::LB, pos + offset); if (lbConst.hasValue()) { lbMap = AffineMap::get( numMapDims, numMapSymbols, @@ -1888,7 +1884,7 @@ if (!ubMap || ubMap.getNumResults() > 1) { LLVM_DEBUG(llvm::dbgs() << "WARNING: Potentially over-approximating slice ub\n"); - auto ubConst = getConstantUpperBound(pos + offset); + auto ubConst = getConstantBound(BoundType::UB, pos + offset); if (ubConst.hasValue()) { (ubMap) = AffineMap::get( numMapDims, numMapSymbols, @@ -1931,18 +1927,17 @@ return success(); } -LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, - AffineMap boundMap, - bool eq, bool lower) { +LogicalResult FlatAffineConstraints::addBound(BoundType type, unsigned pos, + AffineMap boundMap) { assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch"); assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); assert(pos < getNumDimAndSymbolIds() && "invalid position"); // Equality follows the logic of lower bound except that we add an equality // instead of an inequality. - assert((!eq || boundMap.getNumResults() == 1) && "single result expected"); - if (eq) - lower = true; + assert((type != BoundType::EQ || boundMap.getNumResults() == 1) && + "single result expected"); + bool lower = type == BoundType::LB || type == BoundType::EQ; std::vector> flatExprs; if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs))) @@ -1973,7 +1968,7 @@ lower ? -flatExpr[flatExpr.size() - 1] // Upper bound in flattenedExpr is an exclusive one. : flatExpr[flatExpr.size() - 1] - 1; - eq ? addEquality(ineq) : addInequality(ineq); + type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq); } return success(); @@ -2008,9 +2003,9 @@ return alignedMap; } -LogicalResult FlatAffineValueConstraints::addLowerOrUpperBound( - unsigned pos, AffineMap boundMap, ValueRange boundOperands, bool eq, - bool lower) { +LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, + AffineMap boundMap, + ValueRange boundOperands) { // Fully compose map and operands; canonicalize and simplify so that we // transitively get to terminal symbols or loop IVs. auto map = boundMap; @@ -2020,7 +2015,7 @@ canonicalizeMapAndOperands(&map, &operands); for (auto operand : operands) addInductionVarOrTerminalSymbol(operand); - return addLowerOrUpperBound(pos, computeAlignedMap(map, operands), eq, lower); + return addBound(type, pos, computeAlignedMap(map, operands)); } // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper @@ -2052,8 +2047,7 @@ if (lbMap && ubMap && lbMap.getNumResults() == 1 && ubMap.getNumResults() == 1 && lbMap.getResult(0) + 1 == ubMap.getResult(0)) { - if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true, - /*lower=*/true))) + if (failed(addBound(BoundType::EQ, pos, lbMap, operands))) return failure(); continue; } @@ -2063,11 +2057,9 @@ // part of the slice. if (lbMap && lbMap.getNumResults() != 0 && ubMap && ubMap.getNumResults() != 0) { - if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, - /*lower=*/true))) + if (failed(addBound(BoundType::LB, pos, lbMap, operands))) return failure(); - if (failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false, - /*lower=*/false))) + if (failed(addBound(BoundType::UB, pos, ubMap, operands))) return failure(); } else { auto loop = getForInductionVarOwner(values[i]); @@ -2092,33 +2084,30 @@ inequalities(row, i) = inEq[i]; } -void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) { - assert(pos < getNumCols()); - unsigned row = inequalities.appendExtraRow(); - inequalities(row, pos) = 1; - inequalities(row, getNumCols() - 1) = -lb; -} - -void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) { +void FlatAffineConstraints::addBound(BoundType type, unsigned pos, + int64_t value) { assert(pos < getNumCols()); - unsigned row = inequalities.appendExtraRow(); - inequalities(row, pos) = -1; - inequalities(row, getNumCols() - 1) = ub; -} - -void FlatAffineConstraints::addConstantLowerBound(ArrayRef expr, - int64_t lb) { - addInequality(expr); - inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += -lb; + if (type == BoundType::EQ) { + unsigned row = equalities.appendExtraRow(); + equalities(row, pos) = 1; + equalities(row, getNumCols() - 1) = -value; + } else { + unsigned row = inequalities.appendExtraRow(); + inequalities(row, pos) = type == BoundType::LB ? 1 : -1; + inequalities(row, getNumCols() - 1) = + type == BoundType::LB ? -value : value; + } } -void FlatAffineConstraints::addConstantUpperBound(ArrayRef expr, - int64_t ub) { +void FlatAffineConstraints::addBound(BoundType type, ArrayRef expr, + int64_t value) { + assert(type != BoundType::EQ && "EQ not implemented"); assert(expr.size() == getNumCols()); unsigned row = inequalities.appendExtraRow(); for (unsigned i = 0, e = expr.size(); i < e; ++i) - inequalities(row, i) = -expr[i]; - inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += ub; + inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i]; + inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += + type == BoundType::LB ? -value : value; } /// Adds a new local identifier as the floordiv of an affine function of other @@ -2193,22 +2182,13 @@ numSymbols = newSymbolCount; } -/// Sets the specified identifier to a constant value. -void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { - equalities.resizeVertically(equalities.getNumRows() + 1); - unsigned row = equalities.getNumRows() - 1; - equalities(row, pos) = 1; - equalities(row, getNumCols() - 1) = -val; -} - -/// Sets the specified identifier to a constant value; asserts if the id is not -/// found. -void FlatAffineValueConstraints::setIdToConstant(Value value, int64_t val) { +void FlatAffineValueConstraints::addBound(BoundType type, Value val, + int64_t value) { unsigned pos; - if (!findId(value, &pos)) + if (!findId(val, &pos)) // This is a pre-condition for this method. assert(0 && "id not found"); - setIdToConstant(pos, val); + addBound(type, pos, value); } void FlatAffineConstraints::removeEquality(unsigned pos) { @@ -2485,15 +2465,12 @@ return minOrMaxConst; } -Optional -FlatAffineConstraints::getConstantLowerBound(unsigned pos) const { - FlatAffineConstraints tmpCst(*this); - return tmpCst.computeConstantLowerOrUpperBound(pos); -} - -Optional -FlatAffineConstraints::getConstantUpperBound(unsigned pos) const { +Optional FlatAffineConstraints::getConstantBound(BoundType type, + unsigned pos) const { + assert(type != BoundType::EQ && "EQ not implemented"); FlatAffineConstraints tmpCst(*this); + if (type == BoundType::LB) + return tmpCst.computeConstantLowerOrUpperBound(pos); return tmpCst.computeConstantLowerOrUpperBound(pos); } @@ -3042,8 +3019,8 @@ minLb.back() -= otherLbFloorDivisor - 1; } else { // Uncomparable - check for constant lower/upper bounds. - auto constLb = getConstantLowerBound(d); - auto constOtherLb = otherCst.getConstantLowerBound(d); + auto constLb = getConstantBound(BoundType::LB, d); + auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d); if (!constLb.hasValue() || !constOtherLb.hasValue()) return failure(); std::fill(minLb.begin(), minLb.end(), 0); @@ -3058,8 +3035,8 @@ maxUb = otherUb; } else { // Uncomparable - check for constant lower/upper bounds. - auto constUb = getConstantUpperBound(d); - auto constOtherUb = otherCst.getConstantUpperBound(d); + auto constUb = getConstantBound(BoundType::UB, d); + auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d); if (!constUb.hasValue() || !constOtherUb.hasValue()) return failure(); std::fill(maxUb.begin(), maxUb.end(), 0); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -99,7 +99,7 @@ if (isValidSymbol(value)) { // Check if the symbol is a constant. if (auto cOp = value.getDefiningOp()) - cst->setIdToConstant(value, cOp.getValue()); + cst->addBound(FlatAffineConstraints::EQ, value, cOp.getValue()); } else if (auto loop = getForInductionVarOwner(value)) { if (failed(cst->addAffineForOpDomain(loop))) return failure(); @@ -357,11 +357,11 @@ // that will need non-trivials means to eliminate. FlatAffineConstraints cstWithShapeBounds(cst); for (unsigned r = 0; r < rank; r++) { - cstWithShapeBounds.addConstantLowerBound(r, 0); + cstWithShapeBounds.addBound(FlatAffineConstraints::LB, r, 0); int64_t dimSize = memRefType.getDimSize(r); if (ShapedType::isDynamic(dimSize)) continue; - cstWithShapeBounds.addConstantUpperBound(r, dimSize - 1); + cstWithShapeBounds.addBound(FlatAffineConstraints::UB, r, dimSize - 1); } // Find a constant upper bound on the extent of this memref region along each @@ -518,7 +518,7 @@ // Check if the symbol is a constant. if (auto *op = symbol.getDefiningOp()) { if (auto constOp = dyn_cast(op)) { - cst.setIdToConstant(symbol, constOp.getValue()); + cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.getValue()); } } } @@ -583,10 +583,11 @@ if (addMemRefDimBounds) { auto memRefType = memref.getType().cast(); for (unsigned r = 0; r < rank; r++) { - cst.addConstantLowerBound(/*pos=*/r, /*lb=*/0); + cst.addBound(FlatAffineConstraints::LB, /*pos=*/r, /*value=*/0); if (memRefType.isDynamicDim(r)) continue; - cst.addConstantUpperBound(/*pos=*/r, memRefType.getDimSize(r) - 1); + cst.addBound(FlatAffineConstraints::UB, /*pos=*/r, + memRefType.getDimSize(r) - 1); } } cst.removeTrivialRedundancy(); @@ -688,7 +689,7 @@ continue; // Check for overflow: d_i >= memref dim size. - ucst.addConstantLowerBound(r, dimSize); + ucst.addBound(FlatAffineConstraints::LB, r, dimSize); outOfBounds = !ucst.isEmpty(); if (outOfBounds && emitError) { loadOrStoreOp.emitOpError() @@ -699,7 +700,7 @@ FlatAffineConstraints lcst(*region.getConstraints()); std::fill(ineq.begin(), ineq.end(), 0); // d_i <= -1; - lcst.addConstantUpperBound(r, -1); + lcst.addBound(FlatAffineConstraints::UB, r, -1); outOfBounds = !lcst.isEmpty(); if (outOfBounds && emitError) { loadOrStoreOp.emitOpError() diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -2695,8 +2695,8 @@ for (unsigned d = 0; d < rank; d++) { auto dimSize = memRefType.getDimSize(d); assert(dimSize > 0 && "filtered dynamic shapes above"); - regionCst->addConstantLowerBound(d, 0); - regionCst->addConstantUpperBound(d, dimSize - 1); + regionCst->addBound(FlatAffineConstraints::LB, d, 0); + regionCst->addBound(FlatAffineConstraints::UB, d, dimSize - 1); } return true; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -722,8 +722,8 @@ for (unsigned d = 0; d < rank; ++d) { // Use constraint system only in static dimensions. if (shape[d] > 0) { - fac.addConstantLowerBound(d, 0); - fac.addConstantUpperBound(d, shape[d] - 1); + fac.addBound(FlatAffineConstraints::LB, d, 0); + fac.addBound(FlatAffineConstraints::UB, d, shape[d] - 1); } else { memrefTypeDynDims.emplace_back(d); } @@ -746,7 +746,7 @@ newShape[d] = -1; } else { // The lower bound for the shape is always zero. - auto ubConst = fac.getConstantUpperBound(d); + auto ubConst = fac.getConstantBound(FlatAffineConstraints::UB, d); // For a static memref and an affine map with no symbols, this is // always bounded. assert(ubConst.hasValue() && "should always have an upper bound");