diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -578,6 +578,13 @@ // 2. Transpose the tile to match the inner tile order. SmallVector perm = getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos()); + // The permutation is inverted when normalizing so invert back to match the + // ordering in the pack op. + perm = invertPermutationVector(perm); + + LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; + llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); + SmallVector transpShape = readShape; applyPermutationToVector(transpShape, perm); diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir @@ -58,3 +58,21 @@ // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] // CHECK: return %[[INSERT]] + +// ----- + +func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x7x3xf32>) -> tensor<1x1x1x5x7x3xf32> { + %0 = tensor.pack %arg0 inner_dims_pos = [1, 2, 0] inner_tiles = [5, 7, 3] into %arg1 : tensor<3x5x7xf32> -> tensor<1x1x1x5x7x3xf32> + return %0 : tensor<1x1x1x5x7x3xf32> +} +// CHECK-LABEL: func.func @simple_CHW_to_CHWhwc +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<5x7x3xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<3x5x7xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<5x7x3xf32>) +// CHECK-SAME: permutation = [1, 2, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1] +// CHECK: return %[[INSERT]]