diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -25,6 +25,7 @@ class AffineApplyOp; class AffineForOp; class AffineValueMap; +class FlatAffineRelation; class FlatAffineValueConstraints; class Operation; @@ -85,6 +86,30 @@ // Returns true if this access is of a store op. bool isStore() const; + /// Creates an access relation for the access. An access relation maps + /// elements of an iteration domain to the element(s) of an array domain + /// accessed by that iteration of the associated statement through some array + /// reference. For example, given the MLIR code: + /// + /// affine.for %i0 = 0 to 10 { + /// affine.for %i1 = 0 to 10 { + /// %a = affine.load %arr[%i0 + %i1, %i0 + 2 * %i1] : memref<100x100xf32> + /// } + /// } + /// + /// The access relation, assuming that the memory locations for %arr are + /// represented as %m0, %m1 would be: + /// + /// (%i0, %i1) -> (%m0, %m1) + /// %m0 = %i0 + %i1 + /// %m1 = %i0 + 2 * %i1 + /// 0 <= %i0 < 10 + /// 0 <= %i1 < 10 + /// + /// Returns failure for yet unimplemented/unsupported cases (see docs of + /// mlir::getIndexSet and mlir::getRelationFromMap for these cases). + LogicalResult getAccessRelation(FlatAffineRelation &accessRel) const; + /// Populates 'accessMap' with composition of AffineApplyOps reachable from /// 'indices'. void getAccessMap(AffineValueMap *accessMap) const; diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -426,6 +426,10 @@ /// O(VC) time. void removeRedundantConstraints(); + /// Converts identifiers in the column range [idStart, idLimit) to local + /// variables. + void convertDimToLocal(unsigned dimStart, unsigned dimLimit); + /// Merge local ids of `this` and `other`. This is done by appending local ids /// of `other` to `this` and inserting local ids of `this` to `other` at start /// of its local ids. @@ -581,6 +585,16 @@ numLocals + 1, numDims, numSymbols, numLocals, valArgs) {} + FlatAffineValueConstraints(const FlatAffineConstraints &fac, + ArrayRef> valArgs = {}) + : FlatAffineConstraints(fac) { + assert(valArgs.empty() || valArgs.size() == numIds); + if (valArgs.empty()) + values.resize(numIds, None); + else + values.append(valArgs.begin(), valArgs.end()); + } + /// Create a flat affine constraint system from an AffineValueMap or a list of /// these. The constructed system will only include equalities. explicit FlatAffineValueConstraints(const AffineValueMap &avm); @@ -707,7 +721,8 @@ using FlatAffineConstraints::insertDimId; unsigned insertSymbolId(unsigned pos, ValueRange vals); using FlatAffineConstraints::insertSymbolId; - unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; + virtual unsigned insertId(IdKind kind, unsigned pos, + unsigned num = 1) override; unsigned insertId(IdKind kind, unsigned pos, ValueRange vals); /// Append identifiers of the specified kind after the last identifier of that @@ -868,7 +883,7 @@ /// Removes identifiers in the column range [idStart, idLimit), and copies any /// remaining valid data into place, updates member variables, and resizes /// arrays as needed. - void removeIdRange(unsigned idStart, unsigned idLimit) override; + virtual void removeIdRange(unsigned idStart, unsigned idLimit) override; /// Eliminates the identifier at the specified position using Fourier-Motzkin /// variable elimination, but uses Gaussian elimination if there is an @@ -887,6 +902,83 @@ SmallVector, 8> values; }; +/// A FlatAffineRelation represents a set of ordered pairs (domain -> range) +/// where "domain" and "range" are tuples of identifiers. The relation is +/// represented as a FlatAffineValueConstraints with separation of dimension +/// identifiers into domain and range. The identifiers are stored as: +/// [domainIds, rangeIds, symbolIds, localIds, constant]. +class FlatAffineRelation : public FlatAffineValueConstraints { +public: + FlatAffineRelation(unsigned numReservedInequalities, + unsigned numReservedEqualities, unsigned numReservedCols, + unsigned numDomainDims, unsigned numRangeDims, + unsigned numSymbols, unsigned numLocals, + ArrayRef> valArgs = {}) + : FlatAffineValueConstraints( + numReservedInequalities, numReservedEqualities, numReservedCols, + numDomainDims + numRangeDims, numSymbols, numLocals, valArgs), + numDomainDims(numDomainDims), numRangeDims(numRangeDims) {} + + FlatAffineRelation(unsigned numDomainDims = 0, unsigned numRangeDims = 0, + unsigned numSymbols = 0, unsigned numLocals = 0) + : FlatAffineValueConstraints(numDomainDims + numRangeDims, numSymbols, + numLocals), + numDomainDims(numDomainDims), numRangeDims(numRangeDims) {} + + FlatAffineRelation(unsigned numDomainDims, unsigned numRangeDims, + FlatAffineValueConstraints &fac) + : FlatAffineValueConstraints(fac), numDomainDims(numDomainDims), + numRangeDims(numRangeDims) {} + + FlatAffineRelation(unsigned numDomainDims, unsigned numRangeDims, + FlatAffineConstraints &fac) + : FlatAffineValueConstraints(fac), numDomainDims(numDomainDims), + numRangeDims(numRangeDims) {} + + /// Returns a set corresponding to the domain/range of the affine relation. + FlatAffineValueConstraints getDomainSet() const; + FlatAffineValueConstraints getRangeSet() const; + + /// Returns the number of identifiers corresponding to domain/range of + /// relation. + inline unsigned getNumDomainDims() const { return numDomainDims; } + inline unsigned getNumRangeDims() const { return numRangeDims; } + + /// Given affine relation `other: (domainOther -> rangeOther)`, this operation + /// takes the composition of `other` on `this: (domainThis -> rangeThis)`. + /// The resulting relation represents tuples of the form: `domainOther -> + /// rangeThis`. + void compose(const FlatAffineRelation &other); + + /// Swap domain and range of the relation. + /// `(domain -> range)` is converted to `(range -> domain)`. + void inverse(); + + /// Insert `num` identifiers of the specified kind after the `pos` identifier + /// of that kind. The coefficient columns corresponding to the added + /// identifiers are initialized to zero. + void insertDomainId(unsigned pos, unsigned num = 1); + void insertRangeId(unsigned pos, unsigned num = 1); + + /// Append `num` identifiers of the specified kind after the last identifier + /// of that kind. The coefficient columns corresponding to the added + /// identifiers are initialized to zero. + void appendDomainId(unsigned num = 1); + void appendRangeId(unsigned num = 1); + +protected: + // Number of dimension identifers corresponding to domain identifers. + unsigned numDomainDims; + + // Number of dimension identifers corresponding to range identifers. + unsigned numRangeDims; + + /// Removes identifiers in the column range [idStart, idLimit), and copies any + /// remaining valid data into place, updates member variables, and resizes + /// arrays as needed. + void removeIdRange(unsigned idStart, unsigned idLimit) override; +}; + /// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the /// dimensions, symbols, and additional variables that represent floor divisions /// of dimensions, symbols, and in turn other floor divisions. Returns failure @@ -943,6 +1035,26 @@ ValueRange dims, ValueRange syms, SmallVector *newSyms = nullptr); +/// Builds a relation from the given AffineMap/AffineValueMap `map`, containing +/// all pairs of the form `operands -> result` that satisfy `map`. `rel` is set +/// to the relation built. For example, give the AffineMap: +/// +/// (d0, d1)[s0] -> (d0 + s0, d0 - s0) +/// +/// the resulting relation formed is: +/// +/// (d0, d1) -> (r1, r2) +/// [d0 d1 r1 r2 s0 const] +/// 1 0 -1 0 1 0 = 0 +/// 0 1 0 -1 -1 0 = 0 +/// +/// For AffineValueMap, the domain and symbols have Value set corresponding to +/// the Value in `map`. Returns failure if the AffineMap could not be flattened +/// (i.e., semi-affine is not yet handled). +LogicalResult getRelationFromMap(AffineMap &map, FlatAffineRelation &rel); +LogicalResult getRelationFromMap(const AffineValueMap &map, + FlatAffineRelation &rel); + } // end namespace mlir. #endif // MLIR_ANALYSIS_AFFINESTRUCTURES_H diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -232,460 +232,6 @@ return getIndexSet(ops, indexSet); } -namespace { -// ValuePositionMap manages the mapping from Values which represent dimension -// and symbol identifiers from 'src' and 'dst' access functions to positions -// in new space where some Values are kept separate (using addSrc/DstValue) -// and some Values are merged (addSymbolValue). -// Position lookups return the absolute position in the new space which -// has the following format: -// -// [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers] -// -// Note: access function non-IV dimension identifiers (that have 'dimension' -// positions in the access function position space) are assigned as symbols -// in the output position space. Convenience access functions which lookup -// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle -// the common case of resolving positions for all access function operands. -// -// TODO: Generalize this: could take a template parameter for the number of maps -// (3 in the current case), and lookups could take indices of maps to check. So -// getSrcDimOrSymPos would be "getPos(value, {0, 2})". -class ValuePositionMap { -public: - void addSrcValue(Value value) { - if (addValueAt(value, &srcDimPosMap, numSrcDims)) - ++numSrcDims; - } - void addDstValue(Value value) { - if (addValueAt(value, &dstDimPosMap, numDstDims)) - ++numDstDims; - } - void addSymbolValue(Value value) { - if (addValueAt(value, &symbolPosMap, numSymbols)) - ++numSymbols; - } - unsigned getSrcDimOrSymPos(Value value) const { - return getDimOrSymPos(value, srcDimPosMap, 0); - } - unsigned getDstDimOrSymPos(Value value) const { - return getDimOrSymPos(value, dstDimPosMap, numSrcDims); - } - unsigned getSymPos(Value value) const { - auto it = symbolPosMap.find(value); - assert(it != symbolPosMap.end()); - return numSrcDims + numDstDims + it->second; - } - - unsigned getNumSrcDims() const { return numSrcDims; } - unsigned getNumDstDims() const { return numDstDims; } - unsigned getNumDims() const { return numSrcDims + numDstDims; } - unsigned getNumSymbols() const { return numSymbols; } - -private: - bool addValueAt(Value value, DenseMap *posMap, - unsigned position) { - auto it = posMap->find(value); - if (it == posMap->end()) { - (*posMap)[value] = position; - return true; - } - return false; - } - unsigned getDimOrSymPos(Value value, - const DenseMap &dimPosMap, - unsigned dimPosOffset) const { - auto it = dimPosMap.find(value); - if (it != dimPosMap.end()) { - return dimPosOffset + it->second; - } - it = symbolPosMap.find(value); - assert(it != symbolPosMap.end()); - return numSrcDims + numDstDims + it->second; - } - - unsigned numSrcDims = 0; - unsigned numDstDims = 0; - unsigned numSymbols = 0; - DenseMap srcDimPosMap; - DenseMap dstDimPosMap; - DenseMap symbolPosMap; -}; -} // namespace - -// Builds a map from Value to identifier position in a new merged identifier -// list, which is the result of merging dim/symbol lists from src/dst -// iteration domains, the format of which is as follows: -// -// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term] -// -// This method populates 'valuePosMap' with mappings from operand Values in -// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain') -// to the position of these values in the merged list. -static void buildDimAndSymbolPositionMaps( - const FlatAffineValueConstraints &srcDomain, - const FlatAffineValueConstraints &dstDomain, - const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, - ValuePositionMap *valuePosMap, - FlatAffineValueConstraints *dependenceConstraints) { - - // IsDimState is a tri-state boolean. It is used to distinguish three - // different cases of the values passed to updateValuePosMap. - // - When it is TRUE, we are certain that all values are dim values. - // - When it is FALSE, we are certain that all values are symbol values. - // - When it is UNKNOWN, we need to further check whether the value is from a - // loop IV to determine its type (dim or symbol). - - // We need this enumeration because sometimes we cannot determine whether a - // Value is a symbol or a dim by the information from the Value itself. If a - // Value appears in an affine map of a loop, we can determine whether it is a - // dim or not by the function `isForInductionVar`. But when a Value is in the - // affine set of an if-statement, there is no way to identify its category - // (dim/symbol) by itself. Fortunately, the Values to be inserted into - // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such - // information of Value category: `srcDomain` and `dstDomain` organize Values - // by their category, such that the position of each Value stored in - // `srcDomain` and `dstDomain` marks which category that a Value belongs to. - // Therefore, we can separate Values into dim and symbol groups before passing - // them to the function `updateValuePosMap`. Specifically, when passing the - // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE. - // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are - // not explicitly categorized into dim or symbol, and we have to rely on - // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in - // this case. - enum IsDimState { TRUE, FALSE, UNKNOWN }; - - // This function places each given Value (in `values`) under a respective - // category in `valuePosMap`. Specifically, the placement rules are: - // 1) If `isDim` is FALSE, then every value in `values` are inserted into - // `valuePosMap` as symbols. - // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an - // induction variable of a for-loop, we treat it as symbol as well. - // 3) For other cases, we decide whether to add a value to the `src` or the - // `dst` section of the dim category simply by the boolean value `isSrc`. - auto updateValuePosMap = [&](ArrayRef values, bool isSrc, - IsDimState isDim) { - for (unsigned i = 0, e = values.size(); i < e; ++i) { - auto value = values[i]; - if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) { - assert(isValidSymbol(value) && - "access operand has to be either a loop IV or a symbol"); - valuePosMap->addSymbolValue(value); - } else { - if (isSrc) - valuePosMap->addSrcValue(value); - else - valuePosMap->addDstValue(value); - } - } - }; - - // Collect values from the src and dst domains. For each domain, we separate - // the collected values into dim and symbol parts. - SmallVector srcDimValues, dstDimValues, srcSymbolValues, - dstSymbolValues; - srcDomain.getValues(0, srcDomain.getNumDimIds(), &srcDimValues); - dstDomain.getValues(0, dstDomain.getNumDimIds(), &dstDimValues); - srcDomain.getValues(srcDomain.getNumDimIds(), - srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); - dstDomain.getValues(dstDomain.getNumDimIds(), - dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues); - - // Update value position map with dim values from src iteration domain. - updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE); - // Update value position map with dim values from dst iteration domain. - updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE); - // Update value position map with symbols from src iteration domain. - updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE); - // Update value position map with symbols from dst iteration domain. - updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE); - // Update value position map with identifiers from src access function. - updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true, - /*isDim=*/UNKNOWN); - // Update value position map with identifiers from dst access function. - updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false, - /*isDim=*/UNKNOWN); -} - -// Sets up dependence constraints columns appropriately, in the format: -// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term] -static void -initDependenceConstraints(const FlatAffineValueConstraints &srcDomain, - const FlatAffineValueConstraints &dstDomain, - const AffineValueMap &srcAccessMap, - const AffineValueMap &dstAccessMap, - const ValuePositionMap &valuePosMap, - FlatAffineValueConstraints *dependenceConstraints) { - // Calculate number of equalities/inequalities and columns required to - // initialize FlatAffineValueConstraints for 'dependenceDomain'. - unsigned numIneq = - srcDomain.getNumInequalities() + dstDomain.getNumInequalities(); - AffineMap srcMap = srcAccessMap.getAffineMap(); - assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults()); - unsigned numEq = srcMap.getNumResults(); - unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds(); - unsigned numSymbols = valuePosMap.getNumSymbols(); - unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds(); - unsigned numIds = numDims + numSymbols + numLocals; - unsigned numCols = numIds + 1; - - // Set flat affine constraints sizes and reserving space for constraints. - dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols, - numLocals); - - // Set values corresponding to dependence constraint identifiers. - SmallVector srcLoopIVs, dstLoopIVs; - srcDomain.getValues(0, srcDomain.getNumDimIds(), &srcLoopIVs); - dstDomain.getValues(0, dstDomain.getNumDimIds(), &dstLoopIVs); - - dependenceConstraints->setValues(0, srcLoopIVs.size(), srcLoopIVs); - dependenceConstraints->setValues( - srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs); - - // Set values for the symbolic identifier dimensions. `isSymbolDetermined` - // indicates whether we are certain that the `values` passed in are all - // symbols. If `isSymbolDetermined` is true, then we treat every Value in - // `values` as a symbol; otherwise, we let the function `isForInductionVar` to - // distinguish whether a Value in `values` is a symbol or not. - auto setSymbolIds = [&](ArrayRef values, - bool isSymbolDetermined = true) { - for (auto value : values) { - if (isSymbolDetermined || !isForInductionVar(value)) { - assert(isValidSymbol(value) && "expected symbol"); - dependenceConstraints->setValue(valuePosMap.getSymPos(value), value); - } - } - }; - - // We are uncertain about whether all operands in `srcAccessMap` and - // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false. - setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false); - setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false); - - SmallVector srcSymbolValues, dstSymbolValues; - srcDomain.getValues(srcDomain.getNumDimIds(), - srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); - dstDomain.getValues(dstDomain.getNumDimIds(), - dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues); - // Since we only take symbol Values out of `srcDomain` and `dstDomain`, - // `isSymbolDetermined` is kept to its default value: true. - setSymbolIds(srcSymbolValues); - setSymbolIds(dstSymbolValues); - - for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds(); - i < e; i++) - assert(dependenceConstraints->hasValue(i)); -} - -// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into -// 'dependenceDomain'. -// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a -// srcDomain/dstDomain Value maps. -static void addDomainConstraints(const FlatAffineValueConstraints &srcDomain, - const FlatAffineValueConstraints &dstDomain, - const ValuePositionMap &valuePosMap, - FlatAffineValueConstraints *dependenceDomain) { - unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds(); - - SmallVector cst(dependenceDomain->getNumCols()); - - auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) { - const FlatAffineValueConstraints &domain = isSrc ? srcDomain : dstDomain; - unsigned numCsts = - isEq ? domain.getNumEqualities() : domain.getNumInequalities(); - unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds(); - auto at = [&](unsigned i, unsigned j) -> int64_t { - return isEq ? domain.atEq(i, j) : domain.atIneq(i, j); - }; - auto map = [&](unsigned i) -> int64_t { - return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getValue(i)) - : valuePosMap.getDstDimOrSymPos(domain.getValue(i)); - }; - - for (unsigned i = 0; i < numCsts; ++i) { - // Zero fill. - std::fill(cst.begin(), cst.end(), 0); - // Set coefficients for identifiers corresponding to domain. - for (unsigned j = 0; j < numDimAndSymbolIds; ++j) - cst[map(j)] = at(i, j); - // Local terms. - for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++) - cst[depNumDimsAndSymbolIds + localOffset + j] = - at(i, numDimAndSymbolIds + j); - // Set constant term. - cst[cst.size() - 1] = at(i, domain.getNumCols() - 1); - // Add constraint. - if (isEq) - dependenceDomain->addEquality(cst); - else - dependenceDomain->addInequality(cst); - } - }; - - // Add equalities from src domain. - addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0); - // Add inequalities from src domain. - addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0); - // Add equalities from dst domain. - addDomain(/*isSrc=*/false, /*isEq=*/true, - /*localOffset=*/srcDomain.getNumLocalIds()); - // Add inequalities from dst domain. - addDomain(/*isSrc=*/false, /*isEq=*/false, - /*localOffset=*/srcDomain.getNumLocalIds()); -} - -// Adds equality constraints that equate src and dst access functions -// represented by 'srcAccessMap' and 'dstAccessMap' for each result. -// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count. -// For example, given the following two accesses functions to a 2D memref: -// -// Source access function: -// (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2) -// -// Destination access function: -// (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2) -// -// This method constructs the following equality constraints in -// 'dependenceDomain', by equating the access functions for each result -// (i.e. each memref dim). Notice that 'd0' for the destination access function -// is mapped into 'd0' in the equality constraint: -// -// d0 d1 s0 c -// -- -- -- -- -// a0 -c0 (a1 - c1) (a2 - c2) = 0 -// b0 -f0 (b1 - f1) (b2 - f2) = 0 -// -// Returns failure if any AffineExpr cannot be flattened (due to it being -// semi-affine). Returns success otherwise. -static LogicalResult -addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, - const AffineValueMap &dstAccessMap, - const ValuePositionMap &valuePosMap, - FlatAffineValueConstraints *dependenceDomain) { - AffineMap srcMap = srcAccessMap.getAffineMap(); - AffineMap dstMap = dstAccessMap.getAffineMap(); - assert(srcMap.getNumResults() == dstMap.getNumResults()); - unsigned numResults = srcMap.getNumResults(); - - unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); - ArrayRef srcOperands = srcAccessMap.getOperands(); - - unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); - ArrayRef dstOperands = dstAccessMap.getOperands(); - - std::vector> srcFlatExprs; - std::vector> destFlatExprs; - FlatAffineValueConstraints srcLocalVarCst, destLocalVarCst; - // Get flattened expressions for the source destination maps. - if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) || - failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst))) - return failure(); - - unsigned domNumLocalIds = dependenceDomain->getNumLocalIds(); - unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds(); - unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds(); - unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds; - dependenceDomain->appendLocalId(numLocalIdsToAdd); - - unsigned numDims = dependenceDomain->getNumDimIds(); - unsigned numSymbols = dependenceDomain->getNumSymbolIds(); - unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds(); - unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds; - - // Equality to add. - SmallVector eq(dependenceDomain->getNumCols()); - for (unsigned i = 0; i < numResults; ++i) { - // Zero fill. - std::fill(eq.begin(), eq.end(), 0); - - // Flattened AffineExpr for src result 'i'. - const auto &srcFlatExpr = srcFlatExprs[i]; - // Set identifier coefficients from src access function. - for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) - eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j]; - // Local terms. - for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) - eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j]; - // Set constant term. - eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1]; - - // Flattened AffineExpr for dest result 'i'. - const auto &destFlatExpr = destFlatExprs[i]; - // Set identifier coefficients from dst access function. - for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) - eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j]; - // Local terms. - for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) - eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j]; - // Set constant term. - eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1]; - - // Add equality constraint. - dependenceDomain->addEquality(eq); - } - - // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = [&](ArrayRef operands) { - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (isForInductionVar(operands[i])) - continue; - auto symbol = operands[i]; - assert(isValidSymbol(symbol)); - // Check if the symbol is a constant. - if (auto cOp = symbol.getDefiningOp()) - dependenceDomain->addBound(FlatAffineConstraints::EQ, - valuePosMap.getSymPos(symbol), - cOp.getValue()); - } - }; - - // Add equality constraints for any src symbols defined by constant ops. - addEqForConstOperands(srcOperands); - // Add equality constraints for any dst symbols defined by constant ops. - addEqForConstOperands(dstOperands); - - // By construction (see flattener), local var constraints will not have any - // equalities. - assert(srcLocalVarCst.getNumEqualities() == 0 && - destLocalVarCst.getNumEqualities() == 0); - // Add inequalities from srcLocalVarCst and destLocalVarCst into the - // dependence domain. - SmallVector ineq(dependenceDomain->getNumCols()); - for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) { - std::fill(ineq.begin(), ineq.end(), 0); - - // Set identifier coefficients from src local var constraints. - for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) - ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = - srcLocalVarCst.atIneq(r, j); - // Local terms. - for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) - ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j); - // Set constant term. - ineq[ineq.size() - 1] = - srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1); - dependenceDomain->addInequality(ineq); - } - - for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) { - std::fill(ineq.begin(), ineq.end(), 0); - // Set identifier coefficients from dest local var constraints. - for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) - ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] = - destLocalVarCst.atIneq(r, j); - // Local terms. - for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) - ineq[newLocalIdOffset + numSrcLocalIds + j] = - destLocalVarCst.atIneq(r, dstNumIds + j); - // Set constant term. - ineq[ineq.size() - 1] = - destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1); - - dependenceDomain->addInequality(ineq); - } - return success(); -} - // Returns the number of outer loop common to 'src/dstDomain'. // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null. static unsigned @@ -864,6 +410,43 @@ } } +LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &rel) const { + // Create set corresponding to domain of access. + FlatAffineValueConstraints domain; + if (failed(getOpIndexSet(opInst, &domain))) + return failure(); + + // Get access relation from access map. + AffineValueMap accessValueMap; + getAccessMap(&accessValueMap); + if (failed(getRelationFromMap(accessValueMap, rel))) + return failure(); + + FlatAffineRelation domainRel(rel.getNumDomainDims(), /*numRangeDims=*/0, + domain); + + // Merge and align domain ids of `ret` and ids of `domain`. Since the domain + // of the access map is a subset of the domain of access, the domain ids of + // `ret` are guranteed to be a subset of ids of `domain`. + for (unsigned i = 0, e = domain.getNumDimIds(); i < e; ++i) { + unsigned loc; + if (rel.findId(domain.getValue(i), &loc)) { + rel.swapId(i, loc); + } else { + rel.insertDomainId(i); + rel.setValue(i, domain.getValue(i)); + } + } + + // Append domain constraints to `ret`. + domainRel.appendRangeId(rel.getNumRangeDims()); + domainRel.mergeLocalIds(rel); + domainRel.mergeSymbolIds(rel); + rel.append(domainRel); + + return success(); +} + // Populates 'accessMap' with composition of AffineApplyOps reachable from // indices of MemRefAccess. void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { @@ -892,17 +475,16 @@ // common to both accesses (see Dependence in AffineAnalysis.h for details). // // The memref access dependence check is comprised of the following steps: -// *) Compute access functions for each access. Access functions are computed -// using AffineValueMaps initialized with the indices from an access, then -// composed with AffineApplyOps reachable from operands of that access, -// until operands of the AffineValueMap are loop IVs or symbols. -// *) Build iteration domain constraints for each access. Iteration domain -// constraints are pairs of inequality constraints representing the -// upper/lower loop bounds for each AffineForOp in the loop nest associated -// with each access. -// *) Build dimension and symbol position maps for each access, which map -// Values from access functions and iteration domains to their position -// in the merged constraint system built by this method. +// *) Build access relation for each access. An access relation maps elements +// of an iteration domain to the element(s) of an array domain accessed by +// that iteration of the associated statement through some array reference. +// *) Compute the dependence relation by composing access relation of +// `srcAccess` with the inverse of access relation of `dstAccess`. +// Doing this builds a relation between iteration domain of `srcAccess` +// to the iteration domain of `dstAccess` which access the same memory +// location. +// *) Add ordering constraints for `srcAccess` to be accessed before +// `dstAccess`. // // This method builds a constraint system with the following column format: // @@ -929,34 +511,34 @@ // } // } // -// The access functions would be the following: -// -// src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M) -// dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K) -// -// The iteration domains for the src/dst accesses would be the following: +// The access relation for `srcAccess` would be the following: // -// src: 0 <= %i0 <= 100, 0 <= %i1 <= 50 -// dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50 +// [src_dim0, src_dim1, mem_dim0, mem_dim1, %N, %M, const] +// 2 -4 -1 0 1 0 0 = 0 +// 0 3 0 -1 0 -1 0 = 0 +// 1 0 0 0 0 0 0 >= 0 +// -1 0 0 0 0 0 100 >= 0 +// 0 1 0 0 0 0 0 >= 0 +// 0 -1 0 0 0 0 50 >= 0 // -// The symbols by both accesses would be assigned to a canonical position order -// which will be used in the dependence constraint system: +// The access relation for `dstAccess` would be the following: // -// symbol name: %M %N %K -// symbol pos: 0 1 2 +// [dst_dim0, dst_dim1, mem_dim0, mem_dim1, %M, %K, const] +// 7 9 -1 0 -1 0 0 = 0 +// 0 11 0 -1 0 -1 0 = 0 +// 1 0 0 0 0 0 0 >= 0 +// -1 0 0 0 0 0 100 >= 0 +// 0 1 0 0 0 0 0 >= 0 +// 0 -1 0 0 0 0 50 >= 0 // -// Equality constraints are built by equating each result of src/destination -// access functions. For this example, the following two equality constraints -// will be added to the dependence constraint system: +// The equalities in the above relations correspond to the access maps while +// the inequalities corresspond to the iteration domain constraints. // -// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] -// 2 -4 -7 -9 1 1 0 0 = 0 -// 0 3 0 -11 -1 0 1 0 = 0 +// The dependence relation formed: // -// Inequality constraints from the iteration domain will be meged into -// the dependence constraint system -// -// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] +// [src_dim0, src_dim1, dst_dim0, dst_dim1, %M, %N, %K, const] +// 2 -4 -7 -9 1 1 0 0 = 0 +// 0 3 0 -11 -1 0 1 0 = 0 // 1 0 0 0 0 0 0 0 >= 0 // -1 0 0 0 0 0 0 100 >= 0 // 0 1 0 0 0 0 0 0 >= 0 @@ -987,24 +569,16 @@ !isa(dstAccess.opInst)) return DependenceResult::NoDependence; - // Get composed access function for 'srcAccess'. - AffineValueMap srcAccessMap; - srcAccess.getAccessMap(&srcAccessMap); - - // Get composed access function for 'dstAccess'. - AffineValueMap dstAccessMap; - dstAccess.getAccessMap(&dstAccessMap); - - // Get iteration domain for the 'srcAccess' operation. - FlatAffineValueConstraints srcDomain; - if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain))) + // Create access relation from each MemRefAccess. + FlatAffineRelation srcRel, dstRel; + if (failed(srcAccess.getAccessRelation(srcRel))) return DependenceResult::Failure; - - // Get iteration domain for 'dstAccess' operation. - FlatAffineValueConstraints dstDomain; - if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain))) + if (failed(dstAccess.getAccessRelation(dstRel))) return DependenceResult::Failure; + FlatAffineValueConstraints srcDomain = srcRel.getDomainSet(); + FlatAffineValueConstraints dstDomain = dstRel.getDomainSet(); + // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor // operation of 'srcAccess' does not properly dominate the ancestor // operation of 'dstAccess' in the same common operation block. @@ -1017,42 +591,27 @@ numCommonLoops)) { return DependenceResult::NoDependence; } - // Build dim and symbol position maps for each access from access operand - // Value to position in merged constraint system. - ValuePositionMap valuePosMap; - buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, - dstAccessMap, &valuePosMap, - dependenceConstraints); - initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap, - valuePosMap, dependenceConstraints); - - assert(valuePosMap.getNumDims() == - srcDomain.getNumDimIds() + dstDomain.getNumDimIds()); - // Create memref access constraint by equating src/dst access functions. - // Note that this check is conservative, and will fail in the future when - // local variables for mod/div exprs are supported. - if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, - dependenceConstraints))) - return DependenceResult::Failure; + // Compute the dependence relation by composing `srcRel` with the inverse of + // `dstRel`. Doing this builds a relation between iteration domain of + // `srcAccess` to the iteration domain of `dstAccess` which access the same + // memory locations. + dstRel.inverse(); + dstRel.compose(srcRel); + *dependenceConstraints = dstRel; // Add 'src' happens before 'dst' ordering constraints. addOrderingConstraints(srcDomain, dstDomain, loopDepth, dependenceConstraints); - // Add src and dst domain constraints. - addDomainConstraints(srcDomain, dstDomain, valuePosMap, - dependenceConstraints); // Return 'NoDependence' if the solution space is empty: no dependence. - if (dependenceConstraints->isEmpty()) { + if (dependenceConstraints->isEmpty()) return DependenceResult::NoDependence; - } // Compute dependence direction vector and return true. - if (dependenceComponents != nullptr) { + if (dependenceComponents != nullptr) computeDirectionVector(srcDomain, dstDomain, loopDepth, dependenceConstraints, dependenceComponents); - } LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n"); LLVM_DEBUG(dependenceConstraints->dump()); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1851,6 +1851,26 @@ } } +void FlatAffineConstraints::convertDimToLocal(unsigned dimStart, + unsigned dimLimit) { + assert(dimLimit <= getNumDimIds() && "Invalid dim pos range"); + + if (dimStart >= dimLimit) + return; + + // Append new local variables corresponding to the dimensions to be converted. + unsigned convertCount = dimLimit - dimStart; + unsigned newLocalIdStart = getNumIds(); + appendLocalId(convertCount); + + // Swap the new local variables with dimensions. + for (unsigned i = 0; i < convertCount; ++i) + swapId(i + dimStart, i + newLocalIdStart); + + // Remove dimensions converted to local variables. + removeIdRange(dimStart, dimLimit); +} + std::pair FlatAffineConstraints::getLowerAndUpperBound( unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const { @@ -3542,3 +3562,168 @@ return map.replaceDimsAndSymbols(dimReplacements, symReplacements, dims.size(), numSymbols); } + +FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { + FlatAffineValueConstraints domain = *this; + // Convert all range variables to local variables. + domain.convertDimToLocal(getNumDomainDims(), + getNumDomainDims() + getNumRangeDims()); + return domain; +} + +FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { + FlatAffineValueConstraints range = *this; + // Convert all domain variables to local variables. + range.convertDimToLocal(0, getNumDomainDims()); + return range; +} + +void FlatAffineRelation::compose(const FlatAffineRelation &other) { + assert(getNumDomainDims() == other.getNumRangeDims() && + "Domain of this and range of other do not match"); + assert(std::equal(values.begin(), values.begin() + getNumDomainDims(), + other.values.begin() + other.getNumDomainDims()) && + "Domain of this and range of other do not match"); + + FlatAffineRelation rel = other; + mergeSymbolIds(rel); + mergeLocalIds(rel); + + // Convert domain of `this` and range of `rel` to local identifiers. + convertDimToLocal(0, getNumDomainDims()); + rel.convertDimToLocal(rel.getNumDomainDims(), rel.getNumDimIds()); + // Add dimensions such that both relations become `domainRel -> rangeThis`. + appendDomainId(rel.getNumDomainDims()); + rel.appendRangeId(getNumRangeDims()); + + auto thisMaybeValues = getMaybeDimValues(); + auto relMaybeValues = rel.getMaybeDimValues(); + + // Add and match domain of `rel` to domain of `this`. + for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) + if (relMaybeValues[i].hasValue()) + setValue(i, relMaybeValues[i].getValue()); + // Add and match range of `this` to range of `rel`. + for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) { + unsigned rangeIdx = rel.getNumDomainDims() + i; + if (thisMaybeValues[rangeIdx].hasValue()) + rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue()); + } + + // Append `this` to `rel` and simplify constraints. + rel.append(*this); + rel.removeRedundantLocalVars(); + + *this = rel; +} + +void FlatAffineRelation::inverse() { + unsigned oldDomain = getNumDomainDims(); + unsigned oldRange = getNumRangeDims(); + // Add new range ids. + appendRangeId(oldDomain); + // Swap new ids with domain. + for (unsigned i = 0; i < oldDomain; ++i) + swapId(i, oldDomain + oldRange + i); + // Remove the swapped domain. + removeIdRange(0, oldDomain); + // Set domain and range as inverse. + numDomainDims = oldRange; + numRangeDims = oldDomain; +} + +void FlatAffineRelation::insertDomainId(unsigned pos, unsigned num) { + assert(pos <= getNumDomainDims() && + "Id cannot be inserted at invalid position"); + insertDimId(pos, num); + numDomainDims += num; +} + +void FlatAffineRelation::insertRangeId(unsigned pos, unsigned num) { + assert(pos <= getNumRangeDims() && + "Id cannot be inserted at invalid position"); + insertDimId(getNumDomainDims() + pos, num); + numRangeDims += num; +} + +void FlatAffineRelation::appendDomainId(unsigned num) { + insertDimId(getNumDomainDims(), num); + numDomainDims += num; +} + +void FlatAffineRelation::appendRangeId(unsigned num) { + insertDimId(getNumDimIds(), num); + numRangeDims += num; +} + +void FlatAffineRelation::removeIdRange(unsigned idStart, unsigned idLimit) { + if (idStart >= idLimit) + return; + + // Compute number of domain and range identifiers to remove. This is done by + // intersecting the range of domain/range ids with range of ids to remove. + unsigned intersectDomainLHS = std::min(idLimit, getNumDomainDims()); + unsigned intersectDomainRHS = idStart; + unsigned intersectRangeLHS = std::min(idLimit, getNumDimIds()); + unsigned intersectRangeRHS = std::max(idStart, getNumDomainDims()); + + FlatAffineValueConstraints::removeIdRange(idStart, idLimit); + + if (intersectDomainLHS > intersectDomainRHS) + numDomainDims -= intersectDomainLHS - intersectDomainRHS; + if (intersectRangeLHS > intersectRangeRHS) + numRangeDims -= intersectRangeLHS - intersectRangeRHS; +} + +LogicalResult mlir::getRelationFromMap(AffineMap &map, + FlatAffineRelation &rel) { + // Get flattened affine expressions. + std::vector> flatExprs; + FlatAffineConstraints localVarCst; + if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) + return failure(); + + unsigned oldDimNum = localVarCst.getNumDimIds(); + unsigned oldCols = localVarCst.getNumCols(); + unsigned numRangeIds = map.getNumResults(); + unsigned numDomainIds = map.getNumDims(); + + // Add range as the new expressions. + localVarCst.appendDimId(numRangeIds); + + // Add equalities between source and range. + SmallVector eq(localVarCst.getNumCols()); + for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { + // Zero fill. + std::fill(eq.begin(), eq.end(), 0); + // Fill equality. + for (unsigned j = 0, f = oldDimNum; j < f; ++j) + eq[j] = flatExprs[i][j]; + for (unsigned j = oldDimNum, f = oldCols; j < f; ++j) + eq[j + numRangeIds] = flatExprs[i][j]; + // Set this dimension to -1 to equate lhs and rhs and add equality. + eq[numDomainIds + i] = -1; + localVarCst.addEquality(eq); + } + + // Create relation and return success. + rel = FlatAffineRelation(numDomainIds, numRangeIds, localVarCst); + return success(); +} + +LogicalResult mlir::getRelationFromMap(const AffineValueMap &map, + FlatAffineRelation &rel) { + + AffineMap affineMap = map.getAffineMap(); + if (failed(getRelationFromMap(affineMap, rel))) + return failure(); + + // Set symbol values for domain dimensions and symbols. + for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) + rel.setValue(i, map.getOperand(i)); + for (unsigned i = rel.getNumDimIds(), e = rel.getNumDimAndSymbolIds(); i < e; + ++i) + rel.setValue(i, map.getOperand(i - rel.getNumRangeDims())); + + return success(); +}