diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td" include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td" include "mlir/IR/EnumAttr.td" +include "mlir/IR/RegionKindInterface.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -2279,19 +2280,18 @@ } def Vector_MaskOp : Vector_Op<"mask", [ - SingleBlockImplicitTerminator<"vector::YieldOp">, DeclareOpInterfaceMethods, RecursiveMemoryEffects, NoRegionArguments -]> { +] # GraphRegionNoTerminator.traits> { let summary = "Predicates a maskable vector operation"; let description = [{ The `vector.mask` is a `MaskingOpInterface` operation that predicates the execution of another operation. It takes an `i1` vector mask and an optional passthru vector as arguments. - A `vector.yield`-terminated region encloses the operation to be masked. - Values used within the region are captured from above. Only one *maskable* - operation can be masked with a `vector.mask` operation at a time. An - operation is *maskable* if it implements the `MaskableOpInterface`. + A region encloses exactly one operation to be masked. Values used within + the region are captured from above. Only one *maskable* operation can be + masked with a `vector.mask` operation at a time. An operation is *maskable* + if it implements the `MaskableOpInterface`. The vector mask argument holds a bit for each vector lane and determines which vector lanes should execute the maskable operation and which ones diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5304,8 +5304,6 @@ result.regions.reserve(1); Region &maskRegion = *result.addRegion(); - auto &builder = parser.getBuilder(); - // Parse all the operands. OpAsmParser::UnresolvedOperand mask; if (parser.parseOperand(mask)) @@ -5320,8 +5318,10 @@ // Parse op region. if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); - - MaskOp::ensureTerminator(maskRegion, builder, result.location); + // Make sure that there is at least one block so that we generate a better + // error message if the region is empty. + if (maskRegion.empty()) + maskRegion.emplaceBlock(); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) @@ -5352,11 +5352,11 @@ p << " " << getMask(); if (getPassthru()) p << ", " << getPassthru(); - + p << " "; // Print single masked operation and skip terminator. p << " { "; Block *singleBlock = &getMaskRegion().getBlocks().front(); - if (singleBlock && singleBlock->getOperations().size() > 1) + if (singleBlock && singleBlock->getOperations().size() > 0) p.printCustomOrGenericOp(&singleBlock->front()); p << " }"; @@ -5367,33 +5367,12 @@ p << " -> " << getResultTypes(); } -void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { - OpTrait::SingleBlockImplicitTerminator::Impl< - MaskOp>::ensureTerminator(region, builder, loc); - // Keep the default yield terminator if the number of masked operations is not - // the expected. This case will trigger a verification failure. - if (region.front().getOperations().size() != 2) - return; - - // Replace default yield terminator with a new one that returns the results - // from the masked operation. - OpBuilder opBuilder(builder.getContext()); - Operation *maskedOp = ®ion.front().front(); - Operation *oldYieldOp = ®ion.front().back(); - assert(isa(oldYieldOp) && "Expected vector::YieldOp"); - - opBuilder.setInsertionPoint(oldYieldOp); - opBuilder.create(maskedOp->getLoc(), maskedOp->getResults()); - oldYieldOp->dropAllReferences(); - oldYieldOp->erase(); -} - LogicalResult MaskOp::verify() { // Structural checks. Block &block = getMaskRegion().getBlocks().front(); - if (block.getOperations().size() < 2) + if (block.getOperations().size() == 0) return emitOpError("expects an operation to mask"); - if (block.getOperations().size() > 2) + if (block.getOperations().size() > 1) return emitOpError("expects only one operation to mask"); auto maskableOp = dyn_cast(block.front());