diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -76,6 +76,11 @@ if (!sliceOp.hasUnitStride()) return rewriter.notifyMatchFailure(sliceOp, "expected unit stride"); + if (sliceOp.getType().cast().getRank() != + linalgOp->getResult(0).getType().cast().getRank()) { + return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction"); + } + OpOperand *outOperand = linalgOp.getOutputOperand(0); AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand); if (!indexingMap.isProjectedPermutation()) { diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir --- a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir +++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir @@ -156,3 +156,22 @@ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32> // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32> // CHECK: return %[[CONV]] : tensor<1x32x32x16xf32> + +//----- + +// The slice is not supposed to be bubbled up when it is rank-reducing. +func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> { + %cst = arith.constant 1.000000e+00 : f32 + %init = linalg.init_tensor [1, %width] : tensor<1x?xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32> + %slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor + %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor into tensor<1x1x1x?xf32> + return %expand : tensor<1x1x1x?xf32> +} + +// CHECK: func @rank_reducing_slice +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[FILL:.+]] = linalg.fill ins +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] +// CHECK: return %[[EXPAND]]