diff --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h --- a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h +++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h @@ -9,21 +9,15 @@ #ifndef MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ #define MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ -#include - namespace mlir { + class LLVMTypeConverter; -class ModuleOp; -template class OperationPass; class OwningRewritePatternList; /// Collect a set of patterns to convert from the AVX512 dialect to LLVM. void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Create a pass to convert AVX512 operations to the LLVMIR dialect. -std::unique_ptr> createConvertAVX512ToLLVMPass(); - } // namespace mlir #endif // MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -9,7 +9,6 @@ #ifndef MLIR_CONVERSION_PASSES_H #define MLIR_CONVERSION_PASSES_H -#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -73,17 +73,6 @@ ]; } -//===----------------------------------------------------------------------===// -// AVX512ToLLVM -//===----------------------------------------------------------------------===// - -def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> { - let summary = "Convert the operations from the avx512 dialect into the LLVM " - "dialect"; - let constructor = "mlir::createConvertAVX512ToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"]; -} - //===----------------------------------------------------------------------===// // AsyncToLLVM //===----------------------------------------------------------------------===// @@ -401,15 +390,28 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> { let summary = "Lower the operations from the vector dialect into the LLVM " "dialect"; + let description = [{ + + Convert operations from the vector dialect into the LLVM IR dialect + operations. The lowering pass provides several options to control + the kind of optimizations that are allowed. It also provides options + that augment the architectural-neutral vector dialect with + architectural-specific dialects (AVX512, Neon, etc.). + + }]; let constructor = "mlir::createConvertVectorToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect"]; + let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"]; let options = [ Option<"reassociateFPReductions", "reassociate-fp-reductions", "bool", /*default=*/"false", "Allows llvm to reassociate floating-point reductions for speed">, Option<"enableIndexOptimizations", "enable-index-optimizations", "bool", /*default=*/"true", - "Allows compiler to assume indices fit in 32-bit if that yields faster code"> + "Allows compiler to assume indices fit in 32-bit if that yields " + "faster code">, + Option<"enableAVX512", "enable-avx512", + "bool", /*default=*/"false", + "Augments the vector dialect with the AVX512 dialect during lowering"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,6 +23,7 @@ struct LowerVectorToLLVMOptions { bool reassociateFPReductions = false; bool enableIndexOptimizations = true; + bool enableAVX512 = false; LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; return *this; @@ -31,6 +32,10 @@ enableIndexOptimizations = b; return *this; } + LowerVectorToLLVMOptions &setEnableAVX512(bool b) { + enableAVX512 = b; + return *this; + } }; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -8,10 +8,7 @@ #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" -#include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -19,7 +16,6 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::vector; @@ -157,32 +153,3 @@ ScaleFOpPD512Conversion>(ctx, converter); // clang-format on } - -namespace { -struct ConvertAVX512ToLLVMPass - : public ConvertAVX512ToLLVMBase { - void runOnOperation() override; -}; -} // namespace - -void ConvertAVX512ToLLVMPass::runOnOperation() { - // Convert to the LLVM IR dialect. - OwningRewritePatternList patterns; - LLVMTypeConverter converter(&getContext()); - populateAVX512ToLLVMConversionPatterns(converter, patterns); - populateVectorToLLVMConversionPatterns(converter, patterns); - populateStdToLLVMConversionPatterns(converter, patterns); - - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addLegalDialect(); - target.addIllegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - signalPassFailure(); - } -} - -std::unique_ptr> mlir::createConvertAVX512ToLLVMPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(MLIRVectorToLLVM ConvertVectorToLLVM.cpp + ConvertVectorToLLVMPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM @@ -12,6 +13,9 @@ Core LINK_LIBS PUBLIC + MLIRAVX512 + MLIRAVX512ToLLVM + MLIRLLVMAVX512 MLIRLLVMIR MLIRStandardToLLVM MLIRTargetLLVMIRModuleTranslation diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -8,25 +8,14 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/Allocator.h" -#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::vector; @@ -1599,45 +1588,3 @@ patterns.insert(ctx, converter); patterns.insert(ctx, converter); } - -namespace { -struct LowerVectorToLLVMPass - : public ConvertVectorToLLVMBase { - LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { - this->reassociateFPReductions = options.reassociateFPReductions; - this->enableIndexOptimizations = options.enableIndexOptimizations; - } - void runOnOperation() override; -}; -} // namespace - -void LowerVectorToLLVMPass::runOnOperation() { - // Perform progressive lowering of operations on slices and - // all contraction operations. Also applies folding and DCE. - { - OwningRewritePatternList patterns; - populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); - populateVectorSlicesLoweringPatterns(patterns, &getContext()); - populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } - - // Convert to the LLVM IR dialect. - LLVMTypeConverter converter(&getContext()); - OwningRewritePatternList patterns; - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); - populateVectorToLLVMConversionPatterns( - converter, patterns, reassociateFPReductions, enableIndexOptimizations); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); - populateStdToLLVMConversionPatterns(converter, patterns); - - LLVMConversionTarget target(getContext()); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); -} - -std::unique_ptr> -mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { - return std::make_unique(options); -} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -0,0 +1,73 @@ +//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" + +#include "../PassDetail.h" + +#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; + +namespace { +struct LowerVectorToLLVMPass + : public ConvertVectorToLLVMBase { + LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { + this->reassociateFPReductions = options.reassociateFPReductions; + this->enableIndexOptimizations = options.enableIndexOptimizations; + this->enableAVX512 = options.enableAVX512; + } + void runOnOperation() override; +}; +} // namespace + +void LowerVectorToLLVMPass::runOnOperation() { + // Perform progressive lowering of operations on slices and + // all contraction operations. Also applies folding and DCE. + { + OwningRewritePatternList patterns; + populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); + populateVectorSlicesLoweringPatterns(patterns, &getContext()); + populateVectorContractLoweringPatterns(patterns, &getContext()); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + + // Convert to the LLVM IR dialect. + LLVMTypeConverter converter(&getContext()); + OwningRewritePatternList patterns; + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); + populateVectorToLLVMConversionPatterns( + converter, patterns, reassociateFPReductions, enableIndexOptimizations); + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, patterns); + + // Architecture specific augmentations. + LLVMConversionTarget target(getContext()); + if (enableAVX512) { + target.addLegalDialect(); + target.addIllegalDialect(); + populateAVX512ToLLVMConversionPatterns(converter, patterns); + } + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> +mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { + return std::make_unique(options); +} diff --git a/mlir/lib/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Dialect/AVX512/CMakeLists.txt --- a/mlir/lib/Dialect/AVX512/CMakeLists.txt +++ b/mlir/lib/Dialect/AVX512/CMakeLists.txt @@ -10,5 +10,4 @@ LINK_LIBS PUBLIC MLIRIR MLIRSideEffectInterfaces - MLIRVectorToLLVM ) diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -convert-avx512-to-llvm | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-avx512" | mlir-opt | FileCheck %s func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) - -> (vector<16xf32>, vector<8xf64>) + -> (vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>) { // CHECK: llvm_avx512.mask.rndscale.ps.512 %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32> @@ -9,9 +9,10 @@ %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64> // CHECK: llvm_avx512.mask.scalef.ps.512 - %a0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> + %2 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> // CHECK: llvm_avx512.mask.scalef.pd.512 - %a1 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64> + %3 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64> - return %a0, %a1: vector<16xf32>, vector<8xf64> + // Keep results alive. + return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64> }