diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2393,6 +2393,15 @@ let extraClassDeclaration = [{ Block *getMaskBlock() { return &getMaskRegion().front(); } + + /// Returns true if mask op is not masking any operation. + bool isEmpty() { + Block *block = getMaskBlock(); + if (block->getOperations().size() > 1) + return false; + return true; + } + static void ensureTerminator(Region ®ion, Builder &builder, Location loc); }]; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5653,10 +5653,10 @@ if (maskingOp.getMaskableOp()) return failure(); - Block *block = maskOp.getMaskBlock(); - if (block->getOperations().size() > 1) + if (!maskOp.isEmpty()) return failure(); + Block *block = maskOp.getMaskBlock(); auto terminator = cast(block->front()); if (terminator.getNumOperands() == 0) rewriter.eraseOp(maskOp); @@ -5667,9 +5667,28 @@ } }; +/// Folds vector.mask ops with an all-true mask. +class FoldAllTrueMaskOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaskOp maskOp, + PatternRewriter &rewriter) const override { + MaskFormat maskFormat = getMaskFormat(maskOp.getMask()); + if (maskFormat != MaskFormat::AllTrue) + return failure(); + + if (maskOp.isEmpty()) + return failure(); + + Operation *maskableOpClone = rewriter.clone(*maskOp.getMaskableOp()); + rewriter.replaceOp(maskOp, maskableOpClone->getResult(0)); + return success(); + } +}; + void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } // MaskingOpInterface definitions. diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2257,4 +2257,16 @@ return %0 : vector<8xf32> } +// ----- + +// CHECK-LABEL: func @all_true_vector_mask +// CHECK-SAME: %[[IN:.*]]: vector<3x4xf32> +func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> { +// CHECK-NOT: vector.mask +// CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32> +// CHECK: return %[[ADD]] : vector<3x4xf32> + %all_true = vector.constant_mask [3, 4] : vector<3x4xi1> + %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> + return %0 : vector<3x4xf32> +}