diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2287,10 +2287,13 @@ The `vector.mask` is a `MaskingOpInterface` operation that predicates the execution of another operation. It takes an `i1` vector mask and an optional passthru vector as arguments. - A `vector.yield`-terminated region encloses the operation to be masked. - Values used within the region are captured from above. Only one *maskable* - operation can be masked with a `vector.mask` operation at a time. An - operation is *maskable* if it implements the `MaskableOpInterface`. + + A implicitly `vector.yield`-terminated region encloses the operation to be + masked. Values used within the region are captured from above. Only one + *maskable* operation can be masked with a `vector.mask` operation at a time. + An operation is *maskable* if it implements the `MaskableOpInterface`. The + terminator yields all results of the maskable operation to the result of + this operation. The vector mask argument holds a bit for each vector lane and determines which vector lanes should execute the maskable operation and which ones @@ -2321,12 +2324,16 @@ ``` vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref } : vector<16xi1> ``` + + ``` + vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + ``` }]; // TODO: Support multiple results and passthru values. let arguments = (ins VectorOf<[I1]>:$mask, Optional:$passthru); - let results = (outs Optional:$results); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$maskRegion); let skipDefaultBuilders = 1; @@ -2334,10 +2341,10 @@ OpBuilder<(ins "Value":$mask, CArg<"function_ref", "buildTerminatedBody">:$maskRegion)>, - OpBuilder<(ins "Type":$resultType, "Value":$mask, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, CArg<"function_ref", "buildTerminatedBody">:$maskRegion)>, - OpBuilder<(ins "Type":$resultType, "Value":$mask, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Value":$passthru, CArg<"function_ref", "buildTerminatedBody">:$maskRegion)> diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5288,20 +5288,20 @@ } void MaskOp::build( - OpBuilder &builder, OperationState &result, Type resultType, Value mask, - function_ref maskRegionBuilder) { - build(builder, result, resultType, mask, /*passthru=*/Value(), + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + Value mask, function_ref maskRegionBuilder) { + build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskRegionBuilder); } void MaskOp::build( - OpBuilder &builder, OperationState &result, Type resultType, Value mask, - Value passthru, + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + Value mask, Value passthru, function_ref maskRegionBuilder) { build(builder, result, mask, maskRegionBuilder); if (passthru) result.addOperands(passthru); - result.addTypes(resultType); + result.addTypes(resultTypes); } ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {