diff --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h --- a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h +++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h @@ -9,21 +9,15 @@ #ifndef MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ #define MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ -#include - namespace mlir { + class LLVMTypeConverter; -class ModuleOp; -template class OperationPass; class OwningRewritePatternList; /// Collect a set of patterns to convert from the AVX512 dialect to LLVM. void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Create a pass to convert AVX512 operations to the LLVMIR dialect. -std::unique_ptr> createConvertAVX512ToLLVMPass(); - } // namespace mlir #endif // MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -9,7 +9,6 @@ #ifndef MLIR_CONVERSION_PASSES_H #define MLIR_CONVERSION_PASSES_H -#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -73,17 +73,6 @@ ]; } -//===----------------------------------------------------------------------===// -// AVX512ToLLVM -//===----------------------------------------------------------------------===// - -def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> { - let summary = "Convert the operations from the avx512 dialect into the LLVM " - "dialect"; - let constructor = "mlir::createConvertAVX512ToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"]; -} - //===----------------------------------------------------------------------===// // AsyncToLLVM //===----------------------------------------------------------------------===// @@ -402,14 +391,17 @@ let summary = "Lower the operations from the vector dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertVectorToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect"]; + let dependentDialects = ["LLVM::LLVMDialect","LLVM::LLVMAVX512Dialect"]; let options = [ Option<"reassociateFPReductions", "reassociate-fp-reductions", "bool", /*default=*/"false", "Allows llvm to reassociate floating-point reductions for speed">, Option<"enableIndexOptimizations", "enable-index-optimizations", "bool", /*default=*/"true", - "Allows compiler to assume indices fit in 32-bit if that yields faster code"> + "Allows compiler to assume indices fit in 32-bit if that yields faster code">, + Option<"enableAVX512", "enable-avx512", + "bool", /*default=*/"false", + "Augments the vector dialect with the AVX512 dialect during lowering"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,6 +23,7 @@ struct LowerVectorToLLVMOptions { bool reassociateFPReductions = false; bool enableIndexOptimizations = true; + bool enableAVX512 = false; LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; return *this; @@ -31,6 +32,10 @@ enableIndexOptimizations = b; return *this; } + LowerVectorToLLVMOptions &setEnableAVX512(bool b) { + enableAVX512 = b; + return *this; + } }; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -8,10 +8,7 @@ #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" -#include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -19,7 +16,6 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::vector; @@ -157,32 +153,3 @@ ScaleFOpPD512Conversion>(ctx, converter); // clang-format on } - -namespace { -struct ConvertAVX512ToLLVMPass - : public ConvertAVX512ToLLVMBase { - void runOnOperation() override; -}; -} // namespace - -void ConvertAVX512ToLLVMPass::runOnOperation() { - // Convert to the LLVM IR dialect. - OwningRewritePatternList patterns; - LLVMTypeConverter converter(&getContext()); - populateAVX512ToLLVMConversionPatterns(converter, patterns); - populateVectorToLLVMConversionPatterns(converter, patterns); - populateStdToLLVMConversionPatterns(converter, patterns); - - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addLegalDialect(); - target.addIllegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - signalPassFailure(); - } -} - -std::unique_ptr> mlir::createConvertAVX512ToLLVMPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -12,6 +12,9 @@ Core LINK_LIBS PUBLIC + MLIRAVX512 + MLIRAVX512ToLLVM + MLIRLLVMAVX512 MLIRLLVMIR MLIRStandardToLLVM MLIRTargetLLVMIRModuleTranslation diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -9,8 +9,12 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "../PassDetail.h" + +#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -1606,6 +1610,7 @@ LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { this->reassociateFPReductions = options.reassociateFPReductions; this->enableIndexOptimizations = options.enableIndexOptimizations; + this->enableAVX512 = options.enableAVX512; } void runOnOperation() override; }; @@ -1631,7 +1636,14 @@ populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); + // Architecture specific augmentations. LLVMConversionTarget target(getContext()); + if (enableAVX512) { + target.addLegalDialect(); + target.addIllegalDialect(); + populateAVX512ToLLVMConversionPatterns(converter, patterns); + } + if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Dialect/AVX512/CMakeLists.txt --- a/mlir/lib/Dialect/AVX512/CMakeLists.txt +++ b/mlir/lib/Dialect/AVX512/CMakeLists.txt @@ -10,5 +10,4 @@ LINK_LIBS PUBLIC MLIRIR MLIRSideEffectInterfaces - MLIRVectorToLLVM ) diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -convert-avx512-to-llvm | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-avx512" | mlir-opt | FileCheck %s func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) - -> (vector<16xf32>, vector<8xf64>) + -> (vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>) { // CHECK: llvm_avx512.mask.rndscale.ps.512 %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32> @@ -9,9 +9,10 @@ %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64> // CHECK: llvm_avx512.mask.scalef.ps.512 - %a0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> + %2 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> // CHECK: llvm_avx512.mask.scalef.pd.512 - %a1 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64> + %3 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64> - return %a0, %a1: vector<16xf32>, vector<8xf64> + // Keep results alive. + return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64> }