diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5432,6 +5432,119 @@ return success(); } +namespace { + +/// Given a single-result operation that is used by a vector.yield, returns the +/// operand number of such result in the vector.yield. +Optional getYieldOpUseOperandNum(Operation *op) { + auto isYieldOpUse = [](OpOperand &use) -> bool { + return isa(use.getOwner()); + }; + + assert(llvm::count_if(op->getUses(), isYieldOpUse) <= 1 && + "Yielding the same value multiple times is not supported yet"); + + for (OpOperand &use : op->getUses()) { + if (isYieldOpUse(use)) + return use.getOperandNumber(); + } + + return std::nullopt; +} + +/// Given a vector.mask operation with multiple nested operations (other than +/// the vector.yield), hoists all the operations that do not need masking out of +/// the vector.mask operation and create individual vector.mask operations for +/// each nested operation that needs masking, using the mask of the input +/// vector.mask operation. +struct FlattenMultiOpMaskOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaskOp maskOp, + PatternRewriter &rewriter) const override { + Block &block = maskOp.getMaskRegion().getBlocks().front(); + if (block.getOperations().size() <= 2) + return success(); + + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(maskOp); + Value activeMask = maskOp.getMask(); + + // Hoist every operation either to an individual vector.mask or outside of + // the original vector.mask operation. + for (Operation &op : llvm::make_early_inc_range(block)) { + Operation *nestedOp = &op; + if (isa(nestedOp)) + continue; + + assert(nestedOp->getNumResults() <= 1 && + "Multi-result ops are not supported"); + + // The U-D chaine of operations that are returned by the original + // vector.mask need to be rewired properly. + auto maybeResultIdxToReplace = getYieldOpUseOperandNum(nestedOp); + unsigned resultIdxToReplace; + if (maybeResultIdxToReplace) { + resultIdxToReplace = *maybeResultIdxToReplace; + assert(resultIdxToReplace == 0 && "Multi-result ops are not supported"); + nestedOp->dropAllUses(); + } + + if (auto maskableOp = dyn_cast(nestedOp)) { + // Create a new vector.mask operation for this maskable op using the + // original mask. + auto createRegionMask = [nestedOp](OpBuilder &builder, Location loc) { + Block *insBlock = builder.getInsertionBlock(); + // Create a block, put an op in that block. Look for a utility. + // Maybe in conversion pattern rewriter. Way to avoid splice. + // Set insertion point. + insBlock->getOperations().splice( + insBlock->begin(), nestedOp->getBlock()->getOperations(), + nestedOp); + builder.create(loc, nestedOp->getResults()); + }; + + auto newMaskOp = maskableOp->getResults().empty() + ? rewriter.create(maskOp.getLoc(), activeMask, + createRegionMask) + : rewriter.create( + maskOp.getLoc(), maskableOp->getResultTypes().front(), + activeMask, createRegionMask); + + Operation *newMaskOpTerminator = &newMaskOp.getMaskRegion().front().back(); + + // Replace the original uses of the maskable op with result value of the + // new vector.mask containing the maskable op. + for (auto [resIdx, resVal] : llvm::enumerate(maskableOp->getResults())) + rewriter.replaceAllUsesExcept(resVal, newMaskOp.getResult(resIdx), + newMaskOpTerminator); + + // If the maskable op was returned by the original vector.mask, replace + // the original uses with the result value of the new vector.mask. + if (maybeResultIdxToReplace) + rewriter.replaceAllUsesWith(maskOp.getResult(0), + newMaskOp.getResult(0)); + } else { + // This operation doesn't need mask. We just move it outside the vector.mask. + maskOp->getBlock()->getOperations().splice( + Block::iterator(maskOp), nestedOp->getBlock()->getOperations(), + nestedOp); + + // If the operation was returned by the original vector.mask, replace + // the original uses with the result value of the new vector.mask. + if (maybeResultIdxToReplace) + rewriter.replaceAllUsesWith(maskOp.getResult(0), + nestedOp->getResult(0)); + } + } + + rewriter.eraseOp(maskOp); + return success(); + } +}; + +} // namespace + // MaskingOpInterface definitions. /// Returns the operation masked by this 'vector.mask'.