diff --git a/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h b/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h --- a/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h +++ b/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h @@ -12,6 +12,7 @@ #include namespace mlir { +class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class Pass; @@ -22,6 +23,9 @@ namespace index { void populateIndexToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); + +void registerConvertIndexToLLVMInterface(DialectRegistry ®istry); + } // namespace index } // namespace mlir diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h --- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h +++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h @@ -13,6 +13,7 @@ namespace mlir { +class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class Pass; @@ -23,6 +24,9 @@ void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p = true); + +void registerConvertMathToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H diff --git a/mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h b/mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h --- a/mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h +++ b/mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h @@ -13,6 +13,7 @@ namespace mlir { +class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class Pass; @@ -23,6 +24,8 @@ namespace ub { void populateUBToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); + +void registerConvertUBToLLVMInterface(DialectRegistry ®istry); } // namespace ub } // namespace mlir diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -17,8 +17,11 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Target/LLVM/NVVM/Target.h" @@ -36,8 +39,11 @@ registerConvertComplexToLLVMInterface(registry); cf::registerConvertControlFlowToLLVMInterface(registry); func::registerAllExtensions(registry); + index::registerConvertIndexToLLVMInterface(registry); + registerConvertMathToLLVMInterface(registry); registerConvertMemRefToLLVMInterface(registry); registerConvertNVVMToLLVMInterface(registry); + ub::registerConvertUBToLLVMInterface(registry); registerNVVMTarget(registry); } diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp --- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp +++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" @@ -365,3 +367,32 @@ applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert Index to LLVM. +struct IndexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateIndexToLLVMConversionPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::index::registerConvertIndexToLLVMInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, index::IndexDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -331,3 +332,31 @@ >(converter); // clang-format on } + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert Math to LLVM. +struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateMathToLLVMConversionPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::registerConvertMathToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp --- a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp +++ b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp @@ -8,6 +8,7 @@ #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -91,3 +92,31 @@ RewritePatternSet &patterns) { patterns.add(converter); } + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert UB to LLVM. +struct UBToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::ub::registerConvertUBToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, ub::UBDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir --- a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir +++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir @@ -2,6 +2,10 @@ // RUN: mlir-opt %s -convert-index-to-llvm=index-bitwidth=32 | FileCheck %s --check-prefix=INDEX32 // RUN: mlir-opt %s -convert-index-to-llvm=index-bitwidth=64 | FileCheck %s --check-prefix=INDEX64 +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s + // CHECK-LABEL: @trivial_ops func.func @trivial_ops(%a: index, %b: index) { // CHECK: llvm.add diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(convert-math-to-llvm))" | FileCheck %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s + // CHECK-LABEL: @ops func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) { // CHECK: = llvm.intr.exp(%{{.*}}) : (f32) -> f32 diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir --- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir +++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-ub-to-llvm))" %s -split-input-file | FileCheck %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s + // CHECK-LABEL: @check_poison func.func @check_poison() { // CHECK: {{.*}} = llvm.mlir.poison : i64