diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -16,6 +16,18 @@ template class OperationPass; +/// Options to control Vector to LLVM lowering. +/// +/// This should kept in sync with VectorToLLVM options defined for the +/// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td +struct LowerVectorToLLVMOptions { + bool reassociateFPReductions = false; + LowerVectorToLLVMOptions &setReassociateFPReductions(bool r) { + reassociateFPReductions = r; + return *this; + } +}; + /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix /// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics /// will be needed when invoking LLVM. @@ -28,7 +40,8 @@ bool reassociateFPReductions = false); /// Create a pass to convert vector operations to the LLVMIR dialect. -std::unique_ptr> createConvertVectorToLLVMPass(); +std::unique_ptr> createConvertVectorToLLVMPass( + const LowerVectorToLLVMOptions &options = LowerVectorToLLVMOptions()); } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1180,6 +1180,9 @@ namespace { struct LowerVectorToLLVMPass : public ConvertVectorToLLVMBase { + LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { + this->reassociateFPReductions = options.reassociateFPReductions; + } void runOnOperation() override; }; } // namespace @@ -1210,6 +1213,7 @@ } } -std::unique_ptr> mlir::createConvertVectorToLLVMPass() { - return std::make_unique(); +std::unique_ptr> +mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { + return std::make_unique(options); }