diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -26,6 +26,7 @@ class AffineForOp; class AffineValueMap; class FlatAffineValueConstraints; +class FlatAffineRelation; class Operation; /// A description of a (parallelizable) reduction in an affine loop. @@ -85,6 +86,29 @@ // Returns true if this access is of a store op. bool isStore() const; + /// Creates an access relation for the access. An access relation maps + /// elements of an iteration domain to the element(s) of an array domain + /// accessed by that iteration of the associated statement through some array + /// reference. For example, given the MLIR code: + /// + /// affine.for %i0 = 0 to 10 { + /// affine.for %i1 = 0 to 10 { + /// %a = affine.load %arr[%i0 + %i1, %i0 + 2 * %i1] : memref<100x100xf32> + /// } + /// } + /// + /// The access relation, assuming that the memory locations for %arr are + /// represented as %m0, %m1 would be: + /// + /// (%i0, %i1) -> (%m0, %m1) + /// %m0 = %i0 + %i1 + /// %m1 = %i0 + 2 * %i1 + /// 0 <= %i0 < 10 + /// 0 <= %i1 < 10 + /// + /// Returns failure if the access relation coulkd not be created. + LogicalResult getAccessRelation(FlatAffineRelation &accessRel) const; + /// Populates 'accessMap' with composition of AffineApplyOps reachable from /// 'indices'. void getAccessMap(AffineValueMap *accessMap) const; diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -416,6 +416,22 @@ /// O(VC) time. void removeRedundantConstraints(); + /// Removes local variables using equalities. Each equality is checked if it + /// can be reduced to the form: `e = affine-expr`, where `e` is a local + /// variables and `affine-expr` is an affine expression not containing `e`. + /// If an equality satisfies this form, the local variable is replaced in + /// each constraint and then removed. + void removeRedundantLocalVars(); + + /// Converts identifiers in the column range [idStart, idLimit) to local + /// variables + void convertDimToLocal(unsigned dimStart, unsigned dimLimit); + + /// Merge local ids of `this` and `other`. This is done by appending local ids + /// of `other` to `this` and inserting local ids of `this` to `other` at start + /// of its local ids. + void toCommonLocalSpace(FlatAffineConstraints &other); + /// Removes all equalities and inequalities. void clearConstraints(); @@ -552,6 +568,16 @@ numLocals + 1, numDims, numSymbols, numLocals, valArgs) {} + FlatAffineValueConstraints(const FlatAffineConstraints &fac, + ArrayRef> valArgs = {}) + : FlatAffineConstraints(fac) { + assert(valArgs.empty() || valArgs.size() == numIds); + if (valArgs.empty()) + values.resize(numIds, None); + else + values.append(valArgs.begin(), valArgs.end()); + } + /// Create a flat affine constraint system from an AffineValueMap or a list of /// these. The constructed system will only include equalities. explicit FlatAffineValueConstraints(const AffineValueMap &avm); @@ -678,7 +704,8 @@ using FlatAffineConstraints::insertDimId; unsigned insertSymbolId(unsigned pos, ValueRange vals); using FlatAffineConstraints::insertSymbolId; - unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; + virtual unsigned insertId(IdKind kind, unsigned pos, + unsigned num = 1) override; unsigned insertId(IdKind kind, unsigned pos, ValueRange vals); /// Append identifiers of the specified kind after the last identifier of that @@ -825,6 +852,11 @@ setValue(i, values[i - start]); } + /// Merge and align symbols of `this` and `other` such that both get union of + /// of symbols that are unique. Symbols with Value as `None` are considered + /// to be inequal to all other symbols. + void toCommonSymbolSpace(FlatAffineValueConstraints &other); + protected: /// Returns false if the fields corresponding to various identifier counts, or /// equality/inequality buffer sizes aren't consistent; true otherwise. This @@ -834,7 +866,7 @@ /// Removes identifiers in the column range [idStart, idLimit), and copies any /// remaining valid data into place, updates member variables, and resizes /// arrays as needed. - void removeIdRange(unsigned idStart, unsigned idLimit) override; + virtual void removeIdRange(unsigned idStart, unsigned idLimit) override; /// Eliminates the identifier at the specified position using Fourier-Motzkin /// variable elimination, but uses Gaussian elimination if there is an @@ -853,6 +885,64 @@ SmallVector, 8> values; }; +/// A FlatAffineRelation represents a set of ordered pairs (domain -> range) +/// where "domain" and "range" are tuples of identifiers. +class FlatAffineRelation : public FlatAffineValueConstraints { +public: + FlatAffineRelation(unsigned numReservedInequalities, + unsigned numReservedEqualities, unsigned numReservedCols, + unsigned numDomainDims, unsigned numRangeDims, + unsigned numSymbols, unsigned numLocals, + ArrayRef> valArgs = {}) + : FlatAffineValueConstraints( + numReservedInequalities, numReservedEqualities, numReservedCols, + numDomainDims + numRangeDims, numSymbols, numLocals, valArgs), + numDomainDims(numDomainDims), numRangeDims(numRangeDims) {} + + FlatAffineRelation(unsigned numDomainDims = 0, unsigned numRangeDims = 0, + unsigned numSymbols = 0, unsigned numLocals = 0) + : FlatAffineValueConstraints(numDomainDims + numRangeDims, numSymbols, + numLocals), + numDomainDims(numDomainDims), numRangeDims(numRangeDims) {} + + FlatAffineRelation(unsigned numDomainDims, unsigned numRangeDims, + FlatAffineValueConstraints &fac) + : FlatAffineValueConstraints(fac), numDomainDims(numDomainDims), + numRangeDims(numRangeDims) {} + + FlatAffineRelation(unsigned numDomainDims, unsigned numRangeDims, + FlatAffineConstraints &fac) + : FlatAffineValueConstraints(fac), numDomainDims(numDomainDims), + numRangeDims(numRangeDims) {} + + /// Returns a set corresponding to the domain/range of the affine relation. + FlatAffineValueConstraints getDomainSet() const; + FlatAffineValueConstraints getRangeSet() const; + + inline unsigned getNumDomainDims() const { return numDomainDims; } + inline unsigned getNumRangeDims() const { return numRangeDims; } + + /// Given affine relations `other: (domainOther -> rangeOther)` and + /// `this: (domainThis -> rangeThis)`, this operation takes the composition of + /// `other` on `this`: `rel(this): (domainOther -> rangeThis)` + void compose(const FlatAffineRelation &other); + + /// Swap domain and range of the relation + /// (domain -> range) is converted to (range -> domain) + void inverse(); + + void appendDomainId(unsigned num = 1); + void appendRangeId(unsigned num = 1); + +protected: + unsigned numDomainDims, numRangeDims; + + /// Removes identifiers in the column range [idStart, idLimit), and copies any + /// remaining valid data into place, updates member variables, and resizes + /// arrays as needed. + void removeIdRange(unsigned idStart, unsigned idLimit) override; +}; + /// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the /// dimensions, symbols, and additional variables that represent floor divisions /// of dimensions, symbols, and in turn other floor divisions. Returns failure diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -261,460 +261,6 @@ return getIndexSet(ops, indexSet); } -namespace { -// ValuePositionMap manages the mapping from Values which represent dimension -// and symbol identifiers from 'src' and 'dst' access functions to positions -// in new space where some Values are kept separate (using addSrc/DstValue) -// and some Values are merged (addSymbolValue). -// Position lookups return the absolute position in the new space which -// has the following format: -// -// [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers] -// -// Note: access function non-IV dimension identifiers (that have 'dimension' -// positions in the access function position space) are assigned as symbols -// in the output position space. Convenience access functions which lookup -// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle -// the common case of resolving positions for all access function operands. -// -// TODO: Generalize this: could take a template parameter for the number of maps -// (3 in the current case), and lookups could take indices of maps to check. So -// getSrcDimOrSymPos would be "getPos(value, {0, 2})". -class ValuePositionMap { -public: - void addSrcValue(Value value) { - if (addValueAt(value, &srcDimPosMap, numSrcDims)) - ++numSrcDims; - } - void addDstValue(Value value) { - if (addValueAt(value, &dstDimPosMap, numDstDims)) - ++numDstDims; - } - void addSymbolValue(Value value) { - if (addValueAt(value, &symbolPosMap, numSymbols)) - ++numSymbols; - } - unsigned getSrcDimOrSymPos(Value value) const { - return getDimOrSymPos(value, srcDimPosMap, 0); - } - unsigned getDstDimOrSymPos(Value value) const { - return getDimOrSymPos(value, dstDimPosMap, numSrcDims); - } - unsigned getSymPos(Value value) const { - auto it = symbolPosMap.find(value); - assert(it != symbolPosMap.end()); - return numSrcDims + numDstDims + it->second; - } - - unsigned getNumSrcDims() const { return numSrcDims; } - unsigned getNumDstDims() const { return numDstDims; } - unsigned getNumDims() const { return numSrcDims + numDstDims; } - unsigned getNumSymbols() const { return numSymbols; } - -private: - bool addValueAt(Value value, DenseMap *posMap, - unsigned position) { - auto it = posMap->find(value); - if (it == posMap->end()) { - (*posMap)[value] = position; - return true; - } - return false; - } - unsigned getDimOrSymPos(Value value, - const DenseMap &dimPosMap, - unsigned dimPosOffset) const { - auto it = dimPosMap.find(value); - if (it != dimPosMap.end()) { - return dimPosOffset + it->second; - } - it = symbolPosMap.find(value); - assert(it != symbolPosMap.end()); - return numSrcDims + numDstDims + it->second; - } - - unsigned numSrcDims = 0; - unsigned numDstDims = 0; - unsigned numSymbols = 0; - DenseMap srcDimPosMap; - DenseMap dstDimPosMap; - DenseMap symbolPosMap; -}; -} // namespace - -// Builds a map from Value to identifier position in a new merged identifier -// list, which is the result of merging dim/symbol lists from src/dst -// iteration domains, the format of which is as follows: -// -// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term] -// -// This method populates 'valuePosMap' with mappings from operand Values in -// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain') -// to the position of these values in the merged list. -static void buildDimAndSymbolPositionMaps( - const FlatAffineValueConstraints &srcDomain, - const FlatAffineValueConstraints &dstDomain, - const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, - ValuePositionMap *valuePosMap, - FlatAffineValueConstraints *dependenceConstraints) { - - // IsDimState is a tri-state boolean. It is used to distinguish three - // different cases of the values passed to updateValuePosMap. - // - When it is TRUE, we are certain that all values are dim values. - // - When it is FALSE, we are certain that all values are symbol values. - // - When it is UNKNOWN, we need to further check whether the value is from a - // loop IV to determine its type (dim or symbol). - - // We need this enumeration because sometimes we cannot determine whether a - // Value is a symbol or a dim by the information from the Value itself. If a - // Value appears in an affine map of a loop, we can determine whether it is a - // dim or not by the function `isForInductionVar`. But when a Value is in the - // affine set of an if-statement, there is no way to identify its category - // (dim/symbol) by itself. Fortunately, the Values to be inserted into - // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such - // information of Value category: `srcDomain` and `dstDomain` organize Values - // by their category, such that the position of each Value stored in - // `srcDomain` and `dstDomain` marks which category that a Value belongs to. - // Therefore, we can separate Values into dim and symbol groups before passing - // them to the function `updateValuePosMap`. Specifically, when passing the - // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE. - // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are - // not explicitly categorized into dim or symbol, and we have to rely on - // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in - // this case. - enum IsDimState { TRUE, FALSE, UNKNOWN }; - - // This function places each given Value (in `values`) under a respective - // category in `valuePosMap`. Specifically, the placement rules are: - // 1) If `isDim` is FALSE, then every value in `values` are inserted into - // `valuePosMap` as symbols. - // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an - // induction variable of a for-loop, we treat it as symbol as well. - // 3) For other cases, we decide whether to add a value to the `src` or the - // `dst` section of the dim category simply by the boolean value `isSrc`. - auto updateValuePosMap = [&](ArrayRef values, bool isSrc, - IsDimState isDim) { - for (unsigned i = 0, e = values.size(); i < e; ++i) { - auto value = values[i]; - if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) { - assert(isValidSymbol(value) && - "access operand has to be either a loop IV or a symbol"); - valuePosMap->addSymbolValue(value); - } else { - if (isSrc) - valuePosMap->addSrcValue(value); - else - valuePosMap->addDstValue(value); - } - } - }; - - // Collect values from the src and dst domains. For each domain, we separate - // the collected values into dim and symbol parts. - SmallVector srcDimValues, dstDimValues, srcSymbolValues, - dstSymbolValues; - srcDomain.getValues(0, srcDomain.getNumDimIds(), &srcDimValues); - dstDomain.getValues(0, dstDomain.getNumDimIds(), &dstDimValues); - srcDomain.getValues(srcDomain.getNumDimIds(), - srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); - dstDomain.getValues(dstDomain.getNumDimIds(), - dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues); - - // Update value position map with dim values from src iteration domain. - updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE); - // Update value position map with dim values from dst iteration domain. - updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE); - // Update value position map with symbols from src iteration domain. - updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE); - // Update value position map with symbols from dst iteration domain. - updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE); - // Update value position map with identifiers from src access function. - updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true, - /*isDim=*/UNKNOWN); - // Update value position map with identifiers from dst access function. - updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false, - /*isDim=*/UNKNOWN); -} - -// Sets up dependence constraints columns appropriately, in the format: -// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term] -static void -initDependenceConstraints(const FlatAffineValueConstraints &srcDomain, - const FlatAffineValueConstraints &dstDomain, - const AffineValueMap &srcAccessMap, - const AffineValueMap &dstAccessMap, - const ValuePositionMap &valuePosMap, - FlatAffineValueConstraints *dependenceConstraints) { - // Calculate number of equalities/inequalities and columns required to - // initialize FlatAffineValueConstraints for 'dependenceDomain'. - unsigned numIneq = - srcDomain.getNumInequalities() + dstDomain.getNumInequalities(); - AffineMap srcMap = srcAccessMap.getAffineMap(); - assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults()); - unsigned numEq = srcMap.getNumResults(); - unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds(); - unsigned numSymbols = valuePosMap.getNumSymbols(); - unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds(); - unsigned numIds = numDims + numSymbols + numLocals; - unsigned numCols = numIds + 1; - - // Set flat affine constraints sizes and reserving space for constraints. - dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols, - numLocals); - - // Set values corresponding to dependence constraint identifiers. - SmallVector srcLoopIVs, dstLoopIVs; - srcDomain.getValues(0, srcDomain.getNumDimIds(), &srcLoopIVs); - dstDomain.getValues(0, dstDomain.getNumDimIds(), &dstLoopIVs); - - dependenceConstraints->setValues(0, srcLoopIVs.size(), srcLoopIVs); - dependenceConstraints->setValues( - srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs); - - // Set values for the symbolic identifier dimensions. `isSymbolDetermined` - // indicates whether we are certain that the `values` passed in are all - // symbols. If `isSymbolDetermined` is true, then we treat every Value in - // `values` as a symbol; otherwise, we let the function `isForInductionVar` to - // distinguish whether a Value in `values` is a symbol or not. - auto setSymbolIds = [&](ArrayRef values, - bool isSymbolDetermined = true) { - for (auto value : values) { - if (isSymbolDetermined || !isForInductionVar(value)) { - assert(isValidSymbol(value) && "expected symbol"); - dependenceConstraints->setValue(valuePosMap.getSymPos(value), value); - } - } - }; - - // We are uncertain about whether all operands in `srcAccessMap` and - // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false. - setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false); - setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false); - - SmallVector srcSymbolValues, dstSymbolValues; - srcDomain.getValues(srcDomain.getNumDimIds(), - srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); - dstDomain.getValues(dstDomain.getNumDimIds(), - dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues); - // Since we only take symbol Values out of `srcDomain` and `dstDomain`, - // `isSymbolDetermined` is kept to its default value: true. - setSymbolIds(srcSymbolValues); - setSymbolIds(dstSymbolValues); - - for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds(); - i < e; i++) - assert(dependenceConstraints->hasValue(i)); -} - -// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into -// 'dependenceDomain'. -// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a -// srcDomain/dstDomain Value maps. -static void addDomainConstraints(const FlatAffineValueConstraints &srcDomain, - const FlatAffineValueConstraints &dstDomain, - const ValuePositionMap &valuePosMap, - FlatAffineValueConstraints *dependenceDomain) { - unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds(); - - SmallVector cst(dependenceDomain->getNumCols()); - - auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) { - const FlatAffineValueConstraints &domain = isSrc ? srcDomain : dstDomain; - unsigned numCsts = - isEq ? domain.getNumEqualities() : domain.getNumInequalities(); - unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds(); - auto at = [&](unsigned i, unsigned j) -> int64_t { - return isEq ? domain.atEq(i, j) : domain.atIneq(i, j); - }; - auto map = [&](unsigned i) -> int64_t { - return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getValue(i)) - : valuePosMap.getDstDimOrSymPos(domain.getValue(i)); - }; - - for (unsigned i = 0; i < numCsts; ++i) { - // Zero fill. - std::fill(cst.begin(), cst.end(), 0); - // Set coefficients for identifiers corresponding to domain. - for (unsigned j = 0; j < numDimAndSymbolIds; ++j) - cst[map(j)] = at(i, j); - // Local terms. - for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++) - cst[depNumDimsAndSymbolIds + localOffset + j] = - at(i, numDimAndSymbolIds + j); - // Set constant term. - cst[cst.size() - 1] = at(i, domain.getNumCols() - 1); - // Add constraint. - if (isEq) - dependenceDomain->addEquality(cst); - else - dependenceDomain->addInequality(cst); - } - }; - - // Add equalities from src domain. - addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0); - // Add inequalities from src domain. - addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0); - // Add equalities from dst domain. - addDomain(/*isSrc=*/false, /*isEq=*/true, - /*localOffset=*/srcDomain.getNumLocalIds()); - // Add inequalities from dst domain. - addDomain(/*isSrc=*/false, /*isEq=*/false, - /*localOffset=*/srcDomain.getNumLocalIds()); -} - -// Adds equality constraints that equate src and dst access functions -// represented by 'srcAccessMap' and 'dstAccessMap' for each result. -// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count. -// For example, given the following two accesses functions to a 2D memref: -// -// Source access function: -// (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2) -// -// Destination access function: -// (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2) -// -// This method constructs the following equality constraints in -// 'dependenceDomain', by equating the access functions for each result -// (i.e. each memref dim). Notice that 'd0' for the destination access function -// is mapped into 'd0' in the equality constraint: -// -// d0 d1 s0 c -// -- -- -- -- -// a0 -c0 (a1 - c1) (a2 - c2) = 0 -// b0 -f0 (b1 - f1) (b2 - f2) = 0 -// -// Returns failure if any AffineExpr cannot be flattened (due to it being -// semi-affine). Returns success otherwise. -static LogicalResult -addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, - const AffineValueMap &dstAccessMap, - const ValuePositionMap &valuePosMap, - FlatAffineValueConstraints *dependenceDomain) { - AffineMap srcMap = srcAccessMap.getAffineMap(); - AffineMap dstMap = dstAccessMap.getAffineMap(); - assert(srcMap.getNumResults() == dstMap.getNumResults()); - unsigned numResults = srcMap.getNumResults(); - - unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); - ArrayRef srcOperands = srcAccessMap.getOperands(); - - unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); - ArrayRef dstOperands = dstAccessMap.getOperands(); - - std::vector> srcFlatExprs; - std::vector> destFlatExprs; - FlatAffineValueConstraints srcLocalVarCst, destLocalVarCst; - // Get flattened expressions for the source destination maps. - if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) || - failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst))) - return failure(); - - unsigned domNumLocalIds = dependenceDomain->getNumLocalIds(); - unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds(); - unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds(); - unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds; - dependenceDomain->appendLocalId(numLocalIdsToAdd); - - unsigned numDims = dependenceDomain->getNumDimIds(); - unsigned numSymbols = dependenceDomain->getNumSymbolIds(); - unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds(); - unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds; - - // Equality to add. - SmallVector eq(dependenceDomain->getNumCols()); - for (unsigned i = 0; i < numResults; ++i) { - // Zero fill. - std::fill(eq.begin(), eq.end(), 0); - - // Flattened AffineExpr for src result 'i'. - const auto &srcFlatExpr = srcFlatExprs[i]; - // Set identifier coefficients from src access function. - for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) - eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j]; - // Local terms. - for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) - eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j]; - // Set constant term. - eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1]; - - // Flattened AffineExpr for dest result 'i'. - const auto &destFlatExpr = destFlatExprs[i]; - // Set identifier coefficients from dst access function. - for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) - eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j]; - // Local terms. - for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) - eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j]; - // Set constant term. - eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1]; - - // Add equality constraint. - dependenceDomain->addEquality(eq); - } - - // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = [&](ArrayRef operands) { - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (isForInductionVar(operands[i])) - continue; - auto symbol = operands[i]; - assert(isValidSymbol(symbol)); - // Check if the symbol is a constant. - if (auto cOp = symbol.getDefiningOp()) - dependenceDomain->addBound(FlatAffineConstraints::EQ, - valuePosMap.getSymPos(symbol), - cOp.getValue()); - } - }; - - // Add equality constraints for any src symbols defined by constant ops. - addEqForConstOperands(srcOperands); - // Add equality constraints for any dst symbols defined by constant ops. - addEqForConstOperands(dstOperands); - - // By construction (see flattener), local var constraints will not have any - // equalities. - assert(srcLocalVarCst.getNumEqualities() == 0 && - destLocalVarCst.getNumEqualities() == 0); - // Add inequalities from srcLocalVarCst and destLocalVarCst into the - // dependence domain. - SmallVector ineq(dependenceDomain->getNumCols()); - for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) { - std::fill(ineq.begin(), ineq.end(), 0); - - // Set identifier coefficients from src local var constraints. - for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) - ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = - srcLocalVarCst.atIneq(r, j); - // Local terms. - for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) - ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j); - // Set constant term. - ineq[ineq.size() - 1] = - srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1); - dependenceDomain->addInequality(ineq); - } - - for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) { - std::fill(ineq.begin(), ineq.end(), 0); - // Set identifier coefficients from dest local var constraints. - for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) - ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] = - destLocalVarCst.atIneq(r, j); - // Local terms. - for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) - ineq[newLocalIdOffset + numSrcLocalIds + j] = - destLocalVarCst.atIneq(r, dstNumIds + j); - // Set constant term. - ineq[ineq.size() - 1] = - destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1); - - dependenceDomain->addInequality(ineq); - } - return success(); -} - // Returns the number of outer loop common to 'src/dstDomain'. // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null. static unsigned @@ -893,6 +439,103 @@ } } +/// Add local variables and their constraints contained in `cst` obtained by +/// flattening accessValueMap to accessRel. The newly added local variables are +/// added at the end. Returns the offset from which new ids were added. +static unsigned addAccessLocalVars(FlatAffineRelation &accessRel, + const AffineValueMap &accessValueMap, + FlatAffineValueConstraints &cst) { + // Set Values in cst + for (unsigned i = 0, e = accessValueMap.getNumOperands(); i < e; ++i) + cst.setValue(i, accessValueMap.getOperand(i)); + // Add local variables to accessRel + unsigned localOffset = accessRel.getNumIds(); + accessRel.appendLocalId(cst.getNumLocalIds()); + + // Add inequalities from cst to accessRel + for (unsigned i = 0, e = cst.getNumInequalities(); i < e; ++i) { + SmallVector newIneq(accessRel.getNumCols(), 0); + // Set identifier coefficients + for (unsigned j = 0, e = cst.getNumDimAndSymbolIds(); j < e; ++j) { + unsigned operandPos; + accessRel.findId(cst.getValue(j), &operandPos); + newIneq[operandPos] = cst.atIneq(i, j); + } + + // Local terms. + for (unsigned j = 0, e = cst.getNumLocalIds(); j < e; j++) + newIneq[localOffset + j] = cst.atIneq(i, cst.getNumDimAndSymbolIds() + j); + // Set constant term. + newIneq[newIneq.size() - 1] = cst.atIneq(i, cst.getNumCols() - 1); + accessRel.addInequality(newIneq); + } + + return localOffset; +} + +LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &ret) const { + // Create domain of access + FlatAffineValueConstraints domain; + if (failed(getOpIndexSet(opInst, &domain))) + return failure(); + + // Create range of access + AffineValueMap accessValueMap; + getAccessMap(&accessValueMap); + FlatAffineValueConstraints range(accessValueMap.getNumResults(), + accessValueMap.getNumSymbols()); + for (unsigned i = 0, e = accessValueMap.getNumSymbols(); i < e; ++i) + range.setValue(range.getNumDimIds() + i, + accessValueMap.getOperand(i + accessValueMap.getNumDims())); + + // Build access relation + // accessRel: empty -> range + // domainRel: domain -> empty + // accessRel compose rangeRel: domain -> range + FlatAffineRelation accessRel(0, range.getNumDimIds(), range); + FlatAffineRelation domainRel(domain.getNumDimIds(), 0, domain); + accessRel.compose(domainRel); + + // Get flattened expressions + std::vector> flatExprs; + FlatAffineValueConstraints localVarCst; + if (failed(getFlattenedAffineExprs(accessValueMap.getAffineMap(), &flatExprs, + &localVarCst))) + return failure(); + + // Add local ids from access map to accessRelation + unsigned newLocalIdOffset = + addAccessLocalVars(accessRel, accessValueMap, localVarCst); + + // Add access constraints to relation as equalities + ArrayRef operands = accessValueMap.getOperands(); + SmallVector eq(accessRel.getNumCols()); + for (unsigned i = 0, e = accessValueMap.getNumResults(); i < e; ++i) { + // Zero fill. + std::fill(eq.begin(), eq.end(), 0); + // Flattened AffineExpr for i^th range result. + const auto &flatExpr = flatExprs[i]; + // Set identifier coefficients from access map + for (unsigned j = 0, e = operands.size(); j < e; ++j) { + unsigned operandPos; + accessRel.findId(operands[j], &operandPos); + eq[operandPos] = flatExpr[j]; + } + + // Local terms. + for (unsigned j = 0, e = localVarCst.getNumLocalIds(); j < e; j++) + eq[newLocalIdOffset + j] = + flatExpr[localVarCst.getNumDimAndSymbolIds() + j]; + // Set constant term. + eq[eq.size() - 1] = flatExpr[flatExpr.size() - 1]; + // Set this to the i^th range identifier + eq[accessRel.getNumDomainDims() + i] = -1; + accessRel.addEquality(eq); + } + ret = accessRel; + return success(); +} + // Populates 'accessMap' with composition of AffineApplyOps reachable from // indices of MemRefAccess. void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { @@ -921,17 +564,15 @@ // common to both accesses (see Dependence in AffineAnalysis.h for details). // // The memref access dependence check is comprised of the following steps: -// *) Compute access functions for each access. Access functions are computed -// using AffineValueMaps initialized with the indices from an access, then -// composed with AffineApplyOps reachable from operands of that access, -// until operands of the AffineValueMap are loop IVs or symbols. -// *) Build iteration domain constraints for each access. Iteration domain -// constraints are pairs of inequality constraints representing the -// upper/lower loop bounds for each AffineForOp in the loop nest associated -// with each access. -// *) Build dimension and symbol position maps for each access, which map -// Values from access functions and iteration domains to their position -// in the merged constraint system built by this method. +// *) Build access relation for each access. An access relation maps elements +// of an iteration domain to the element(s) of an array domain accessed by +// that iteration of the associated statement through some array reference. +// *) Compute the dependence relation by composing access relation of +// `srcAccess` with the inverse of access relation of `dstAccess`. +// Doing this builds a relation between iteration domain of `srcAccess` +// to the iteration domain of `dstAccess` which access the same memory +// location. +// *) Add ordering constraints for `srcAccess` to be access before `dstAccess`. // // This method builds a constraint system with the following column format: // @@ -958,34 +599,34 @@ // } // } // -// The access functions would be the following: -// -// src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M) -// dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K) -// -// The iteration domains for the src/dst accesses would be the following: +// The access relation for `srcAccess` would be the following: // -// src: 0 <= %i0 <= 100, 0 <= %i1 <= 50 -// dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50 +// [src_dim0, src_dim1, mem_dim0, mem_dim1, %N, %M, const] +// 2 -4 -1 0 1 0 0 = 0 +// 0 3 0 -1 0 -1 0 = 0 +// 1 0 0 0 0 0 0 >= 0 +// -1 0 0 0 0 0 100 >= 0 +// 0 1 0 0 0 0 0 >= 0 +// 0 -1 0 0 0 0 50 >= 0 // -// The symbols by both accesses would be assigned to a canonical position order -// which will be used in the dependence constraint system: +// The access relation for `dstAccess` would be the following: // -// symbol name: %M %N %K -// symbol pos: 0 1 2 +// [dst_dim0, dst_dim1, mem_dim0, mem_dim1, %M, %K, const] +// 7 9 -1 0 -1 0 0 = 0 +// 0 11 0 -1 0 -1 0 = 0 +// 1 0 0 0 0 0 0 >= 0 +// -1 0 0 0 0 0 100 >= 0 +// 0 1 0 0 0 0 0 >= 0 +// 0 -1 0 0 0 0 50 >= 0 // -// Equality constraints are built by equating each result of src/destination -// access functions. For this example, the following two equality constraints -// will be added to the dependence constraint system: +// The equalities in the above relations correspond to the access maps while +// the inequalities corresspond to the iteration domain constraints. // -// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] -// 2 -4 -7 -9 1 1 0 0 = 0 -// 0 3 0 -11 -1 0 1 0 = 0 +// The dependence relation formed: // -// Inequality constraints from the iteration domain will be meged into -// the dependence constraint system -// -// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] +// [src_dim0, src_dim1, dst_dim0, dst_dim1, %M, %N, %K, const] +// 2 -4 -7 -9 1 1 0 0 = 0 +// 0 3 0 -11 -1 0 1 0 = 0 // 1 0 0 0 0 0 0 0 >= 0 // -1 0 0 0 0 0 0 100 >= 0 // 0 1 0 0 0 0 0 0 >= 0 @@ -1016,24 +657,16 @@ !isa(dstAccess.opInst)) return DependenceResult::NoDependence; - // Get composed access function for 'srcAccess'. - AffineValueMap srcAccessMap; - srcAccess.getAccessMap(&srcAccessMap); - - // Get composed access function for 'dstAccess'. - AffineValueMap dstAccessMap; - dstAccess.getAccessMap(&dstAccessMap); - - // Get iteration domain for the 'srcAccess' operation. - FlatAffineValueConstraints srcDomain; - if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain))) + // Create access relations + FlatAffineRelation srcRel, dstRel; + if (failed(srcAccess.getAccessRelation(srcRel))) return DependenceResult::Failure; - - // Get iteration domain for 'dstAccess' operation. - FlatAffineValueConstraints dstDomain; - if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain))) + if (failed(dstAccess.getAccessRelation(dstRel))) return DependenceResult::Failure; + FlatAffineValueConstraints srcDomain = srcRel.getDomainSet(); + FlatAffineValueConstraints dstDomain = dstRel.getDomainSet(); + // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor // operation of 'srcAccess' does not properly dominate the ancestor // operation of 'dstAccess' in the same common operation block. @@ -1046,42 +679,24 @@ numCommonLoops)) { return DependenceResult::NoDependence; } - // Build dim and symbol position maps for each access from access operand - // Value to position in merged constraint system. - ValuePositionMap valuePosMap; - buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, - dstAccessMap, &valuePosMap, - dependenceConstraints); - initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap, - valuePosMap, dependenceConstraints); - - assert(valuePosMap.getNumDims() == - srcDomain.getNumDimIds() + dstDomain.getNumDimIds()); - - // Create memref access constraint by equating src/dst access functions. - // Note that this check is conservative, and will fail in the future when - // local variables for mod/div exprs are supported. - if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, - dependenceConstraints))) - return DependenceResult::Failure; + + // Create dependence relation + dstRel.inverse(); + dstRel.compose(srcRel); + *dependenceConstraints = dstRel; // Add 'src' happens before 'dst' ordering constraints. addOrderingConstraints(srcDomain, dstDomain, loopDepth, dependenceConstraints); - // Add src and dst domain constraints. - addDomainConstraints(srcDomain, dstDomain, valuePosMap, - dependenceConstraints); // Return 'NoDependence' if the solution space is empty: no dependence. - if (dependenceConstraints->isEmpty()) { + if (dependenceConstraints->isEmpty()) return DependenceResult::NoDependence; - } // Compute dependence direction vector and return true. - if (dependenceComponents != nullptr) { + if (dependenceComponents != nullptr) computeDirectionVector(srcDomain, dstDomain, loopDepth, dependenceConstraints, dependenceComponents); - } LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n"); LLVM_DEBUG(dependenceConstraints->dump()); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -417,13 +417,11 @@ b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(), [](Optional id) { return id.hasValue(); })); - // Place local id's of A after local id's of B. - b->insertLocalId(/*pos=*/0, /*num=*/a->getNumLocalIds()); - a->appendLocalId(/*num=*/b->getNumLocalIds() - a->getNumLocalIds()); + // Bring A and B to common local space + a->toCommonLocalSpace(*b); - SmallVector aDimValues, aSymValues; + SmallVector aDimValues; a->getValues(offset, a->getNumDimIds(), &aDimValues); - a->getValues(a->getNumDimIds(), a->getNumDimAndSymbolIds(), &aSymValues); { // Merge dims from A into B. @@ -448,29 +446,8 @@ "expected same number of dims"); } - { - // Merge symbols: merge A's symbols into B first. - unsigned s = 0; - for (auto aSymValue : aSymValues) { - unsigned loc; - if (b->findId(aSymValue, &loc)) { - assert(loc >= b->getNumDimIds() && loc < b->getNumDimAndSymbolIds() && - "A's symbol appears in B's non-symbol position"); - b->swapId(s + b->getNumDimIds(), loc); - } else { - b->insertSymbolId(s, aSymValue); - } - s++; - } - // Symbols that are in B, but not in A, are added at the end. - for (unsigned t = a->getNumDimAndSymbolIds(), - e = b->getNumDimAndSymbolIds(); - t < e; t++) { - a->appendSymbolId(b->getValue(t)); - } - assert(a->getNumDimAndSymbolIds() == b->getNumDimAndSymbolIds() && - "expected same number of dims and symbols"); - } + // Merge and align symbols of A and B + a->toCommonSymbolSpace(*b); assert(areIdsAligned(*a, *b) && "IDs expected to be aligned"); } @@ -548,6 +525,38 @@ } } +/// Merge and align symbols of `this` and `other` such that both get union of +/// of symbols that are unique. Symbols with Value as `None` are considered +/// to be inequal to all other symbols. +void FlatAffineValueConstraints::toCommonSymbolSpace( + FlatAffineValueConstraints &other) { + SmallVector aSymValues; + getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues); + + // Merge symbols: merge symbols into `other` first from `this`. + unsigned s = other.getNumDimIds(); + for (auto aSymValue : aSymValues) { + unsigned loc; + // If the id is a symbol in `other`, then align it, otherwise assume that + // it is a new symbol + if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() && + loc < getNumDimAndSymbolIds()) + other.swapId(s, loc); + else + other.insertSymbolId(s - other.getNumDimIds(), aSymValue); + s++; + } + + // Symbols that are in other, but not in this, are added at the end. + for (unsigned t = other.getNumDimIds() + getNumSymbolIds(), + e = other.getNumDimAndSymbolIds(); + t < e; t++) + insertSymbolId(getNumSymbolIds(), other.getValue(t)); + + assert(getNumSymbolIds() == other.getNumSymbolIds() && + "expected same number of symbols"); +} + // Changes all symbol identifiers which are loop IVs to dim identifiers. void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() { // Gather all symbols which are loop IVs. @@ -951,6 +960,7 @@ if (tmpCst.hasInvalidConstraint()) return true; } + return false; } @@ -1757,6 +1767,78 @@ equalities.resizeVertically(pos); } +/// Removes local variables using equalities. Each equality is checked if it +/// can be reduced to the form: `e = affine-expr`, where `e` is a local +/// variables and `affine-expr` is an affine expression not containing `e`. +/// If an equality satisfies this form, the local variable is replaced in +/// each constraint and then removed. +void FlatAffineConstraints::removeRedundantLocalVars() { + bool change = true; + while (change) { + change = false; + for (int64_t i = 0; i < getNumEqualities(); ++i) { + bool foundOne = false; + unsigned eliminateVar; + for (unsigned j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j) { + if (std::abs(atEq(i, j)) == 1) { + foundOne = true; + eliminateVar = j; + break; + } + } + + if (!foundOne) + continue; + + change = true; + // Use equality to simplify other constraints + for (unsigned j = 0, f = getNumEqualities(); j < f; ++j) + eliminateFromConstraint(this, j, i, eliminateVar, eliminateVar, + /*isEq=*/true); + for (unsigned j = 0, f = getNumInequalities(); j < f; ++j) + eliminateFromConstraint(this, j, i, eliminateVar, eliminateVar, + /*isEq=*/false); + removeId(eliminateVar); + removeEquality(i); + --i; + normalizeConstraintsByGCD(); + break; + } + } +} + +/// Converts identifiers in the column range [idStart, idLimit) to local +/// variables +void FlatAffineConstraints::convertDimToLocal(unsigned dimStart, + unsigned dimLimit) { + assert(dimStart >= 0 && dimLimit <= getNumDimIds() && + "Invalid dim pos range"); + + if (dimStart >= dimLimit) + return; + + // Append new local variables + unsigned convertCount = dimLimit - dimStart; + unsigned newLocalIdStart = getNumIds(); + appendLocalId(convertCount); + + // Swap the new local variables with dimensions + for (unsigned i = 0; i < convertCount; ++i) + swapId(i + dimStart, i + newLocalIdStart); + + // Remove dimensions + removeIdRange(dimStart, dimLimit); +} + +/// Merge local ids of `this` and `other`. This is done by appending local ids +/// of `other` to `this` and inserting local ids of `this` to `other` at start +/// of its local ids. +void FlatAffineConstraints::toCommonLocalSpace(FlatAffineConstraints &other) { + unsigned initLocals = getNumLocalIds(); + insertLocalId(getNumLocalIds(), other.getNumLocalIds()); + other.insertLocalId(0, initLocals); +} + std::pair FlatAffineConstraints::getLowerAndUpperBound( unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const { @@ -3435,3 +3517,109 @@ return map.replaceDimsAndSymbols(dimReplacements, symReplacements, dims.size(), numSymbols); } + +FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { + FlatAffineValueConstraints domain = *this; + // Convert all range variables to local variables + domain.convertDimToLocal(getNumDomainDims(), + getNumDomainDims() + getNumRangeDims()); + domain.removeRedundantLocalVars(); + return domain; +} + +FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { + FlatAffineValueConstraints range = *this; + // Convert all domain variables to local variables + range.convertDimToLocal(0, getNumDomainDims()); + range.removeRedundantLocalVars(); + return range; +} + +/// Given affine relations `other: (domainOther -> rangeOther)` and +/// `this: (domainThis -> rangeThis)`, this operation takes the composition of +/// `other` on `this`: `rel(this): (domainOther -> rangeThis)` +void FlatAffineRelation::compose(const FlatAffineRelation &other) { + assert(getNumDomainDims() == other.getNumRangeDims() && + "Domain of this and range of other do not match"); + assert(std::equal(values.begin(), values.begin() + getNumDomainDims(), + other.values.begin() + other.getNumDomainDims()) && + "Domain of this and range of other do not match"); + + FlatAffineRelation rel = other; + // Bring `this` and `rel` to common symbol and local space + toCommonSymbolSpace(rel); + toCommonLocalSpace(rel); + // Convert domain of `this` and range of `rel` to local identifiers. + convertDimToLocal(0, getNumDomainDims()); + rel.convertDimToLocal(rel.getNumDomainDims(), + rel.getNumDomainDims() + rel.getNumRangeDims()); + // Add dimensions such that both relations become `domainRel -> rangeThis` + appendDomainId(rel.getNumDomainDims()); + rel.appendRangeId(getNumRangeDims()); + // Get values of domain and rel's range + auto thisMaybeValues = getMaybeDimValues(); + auto relMaybeValues = rel.getMaybeDimValues(); + + // Add and match domain of `rel` to domain of `this` + for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) + if (relMaybeValues[i].hasValue()) + setValue(i, relMaybeValues[i].getValue()); + // Add and match range of `this` to range of `rel` + for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) { + unsigned rangeIdx = rel.getNumDomainDims() + i; + if (thisMaybeValues[rangeIdx].hasValue()) + rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue()); + } + + // Append `rel` to `this` + append(rel); + removeRedundantLocalVars(); +} + +/// Swap domain and range of the relation +void FlatAffineRelation::inverse() { + unsigned oldDomain = getNumDomainDims(); + unsigned oldRange = getNumRangeDims(); + // Add new range ids + appendRangeId(oldDomain); + // Swap new ids with domain + for (unsigned i = 0; i < oldDomain; ++i) + swapId(i, oldDomain + oldRange + i); + // Remove the swapped domain + removeIdRange(0, oldDomain); + // Set domain and range as inverse + numDomainDims = oldRange; + numRangeDims = oldDomain; +} + +void FlatAffineRelation::appendDomainId(unsigned num) { + insertDimId(getNumDomainDims(), num); + numDomainDims += num; +} + +void FlatAffineRelation::appendRangeId(unsigned num) { + insertDimId(getNumDimIds(), num); + numRangeDims += num; +} + +/// Removes identifiers in the column range [idStart, idLimit), and copies any +/// remaining valid data into place, updates member variables, and resizes +/// arrays as needed. +void FlatAffineRelation::removeIdRange(unsigned idStart, unsigned idLimit) { + FlatAffineValueConstraints::removeIdRange(idStart, idLimit); + + if (idStart >= idLimit) + return; + + // domain and range dimensions to remove are computed by calculating + // intersection with of removed ids range + unsigned domainDimsToRemove = std::min(idLimit, getNumDomainDims()) - idStart; + unsigned rangeDimsToRemove = + std::min(idLimit, getNumDomainDims() + getNumRangeDims()) - + std::max(idStart, getNumDomainDims()); + + if (domainDimsToRemove > 0) + numDomainDims -= domainDimsToRemove; + if (rangeDimsToRemove > 0) + numRangeDims -= rangeDimsToRemove; +} diff --git a/mlir/test/Transforms/loop-fusion-2.mlir b/mlir/test/Transforms/loop-fusion-2.mlir --- a/mlir/test/Transforms/loop-fusion-2.mlir +++ b/mlir/test/Transforms/loop-fusion-2.mlir @@ -225,9 +225,6 @@ // The fused slice has 16 iterations from along %i0. -// CHECK-DAG: [[$MAP_LB:#map[0-9]+]] = affine_map<(d0) -> (d0 * 16)> -// CHECK-DAG: [[$MAP_UB:#map[0-9]+]] = affine_map<(d0) -> (d0 * 16 + 16)> - // CHECK-LABEL: slice_tile func @slice_tile(%arg0: memref<128x8xf32>, %arg1: memref<32x8xf32>, %0 : f32) -> memref<32x8xf32> { affine.for %i0 = 0 to 32 { @@ -252,9 +249,9 @@ } return %arg1 : memref<32x8xf32> } -// CHECK: affine.for %{{.*}} = 0 to 2 { -// CHECK-NEXT: affine.for %{{.*}} = 0 to 8 { -// CHECK-NEXT: affine.for %{{.*}} = [[$MAP_LB]](%{{.*}}) to [[$MAP_UB]](%{{.*}}) { +// CHECK: affine.for %{{.*}} = 0 to 8 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 2 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 32 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<32x8xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 {