diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -28,6 +28,10 @@ /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass(); +// Perform progressive lowering of operations on "slices" and +// all contraction operations. Also applies folding and DCE. +void lowerVectorSlicesToLLVM(Operation *op, MLIRContext *context); + } // namespace mlir #endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -580,6 +580,7 @@ populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); + lowerVectorSlicesToLLVM(getModule(), &getContext()); populateVectorToLLVMConversionPatterns(converter, patterns); populateLinalgToStandardConversionPatterns(patterns, &getContext()); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1117,6 +1117,13 @@ patterns.insert(ctx, converter); } +void mlir::lowerVectorSlicesToLLVM(Operation *op, MLIRContext *context) { + OwningRewritePatternList patterns; + populateVectorSlicesLoweringPatterns(patterns, context); + populateVectorContractLoweringPatterns(patterns, context); + applyPatternsGreedily(op, patterns); +} + namespace { struct LowerVectorToLLVMPass : public ModulePass { /// Include the generated pass utilities. @@ -1128,14 +1135,7 @@ } // namespace void LowerVectorToLLVMPass::runOnModule() { - // Perform progressive lowering of operations on slices and - // all contraction operations. Also applies folding and DCE. - { - OwningRewritePatternList patterns; - populateVectorSlicesLoweringPatterns(patterns, &getContext()); - populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getModule(), patterns); - } + lowerVectorSlicesToLLVM(getModule(), &getContext()); // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext());