diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -252,9 +252,9 @@ auto map = TransferReadOp::getTransferMinorIdentityMap( xferOp.getMemRefType(), minorVectorType); ArrayAttr masked; - if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { + if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { OpBuilder &b = ScopedContext::getBuilderRef(); - masked = b.getBoolArrayAttr({true}); + masked = b.getBoolArrayAttr({false}); } return vector_transfer_read(minorVectorType, memref, indexing, AffineMapAttr::get(map), xferOp.padding(), @@ -356,9 +356,9 @@ auto map = TransferWriteOp::getTransferMinorIdentityMap( xferOp.getMemRefType(), minorVectorType); ArrayAttr masked; - if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { + if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { OpBuilder &b = ScopedContext::getBuilderRef(); - masked = b.getBoolArrayAttr({true}); + masked = b.getBoolArrayAttr({false}); } vector_transfer_write(result, xferOp.memref(), indexing, AffineMapAttr::get(map), masked); diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir @@ -353,15 +353,15 @@ // FULL-UNROLL-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 1)> // FULL-UNROLL-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 2)> -// CHECK-LABEL: transfer_write_progressive_not_masked( +// CHECK-LABEL: transfer_write_progressive_unmasked( // CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref, // CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32> -// FULL-UNROLL-LABEL: transfer_write_progressive_not_masked( +// FULL-UNROLL-LABEL: transfer_write_progressive_unmasked( // FULL-UNROLL-SAME: %[[A:[a-zA-Z0-9]+]]: memref, // FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index, // FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32> -func @transfer_write_progressive_not_masked(%A : memref, %base: index, %vec: vector<3x15xf32>) { +func @transfer_write_progressive_unmasked(%A : memref, %base: index, %vec: vector<3x15xf32>) { // CHECK-NOT: scf.if // CHECK-NEXT: %[[alloc:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<15xf32>> // CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref> @@ -369,16 +369,16 @@ // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 3 { // CHECK-NEXT: %[[add:.*]] = affine.apply #[[$MAP0]](%[[I]])[%[[base]]] // CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<3xvector<15xf32>> - // CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref + // CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] {masked = [false]} : vector<15xf32>, memref // FULL-UNROLL: %[[VEC0:.*]] = vector.extract %[[vec]][0] : vector<3x15xf32> - // FULL-UNROLL: vector.transfer_write %[[VEC0]], %[[A]][%[[base]], %[[base]]] : vector<15xf32>, memref + // FULL-UNROLL: vector.transfer_write %[[VEC0]], %[[A]][%[[base]], %[[base]]] {masked = [false]} : vector<15xf32>, memref // FULL-UNROLL: %[[I1:.*]] = affine.apply #[[$MAP1]]()[%[[base]]] // FULL-UNROLL: %[[VEC1:.*]] = vector.extract %[[vec]][1] : vector<3x15xf32> - // FULL-UNROLL: vector.transfer_write %2, %[[A]][%[[I1]], %[[base]]] : vector<15xf32>, memref + // FULL-UNROLL: vector.transfer_write %2, %[[A]][%[[I1]], %[[base]]] {masked = [false]} : vector<15xf32>, memref // FULL-UNROLL: %[[I2:.*]] = affine.apply #[[$MAP2]]()[%[[base]]] // FULL-UNROLL: %[[VEC2:.*]] = vector.extract %[[vec]][2] : vector<3x15xf32> - // FULL-UNROLL: vector.transfer_write %[[VEC2:.*]], %[[A]][%[[I2]], %[[base]]] : vector<15xf32>, memref + // FULL-UNROLL: vector.transfer_write %[[VEC2:.*]], %[[A]][%[[I2]], %[[base]]] {masked = [false]} : vector<15xf32>, memref vector.transfer_write %vec, %A[%base, %base] {masked = [false, false]} : vector<3x15xf32>, memref return