diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -35,6 +35,18 @@ void populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of patterns to split transfer read/write ops. +/// +/// These patterns unrolls transfer read/write ops if the vector consumers/ +/// producers are extract/insert slices op. Transfer ops can map to hardware +/// load/store functionalities, where the vector size matters for bandwith +/// considerations. So these patterns should be collected separately, instead +/// of being generic canonicalization patterns. Also one can let the +/// `ignoreFilter` to return true to fail matching for fine-grained control. +void populateSplitVectorTransferPatterns( + OwningRewritePatternList &patterns, MLIRContext *context, + std::function ignoreFilter = nullptr); + /// Collect a set of leading one dimension removal patterns. /// /// These patterns insert vector.shape_cast to remove leading one dimensions diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -705,22 +705,33 @@ namespace { -// Splits vector TransferReadOp into smaller TransferReadOps based on slicing -// scheme of its unique ExtractSlicesOp user. -struct SplitTransferReadOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Splits a TransferReadOp into smaller TransferReadOps based on slicing +// scheme of its unique ExtractSlicesOp users. +class SplitTransferReadOp : public OpRewritePattern { +public: + SplitTransferReadOp(MLIRContext *context, + std::function ignoreFilter = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), ignoreFilter(ignoreFilter) {} - LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp, + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { - // TODO: Support splitting TransferReadOp with non-identity - // permutation maps. Repurpose code from MaterializeVectors transformation. - if (!isIdentitySuffix(xferReadOp.permutation_map())) + if (ignoreFilter && ignoreFilter(readOp)) + return failure(); + + // TODO: Support splitting TransferReadOp with non-identity permutation + // maps. Repurpose code from MaterializeVectors transformation. + if (!isIdentitySuffix(readOp.permutation_map())) + return failure(); + + // Return unless there is only one user, and it is an ExtractSlicesOp. + Value readResult = readOp.getResult(); + if (!readResult.hasOneUse()) return failure(); - // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. - Value xferReadResult = xferReadOp.getResult(); + auto extractSlicesOp = - dyn_cast(*xferReadResult.getUsers().begin()); - if (!xferReadResult.hasOneUse() || !extractSlicesOp) + dyn_cast(readResult.use_begin()->getOwner()); + if (!extractSlicesOp) return failure(); // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user. @@ -730,37 +741,48 @@ extractSlicesOp.getStrides(strides); assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); - Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter); + Value newVec = unrollTransferReadOp(readOp, sizes, rewriter); if (!newVec) return failure(); - rewriter.replaceOp(xferReadOp, newVec); + rewriter.replaceOp(readOp, newVec); return success(); } + +private: + std::function ignoreFilter; }; -// Splits vector TransferWriteOp into smaller TransferWriteOps for each source. -struct SplitTransferWriteOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Splits a TransferWriteOp into smaller TransferWriteOps for each source. +class SplitTransferWriteOp : public OpRewritePattern { +public: + SplitTransferWriteOp(MLIRContext *context, + std::function ignoreFilter = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), ignoreFilter(ignoreFilter) {} - LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp, + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { - // TODO: Support splitting TransferWriteOp with non-identity - // permutation maps. Repurpose code from MaterializeVectors transformation. - if (!isIdentitySuffix(xferWriteOp.permutation_map())) + if (ignoreFilter && ignoreFilter(writeOp)) + return failure(); + + // TODO: Support splitting TransferWriteOp with non-identity permutation + // maps. Repurpose code from MaterializeVectors transformation. + if (!isIdentitySuffix(writeOp.permutation_map())) return failure(); - // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'. - auto *vectorDefOp = xferWriteOp.vector().getDefiningOp(); - auto insertSlicesOp = dyn_cast_or_null(vectorDefOp); + + // Fail to match unless this is writing a vector resulting from an + // InsertSlicesOp. + auto insertSlicesOp = + writeOp.vector().getDefiningOp(); if (!insertSlicesOp) return failure(); - // Get TupleOp operand of 'insertSlicesOp'. - auto tupleOp = dyn_cast_or_null( - insertSlicesOp.vectors().getDefiningOp()); + // Get the TupleOp operand of the InsertSlicesOp. + auto tupleOp = insertSlicesOp.vectors().getDefiningOp(); if (!tupleOp) return failure(); - // Get 'sizes' and 'strides' parameters from InsertSlicesOp user. + // Get 'sizes' and 'strides' parameters from the InsertSlicesOp user. auto sourceTupleType = insertSlicesOp.getSourceTupleType(); auto resultVectorType = insertSlicesOp.getResultVectorType(); SmallVector sizes; @@ -768,21 +790,20 @@ SmallVector strides; insertSlicesOp.getStrides(strides); - Location loc = xferWriteOp.getLoc(); + Location loc = writeOp.getLoc(); auto shapedElementType = - xferWriteOp.source().getType().cast().getElementType(); - SmallVector indices(xferWriteOp.indices().begin(), - xferWriteOp.indices().end()); + writeOp.source().getType().cast().getElementType(); + auto indices = llvm::to_vector<4>(writeOp.indices()); Value resultTensor; auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. - // `masked` attribute propagates conservatively: if the coarse op didn't + // 'masked' attribute propagates conservatively: if the coarse op didn't // need masking, the fine op doesn't either. Operation *write = rewriter.create( loc, tupleOp.getOperand(index), - resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices, - xferWriteOp.permutation_map(), - xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr()); + resultTensor ? resultTensor : writeOp.source(), sliceIndices, + writeOp.permutation_map(), + writeOp.masked() ? *writeOp.masked() : ArrayAttr()); if (!write->getResults().empty()) resultTensor = write->getResult(0); }; @@ -790,13 +811,15 @@ sourceTupleType, sizes, strides, indices, rewriter, createSlice); - // Erase old 'xferWriteOp'. if (resultTensor) - rewriter.replaceOp(xferWriteOp, ArrayRef(resultTensor)); + rewriter.replaceOp(writeOp, ArrayRef(resultTensor)); else - rewriter.eraseOp(xferWriteOp); + rewriter.eraseOp(writeOp); return success(); } + +private: + std::function ignoreFilter; }; /// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each @@ -3028,15 +3051,16 @@ // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - // clang-format off - patterns.insert(context); - // clang-format on + patterns.insert( + context); +} + +void mlir::vector::populateSplitVectorTransferPatterns( + OwningRewritePatternList &patterns, MLIRContext *context, + std::function ignoreFilter) { + patterns.insert(context, + ignoreFilter); } void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -47,6 +47,7 @@ populateVectorToVectorTransformationPatterns(patterns, ctx); populateBubbleVectorBitCastOpPatterns(patterns, ctx); populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx); + populateSplitVectorTransferPatterns(patterns, ctx); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); }