diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -29,16 +29,20 @@ // General utilities //===----------------------------------------------------------------------===// +/// Check if `permutation` is a permutation of the range +/// `[0, permutation.size())`. +bool isPermutation(ArrayRef permutation, MLIRContext *context); + /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector /// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. template void applyPermutationToVector(SmallVector &inVec, - ArrayRef permutation) { + ArrayRef permutation) { SmallVector auxVec(inVec.size()); - for (unsigned i = 0; i < permutation.size(); ++i) - auxVec[i] = inVec[permutation[i]]; + for (auto en : enumerate(permutation)) + auxVec[en.index()] = inVec[en.value()]; inVec = auxVec; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -363,6 +363,8 @@ ArrayRef tileInterchange) { assert(tileSizes.size() == tileInterchange.size() && "expect the number of tile sizes and interchange dims to match"); + assert(isPermutation(tileInterchange, b.getContext()) && + "expect tile interchange is a permutation"); // Create an empty tile loop nest. TileLoopNest tileLoopNest(consumerOp); @@ -371,9 +373,7 @@ // inner reduction dimensions. SmallVector iterTypes = llvm::to_vector<6>(consumerOp.iterator_types().getAsRange()); - applyPermutationToVector( - iterTypes, - SmallVector(tileInterchange.begin(), tileInterchange.end())); + applyPermutationToVector(iterTypes, tileInterchange); auto *it = find_if(iterTypes, [&](StringAttr iterType) { return !isParallelIterator(iterType); }); @@ -455,14 +455,10 @@ tileInterchange.begin() + rootOp.getNumLoops()); - // As a tiling can only tile a loop dimension once, `rootInterchange` has to - // be a permutation of the `rootOp` loop dimensions. - SmallVector rootInterchangeExprs; - transform(rootInterchange, std::back_inserter(rootInterchangeExprs), - [&](int64_t dim) { return b.getAffineDimExpr(dim); }); - AffineMap rootInterchangeMap = AffineMap::get( - rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext()); - if (!rootInterchangeMap.isPermutation()) + // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. + // It has to be a permutation since the tiling cannot tile the same loop + // dimension multiple times. + if (!isPermutation(rootInterchange, funcOp.getContext())) return notifyFailure( "expect the tile interchange permutes the root loops"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -69,7 +69,9 @@ ArrayRef itTypes = genericOp.iterator_types().getValue(); SmallVector itTypesVector; llvm::append_range(itTypesVector, itTypes); - applyPermutationToVector(itTypesVector, interchangeVector); + SmallVector permutation(interchangeVector.begin(), + interchangeVector.end()); + applyPermutationToVector(itTypesVector, permutation); genericOp->setAttr(getIteratorTypesAttrName(), ArrayAttr::get(context, itTypesVector)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -205,9 +205,10 @@ // `inversePermutation` must succeed. invPermutationMap = inversePermutation( AffineMap::getPermutationMap(interchangeVector, b.getContext())); - assert(invPermutationMap); - applyPermutationToVector(loopRanges, interchangeVector); - applyPermutationToVector(iteratorTypes, interchangeVector); + SmallVector permutation(interchangeVector.begin(), + interchangeVector.end()); + applyPermutationToVector(loopRanges, permutation); + applyPermutationToVector(iteratorTypes, permutation); } // 2. Create the tiled loops. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -138,6 +138,15 @@ namespace mlir { namespace linalg { +bool isPermutation(ArrayRef permutation, MLIRContext *context) { + SmallVector permutationExprs; + transform(permutation, std::back_inserter(permutationExprs), + [&](int64_t dim) { return getAffineDimExpr(dim, context); }); + AffineMap permutationMap = + AffineMap::get(permutation.size(), 0, permutationExprs, context); + return permutationMap.isPermutation(); +} + /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {