diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp @@ -59,6 +59,11 @@ if (!unpackOp) return failure(); + if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { + return rewriter.notifyMatchFailure( + sliceOp, "rank-reduced folding is not supported"); + } + // Check all offsets are zeros, and all strides are ones. if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir --- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir +++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir @@ -45,6 +45,19 @@ // ----- +func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor, %arg1 : tensor, + %arg2 : index, %arg3 : index) -> tensor { + %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 + : tensor -> tensor + %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor to tensor + return %1 : tensor +} +// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced( +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK: tensor.extract_slice %[[UNPACK]] + +// ----- + func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32