diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -3200,19 +3201,6 @@ return success(); } -/// Returns a vector that interchanges `elements` starting at offset `offset` -/// based on the indexes in `interchangeVector`. -template -SmallVector interchange(ArrayRef elements, - ArrayRef interchangeVector, - int offset = 0) { - SmallVector vec = llvm::to_vector(elements); - for (auto en : llvm::enumerate(interchangeVector)) - vec[en.index() + offset] = elements[en.value() + offset]; - - return vec; -} - /// Get the expected packed type based on source type, tile factors, position of /// the inner tiles and permutation of the outer tiled loop. ShapedType PackOp::inferPackedType(ShapedType sourceType, @@ -3231,7 +3219,8 @@ innerTileSizes[tiledDim.index()]); } - resultShape = interchange(resultShape, outerDimsPerm); + if (!outerDimsPerm.empty()) + applyPermutationToVector(resultShape, outerDimsPerm); // Append the inner tile dimensions. resultShape.append(innerTileSizes.begin(), innerTileSizes.end());