diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -29,16 +29,20 @@ // General utilities //===----------------------------------------------------------------------===// +/// Check if `permutation` is a permutation of the range +/// `[0, permutation.size())`. +bool isPermutation(ArrayRef permutation); + /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector /// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. template void applyPermutationToVector(SmallVector &inVec, - ArrayRef permutation) { + ArrayRef permutation) { SmallVector auxVec(inVec.size()); - for (unsigned i = 0; i < permutation.size(); ++i) - auxVec[i] = inVec[permutation[i]]; + for (auto en : enumerate(permutation)) + auxVec[en.index()] = inVec[en.value()]; inVec = auxVec; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -367,6 +367,8 @@ ArrayRef tileInterchange) { assert(tileSizes.size() == tileInterchange.size() && "expect the number of tile sizes and interchange dims to match"); + assert(isPermutation(tileInterchange) && + "expect tile interchange is a permutation"); // Create an empty tile loop nest. TileLoopNest tileLoopNest(consumerOp); @@ -375,9 +377,7 @@ // inner reduction dimensions. SmallVector iterTypes = llvm::to_vector<6>(consumerOp.iterator_types().getAsRange()); - applyPermutationToVector( - iterTypes, - SmallVector(tileInterchange.begin(), tileInterchange.end())); + applyPermutationToVector(iterTypes, tileInterchange); auto *it = find_if(iterTypes, [&](StringAttr iterType) { return !isParallelIterator(iterType); }); @@ -459,14 +459,10 @@ tileInterchange.begin() + rootOp.getNumLoops()); - // As a tiling can only tile a loop dimension once, `rootInterchange` has to - // be a permutation of the `rootOp` loop dimensions. - SmallVector rootInterchangeExprs; - transform(rootInterchange, std::back_inserter(rootInterchangeExprs), - [&](int64_t dim) { return b.getAffineDimExpr(dim); }); - AffineMap rootInterchangeMap = AffineMap::get( - rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext()); - if (!rootInterchangeMap.isPermutation()) + // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. + // It has to be a permutation since the tiling cannot tile the same loop + // dimension multiple times. + if (!isPermutation(rootInterchange)) return notifyFailure( "expect the tile interchange permutes the root loops"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -69,7 +69,9 @@ ArrayRef itTypes = genericOp.iterator_types().getValue(); SmallVector itTypesVector; llvm::append_range(itTypesVector, itTypes); - applyPermutationToVector(itTypesVector, interchangeVector); + SmallVector permutation(interchangeVector.begin(), + interchangeVector.end()); + applyPermutationToVector(itTypesVector, permutation); genericOp->setAttr(getIteratorTypesAttrName(), ArrayAttr::get(context, itTypesVector)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -206,8 +206,10 @@ invPermutationMap = inversePermutation( AffineMap::getPermutationMap(interchangeVector, b.getContext())); assert(invPermutationMap); - applyPermutationToVector(loopRanges, interchangeVector); - applyPermutationToVector(iteratorTypes, interchangeVector); + SmallVector permutation(interchangeVector.begin(), + interchangeVector.end()); + applyPermutationToVector(loopRanges, permutation); + applyPermutationToVector(iteratorTypes, permutation); } // 2. Create the tiled loops. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -138,6 +138,19 @@ namespace mlir { namespace linalg { +bool isPermutation(ArrayRef permutation) { + // Count the number of appearances for all indices. + SmallVector indexCounts(permutation.size(), 0); + for (auto index : permutation) { + // Exit if the index is out-of-range. + if (index < 0 || index >= static_cast(permutation.size())) + return false; + indexCounts[index]++; + } + // Return true if all indices appear once. + return count(indexCounts, 1) == static_cast(permutation.size()); +} + /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {