diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -394,9 +394,10 @@ Convert operations from the vector dialect into the LLVM IR dialect operations. The lowering pass provides several options to control - the kind of optimizations that are allowed. It also provides options - that augment the architectural-neutral vector dialect with - architectural-specific dialects (AVX512, Neon, etc.). + the kinds of optimizations that are allowed. It also provides options + that enable the use of one or more architectural-specific dialects + (AVX512, Neon, SVE, etc.) in combination with the architectural-neutral + vector dialect lowering. }]; let constructor = "mlir::createConvertVectorToLLVMPass()"; @@ -411,7 +412,8 @@ "faster code">, Option<"enableAVX512", "enable-avx512", "bool", /*default=*/"false", - "Augments the vector dialect with the AVX512 dialect during lowering"> + "Enables the use of AVX512 dialect while lowering the vector " + "dialect."> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -21,9 +21,10 @@ /// This should kept in sync with VectorToLLVM options defined for the /// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td struct LowerVectorToLLVMOptions { - bool reassociateFPReductions = false; - bool enableIndexOptimizations = true; - bool enableAVX512 = false; + LowerVectorToLLVMOptions() + : reassociateFPReductions(false), enableIndexOptimizations(true), + enableAVX512(false) {} + LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; return *this; @@ -36,6 +37,10 @@ enableAVX512 = b; return *this; } + + bool reassociateFPReductions; + bool enableIndexOptimizations; + bool enableAVX512; }; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix