diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -403,6 +403,9 @@ /// Simplifies an affine map by simplifying its underlying AffineExpr results. AffineMap simplifyAffineMap(AffineMap map); +/// Drop the dims that are listed in `unusedDims`. +AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims); + /// Drop the dims that are not used. AffineMap compressUnusedDims(AffineMap map); @@ -411,9 +414,6 @@ /// dims and symbols. SmallVector compressUnusedDims(ArrayRef maps); -/// Drop the dims that are not listed in `unusedDims`. -AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims); - /// Drop the symbols that are not used. AffineMap compressUnusedSymbols(AffineMap map); @@ -469,7 +469,7 @@ /// Return the reverse map of a projected permutation where the projected /// dimensions are transformed into 0s. /// -/// Prerequisites: `map` must be a projected permuation. +/// Prerequisites: `map` must be a projected permutation. /// /// Example 1: /// @@ -559,9 +559,21 @@ /// projected_dimensions : {1} /// result : affine_map<(d0, d1) -> (d0, 0)> /// -/// This function also compresses unused symbols away. +/// This function also compresses the dims when the boolean flag is true. +AffineMap projectDims(AffineMap map, + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag = false); +/// Symbol counterpart of `projectDims`. +/// This function also compresses the symbols when the boolean flag is true. +AffineMap projectSymbols(AffineMap map, + const llvm::SmallBitVector &projectedSymbols, + bool compressSymbolsFlag = false); +/// Calls `projectDims(map, projectedDimensions, compressDimsFlag)`. +/// If `compressSymbolsFlag` is true, additionally call `compressUnusedSymbols`. AffineMap getProjectedMap(AffineMap map, - const llvm::SmallBitVector &projectedDimensions); + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag = true, + bool compressSymbolsFlag = true); /// Apply a permutation from `map` to `source` and return the result. template @@ -610,6 +622,10 @@ // by any of the maps in the input array `maps`. llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef maps); +// Return a bitvector where each bit set indicates a symbol that is not used +// by any of the maps in the input array `maps`. +llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef maps); + } // namespace mlir namespace llvm { diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -18,6 +18,7 @@ #include "llvm/Support/raw_ostream.h" #include #include +#include using namespace mlir; @@ -569,32 +570,9 @@ return getSliceMap(getNumResults() - numResults, numResults); } -AffineMap mlir::compressDims(AffineMap map, - const llvm::SmallBitVector &unusedDims) { - unsigned numDims = 0; - SmallVector dimReplacements; - dimReplacements.reserve(map.getNumDims()); - MLIRContext *context = map.getContext(); - for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) { - if (unusedDims.test(dim)) - dimReplacements.push_back(getAffineConstantExpr(0, context)); - else - dimReplacements.push_back(getAffineDimExpr(numDims++, context)); - } - SmallVector resultExprs; - resultExprs.reserve(map.getNumResults()); - for (auto e : map.getResults()) - resultExprs.push_back(e.replaceDims(dimReplacements)); - return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context); -} - -AffineMap mlir::compressUnusedDims(AffineMap map) { - return compressDims(map, getUnusedDimsBitVector({map})); -} - -static SmallVector -compressUnusedImpl(ArrayRef maps, - llvm::function_ref compressionFun) { +static SmallVector compressUnusedListImpl( + ArrayRef maps, + llvm::function_ref compressionFun) { if (maps.empty()) return SmallVector(); SmallVector allExprs; @@ -622,41 +600,31 @@ return res; } +AffineMap mlir::compressDims(AffineMap map, + const llvm::SmallBitVector &unusedDims) { + return projectDims(map, unusedDims, /*compressDimsFlag=*/true); +} + +AffineMap mlir::compressUnusedDims(AffineMap map) { + return compressDims(map, getUnusedDimsBitVector({map})); +} + SmallVector mlir::compressUnusedDims(ArrayRef maps) { - return compressUnusedImpl(maps, - [](AffineMap m) { return compressUnusedDims(m); }); + return compressUnusedListImpl( + maps, [](AffineMap m) { return compressUnusedDims(m); }); } AffineMap mlir::compressSymbols(AffineMap map, const llvm::SmallBitVector &unusedSymbols) { - unsigned numSymbols = 0; - SmallVector symReplacements; - symReplacements.reserve(map.getNumSymbols()); - MLIRContext *context = map.getContext(); - for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) { - if (unusedSymbols.test(sym)) - symReplacements.push_back(getAffineConstantExpr(0, context)); - else - symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context)); - } - SmallVector resultExprs; - resultExprs.reserve(map.getNumResults()); - for (auto e : map.getResults()) - resultExprs.push_back(e.replaceSymbols(symReplacements)); - return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context); + return projectSymbols(map, unusedSymbols, /*compressSymbolsFlag=*/true); } AffineMap mlir::compressUnusedSymbols(AffineMap map) { - llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true); - map.walkExprs([&](AffineExpr expr) { - if (auto symExpr = expr.dyn_cast()) - unusedSymbols.reset(symExpr.getPosition()); - }); - return compressSymbols(map, unusedSymbols); + return compressSymbols(map, getUnusedSymbolsBitVector({map})); } SmallVector mlir::compressUnusedSymbols(ArrayRef maps) { - return compressUnusedImpl( + return compressUnusedListImpl( maps, [](AffineMap m) { return compressUnusedSymbols(m); }); } @@ -741,9 +709,70 @@ maps.front().getContext()); } +template +static AffineMap projectCommonImpl(AffineMap map, + const llvm::SmallBitVector &toProject, + bool compress) { + static_assert(std::is_same::value || + std::is_same::value, + "expected AffineDimExpr or AffineSymbolExpr"); + constexpr bool is_dim = + std::is_same::value; + int64_t numDimOrSym = (is_dim) ? map.getNumDims() : map.getNumSymbols(); + int64_t newNumDimOrSym = 0; + SmallVector replacements; + replacements.reserve(numDimOrSym); + + auto createNewDimOrSym = (is_dim) ? getAffineDimExpr : getAffineSymbolExpr; + auto replaceDims = [](AffineExpr e, ArrayRef replacements) { + return e.replaceDims(replacements); + }; + auto replaceSymbols = [](AffineExpr e, ArrayRef replacements) { + return e.replaceSymbols(replacements); + }; + auto replaceNewDimOrSym = (is_dim) ? replaceDims : replaceSymbols; + + MLIRContext *context = map.getContext(); + for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) { + if (toProject.test(dimOrSym)) { + replacements.push_back(getAffineConstantExpr(0, context)); + continue; + } + replacements.push_back(createNewDimOrSym(newNumDimOrSym++, context)); + } + SmallVector resultExprs; + resultExprs.reserve(map.getNumResults()); + for (auto e : map.getResults()) + resultExprs.push_back(replaceNewDimOrSym(e, replacements)); + + int64_t numDims = (compress && is_dim) ? newNumDimOrSym : map.getNumDims(); + int64_t numSyms = + (compress && !is_dim) ? newNumDimOrSym : map.getNumSymbols(); + return AffineMap::get(numDims, numSyms, resultExprs, context); +} + +AffineMap mlir::projectDims(AffineMap map, + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag) { + return projectCommonImpl(map, projectedDimensions, + compressDimsFlag); +} + +AffineMap mlir::projectSymbols(AffineMap map, + const llvm::SmallBitVector &projectedSymbols, + bool compressSymbolsFlag) { + return projectCommonImpl(map, projectedSymbols, + compressSymbolsFlag); +} + AffineMap mlir::getProjectedMap(AffineMap map, - const llvm::SmallBitVector &unusedDims) { - return compressUnusedSymbols(compressDims(map, unusedDims)); + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag, + bool compressSymbolsFlag) { + map = projectDims(map, projectedDimensions, compressDimsFlag); + if (compressSymbolsFlag) + map = compressUnusedSymbols(map); + return map; } llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef maps) { @@ -758,6 +787,18 @@ return numDimsBitVector; } +llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef maps) { + unsigned numSymbols = maps[0].getNumSymbols(); + llvm::SmallBitVector numSymbolsBitVector(numSymbols, true); + for (const auto &m : maps) { + for (unsigned i = 0; i < numSymbols; ++i) { + if (m.isFunctionOfSymbol(i)) + numSymbolsBitVector.reset(i); + } + } + return numSymbolsBitVector; +} + //===----------------------------------------------------------------------===// // MutableAffineMap. //===----------------------------------------------------------------------===// @@ -784,8 +825,8 @@ return false; } -// Simplifies the result affine expressions of this map. The expressions have to -// be pure for the simplification implemented. +// Simplifies the result affine expressions of this map. The expressions +// have to be pure for the simplification implemented. void MutableAffineMap::simplify() { // Simplify each of the results if possible. // TODO: functional-style map