diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h --- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h +++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h @@ -11,8 +11,12 @@ #include +#include "mlir/Pass/Pass.h" + +#define GEN_PASS_DECL_CONVERTTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" + namespace mlir { -class Pass; /// Create a pass that performs dialect conversion to LLVM for all dialects /// implementing `ConvertToLLVMPatternInterface`. diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h --- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h +++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h @@ -16,6 +16,7 @@ namespace mlir { +class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; @@ -33,6 +34,8 @@ void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); +void registerConvertFuncToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_CONVERSION_FUNCTOLLVM_CONVERTFUNCTOLLVM_H diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -23,6 +23,12 @@ `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects the injection of conversion patterns. }]; + + let constructor = "mlir::createConvertToLLVMPass()"; + let options = [ + ListOption<"filterDialects", "filter-dialects", "std::string", + "Test conversion patterns of only the specified dialects">, + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -17,6 +17,7 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -39,6 +40,7 @@ registerConvertComplexToLLVMInterface(registry); cf::registerConvertControlFlowToLLVMInterface(registry); func::registerAllExtensions(registry); + registerConvertFuncToLLVMInterface(registry); index::registerConvertIndexToLLVMInterface(registry); registerConvertMathToLLVMInterface(registry); registerConvertMemRefToLLVMInterface(registry); diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -72,25 +72,44 @@ registry.addExtensions(); } - ConvertToLLVMPass(const ConvertToLLVMPass &other) - : ConvertToLLVMPassBase(other), patterns(other.patterns), - target(other.target), typeConverter(other.typeConverter) {} - LogicalResult initialize(MLIRContext *context) final { RewritePatternSet tempPatterns(context); auto target = std::make_shared(*context); target->addLegalDialect(); auto typeConverter = std::make_shared(context); - for (Dialect *dialect : context->getLoadedDialects()) { - // First time we encounter this dialect: if it implements the interface, - // let's populate patterns ! - auto iface = dyn_cast(dialect); - if (!iface) - continue; - iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter, - tempPatterns); + + if (!filterDialects.empty()) { + // Test mode: Populate only patterns from the specified dialects. Produce + // an error if the dialect is not loaded or does not implement the + // interface. + for (std::string &dialectName : filterDialects) { + Dialect *dialect = context->getLoadedDialect(dialectName); + if (!dialect) + return emitError(UnknownLoc::get(context)) + << "dialect not loaded: " << dialectName << "\n"; + auto iface = dyn_cast(dialect); + if (!iface) + return emitError(UnknownLoc::get(context)) + << "dialect does not implement ConvertToLLVMPatternInterface: " + << dialectName << "\n"; + iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter, + tempPatterns); + } + } else { + // Normal mode: Populate all patterns from all dialects that implement the + // interface. + for (Dialect *dialect : context->getLoadedDialects()) { + // First time we encounter this dialect: if it implements the interface, + // let's populate patterns ! + auto iface = dyn_cast(dialect); + if (!iface) + continue; + iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter, + tempPatterns); + } } - patterns = + + this->patterns = std::make_unique(std::move(tempPatterns)); this->target = target; this->typeConverter = typeConverter; @@ -104,3 +123,7 @@ }; } // namespace + +std::unique_ptr mlir::createConvertToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -783,3 +784,27 @@ } }; } // namespace + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert Func to LLVM. +struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateFuncToLLVMConversionPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::registerConvertFuncToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -2,7 +2,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. -// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=arith" --split-input-file %s | FileCheck %s // CHECK-LABEL: @vector_ops func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>, %arg3: vector<4xi64>) -> vector<4xf32> { diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -2,7 +2,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. -// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=complex" --split-input-file %s | FileCheck %s // CHECK-LABEL: func @complex_create // CHECK-SAME: (%[[REAL0:.*]]: f32, %[[IMAG0:.*]]: f32) diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir --- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir +++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir @@ -2,7 +2,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. -// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=cf" --split-input-file %s | FileCheck %s func.func @main() { %a = arith.constant 0 : i1 diff --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir --- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir +++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir @@ -1,9 +1,13 @@ -// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-math-to-llvm,convert-arith-to-llvm),convert-func-to-llvm{use-opaque-pointers=1},reconcile-unrealized-casts)" %s -split-input-file | FileCheck %s +// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-math-to-llvm,convert-arith-to-llvm),convert-func-to-llvm{use-opaque-pointers=1},reconcile-unrealized-casts)" %s | FileCheck %s -// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-math-to-llvm,convert-arith-to-llvm{index-bitwidth=32}),convert-func-to-llvm{index-bitwidth=32 use-opaque-pointers=1},reconcile-unrealized-casts)" %s -split-input-file | FileCheck --check-prefix=CHECK32 %s +// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-math-to-llvm,convert-arith-to-llvm{index-bitwidth=32}),convert-func-to-llvm{index-bitwidth=32 use-opaque-pointers=1},reconcile-unrealized-casts)" %s | FileCheck --check-prefix=CHECK32 %s // RUN: mlir-opt -test-transform-dialect-interpreter %s | FileCheck --check-prefix=CHECK32 %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm="filter-dialects=arith,cf,func,math" %s | FileCheck %s + // CHECK-LABEL: func @empty() { // CHECK-NEXT: llvm.return // CHECK-NEXT: } diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir --- a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir +++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir @@ -4,7 +4,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. -// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=index" --split-input-file %s | FileCheck %s // CHECK-LABEL: @trivial_ops func.func @trivial_ops(%a: index, %b: index) { diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -2,7 +2,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. -// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=math" --split-input-file %s | FileCheck %s // CHECK-LABEL: @ops func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) { diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -5,7 +5,7 @@ // and the generic `convert-to-llvm` pass. This produces slightly different IR // because the conversion target is set up differently. Only one test case is // checked. -// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck --check-prefix=CHECK-INTERFACE %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=memref" --split-input-file %s | FileCheck --check-prefix=CHECK-INTERFACE %s // CHECK-LABEL: func @view( // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt --convert-nvvm-to-llvm --convert-arith-to-llvm --split-input-file %s | FileCheck %s + // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir --- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir +++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir @@ -2,7 +2,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. -// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=ub" --split-input-file %s | FileCheck %s // CHECK-LABEL: @check_poison func.func @check_poison() { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4111,6 +4111,7 @@ ":ComplexToLLVM", ":ControlFlowToLLVM", ":FuncExtensions", + ":FuncToLLVM", ":IndexToLLVM", ":MathToLLVM", ":MemRefToLLVM", @@ -7185,6 +7186,7 @@ ":ArithToLLVM", ":ControlFlowToLLVM", ":ConversionPassIncGen", + ":ConvertToLLVM", ":DataLayoutInterfaces", ":DialectUtils", ":FuncDialect",