diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -350,16 +350,15 @@ /// and with the dimensions set to the equalities specified by the value map. /// Returns failure if the composition fails (when vMap is a semi-affine map). /// The vMap's operand Value's are used to look up the right positions in - /// the FlatAffineConstraints with which to associate. The dimensional and - /// symbolic operands of vMap should match 1:1 (in the same order) with those - /// of this constraint system, but the latter could have additional trailing - /// operands. + /// the FlatAffineConstraints with which to associate. Every operand of vMap + /// should have a matching dim/symbol column in this constraint system (with + /// the same associated Value). LogicalResult composeMap(const AffineValueMap *vMap); - /// Composes an affine map whose dimensions match one to one to the - /// dimensions of this FlatAffineConstraints. The results of the map 'other' - /// are added as the leading dimensions of this constraint system. Returns - /// failure if 'other' is a semi-affine map. + /// Composes an affine map whose dimensions and symbols match one to one with + /// the dimensions and symbols of this FlatAffineConstraints. The results of + /// the map `other` are added as the leading dimensions of this constraint + /// system. Returns failure if `other` is a semi-affine map. LogicalResult composeMatchingMap(AffineMap other); /// Projects out (aka eliminates) 'num' identifiers starting at position @@ -599,6 +598,10 @@ template Optional computeConstantLowerOrUpperBound(unsigned pos); + /// Align `map` with this constraint system based on `operands`. Each operand + /// must already have a corresponding dim/symbol in this constraint system. + AffineMap computeAlignedMap(AffineMap map, ValueRange operands) const; + // Eliminates a single identifier at 'position' from equality and inequality // constraints. Returns 'success' if the identifier was eliminated, and // 'failure' otherwise. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -387,81 +387,15 @@ mergeAndAlignIds(offset, this, other); } -// This routine may add additional local variables if the flattened expression -// corresponding to the map has such variables due to mod's, ceildiv's, and -// floordiv's in it. +/// Align the AffineValueMap with `this` system and call `composeMatchingMap`. LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) { - std::vector> flatExprs; - FlatAffineConstraints localCst; - if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, - &localCst))) { - LLVM_DEBUG(llvm::dbgs() - << "composition unimplemented for semi-affine maps\n"); - return failure(); - } - assert(flatExprs.size() == vMap->getNumResults()); - - // Add localCst information. - if (localCst.getNumLocalIds() > 0) { - localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(), - /*values=*/vMap->getOperands()); - // Align localCst and this. - mergeAndAlignIds(/*offset=*/0, &localCst, this); - // Finally, append localCst to this constraint set. - append(localCst); - } - - // Add dimensions corresponding to the map's results. - for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) { - // TODO: Consider using a batched version to add a range of IDs. - addDimId(0); - } - - // We add one equality for each result connecting the result dim of the map to - // the other identifiers. - // For eg: if the expression is 16*i0 + i1, and this is the r^th - // iteration/result of the value map, we are adding the equality: - // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we - // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. - for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { - const auto &flatExpr = flatExprs[r]; - assert(flatExpr.size() >= vMap->getNumOperands() + 1); - - // eqToAdd is the equality corresponding to the flattened affine expression. - SmallVector eqToAdd(getNumCols(), 0); - // Set the coefficient for this result to one. - eqToAdd[r] = 1; - - // Dims and symbols. - for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { - unsigned loc; - bool ret = findId(vMap->getOperand(i), &loc); - assert(ret && "value map's id can't be found"); - (void)ret; - // Negate 'eq[r]' since the newly added dimension will be set to this one. - eqToAdd[loc] = -flatExpr[i]; - } - // Local vars common to eq and localCst are at the beginning. - unsigned j = getNumDimIds() + getNumSymbolIds(); - unsigned end = flatExpr.size() - 1; - for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) { - eqToAdd[j] = -flatExpr[i]; - } - - // Constant term. - eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; - - // Add the equality connecting the result of the map to this constraint set. - addEquality(eqToAdd); - } - - return success(); + return composeMatchingMap( + computeAlignedMap(vMap->getAffineMap(), vMap->getOperands())); } -// Similar to composeMap except that no Value's need be associated with the -// constraint system nor are they looked at -- since the dimensions and -// symbols of 'other' are expected to correspond 1:1 to 'this' system. It -// is thus not convenient to share code with composeMap. +/// Similar to `composeMap` except that no Values need be associated with the +/// constraint system nor are they looked at -- the dimensions and symbols of +/// `other` are expected to correspond 1:1 to `this` system. LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) { assert(other.getNumDims() == getNumDimIds() && "dim mismatch"); assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); @@ -477,11 +411,15 @@ // Add localCst information. if (localCst.getNumLocalIds() > 0) { - // Place local id's of A after local id's of B. - for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) { + unsigned numLocalIds = getNumLocalIds(); + // Insert local dims of localCst at the beginning. + for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; ++l) addLocalId(0); - } - // Finally, append localCst to this constraint set. + // Insert local dims of `this` at the end of localCst. + for (unsigned l = 0; l < numLocalIds; ++l) + localCst.addLocalId(localCst.getNumLocalIds()); + // Dimensions of localCst and this constraint set match. Append localCst to + // this constraint set. append(localCst); } @@ -2001,19 +1939,9 @@ return success(); } -LogicalResult -FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, - ValueRange boundOperands, bool eq, - bool lower) { - // Fully compose map and operands; canonicalize and simplify so that we - // transitively get to terminal symbols or loop IVs. - auto map = boundMap; - SmallVector operands(boundOperands.begin(), boundOperands.end()); - fullyComposeAffineMapAndOperands(&map, &operands); - map = simplifyAffineMap(map); - canonicalizeMapAndOperands(&map, &operands); - for (auto operand : operands) - addInductionVarOrTerminalSymbol(operand); +AffineMap FlatAffineConstraints::computeAlignedMap(AffineMap map, + ValueRange operands) const { + assert(map.getNumInputs() == operands.size() && "number of inputs mismatch"); SmallVector dims, syms; #ifndef NDEBUG @@ -2036,8 +1964,23 @@ assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols"); assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) && "unexpected new/missing symbols"); + return alignedMap; +} - return addLowerOrUpperBound(pos, alignedMap, eq, lower); +LogicalResult +FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, + ValueRange boundOperands, bool eq, + bool lower) { + // Fully compose map and operands; canonicalize and simplify so that we + // transitively get to terminal symbols or loop IVs. + auto map = boundMap; + SmallVector operands(boundOperands.begin(), boundOperands.end()); + fullyComposeAffineMapAndOperands(&map, &operands); + map = simplifyAffineMap(map); + canonicalizeMapAndOperands(&map, &operands); + for (auto operand : operands) + addInductionVarOrTerminalSymbol(operand); + return addLowerOrUpperBound(pos, computeAlignedMap(map, operands), eq, lower); } // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper