diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1015,6 +1015,7 @@ /// (Implies that sizes of `insertOp` are all static.) /// - Only unit strides in `insertOp`. /// - Single, scalar padding value. +/// - `padOp` result not used as destination. struct PadTensorOpVectorizationWithInsertSlicePattern : public VectorizePadTensorOpUserPattern { using VectorizePadTensorOpUserPattern< @@ -1035,6 +1036,9 @@ // Dynamic shapes not supported. if (!padOp.result().getType().cast().hasStaticShape()) return failure(); + // Pad result not used as destination. + if (insertOp.dest() == padOp.result()) + return failure(); auto vecType = VectorType::get(padOp.getType().getShape(), padOp.getType().getElementType()); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -684,7 +684,7 @@ func private @make_vector() -> tensor<12x13xf32> -// CHECK-LABEL: func @pad_and_insert_slice +// CHECK-LABEL: func @pad_and_insert_slice_source // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK-NOT: linalg.pad_tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -693,7 +693,7 @@ // CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> // CHECK: return %[[WRITE]] -func @pad_and_insert_slice( +func @pad_and_insert_slice_source( %arg0: tensor<5x6xf32>) -> tensor<12x13xf32> { %c0 = arith.constant 0 : index %c5 = arith.constant 5.0 : f32 @@ -708,6 +708,26 @@ // ----- +func private @make_vector() -> tensor<12x13xf32> + +// CHECK-LABEL: func @pad_and_insert_slice_dest +// Check the insert slice is not rewritten if the padded result is used by the destination operand. +// CHECK: %[[T1:.*]] = call @make_vector() : () -> tensor<12x13xf32> +// CHECK: = tensor.insert_slice %[[T1]] into +func @pad_and_insert_slice_dest( + %arg0: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> { + %c5 = arith.constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 7, 7] { + ^bb0(%arg2: index, %arg3: index, %arg4: index): + linalg.yield %c5 : f32 + } : tensor<1x5x6xf32> to tensor<1x12x13xf32> + %1 = call @make_vector() : () -> tensor<12x13xf32> + %r = tensor.insert_slice %1 into %0[0, 0, 0][1, 12, 13][1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32> + return %r : tensor<1x12x13xf32> +} + +// ----- + // CHECK-LABEL: func @pad_tensor_non_const_pad_value // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK-NOT: linalg.pad_tensor