diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -228,6 +228,10 @@ this->lateCodegenStrategyOptions.maxTransferRank = val; return *this; } + CodegenStrategy &setEnableVectorTransferLowering(bool val) { + this->lateCodegenStrategyOptions.enableVectorTransferLowering = val; + return *this; + } CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) { this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val; return *this; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -836,6 +836,7 @@ /// Vector lowering operations may result in surprising behavior when /// composing multiple codegen strategies and must be enabled explicitly. int64_t maxTransferRank = 1; + bool enableVectorTransferLowering = false; bool enableVectorTransferPartialRewrite = false; bool enableVectorContractLowering = false; bool enableVectorToSCFConversion = false; @@ -854,6 +855,7 @@ /// form. struct LinalgVectorLoweringOptions { int64_t maxTransferRank = 1; + bool enableVectorTransferLowering = false; bool enableVectorTransferPartialRewrite = false; bool enableVectorContractLowering = false; bool enableVectorToSCFConversion = false; diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -49,6 +49,8 @@ LinalgVectorLoweringOptions vectorLoweringOptions; vectorLoweringOptions.maxTransferRank = lateCodegenStrategyOptions.maxTransferRank; + vectorLoweringOptions.enableVectorTransferLowering = + lateCodegenStrategyOptions.enableVectorTransferLowering; vectorLoweringOptions.enableVectorTransferPartialRewrite = lateCodegenStrategyOptions.enableVectorTransferPartialRewrite; vectorLoweringOptions.enableVectorContractLowering = diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -260,8 +260,10 @@ MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); - vector::populateVectorTransferLoweringPatterns(patterns, - options.maxTransferRank); + if (options.enableVectorTransferLowering) { + vector::populateVectorTransferLoweringPatterns(patterns, + options.maxTransferRank); + } if (options.enableVectorTransferPartialRewrite) { patterns.add( context, options.vectorTransformOptions);