diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -762,11 +762,11 @@ PatternRewriter &rewriter) const override { if (!sliceOp.source().getDefiningOp()) return failure(); + // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved + // as well as its result type. rewriter.replaceOpWithNewOp( sliceOp, sliceOp.sizes(), - llvm::to_vector<4>(llvm::map_range( - sliceOp.static_sizes(), - [](Attribute attr) { return attr.cast().getInt(); })), + sliceOp.result().getType().cast().getShape(), sliceOp.getSourceType().getElementType()); return success(); } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -890,3 +890,15 @@ return } + +// ----- + +// CHECK-LABEL: func @rank_reducing_init_extract +func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> { + // CHECK: linalg.init_tensor [2] : tensor<2xf32> + %a = linalg.init_tensor [%sz, 2] : tensor + + // CHECK-NOT: extract + %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> + return %r: tensor<2xf32> +}