diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -106,7 +106,7 @@ rewriter.create( writeOp.getLoc(), writeOp.getVector(), *resultBuffer, writeOp.getIndices(), writeOp.getPermutationMapAttr(), - writeOp.getInBoundsAttr()); + writeOp.getMask(), writeOp.getInBoundsAttr()); replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); return success(); diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir --- a/mlir/test/Dialect/Vector/bufferize.mlir +++ b/mlir/test/Dialect/Vector/bufferize.mlir @@ -15,16 +15,17 @@ // ----- // CHECK-LABEL: func @transfer_write( -// CHECK-SAME: %[[t:.*]]: tensor, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>) +// CHECK-SAME: %[[t:.*]]: tensor, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>, %[[mask:.*]]: vector<5x6xi1>) // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref // CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}, %{{.*}}) {{.*}} : memref // CHECK: memref.copy %[[m]], %[[alloc]] -// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]] {in_bounds = [true, false]} : vector<5x6xf32>, memref +// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]], %[[mask]] {in_bounds = [true, false]} : vector<5x6xf32>, memref // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] : memref // CHECK: return %[[r]] func.func @transfer_write(%t: tensor, %o1: index, - %o2: index, %vec: vector<5x6xf32>) -> tensor { - %0 = vector.transfer_write %vec, %t[%o1, %o2] {in_bounds = [true, false]} + %o2: index, %vec: vector<5x6xf32>, + %mask: vector<5x6xi1>) -> tensor { + %0 = vector.transfer_write %vec, %t[%o1, %o2], %mask {in_bounds = [true, false]} : vector<5x6xf32>, tensor return %0 : tensor }