diff --git a/mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h b/mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h --- a/mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h +++ b/mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h @@ -13,6 +13,7 @@ namespace mlir { +class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class Pass; @@ -23,6 +24,8 @@ namespace arith { void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); + +void registerConvertArithToLLVMInterface(DialectRegistry ®istry); } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h --- a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h +++ b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h @@ -14,6 +14,7 @@ namespace mlir { class ConversionTarget; +class DialectRegistry; class Pass; class MLIRContext; class TypeConverter; @@ -34,6 +35,8 @@ TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); +void registerConvertAsyncToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H diff --git a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h --- a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h +++ b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h @@ -11,6 +11,7 @@ #include "mlir/Conversion/LLVMCommon/StructBuilder.h" namespace mlir { +class DialectRegistry; class LLVMTypeConverter; class Pass; class RewritePatternSet; @@ -40,6 +41,9 @@ /// Populate the given list with patterns that convert from Complex to LLVM. void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); + +void registerConvertComplexToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h --- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h +++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h @@ -16,6 +16,7 @@ #include namespace mlir { +class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class Pass; @@ -37,6 +38,9 @@ void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure = true); + +void registerConvertControlFlowToLLVMInterface(DialectRegistry ®istry); + } // namespace cf } // namespace mlir diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -14,6 +14,10 @@ #ifndef MLIR_INITALLEXTENSIONS_H_ #define MLIR_INITALLEXTENSIONS_H_ +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -29,6 +33,10 @@ /// individually register the specific extensions that are useful for the /// pipelines and transformations you are using. inline void registerAllExtensions(DialectRegistry ®istry) { + arith::registerConvertArithToLLVMInterface(registry); + registerConvertAsyncToLLVMInterface(registry); + registerConvertComplexToLLVMInterface(registry); + cf::registerConvertControlFlowToLLVMInterface(registry); func::registerAllExtensions(registry); registerConvertMemRefToLLVMInterface(registry); registerConvertNVVMToLLVMInterface(registry); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -451,6 +452,35 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert MemRef to LLVM. +struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::arith::registerConvertArithToLLVMInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + dialect->addInterfaces(); + }); +} + //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -8,6 +8,7 @@ #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -1251,3 +1252,32 @@ target.addDynamicallyLegalOp( [&](Operation *op) { return typeConverter.isLegal(op); }); } + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert MemRef to LLVM. +struct AsyncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateAsyncStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + } +}; +} // namespace + +void mlir::registerConvertAsyncToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, AsyncDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -8,6 +8,7 @@ #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -342,3 +343,32 @@ applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert MemRef to LLVM. +struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateComplexToLLVMConversionPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::registerConvertComplexToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *ctx, complex::ComplexDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -287,3 +288,34 @@ } }; } // namespace + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert MemRef to LLVM. +struct ControlFlowToLLVMDialectInterface + : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); + } +}; +} // namespace + +void mlir::cf::registerConvertControlFlowToLLVMInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-arith-to-llvm))" %s -split-input-file | FileCheck %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s + // CHECK-LABEL: @vector_ops func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>, %arg3: vector<4xi64>) -> vector<4xf32> { // CHECK-NEXT: %0 = llvm.mlir.constant(dense<4.200000e+01> : vector<4xf32>) : vector<4xf32> diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt -async-to-async-runtime --convert-to-llvm --split-input-file %s + // CHECK-LABEL: reference_counting func.func @reference_counting(%arg0: !async.token) { // CHECK: %[[C2:.*]] = arith.constant 2 : i64 diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -convert-complex-to-llvm | FileCheck %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s + // CHECK-LABEL: func @complex_create // CHECK-SAME: (%[[REAL0:.*]]: f32, %[[IMAG0:.*]]: f32) // CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)> diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir --- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir +++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -convert-cf-to-llvm='use-opaque-pointers=1' | FileCheck %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s + func.func @main() { %a = arith.constant 0 : i1 cf.assert %a, "assertion foo"