diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -28,6 +28,10 @@ /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass(); +// Perform progressive lowering of operations on "slices" and +// all contraction operations. Also applies folding and DCE. +void lowerVectorToVector(Operation *op, MLIRContext *context); + } // namespace mlir #endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -569,6 +569,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() { auto module = getOperation(); + lowerVectorToVector(module, &getContext()); // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1230,6 +1230,15 @@ patterns.insert(ctx, converter); } +void mlir::lowerVectorToVector(Operation *op, MLIRContext *context) { + // Perform progressive lowering of operations on slices and + // all contraction operations. Also applies folding and DCE. + OwningRewritePatternList patterns; + populateVectorSlicesLoweringPatterns(patterns, context); + populateVectorContractLoweringPatterns(patterns, context); + applyPatternsAndFoldGreedily(op, patterns); +} + namespace { struct LowerVectorToLLVMPass : public ConvertVectorToLLVMBase { @@ -1238,14 +1247,7 @@ } // namespace void LowerVectorToLLVMPass::runOnOperation() { - // Perform progressive lowering of operations on slices and - // all contraction operations. Also applies folding and DCE. - { - OwningRewritePatternList patterns; - populateVectorSlicesLoweringPatterns(patterns, &getContext()); - populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(getOperation(), patterns); - } + lowerVectorToVector(getOperation(), &getContext()); // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext());