diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -694,6 +694,31 @@ std::vector> *flattenedExprs, FlatAffineConstraints *cst = nullptr); +/// Re-indexes the dimensions and symbols of an affine map with given `operands` +/// values to align with `dims` and `syms` values. +/// +/// Each dimension/symbol of the map, bound to an operand `o`, is replaced with +/// dimension `i`, where `i` is the position of `o` within `dims`. If `o` is not +/// in `dims`, replace it with symbol `i`, where `i` is the position of `o` +/// within `syms`. If `o` is not in `syms` either, replace it with a new symbol. +/// +/// Note: If a value appears multiple times as a dimension/symbol (or both), all +/// corresponding dim/sym expressions are replaced with the first dimension +/// bound to that value (or first symbol if no such dimension exists). +/// +/// The resulting affine map has `dims.size()` many dimensions and at least +/// `syms.size()` many symbols. +/// +/// The SSA values of the symbols of the resulting map are optionally returned +/// via `newSyms`. This is a concatenation of `syms` with the SSA values of the +/// newly added symbols. +/// +/// Note: As part of this re-indexing, dimensions may turn into symbols, or vice +/// versa. +AffineMap alignAffineMapWithValues(AffineMap map, ValueRange operands, + ValueRange dims, ValueRange syms, + SmallVector *newSyms = nullptr); + } // end namespace mlir. #endif // MLIR_ANALYSIS_AFFINESTRUCTURES_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -3274,3 +3274,48 @@ for (auto nbIndex : llvm::reverse(nbEqIndices)) removeEquality(nbIndex); } + +AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands, + ValueRange dims, ValueRange syms, + SmallVector *newSyms) { + assert(operands.size() == map.getNumInputs() && + "expected same number of operands and map inputs"); + MLIRContext *ctx = map.getContext(); + Builder builder(ctx); + SmallVector dimReplacements(map.getNumDims(), {}); + unsigned numSymbols = syms.size(); + SmallVector symReplacements(map.getNumSymbols(), {}); + if (newSyms) { + newSyms->clear(); + newSyms->append(syms.begin(), syms.end()); + } + + for (auto operand : llvm::enumerate(operands)) { + // Compute replacement dim/sym of operand. + AffineExpr replacement; + auto dimIt = std::find(dims.begin(), dims.end(), operand.value()); + auto symIt = std::find(syms.begin(), syms.end(), operand.value()); + if (dimIt != dims.end()) { + replacement = + builder.getAffineDimExpr(std::distance(dims.begin(), dimIt)); + } else if (symIt != syms.end()) { + replacement = + builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt)); + } else { + // This operand is neither a dimension nor a symbol. Add it as a new + // symbol. + replacement = builder.getAffineSymbolExpr(numSymbols++); + if (newSyms) + newSyms->push_back(operand.value()); + } + // Add to corresponding replacements vector. + if (operand.index() < map.getNumDims()) { + dimReplacements[operand.index()] = replacement; + } else { + symReplacements[operand.index() - map.getNumDims()] = replacement; + } + } + + return map.replaceDimsAndSymbols(dimReplacements, symReplacements, + dims.size(), numSymbols); +}