diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h @@ -0,0 +1,54 @@ +//===- ToLLVMInterface.h - Conversion to LLVM iface ---*- C++ -*-=============// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H +#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H + +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +class ConversionTarget; +class LLVMTypeConverter; +class MLIRContext; +class Operation; +class RewritePatternSet; + +/// Base class for dialect interfaces providing translation to LLVM IR. +/// Dialects that can be translated should provide an implementation of this +/// interface for the supported operations. The interface may be implemented in +/// a separate library to avoid the "main" dialect library depending on LLVM IR. +/// The interface can be attached using the delayed registration mechanism +/// available in DialectRegistry. +class ConvertToLLVMPatternInterface + : public DialectInterface::Base { +public: + ConvertToLLVMPatternInterface(Dialect *dialect) : Base(dialect) {} + + /// Hook for derived dialect interface to load the dialects they + /// target. The LLVMDialect is implicitly already loaded, but this + /// method allows to load other intermediate dialects used in the + /// conversion, or target dialects like NVVM for example. + virtual void loadDependentDialects(MLIRContext *context) const {} + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + virtual void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, RewritePatternSet &patterns) const = 0; +}; + +/// Recursively walk the IR and collect all dialects implementing the interface, +/// and populate the conversion patterns. +void populateConversionTargetFromOperation(Operation *op, + ConversionTarget &target, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h @@ -0,0 +1,23 @@ +//===- ToLLVMPass.h - Conversion to LLVM pass ---*- C++ -*-===================// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H +#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H + +#include + +namespace mlir { +class Pass; + +/// Create a pass that performs dialect conversion to LLVM for all dialects +/// implementing `ConvertToLLVMPatternInterface`. +std::unique_ptr createConvertToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H diff --git a/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h --- a/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h +++ b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h @@ -11,7 +11,7 @@ #include namespace mlir { - +class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; class Pass; @@ -21,6 +21,8 @@ void populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns); +void registerConvertNVVMToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -24,6 +24,7 @@ #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -11,6 +11,20 @@ include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// ToLLVM +//===----------------------------------------------------------------------===// + +def ConvertToLLVMPass : Pass<"convert-to-llvm"> { + let summary = "Convert to LLVM via dialect interfaces found in the input IR"; + let description = [{ + This is a generic pass to convert to LLVM, it uses the + `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects + the injection of conversion patterns. + }]; +} + //===----------------------------------------------------------------------===// // AffineToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -44,7 +44,8 @@ virtual ~DialectExtensionBase(); /// Return the dialects that our required by this extension to be loaded - /// before applying. + /// before applying. If empty then the extension is invoked for every loaded + /// dialect indepently. ArrayRef getRequiredDialects() const { return dialectNames; } /// Apply this extension to the given context and the required dialects. @@ -55,12 +56,11 @@ virtual std::unique_ptr clone() const = 0; protected: - /// Initialize the extension with a set of required dialects. Note that there - /// should always be at least one affected dialect. + /// Initialize the extension with a set of required dialects. + /// If the list is empty, the extension is invoked for every loaded dialect + /// independently. DialectExtensionBase(ArrayRef dialectNames) - : dialectNames(dialectNames.begin(), dialectNames.end()) { - assert(!dialectNames.empty() && "expected at least one affected dialect"); - } + : dialectNames(dialectNames.begin(), dialectNames.end()) {} private: /// The names of the dialects affected by this extension. diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -14,6 +14,7 @@ #ifndef MLIR_INITALLEXTENSIONS_H_ #define MLIR_INITALLEXTENSIONS_H_ +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include @@ -27,6 +28,7 @@ /// pipelines and transformations you are using. inline void registerAllExtensions(DialectRegistry ®istry) { func::registerAllExtensions(registry); + registerConvertNVVMToLLVMInterface(registry); } } // namespace mlir diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(ComplexToStandard) add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSPIRV) +add_subdirectory(ConvertToLLVM) add_subdirectory(FuncToLLVM) add_subdirectory(FuncToSPIRV) add_subdirectory(GPUCommon) diff --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt @@ -0,0 +1,27 @@ +set(LLVM_OPTIONAL_SOURCES + ConvertToLLVMPass.cpp + ToLLVMInterface.cpp +) + +add_llvm_library(MLIRConvertToLLVMInterface + ToLLVMInterface.cpp + + DEPENDS + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport +) + +add_mlir_conversion_library(MLIRConvertToLLVMPass + ConvertToLLVMPass.cpp + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRConvertToLLVMInterface + MLIRPass + MLIRIR + MLIRSupport + ) diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -0,0 +1,101 @@ +//===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/DialectConversion.h" +#include + +#define DEBUG_TYPE "convert-to-llvm" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +/// This DialectExtension can be attached to the context, which will invoke the +/// `apply()` method for every loaded dialect. If a dialect implements the +/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects +/// through the interface. This extension is loaded in the context before +/// starting a pass pipeline that involves dialect conversion to LLVM. +class LoadDependentDialectExtension : public DialectExtensionBase { +public: + LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {} + + void apply(MLIRContext *context, + MutableArrayRef dialects) const final { + LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n"); + for (Dialect *dialect : dialects) { + auto iface = dyn_cast(dialect); + if (!iface) + continue; + LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for " + << dialect->getNamespace() << "\n"); + iface->loadDependentDialects(context); + } + } + + /// Return a copy of this extension. + virtual std::unique_ptr clone() const final { + return std::make_unique(*this); + } +}; + +/// This is a generic pass to convert to LLVM, it uses the +/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects +/// the injection of conversion patterns. +class ConvertToLLVMPass + : public impl::ConvertToLLVMPassBase { + std::shared_ptr patterns; + std::shared_ptr target; + +public: + using impl::ConvertToLLVMPassBase::ConvertToLLVMPassBase; + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + registry.addExtension(std::make_unique()); + } + + ConvertToLLVMPass(const ConvertToLLVMPass &other) + : ConvertToLLVMPassBase(other), patterns(other.patterns), + target(other.target) {} + + LogicalResult initialize(MLIRContext *context) final { + RewritePatternSet tempPatterns(context); + auto target = std::make_shared(*context); + target->addLegalDialect(); + for (Dialect *dialect : context->getLoadedDialects()) { + // First time we encounter this dialect: if it implements the interface, + // let's populate patterns ! + auto iface = dyn_cast(dialect); + if (!iface) + continue; + iface->populateConvertToLLVMConversionPatterns(*target, tempPatterns); + } + patterns = + std::make_unique(std::move(tempPatterns)); + this->target = target; + return success(); + } + + void runOnOperation() final { + if (failed(applyPartialConversion(getOperation(), *target, *patterns))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp @@ -0,0 +1,31 @@ +//===- ToLLVMInterface.cpp - MLIR LLVM Conversion -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/DenseSet.h" + +using namespace mlir; + +void mlir::populateConversionTargetFromOperation(Operation *root, + ConversionTarget &target, + RewritePatternSet &patterns) { + DenseSet dialects; + root->walk([&](Operation *op) { + Dialect *dialect = op->getDialect(); + if (!dialects.insert(dialect).second) + return; + // First time we encounter this dialect: if it implements the interface, + // let's populate patterns ! + auto iface = dyn_cast(dialect); + if (!iface) + return; + iface->populateConvertToLLVMConversionPatterns(target, patterns); + }); +} diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -190,8 +191,29 @@ } }; +/// Implement the interface to convert NNVM to LLVM. +struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, RewritePatternSet &patterns) const final { + populateNVVMToLLVMConversionPatterns(patterns); + } +}; + } // namespace void mlir::populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } + +void mlir::registerConvertNVVMToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -721,6 +722,7 @@ // Support unknown operations because not all NVVM operations are // registered. allowUnknownOperations(); + declarePromisedInterface(); } LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -209,6 +209,11 @@ // Functor used to try to apply the given extension. auto applyExtension = [&](const DialectExtensionBase &extension) { ArrayRef dialectNames = extension.getRequiredDialects(); + // An empty set is equivalent to always invoke. + if (dialectNames.empty()) { + extension.apply(ctx, dialect); + return; + } // Handle the simple case of a single dialect name. In this case, the // required dialect should be the current dialect. @@ -251,6 +256,11 @@ // Functor used to try to apply the given extension. auto applyExtension = [&](const DialectExtensionBase &extension) { ArrayRef dialectNames = extension.getRequiredDialects(); + if (dialectNames.empty()) { + auto loadedDialects = ctx->getLoadedDialects(); + extension.apply(ctx, loadedDialects); + return; + } // Check to see if all of the dialects for this extension are loaded. SmallVector requiredDialects; diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -1,4 +1,7 @@ // RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s +// Same below, but using the `ConvertToLLVMPatternInterface` entry point +// and the generic `convert-to-llvm` pass. +// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s // CHECK-LABEL : @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {