diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h --- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h +++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h @@ -40,13 +40,15 @@ /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. virtual void populateConvertToLLVMConversionPatterns( - ConversionTarget &target, RewritePatternSet &patterns) const = 0; + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const = 0; }; /// Recursively walk the IR and collect all dialects implementing the interface, /// and populate the conversion patterns. void populateConversionTargetFromOperation(Operation *op, ConversionTarget &target, + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h --- a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h +++ b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h @@ -12,6 +12,7 @@ #include namespace mlir { +class DialectRegistry; class Pass; class LLVMTypeConverter; class RewritePatternSet; @@ -23,6 +24,9 @@ /// MemRef dialect to the LLVM dialect. void populateFinalizeMemRefToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns); + +void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_CONVERSION_MEMREFTOLLVM_MEMREFTOLLVM_H diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -14,6 +14,7 @@ #ifndef MLIR_INITALLEXTENSIONS_H_ #define MLIR_INITALLEXTENSIONS_H_ +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Target/LLVM/NVVM/Target.h" @@ -29,6 +30,7 @@ /// pipelines and transformations you are using. inline void registerAllExtensions(DialectRegistry ®istry) { func::registerAllExtensions(registry); + registerConvertMemRefToLLVMInterface(registry); registerConvertNVVMToLLVMInterface(registry); registerNVVMTarget(registry); } diff --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt @@ -22,6 +22,7 @@ LINK_LIBS PUBLIC MLIRConvertToLLVMInterface MLIRIR + MLIRLLVMCommonConversion MLIRLLVMDialect MLIRPass MLIRRewrite diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -62,6 +63,7 @@ : public impl::ConvertToLLVMPassBase { std::shared_ptr patterns; std::shared_ptr target; + std::shared_ptr typeConverter; public: using impl::ConvertToLLVMPassBase::ConvertToLLVMPassBase; @@ -72,23 +74,26 @@ ConvertToLLVMPass(const ConvertToLLVMPass &other) : ConvertToLLVMPassBase(other), patterns(other.patterns), - target(other.target) {} + target(other.target), typeConverter(other.typeConverter) {} LogicalResult initialize(MLIRContext *context) final { RewritePatternSet tempPatterns(context); auto target = std::make_shared(*context); target->addLegalDialect(); + auto typeConverter = std::make_shared(context); for (Dialect *dialect : context->getLoadedDialects()) { // First time we encounter this dialect: if it implements the interface, // let's populate patterns ! auto iface = dyn_cast(dialect); if (!iface) continue; - iface->populateConvertToLLVMConversionPatterns(*target, tempPatterns); + iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter, + tempPatterns); } patterns = std::make_unique(std::move(tempPatterns)); this->target = target; + this->typeConverter = typeConverter; return success(); } diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp --- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp @@ -13,9 +13,9 @@ using namespace mlir; -void mlir::populateConversionTargetFromOperation(Operation *root, - ConversionTarget &target, - RewritePatternSet &patterns) { +void mlir::populateConversionTargetFromOperation( + Operation *root, ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) { DenseSet dialects; root->walk([&](Operation *op) { Dialect *dialect = op->getDialect(); @@ -26,6 +26,7 @@ auto iface = dyn_cast(dialect); if (!iface) return; - iface->populateConvertToLLVMConversionPatterns(target, patterns); + iface->populateConvertToLLVMConversionPatterns(target, typeConverter, + patterns); }); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -1935,4 +1936,27 @@ signalPassFailure(); } }; + +/// Implement the interface to convert MemRef to LLVM. +struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); + } +}; + } // namespace + +void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -201,7 +201,8 @@ /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( - ConversionTarget &target, RewritePatternSet &patterns) const final { + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { populateNVVMToLLVMConversionPatterns(patterns); } }; diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -1,6 +1,12 @@ // RUN: mlir-opt -finalize-memref-to-llvm='use-opaque-pointers=1' %s -split-input-file | FileCheck %s // RUN: mlir-opt -finalize-memref-to-llvm='index-bitwidth=32 use-opaque-pointers=1' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. This produces slightly different IR +// because the conversion target is set up differently. Only one test case is +// checked. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck --check-prefix=CHECK-INTERFACE %s + // CHECK-LABEL: func @view( // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) { @@ -88,6 +94,10 @@ // CHECK-LABEL: func @view_empty_memref( // CHECK: %[[ARG0:.*]]: index, // CHECK: %[[ARG1:.*]]: memref<0xi8>) + +// CHECK-INTERFACE-LABEL: func @view_empty_memref( +// CHECK-INTERFACE: %[[ARG0:.*]]: index, +// CHECK-INTERFACE: %[[ARG1:.*]]: memref<0xi8>) func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) { // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -101,6 +111,18 @@ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.mlir.constant(4 : index) : i64 // CHECK: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + // CHECK-INTERFACE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-INTERFACE: llvm.mlir.constant(0 : index) : i64 + // CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-INTERFACE: llvm.mlir.constant(4 : index) : i64 + // CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-INTERFACE: llvm.mlir.constant(1 : index) : i64 + // CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-INTERFACE: llvm.mlir.constant(0 : index) : i64 + // CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-INTERFACE: llvm.mlir.constant(4 : index) : i64 + // CHECK-INTERFACE: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %0 = memref.view %mem[%offset][] : memref<0xi8> to memref<0x4xf32> return