diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -509,6 +509,8 @@ let options = [ Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false", "Perform full unrolling when converting vector transfers to SCF">, + Option<"targetRank", "target-rank", "unsigned", /*default=*/"1", + "Target vector rank to which transfer ops should be lowered">, ]; } diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -49,10 +49,17 @@ struct VectorTransferToSCFOptions { bool unroll = false; + unsigned targetRank = 1; + VectorTransferToSCFOptions &setUnroll(bool u) { unroll = u; return *this; } + + VectorTransferToSCFOptions &setTargetRank(unsigned r) { + targetRank = r; + return *this; + } }; /// Collect a set of patterns to convert from the Vector dialect to SCF + std. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -38,8 +38,16 @@ /// Attribute name used for labeling transfer ops during progressive lowering. static const char kPassLabel[] = "__vector_to_scf_lowering__"; -/// Lower to 1D transfer ops. Target-specific lowering will lower those. -static const int64_t kTargetRank = 1; +/// Patterns that inherit from this struct have access to +/// VectorTransferToSCFOptions. +template +struct VectorToSCFPattern : public OpRewritePattern { + explicit VectorToSCFPattern(MLIRContext *context, + VectorTransferToSCFOptions opt) + : OpRewritePattern(context), options(opt) {} + + VectorTransferToSCFOptions options; +}; /// Given a MemRefType with VectorType element type, unpack one dimension from /// the VectorType into the MemRefType. @@ -270,8 +278,9 @@ /// Add the pass label to a vector transfer op if its rank is not the target /// rank. template -static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp) { - if (newXferOp.getVectorType().getRank() > kTargetRank) +static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp, + unsigned targetRank) { + if (newXferOp.getVectorType().getRank() > targetRank) newXferOp->setAttr(kPassLabel, builder.getUnitAttr()); } @@ -347,8 +356,10 @@ /// Note: The loop and type cast are generated in TransferOpConversion. /// The original TransferReadOp and store op are deleted in `cleanup`. /// Note: The `mask` operand is set in TransferOpConversion. - static TransferReadOp rewriteOp(OpBuilder &builder, TransferReadOp xferOp, - Value buffer, Value iv) { + static TransferReadOp rewriteOp(OpBuilder &builder, + VectorTransferToSCFOptions options, + TransferReadOp xferOp, Value buffer, + Value iv) { SmallVector storeIndices; getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -367,7 +378,8 @@ .value; maybeApplyPassLabel(builder, - dyn_cast(newXfer.getDefiningOp())); + dyn_cast(newXfer.getDefiningOp()), + options.targetRank); memref_store(newXfer, buffer, storeIndices); return newXfer.getDefiningOp(); @@ -428,8 +440,10 @@ /// to memory. /// /// Note: For more details, see comments on Strategy. - static TransferWriteOp rewriteOp(OpBuilder &builder, TransferWriteOp xferOp, - Value buffer, Value iv) { + static TransferWriteOp rewriteOp(OpBuilder &builder, + VectorTransferToSCFOptions options, + TransferWriteOp xferOp, Value buffer, + Value iv) { SmallVector loadIndices; getBufferIndices(xferOp, loadIndices); loadIndices.push_back(iv); @@ -444,7 +458,7 @@ AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(), inBoundsAttr); - maybeApplyPassLabel(builder, newXfer.op); + maybeApplyPassLabel(builder, newXfer.op, options.targetRank); return newXfer; } @@ -460,10 +474,10 @@ }; template -LogicalResult checkPrepareXferOp(OpTy xferOp) { +LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) { if (xferOp->hasAttr(kPassLabel)) return failure(); - if (xferOp.getVectorType().getRank() <= kTargetRank) + if (xferOp.getVectorType().getRank() <= targetRank) return failure(); return success(); } @@ -491,12 +505,13 @@ /// ``` /// /// Note: A second temporary buffer may be allocated for the `mask` operand. -struct PrepareTransferReadConversion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct PrepareTransferReadConversion + : public VectorToSCFPattern { + using VectorToSCFPattern::VectorToSCFPattern; LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { - if (checkPrepareXferOp(xferOp).failed()) + if (checkPrepareXferOp(xferOp, options.targetRank).failed()) return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); @@ -539,12 +554,12 @@ /// /// Note: A second temporary buffer may be allocated for the `mask` operand. struct PrepareTransferWriteConversion - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public VectorToSCFPattern { + using VectorToSCFPattern::VectorToSCFPattern; LogicalResult matchAndRewrite(TransferWriteOp xferOp, PatternRewriter &rewriter) const override { - if (checkPrepareXferOp(xferOp).failed()) + if (checkPrepareXferOp(xferOp, options.targetRank).failed()) return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); @@ -583,8 +598,8 @@ /// out-of-bounds, generate an if-check and handle both cases separately. /// 3. Clean up according to the corresponding Strategy. template -struct TransferOpConversion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct TransferOpConversion : public VectorToSCFPattern { + using VectorToSCFPattern::VectorToSCFPattern; LogicalResult matchAndRewrite(OpTy xferOp, PatternRewriter &rewriter) const override { @@ -635,8 +650,8 @@ /*inBoundsCase=*/ [&](OpBuilder &b, Location /*loc*/) { // Create new transfer op. - OpTy newXfer = - Strategy::rewriteOp(b, xferOp, castedDataBuffer, iv); + OpTy newXfer = Strategy::rewriteOp( + b, this->options, xferOp, castedDataBuffer, iv); // If old transfer op has a mask: Set mask on new transfer op. // Special case: If the mask of the old transfer op is 1D and @@ -722,8 +737,9 @@ /// /// Note: A pass label is attached to new TransferReadOps, so that subsequent /// applications of this pattern do not create an additional %v_init vector. -struct UnrollTransferReadConversion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct UnrollTransferReadConversion + : public VectorToSCFPattern { + using VectorToSCFPattern::VectorToSCFPattern; /// Find the result vector %v_init or create a new vector if this the first /// application of the pattern. @@ -758,7 +774,7 @@ /// accesses, and broadcasts and transposes in permutation maps. LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { - if (xferOp.getVectorType().getRank() <= kTargetRank) + if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); if (xferOp->hasAttr(kPassLabel) && !getInsertOp(xferOp)) { @@ -809,7 +825,7 @@ dyn_cast(newXferOpVal.getDefiningOp()); maybeAssignMask(b, xferOp, newXferOp, i); - maybeApplyPassLabel(b, newXferOp); + maybeApplyPassLabel(b, newXferOp, options.targetRank); return vector_insert(newXferOp, vec, insertionIndices).value; }, @@ -856,8 +872,8 @@ /// applications of this pattern can read the indices of previously generated /// vector.extract ops. struct UnrollTransferWriteConversion - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public VectorToSCFPattern { + using VectorToSCFPattern::VectorToSCFPattern; /// If this is not the first application of the pattern, find the original /// vector %vec that is written by this transfer op. Otherwise, return the @@ -887,7 +903,7 @@ /// accesses, and broadcasts and transposes in permutation maps. LogicalResult matchAndRewrite(TransferWriteOp xferOp, PatternRewriter &rewriter) const override { - if (xferOp.getVectorType().getRank() <= kTargetRank) + if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); if (xferOp->hasAttr(kPassLabel) && !getExtractOp(xferOp)) { @@ -934,7 +950,7 @@ .op; maybeAssignMask(b, xferOp, newXferOp, i); - maybeApplyPassLabel(b, newXferOp); + maybeApplyPassLabel(b, newXferOp, options.targetRank); }); } @@ -1067,8 +1083,8 @@ /// } /// ``` template -struct TransferOp1dConversion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct TransferOp1dConversion : public VectorToSCFPattern { + using VectorToSCFPattern::VectorToSCFPattern; LogicalResult matchAndRewrite(OpTy xferOp, PatternRewriter &rewriter) const override { @@ -1111,17 +1127,18 @@ RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { if (options.unroll) { patterns.add( - patterns.getContext()); + patterns.getContext(), options); } else { patterns.add, - TransferOpConversion>(patterns.getContext()); + TransferOpConversion>(patterns.getContext(), + options); } - if (kTargetRank == 1) { + if (options.targetRank == 1) { patterns.add, - TransferOp1dConversion>( - patterns.getContext()); + TransferOp1dConversion>(patterns.getContext(), + options); } } @@ -1134,12 +1151,16 @@ ConvertVectorToSCFPass() = default; ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { this->fullUnroll = options.unroll; + this->targetRank = options.targetRank; } void runOnFunction() override { + VectorTransferToSCFOptions options; + options.setUnroll(fullUnroll); + options.setTargetRank(targetRank); + RewritePatternSet patterns(getFunction().getContext()); - populateVectorToSCFConversionPatterns( - patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll)); + populateVectorToSCFConversionPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };