diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -403,6 +403,9 @@ /// Simplifies an affine map by simplifying its underlying AffineExpr results. AffineMap simplifyAffineMap(AffineMap map); +/// Drop the dims that are listed in `unusedDims`. +AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims); + /// Drop the dims that are not used. AffineMap compressUnusedDims(AffineMap map); @@ -411,8 +414,9 @@ /// dims and symbols. SmallVector compressUnusedDims(ArrayRef maps); -/// Drop the dims that are not listed in `unusedDims`. -AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims); +/// Drop the symbols that are listed in `unusedSymbols`. +AffineMap compressSymbols(AffineMap map, + const llvm::SmallBitVector &unusedSymbols); /// Drop the symbols that are not used. AffineMap compressUnusedSymbols(AffineMap map); @@ -422,10 +426,6 @@ /// dims and symbols. SmallVector compressUnusedSymbols(ArrayRef maps); -/// Drop the symbols that are not listed in `unusedSymbols`. -AffineMap compressSymbols(AffineMap map, - const llvm::SmallBitVector &unusedSymbols); - /// Returns a map with the same dimension and symbol count as `map`, but whose /// results are the unique affine expressions of `map`. AffineMap removeDuplicateExprs(AffineMap map); @@ -469,7 +469,7 @@ /// Return the reverse map of a projected permutation where the projected /// dimensions are transformed into 0s. /// -/// Prerequisites: `map` must be a projected permuation. +/// Prerequisites: `map` must be a projected permutation. /// /// Example 1: /// @@ -559,9 +559,38 @@ /// projected_dimensions : {1} /// result : affine_map<(d0, d1) -> (d0, 0)> /// -/// This function also compresses unused symbols away. +/// This function also compresses the dims when the boolean flag is true. +AffineMap projectDims(AffineMap map, + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag = false); +/// Symbol counterpart of `projectDims`. +/// This function also compresses the symbols when the boolean flag is true. +AffineMap projectSymbols(AffineMap map, + const llvm::SmallBitVector &projectedSymbols, + bool compressSymbolsFlag = false); +/// Calls `projectDims(map, projectedDimensions, compressDimsFlag)`. +/// If `compressSymbolsFlag` is true, additionally call `compressUnusedSymbols`. AffineMap getProjectedMap(AffineMap map, - const llvm::SmallBitVector &projectedDimensions); + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag = true, + bool compressSymbolsFlag = true); + +// Return a bitvector where each bit set indicates a dimension that is not used +// by any of the maps in the input array `maps`. +llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef maps); + +// Return a bitvector where each bit set indicates a symbol that is not used +// by any of the maps in the input array `maps`. +llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef maps); + +inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { + map.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// Templated helper functions. +//===----------------------------------------------------------------------===// /// Apply a permutation from `map` to `source` and return the result. template @@ -584,7 +613,7 @@ return result; } -/// Calculates maxmimum dimension and symbol positions from the expressions +/// Calculates maximum dimension and symbol positions from the expressions /// in `exprsLists` and stores them in `maxDim` and `maxSym` respectively. template static void getMaxDimAndSymbol(ArrayRef exprsList, @@ -601,15 +630,6 @@ } } -inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { - map.print(os); - return os; -} - -// Return a bitvector where each bit set indicates a dimension that is not used -// by any of the maps in the input array `maps`. -llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef maps); - } // namespace mlir namespace llvm { diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -12,12 +12,14 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include #include +#include using namespace mlir; @@ -569,32 +571,13 @@ return getSliceMap(getNumResults() - numResults, numResults); } -AffineMap mlir::compressDims(AffineMap map, - const llvm::SmallBitVector &unusedDims) { - unsigned numDims = 0; - SmallVector dimReplacements; - dimReplacements.reserve(map.getNumDims()); - MLIRContext *context = map.getContext(); - for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) { - if (unusedDims.test(dim)) - dimReplacements.push_back(getAffineConstantExpr(0, context)); - else - dimReplacements.push_back(getAffineDimExpr(numDims++, context)); - } - SmallVector resultExprs; - resultExprs.reserve(map.getNumResults()); - for (auto e : map.getResults()) - resultExprs.push_back(e.replaceDims(dimReplacements)); - return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context); -} - -AffineMap mlir::compressUnusedDims(AffineMap map) { - return compressDims(map, getUnusedDimsBitVector({map})); -} - -static SmallVector -compressUnusedImpl(ArrayRef maps, - llvm::function_ref compressionFun) { +/// Implementation detail to compress multiple affine maps with a compressionFun +/// that is expected to be either compressUnusedDims or compressUnusedSymbols. +/// The implementation keeps track of num dims and symbols across the different +/// affine maps. +static SmallVector compressUnusedListImpl( + ArrayRef maps, + llvm::function_ref compressionFun) { if (maps.empty()) return SmallVector(); SmallVector allExprs; @@ -622,41 +605,31 @@ return res; } +AffineMap mlir::compressDims(AffineMap map, + const llvm::SmallBitVector &unusedDims) { + return projectDims(map, unusedDims, /*compressDimsFlag=*/true); +} + +AffineMap mlir::compressUnusedDims(AffineMap map) { + return compressDims(map, getUnusedDimsBitVector({map})); +} + SmallVector mlir::compressUnusedDims(ArrayRef maps) { - return compressUnusedImpl(maps, - [](AffineMap m) { return compressUnusedDims(m); }); + return compressUnusedListImpl( + maps, [](AffineMap m) { return compressUnusedDims(m); }); } AffineMap mlir::compressSymbols(AffineMap map, const llvm::SmallBitVector &unusedSymbols) { - unsigned numSymbols = 0; - SmallVector symReplacements; - symReplacements.reserve(map.getNumSymbols()); - MLIRContext *context = map.getContext(); - for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) { - if (unusedSymbols.test(sym)) - symReplacements.push_back(getAffineConstantExpr(0, context)); - else - symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context)); - } - SmallVector resultExprs; - resultExprs.reserve(map.getNumResults()); - for (auto e : map.getResults()) - resultExprs.push_back(e.replaceSymbols(symReplacements)); - return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context); + return projectSymbols(map, unusedSymbols, /*compressSymbolsFlag=*/true); } AffineMap mlir::compressUnusedSymbols(AffineMap map) { - llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true); - map.walkExprs([&](AffineExpr expr) { - if (auto symExpr = expr.dyn_cast()) - unusedSymbols.reset(symExpr.getPosition()); - }); - return compressSymbols(map, unusedSymbols); + return compressSymbols(map, getUnusedSymbolsBitVector({map})); } SmallVector mlir::compressUnusedSymbols(ArrayRef maps) { - return compressUnusedImpl( + return compressUnusedListImpl( maps, [](AffineMap m) { return compressUnusedSymbols(m); }); } @@ -741,15 +714,80 @@ maps.front().getContext()); } +/// Common implementation to project out dimensions or symbols from an affine +/// map based on the template type. +/// Additionally, if 'compress' is true, the projected out dimensions or symbols +/// are also dropped from the resulting map. +template +static AffineMap projectCommonImpl(AffineMap map, + const llvm::SmallBitVector &toProject, + bool compress) { + static_assert(llvm::is_one_of::value, + "expected AffineDimExpr or AffineSymbolExpr"); + + constexpr bool isDim = std::is_same::value; + int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols(); + SmallVector replacements; + replacements.reserve(numDimOrSym); + + auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr; + auto replaceDims = [](AffineExpr e, ArrayRef replacements) { + return e.replaceDims(replacements); + }; + auto replaceSymbols = [](AffineExpr e, ArrayRef replacements) { + return e.replaceSymbols(replacements); + }; + auto replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols; + + MLIRContext *context = map.getContext(); + int64_t newNumDimOrSym = 0; + for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) { + if (toProject.test(dimOrSym)) { + replacements.push_back(getAffineConstantExpr(0, context)); + continue; + } + int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym; + replacements.push_back(createNewDimOrSym(newPos, context)); + } + SmallVector resultExprs; + resultExprs.reserve(map.getNumResults()); + for (auto e : map.getResults()) + resultExprs.push_back(replaceNewDimOrSym(e, replacements)); + + int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims(); + int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols(); + return AffineMap::get(numDims, numSyms, resultExprs, context); +} + +AffineMap mlir::projectDims(AffineMap map, + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag) { + return projectCommonImpl(map, projectedDimensions, + compressDimsFlag); +} + +AffineMap mlir::projectSymbols(AffineMap map, + const llvm::SmallBitVector &projectedSymbols, + bool compressSymbolsFlag) { + return projectCommonImpl(map, projectedSymbols, + compressSymbolsFlag); +} + AffineMap mlir::getProjectedMap(AffineMap map, - const llvm::SmallBitVector &unusedDims) { - return compressUnusedSymbols(compressDims(map, unusedDims)); + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag, + bool compressSymbolsFlag) { + map = projectDims(map, projectedDimensions, compressDimsFlag); + if (compressSymbolsFlag) + map = compressUnusedSymbols(map); + return map; } llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef maps) { unsigned numDims = maps[0].getNumDims(); llvm::SmallBitVector numDimsBitVector(numDims, true); - for (const auto &m : maps) { + for (AffineMap m : maps) { for (unsigned i = 0; i < numDims; ++i) { if (m.isFunctionOfDim(i)) numDimsBitVector.reset(i); @@ -758,6 +796,18 @@ return numDimsBitVector; } +llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef maps) { + unsigned numSymbols = maps[0].getNumSymbols(); + llvm::SmallBitVector numSymbolsBitVector(numSymbols, true); + for (AffineMap m : maps) { + for (unsigned i = 0; i < numSymbols; ++i) { + if (m.isFunctionOfSymbol(i)) + numSymbolsBitVector.reset(i); + } + } + return numSymbolsBitVector; +} + //===----------------------------------------------------------------------===// // MutableAffineMap. //===----------------------------------------------------------------------===// @@ -784,8 +834,8 @@ return false; } -// Simplifies the result affine expressions of this map. The expressions have to -// be pure for the simplification implemented. +// Simplifies the result affine expressions of this map. The expressions +// have to be pure for the simplification implemented. void MutableAffineMap::simplify() { // Simplify each of the results if possible. // TODO: functional-style map