diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1325,7 +1325,14 @@ /*source=*/emptyOp, /*indices=*/SmallVector(rank, zero), /*inBounds=*/SmallVector(rank, true)); - write = mlir::vector::maskOperation(rewriter, write, mask); + bool needMaskForWrite = llvm::any_of( + llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()), + [](auto it) { return std::get<0>(it) != std::get<1>(it); }); + if (needMaskForWrite) { + Value maskForWrite = rewriter.create( + loc, maskType, reifiedReturnShapes[0]); + write = mlir::vector::maskOperation(rewriter, write, maskForWrite); + } newResults.push_back(write->getResult(0)); return success(); } diff --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir --- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir @@ -405,14 +405,17 @@ // ----- -// CHECK-LABEL: func @test_masked_vectorize_dynamic_pad +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s1 + s0)> +// CHECK: func @test_masked_vectorize_dynamic_pad func.func @test_masked_vectorize_dynamic_pad( %0 : tensor, %h0 : index, %h1 : index) -> tensor { // CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index - // CHECK: %[[empty:.*]] = tensor.empty({{.+}}) : tensor + // CHECK-DAG: %[[res_d0:.+]] = affine.apply #[[MAP]]() + // CHECK-DAG: %[[res_d1:.+]] = affine.apply #[[MAP]]() + // CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> @@ -421,7 +424,8 @@ // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]] // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> - // CHECK: %[[masked_write:.*]] = vector.mask %[[mask]] { + // CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1> + // CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] { // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]] // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor // CHECK: return %[[masked_write]] : tensor