diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -378,7 +378,8 @@ # !if(!gt(numResults, 0), "$res = inst;", ""); string mlirBuilder = [{ - FailureOr> mlirOperands = convertValues(llvmOperands); + FailureOr> mlirOperands = + moduleImport.convertValues(llvmOperands); if (failed(mlirOperands)) return failure(); SmallVector resultTypes = @@ -386,7 +387,7 @@ auto op = $_builder.create<$_qualCppClassName>( $_location, resultTypes, *mlirOperands); }] # !if(!gt(requiresFastmath, 0), - "setFastmathFlagsAttr(inst, op);", "") + "moduleImport.setFastmathFlagsAttr(inst, op);", "") # !if(!gt(numResults, 0), "$res = op;", "(void)op;"); } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -62,7 +62,7 @@ let arguments = !con(commonArgs, fmfArg); string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); - setFastmathFlagsAttr(inst, op); + moduleImport.setFastmathFlagsAttr(inst, op); $res = op; }]; } @@ -82,7 +82,7 @@ string llvmInstName = instName; string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>($_location, $operand); - setFastmathFlagsAttr(inst, op); + moduleImport.setFastmathFlagsAttr(inst, op); $res = op; }]; } @@ -157,7 +157,7 @@ auto *fCmpInst = cast(inst); auto op = $_builder.create<$_qualCppClassName>( $_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs); - setFastmathFlagsAttr(inst, op); + moduleImport.setFastmathFlagsAttr(inst, op); $res = op; }]; // Set the $predicate index to -1 to indicate there is no matching operand @@ -227,7 +227,8 @@ // FIXME: Import attributes. string mlirBuilder = [{ auto *allocaInst = cast(inst); - Type allocatedType = convertType(allocaInst->getAllocatedType()); + Type allocatedType = + moduleImport.convertType(allocaInst->getAllocatedType()); unsigned alignment = allocaInst->getAlign().value(); $res = $_builder.create( $_location, $_resultType, allocatedType, $arraySize, alignment); @@ -825,7 +826,8 @@ builder.CreateRetVoid(); }]; string mlirBuilder = [{ - FailureOr> mlirOperands = convertValues(llvmOperands); + FailureOr> mlirOperands = + moduleImport.convertValues(llvmOperands); if (failed(mlirOperands)) return failure(); $_builder.create($_location, *mlirOperands); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -17,6 +17,7 @@ #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMImport.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" @@ -40,6 +41,13 @@ registerROCDLDialectTranslation(registry); registerX86VectorDialectTranslation(registry); } + +/// Registers all dialects that can be translated from LLVM IR and the +/// corresponding translation interfaces. +static inline void +registerAllFromLLVMIRTranslations(DialectRegistry ®istry) { + registerLLVMDialectImport(registry); +} } // namespace mlir #endif // MLIR_TARGET_LLVMIR_DIALECT_ALL_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMImport.h b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMImport.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMImport.h @@ -0,0 +1,31 @@ +//===- LLVMIRToLLVMImport.h - LLVM Dialect to LLVM IR------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides registration calls for the LLVM IR to LLVM dialect import. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMIRTOLLVMIMPORT_H +#define MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMIRTOLLVMIMPORT_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Registers the LLVM dialect and its import from LLVM IR in the given +/// registry. +void registerLLVMDialectImport(DialectRegistry ®istry); + +/// Registers the LLVM dialect and its import from LLVM IR with the given +/// context. +void registerLLVMDialectImport(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMIRTOLLVMIMPORT_H diff --git a/mlir/include/mlir/Target/LLVMIR/Import.h b/mlir/include/mlir/Target/LLVMIR/Import.h --- a/mlir/include/mlir/Target/LLVMIR/Import.h +++ b/mlir/include/mlir/Target/LLVMIR/Import.h @@ -30,10 +30,11 @@ class MLIRContext; class ModuleOp; -/// Convert the given LLVM module into MLIR's LLVM dialect. The LLVM context is -/// extracted from the registered LLVM IR dialect. In case of error, report it -/// to the error handler registered with the MLIR context, if any (obtained from -/// the MLIR module), and return `{}`. +/// Translates the LLVM module into an MLIR module living in the the given +/// context. The translation supports operations from any dialect that has a +/// registered implementation of the LLVMImportDialectInterface. It returns +/// nullptr if the translation fails and reports errors using the error handler +/// registered with the MLIR context. OwningOpRef translateLLVMIRToModule(std::unique_ptr llvmModule, MLIRContext *context); diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h @@ -0,0 +1,117 @@ +//===- LLVMImportInterface.h - Import from LLVM interface -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines dialect interfaces for the LLVM IR import. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_LLVMIMPORTINTERFACE_H +#define MLIR_TARGET_LLVMIR_LLVMIMPORTINTERFACE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/Location.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/FormatVariadic.h" + +namespace llvm { +class IRBuilderBase; +} // namespace llvm + +namespace mlir { +namespace LLVM { +class ModuleImport; +} // namespace LLVM + +/// Base class for dialect interfaces used to import LLVM IR. Dialects that can +/// be imported should provide an implementation of this interface for the +/// supported intrinsics. The interface may be implemented in a separate library +/// to avoid the "main" dialect library depending on LLVM IR. The interface can +/// be attached using the delayed registration mechanism available in +/// DialectRegistry. +class LLVMImportDialectInterface + : public DialectInterface::Base { +public: + LLVMImportDialectInterface(Dialect *dialect) : Base(dialect) {} + + /// Hook for derived dialect interfaces to implement the import of + /// intrinsics into MLIR. + virtual LogicalResult + convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst, + LLVM::ModuleImport &moduleImport) const { + return failure(); + } + + /// Hook for derived dialect interfaces to publish the supported intrinsics. + virtual SmallVector getSupportedIntrinsics() const { return {}; } +}; + +/// Interface collection for the import of LLVM IR that dispatches to a concrete +/// dialect interface implementation. Queries the dialect interfaces to obtain a +/// list of the supported LLVM IR constructs and then builds a mapping for the +/// efficient dispatch. +class LLVMImportInterface + : public DialectInterfaceCollection { +public: + using Base::Base; + + /// Queries all dialect interfaces to build a map from intrinsic identifiers + /// to the dialect interface that supports importing the intrinsic. Returns + /// failure if the mapping is not unique. + LogicalResult querySupportedIntrinsics() { + for (const LLVMImportDialectInterface &iface : *this) { + for (unsigned id : iface.getSupportedIntrinsics()) { + if (intrinsicToDialect.count(id)) { + Location loc = UnknownLoc::get(iface.getContext()); + return emitError( + loc, llvm::formatv( + "found that the {0} and {1} dialect interfaces provide " + "conflicting conversions for the intrinsic {2}", + iface.getDialect()->getNamespace(), + intrinsicToDialect.lookup(id)->getNamespace(), id)); + } + intrinsicToDialect[id] = iface.getDialect(); + } + } + + return success(); + } + + /// Converts the LLVM intrinsic to an MLIR operation if a conversion exists. + /// Returns failure otherwise. + LogicalResult convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst, + LLVM::ModuleImport &moduleImport) const { + // Lookup the dialect interface for the given intrinsic. + Dialect *dialect = intrinsicToDialect.lookup(inst->getIntrinsicID()); + if (!dialect) + return failure(); + + // Dispatch the conversion to the dialect interface. + const LLVMImportDialectInterface *iface = getInterfaceFor(dialect); + assert(iface && "expected to find a dialect interface"); + return iface->convertIntrinsic(builder, inst, moduleImport); + } + + /// Returns true if the given LLVM IR intrinsic is convertible to an MLIR + /// operation. + bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) { + return intrinsicToDialect.count(id); + } + +private: + DenseMap intrinsicToDialect; +}; + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_LLVMIMPORTINTERFACE_H diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Target/LLVMIR/Import.h" +#include "mlir/Target/LLVMIR/LLVMImportInterface.h" #include "mlir/Target/LLVMIR/TypeFromLLVM.h" namespace llvm { @@ -44,6 +45,17 @@ public: ModuleImport(ModuleOp mlirModule, std::unique_ptr llvmModule); + /// Queries the dialect interfaces for the supported LLVM IR intrinsics. + LogicalResult querySupportedIntrinsics() { + return iface.querySupportedIntrinsics(); + } + + /// Converts all functions of the LLVM module to MLIR functions. + LogicalResult convertFunctions(); + + /// Converts all global variables of the LLVM module to MLIR global variables. + LogicalResult convertGlobals(); + /// Stores the mapping between an LLVM value and its MLIR counterpart. void mapValue(llvm::Value *llvm, Value mlir) { mapValue(llvm) = mlir; } @@ -95,16 +107,6 @@ return typeTranslator.translateType(type); } - /// Converts an LLVM intrinsic to an MLIR LLVM dialect operation if an MLIR - /// counterpart exists. Otherwise, returns failure. - LogicalResult convertIntrinsic(OpBuilder &odsBuilder, llvm::CallInst *inst, - llvm::Intrinsic::ID intrinsicID); - - /// Converts an LLVM instruction to an MLIR LLVM dialect operation if an MLIR - /// counterpart exists. Otherwise, returns failure. - LogicalResult convertOperation(OpBuilder &odsBuilder, - llvm::Instruction *inst); - /// Imports `func` into the current module. LogicalResult processFunction(llvm::Function *func); @@ -115,11 +117,10 @@ /// Imports `globalVar` as a GlobalOp, creating it if it doesn't exist. GlobalOp processGlobal(llvm::GlobalVariable *globalVar); - /// Converts all functions of the LLVM module to MLIR functions. - LogicalResult convertFunctions(); - - /// Converts all global variables of the LLVM module to MLIR global variables. - LogicalResult convertGlobals(); + /// Sets the fastmath flags attribute for the imported operation `op` given + /// the original instruction `inst`. Asserts if the operation does not + /// implement the fastmath interface. + void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const; private: /// Clears the block and value mapping before processing a new region. @@ -133,14 +134,17 @@ constantInsertionOp = nullptr; } - /// Sets the fastmath flags attribute for the imported operation `op` given - /// the original instruction `inst`. Asserts if the operation does not - /// implement the fastmath interface. - void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const; /// Returns personality of `func` as a FlatSymbolRefAttr. FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func); /// Imports `bb` into `block`, which must be initially empty. LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block); + /// Converts an LLVM intrinsic to an MLIR LLVM dialect operation if an MLIR + /// counterpart exists. Otherwise, returns failure. + LogicalResult convertIntrinsic(OpBuilder &odsBuilder, llvm::CallInst *inst); + /// Converts an LLVM instruction to an MLIR LLVM dialect operation if an MLIR + /// counterpart exists. Otherwise, returns failure. + LogicalResult convertInstruction(OpBuilder &odsBuilder, + llvm::Instruction *inst); /// Imports `inst` and populates valueMapping[inst] with the result of the /// imported operation. LogicalResult processInstruction(llvm::Instruction *inst); @@ -192,6 +196,10 @@ /// The LLVM module being imported. std::unique_ptr llvmModule; + /// A dialect interface collection used for dispatching the import to specific + /// dialects. + LLVMImportInterface iface; + /// Function-local mapping between original and imported block. DenseMap blockMapping; /// Function-local mapping between original and imported values. diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h --- a/mlir/include/mlir/Tools/mlir-translate/Translation.h +++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h @@ -104,14 +104,20 @@ TranslateToMLIRRegistration( llvm::StringRef name, llvm::StringRef description, const TranslateSourceMgrToMLIRFunction &function, + const std::function &dialectRegistration = + [](DialectRegistry &) {}, Optional inputAlignment = std::nullopt); TranslateToMLIRRegistration( llvm::StringRef name, llvm::StringRef description, const TranslateRawSourceMgrToMLIRFunction &function, + const std::function &dialectRegistration = + [](DialectRegistry &) {}, Optional inputAlignment = std::nullopt); TranslateToMLIRRegistration( llvm::StringRef name, llvm::StringRef description, const TranslateStringRefToMLIRFunction &function, + const std::function &dialectRegistration = + [](DialectRegistry &) {}, Optional inputAlignment = std::nullopt); }; diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -51,7 +51,6 @@ ) add_mlir_translation_library(MLIRTargetLLVMIRImport - ConvertFromLLVMIR.cpp DebugImporter.cpp ModuleImport.cpp TypeFromLLVM.cpp @@ -68,3 +67,10 @@ MLIRLLVMDialect MLIRTranslateLib ) + +add_mlir_translation_library(MLIRFromLLVMIRTranslationRegistration + ConvertFromLLVMIR.cpp + + LINK_LIBS PUBLIC + MLIRLLVMIRToLLVMImport + ) diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -11,7 +11,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" #include "mlir/Target/LLVMIR/Import.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/IR/Module.h" @@ -39,6 +41,10 @@ return {}; } return translateLLVMIRToModule(std::move(llvmModule), context); + }, + [](DialectRegistry ®istry) { + registry.insert(); + registerAllFromLLVMIRTranslations(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt @@ -1,3 +1,21 @@ +set(LLVM_OPTIONAL_SOURCES + LLVMIRToLLVMImport.cpp + LLVMToLLVMIRTranslation.cpp + ) + +add_mlir_translation_library(MLIRLLVMIRToLLVMImport + LLVMIRToLLVMImport.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRImport + ) + add_mlir_translation_library(MLIRLLVMToLLVMIRTranslation LLVMToLLVMIRTranslation.cpp diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMImport.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMImport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMImport.cpp @@ -0,0 +1,96 @@ +//===- LLVMIRToLLVMImport.cpp - Translate LLVM IR to LLVM dialect ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the import of LLVM IR into the MLIR LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMImport.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/ModuleImport.h" + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" + +using namespace mlir; +using namespace mlir::LLVM; +using namespace mlir::LLVM::detail; + +#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc" + +/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect +/// intrinsic. Returns false otherwise. +static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) { + static const DenseSet convertibleIntrinsics = { +#include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc" + }; + return convertibleIntrinsics.contains(id); +} + +/// Converts the LLVM intrinsic to an MLIR LLVM dialect operation if a +/// conversion exits. Returns failure otherwise. +static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder, + llvm::CallInst *inst, + LLVM::ModuleImport &moduleImport) { + llvm::Intrinsic::ID intrinsicID = inst->getIntrinsicID(); + + // Check if the intrinsic is convertible to an MLIR dialect counterpart and + // copy the arguments to an an LLVM operands array reference for conversion. + if (isConvertibleIntrinsic(intrinsicID)) { + SmallVector args(inst->args()); + ArrayRef llvmOperands(args); +#include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc" + } + + return failure(); +} + +namespace { + +/// Implementation of the dialect interface that converts operations belonging +/// to the LLVM dialect to LLVM IR. +class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface { +public: + using LLVMImportDialectInterface::LLVMImportDialectInterface; + + /// Converts the LLVM intrinsic to an MLIR LLVM dialect operation if a + /// conversion exits. Returns failure otherwise. + LogicalResult convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst, + LLVM::ModuleImport &moduleImport) const final { + return convertIntrinsicImpl(builder, inst, moduleImport); + } + + /// Returns the list of LLVM IR intrinsic identifiers that may be converted to + /// MLIR LLVM dialect operations. + SmallVector getSupportedIntrinsics() const final { + SmallVector supportedInstrinsics = { +#include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc" + }; + return supportedInstrinsics; + } +}; +} // namespace + +void mlir::registerLLVMDialectImport(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void mlir::registerLLVMDialectImport(MLIRContext &context) { + DialectRegistry registry; + registerLLVMDialectImport(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -37,15 +37,6 @@ #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc" -/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect -/// intrinsic, or false if no counterpart exists. -static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) { - static const DenseSet convertibleIntrinsics = { -#include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc" - }; - return convertibleIntrinsics.contains(id); -} - // Utility to print an LLVM value as a string for passing to emitError(). // FIXME: Diagnostic should be able to natively handle types that have // operator << (raw_ostream&) defined. @@ -58,7 +49,7 @@ /// Creates an attribute containing ABI and preferred alignment numbers parsed /// a string. The string may be either "abi:preferred" or just "abi". In the -/// latter case, the prefrred alignment is considered equal to ABI alignment. +/// latter case, the preferred alignment is considered equal to ABI alignment. static DenseIntElementsAttr parseDataLayoutAlignment(MLIRContext &ctx, StringRef spec) { auto i32 = IntegerType::get(&ctx, 32); @@ -320,6 +311,7 @@ std::unique_ptr llvmModule) : builder(mlirModule->getContext()), context(mlirModule->getContext()), mlirModule(mlirModule), llvmModule(std::move(llvmModule)), + iface(mlirModule->getContext()), typeTranslator(*mlirModule->getContext()), debugImporter(std::make_unique(mlirModule->getContext())) { builder.setInsertionPointToStart(mlirModule.getBody()); @@ -807,26 +799,20 @@ } LogicalResult ModuleImport::convertIntrinsic(OpBuilder &odsBuilder, - llvm::CallInst *inst, - llvm::Intrinsic::ID intrinsicID) { - Location loc = translateLoc(inst->getDebugLoc()); - - // Check if the intrinsic is convertible to an MLIR dialect counterpart and - // copy the arguments to an an LLVM operands array reference for conversion. - if (isConvertibleIntrinsic(intrinsicID)) { - SmallVector args(inst->args()); - ArrayRef llvmOperands(args); -#include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc" - } + llvm::CallInst *inst) { + if (succeeded(iface.convertIntrinsic(builder, inst, *this))) + return success(); + Location loc = translateLoc(inst->getDebugLoc()); return emitError(loc) << "unhandled intrinsic " << diag(*inst); } -LogicalResult ModuleImport::convertOperation(OpBuilder &odsBuilder, - llvm::Instruction *inst) { +LogicalResult ModuleImport::convertInstruction(OpBuilder &odsBuilder, + llvm::Instruction *inst) { // Copy the operands to an LLVM operands array reference for conversion. SmallVector operands(inst->operands()); ArrayRef llvmOperands(operands); + ModuleImport &moduleImport = *this; // Convert all instructions that provide an MLIR builder. #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc" @@ -1006,11 +992,11 @@ if (auto *callInst = dyn_cast(inst)) { llvm::Function *callee = callInst->getCalledFunction(); if (callee && callee->isIntrinsic()) - return convertIntrinsic(builder, callInst, callInst->getIntrinsicID()); + return convertIntrinsic(builder, callInst); } // Convert all remaining LLVM instructions to MLIR operations. - return convertOperation(builder, inst); + return convertInstruction(builder, inst); } FlatSymbolRefAttr ModuleImport::getPersonalityAsAttr(llvm::Function *f) { @@ -1049,7 +1035,8 @@ auto functionType = convertType(func->getFunctionType()).dyn_cast(); - if (func->isIntrinsic() && isConvertibleIntrinsic(func->getIntrinsicID())) + if (func->isIntrinsic() && + iface.isConvertibleIntrinsic(func->getIntrinsicID())) return success(); bool dsoLocal = func->hasLocalLinkage(); @@ -1166,6 +1153,8 @@ module.get()->setAttr(DLTIDialect::kDataLayoutAttrName, dlSpec); ModuleImport moduleImport(module.get(), std::move(llvmModule)); + if (failed(moduleImport.querySupportedIntrinsics())) + return {}; if (failed(moduleImport.convertGlobals())) return {}; if (failed(moduleImport.convertFunctions())) diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp --- a/mlir/lib/Tools/mlir-translate/Translation.cpp +++ b/mlir/lib/Tools/mlir-translate/Translation.cpp @@ -73,10 +73,16 @@ // Puts `function` into the to-MLIR translation registry unless there is already // a function registered for the same name. static void registerTranslateToMLIRFunction( - StringRef name, StringRef description, Optional inputAlignment, + StringRef name, StringRef description, + const std::function &dialectRegistration, + Optional inputAlignment, const TranslateSourceMgrToMLIRFunction &function) { - auto wrappedFn = [function](const std::shared_ptr &sourceMgr, - raw_ostream &output, MLIRContext *context) { + auto wrappedFn = [function, dialectRegistration]( + const std::shared_ptr &sourceMgr, + raw_ostream &output, MLIRContext *context) { + DialectRegistry registry; + dialectRegistration(registry); + context->appendDialectRegistry(registry); OwningOpRef op = function(sourceMgr, context); if (!op || failed(verify(*op))) return failure(); @@ -89,15 +95,18 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, const TranslateSourceMgrToMLIRFunction &function, + const std::function &dialectRegistration, Optional inputAlignment) { - registerTranslateToMLIRFunction(name, description, inputAlignment, function); + registerTranslateToMLIRFunction(name, description, dialectRegistration, + inputAlignment, function); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, const TranslateRawSourceMgrToMLIRFunction &function, + const std::function &dialectRegistration, Optional inputAlignment) { registerTranslateToMLIRFunction( - name, description, inputAlignment, + name, description, dialectRegistration, inputAlignment, [function](const std::shared_ptr &sourceMgr, MLIRContext *ctx) { return function(*sourceMgr, ctx); }); } @@ -106,9 +115,10 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, const TranslateStringRefToMLIRFunction &function, + const std::function &dialectRegistration, Optional inputAlignment) { registerTranslateToMLIRFunction( - name, description, inputAlignment, + name, description, dialectRegistration, inputAlignment, [function](const std::shared_ptr &sourceMgr, MLIRContext *ctx) { const llvm::MemoryBuffer *buffer = diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -246,11 +246,11 @@ if (isVariadicOperandName(op, name)) { as << formatv( "FailureOr> _llvmir_gen_operand_{0} = " - "convertValues(llvmOperands.drop_front({1}));\n", + "moduleImport.convertValues(llvmOperands.drop_front({1}));\n", name, idx); } else { as << formatv("FailureOr _llvmir_gen_operand_{0} = " - "convertValue(llvmOperands[{1}]);\n", + "moduleImport.convertValue(llvmOperands[{1}]);\n", name, idx); } as << formatv("if (failed(_llvmir_gen_operand_{0}))\n" @@ -261,15 +261,15 @@ } else if (isResultName(op, name)) { if (op.getNumResults() != 1) return emitError(record, "expected op to have one result"); - bs << "mapValue(inst)"; + bs << "moduleImport.mapValue(inst)"; } else if (name == "_int_attr") { - bs << "matchIntegerAttr"; + bs << "moduleImport.matchIntegerAttr"; } else if (name == "_var_attr") { - bs << "matchLocalVariableAttr"; + bs << "moduleImport.matchLocalVariableAttr"; } else if (name == "_resultType") { - bs << "convertType(inst->getType())"; + bs << "moduleImport.convertType(inst->getType())"; } else if (name == "_location") { - bs << "translateLoc(inst->getDebugLoc())"; + bs << "moduleImport.translateLoc(inst->getDebugLoc())"; } else if (name == "_builder") { bs << "odsBuilder"; } else if (name == "_qualCppClassName") {