diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -224,6 +224,10 @@ .enableHoistRedundantVectorTransfersOnTensor = val; return *this; } + CodegenStrategy &setMaxTransferRank(int64_t val) { + this->lateCodegenStrategyOptions.maxTransferRank = val; + return *this; + } CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) { this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val; return *this; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -835,6 +835,7 @@ bool enableHoistRedundantVectorTransfersOnTensor = true; /// Vector lowering operations may result in surprising behavior when /// composing multiple codegen strategies and must be enabled explicitly. + int64_t maxTransferRank = 1; bool enableVectorTransferPartialRewrite = false; bool enableVectorContractLowering = false; bool enableVectorToSCFConversion = false; @@ -852,6 +853,7 @@ /// Vector lowering options control how ops are lowered down to 1-D and scf.for /// form. struct LinalgVectorLoweringOptions { + int64_t maxTransferRank = 1; bool enableVectorTransferPartialRewrite = false; bool enableVectorContractLowering = false; bool enableVectorToSCFConversion = false; diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -47,6 +47,8 @@ pm.addPass(createLinalgStrategyEnablePass()); } LinalgVectorLoweringOptions vectorLoweringOptions; + vectorLoweringOptions.maxTransferRank = + lateCodegenStrategyOptions.maxTransferRank; vectorLoweringOptions.enableVectorTransferPartialRewrite = lateCodegenStrategyOptions.enableVectorTransferPartialRewrite; vectorLoweringOptions.enableVectorContractLowering = diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -260,6 +260,8 @@ MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); + vector::populateVectorTransferLoweringPatterns(patterns, + options.maxTransferRank); if (options.enableVectorTransferPartialRewrite) { patterns.add( context, options.vectorTransformOptions);