diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2374,7 +2374,7 @@ ``` }]; - // TODO: Support multiple results and passthru values. + // TODO: Support multiple passthru values. let arguments = (ins VectorOf<[I1]>:$mask, Optional:$passthru); let results = (outs Variadic:$results); @@ -2393,10 +2393,21 @@ let extraClassDeclaration = [{ Block *getMaskBlock() { return &getMaskRegion().front(); } - static void ensureTerminator(Region ®ion, Builder &builder, Location loc); + + /// Returns true if mask op is not masking any operation. + bool isEmpty() { + Block *block = getMaskBlock(); + if (block->getOperations().size() > 1) + return false; + return true; + } + + static void ensureTerminator(Region ®ion, Builder &builder, + Location loc); }]; let hasCanonicalizer = 1; + let hasFolder = 1; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5641,6 +5641,25 @@ return success(); } +/// Folds vector.mask ops with an all-true mask. +LogicalResult MaskOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + MaskFormat maskFormat = getMaskFormat(getMask()); + if (isEmpty()) + return failure(); + + if (maskFormat != MaskFormat::AllTrue) + return failure(); + + // Move maskable operation outside of the `vector.mask` region. + Operation *maskableOp = getMaskableOp(); + maskableOp->dropAllUses(); + maskableOp->moveBefore(getOperation()); + + results.push_back(maskableOp->getResult(0)); + return success(); +} + // Elides empty vector.mask operations with or without return values. Propagates // the yielded values by the vector.yield terminator, if any, or erases the op, // otherwise. @@ -5653,10 +5672,10 @@ if (maskingOp.getMaskableOp()) return failure(); - Block *block = maskOp.getMaskBlock(); - if (block->getOperations().size() > 1) + if (!maskOp.isEmpty()) return failure(); + Block *block = maskOp.getMaskBlock(); auto terminator = cast(block->front()); if (terminator.getNumOperands() == 0) rewriter.eraseOp(maskOp); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2257,4 +2257,16 @@ return %0 : vector<8xf32> } +// ----- + +// CHECK-LABEL: func @all_true_vector_mask +// CHECK-SAME: %[[IN:.*]]: vector<3x4xf32> +func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> { +// CHECK-NOT: vector.mask +// CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32> +// CHECK: return %[[ADD]] : vector<3x4xf32> + %all_true = vector.constant_mask [3, 4] : vector<3x4xi1> + %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> + return %0 : vector<3x4xf32> +}