diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -189,6 +189,9 @@ } int dumpLLVMIR(mlir::ModuleOp module) { + // Register the translation to LLVM IR with the MLIR context. + mlir::registerLLVMDialectTranslation(*module->getContext()); + // Convert the module to LLVM IR in a new LLVM IR context. llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); @@ -219,6 +222,10 @@ llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + // Register the translation from MLIR to LLVM IR, which must happen before we + // can JIT-compile. + mlir::registerLLVMDialectTranslation(*module->getContext()); + // An optimization pipeline to use within the execution engine. auto optPipeline = mlir::makeOptimizingTransformer( /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -190,6 +190,9 @@ } int dumpLLVMIR(mlir::ModuleOp module) { + // Register the translation to LLVM IR with the MLIR context. + mlir::registerLLVMDialectTranslation(*module->getContext()); + // Convert the module to LLVM IR in a new LLVM IR context. llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); @@ -220,6 +223,10 @@ llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + // Register the translation from MLIR to LLVM IR, which must happen before we + // can JIT-compile. + mlir::registerLLVMDialectTranslation(*module->getContext()); + // An optimization pipeline to use within the execution engine. auto optPipeline = mlir::makeOptimizingTransformer( /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, diff --git a/mlir/include/mlir/Target/LLVMIR.h b/mlir/include/mlir/Target/LLVMIR.h --- a/mlir/include/mlir/Target/LLVMIR.h +++ b/mlir/include/mlir/Target/LLVMIR.h @@ -25,6 +25,7 @@ namespace mlir { +class DialectRegistry; class OwningModuleRef; class MLIRContext; class ModuleOp; @@ -45,6 +46,15 @@ translateLLVMIRToModule(std::unique_ptr llvmModule, MLIRContext *context); +/// Register the LLVM dialect and the translation from it to the LLVM IR in the +/// given registry; +void registerLLVMDialectTranslation(DialectRegistry ®istry); + +/// Register the LLVM dialect and the translation from it in the registry +/// associated with the given context. This checks if the interface is already +/// registered and avoids double registation. +void registerLLVMDialectTranslation(MLIRContext &context); + } // namespace mlir #endif // MLIR_TARGET_LLVMIR_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h @@ -0,0 +1,37 @@ +//===- LLVMToLLVMIRTranslation.h - LLVM Dialect to LLVM IR-------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect interface for translating the LLVM dialect +// to LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H + +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" + +namespace mlir { + +/// Implementation of the dialect interface that converts operations beloning to +/// the LLVM dialect to LLVM IR. +class LLVMDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final; +}; + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMIR_LLVMTOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h @@ -0,0 +1,68 @@ +//===- LLVMTranslationInterface.h - Translation to LLVM iface ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines dialect interfaces for translation to LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H +#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H + +#include "mlir/IR/DialectInterface.h" +#include "mlir/Support/LogicalResult.h" + +namespace llvm { +class IRBuilderBase; +} + +namespace mlir { +namespace LLVM { +class ModuleTranslation; +} // namespace LLVM + +/// Base class for dialect interfaces providing translation to LLVM IR. +/// Dialects that can be translated should provide an implementation of this +/// interface for the supported operations. The interface may be implemented in +/// a separate library to avoid the "main" dialect library depending on LLVM IR. +/// The interface can be attached using the delayed registration mechanism +/// available in DialectRegistry. +class LLVMTranslationDialectInterface + : public DialectInterface::Base { +public: + LLVMTranslationDialectInterface(Dialect *dialect) : Base(dialect) {} + + /// Hook for derived dialect interface to provide translation of the + /// operations to LLVM IR. + virtual LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + return failure(); + } +}; + +/// Interface collection for translation to LLVM IR, dispatches to a concrete +/// interface implementation based on the dialect to which the given op belongs. +class LLVMTranslationInterface + : public DialectInterfaceCollection { +public: + using Base::Base; + + /// Translates the given operation to LLVM IR using the interface implemented + /// by the op's dialect. + virtual LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + if (const LLVMTranslationDialectInterface *iface = getInterfaceFor(op)) + return iface->convertOperation(op, builder, moduleTranslation); + return failure(); + } +}; + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -19,6 +19,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Value.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" @@ -134,6 +135,31 @@ return branchMapping.lookup(op); } + /// Converts the type from MLIR LLVM dialect to LLVM. + llvm::Type *convertType(Type type); + + /// Looks up remapped a list of remapped values. + SmallVector lookupValues(ValueRange values); + + /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. + /// This currently supports integer, floating point, splat and dense element + /// attributes and combinations thereof. In case of error, report it to `loc` + /// and return nullptr. + llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, + Location loc); + + /// Returns the MLIR context of the module being translated. + MLIRContext &getContext() { return *mlirModule->getContext(); } + + /// Returns the LLVM context in which the IR is being constructed. + llvm::LLVMContext &getLLVMContext() { return llvmModule->getContext(); } + + /// Finds an LLVM IR global value that corresponds to the given MLIR operation + /// defining a global value. + llvm::GlobalValue *lookupGlobal(Operation *op) { + return globalsMapping.lookup(op); + } + protected: /// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an /// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an @@ -158,16 +184,10 @@ virtual LogicalResult convertOmpWsLoop(Operation &opInst, llvm::IRBuilder<> &builder); - /// Converts the type from MLIR LLVM dialect to LLVM. - llvm::Type *convertType(Type type); - static std::unique_ptr prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name); - /// A helper to look up remapped operands in the value remapping table. - SmallVector lookupValues(ValueRange values); - private: /// Check whether the module contains only supported ops directly in its body. static LogicalResult checkSupportedModuleOps(Operation *m); @@ -179,9 +199,6 @@ LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilder<> &builder); - llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, - Location loc); - /// Original and translated module. Operation *mlirModule; std::unique_ptr llvmModule; @@ -202,7 +219,8 @@ /// A stateful object used to translate types. TypeToLLVMIRTranslator typeTranslator; -private: + LLVMTranslationInterface iface; + /// Mappings between original and translated values, used for lookups. llvm::StringMap functionMapping; DenseMap valueMapping; diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(SPIRV) +add_subdirectory(LLVMIR) add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation LLVMIR/DebugTranslation.cpp @@ -39,6 +40,7 @@ MLIRIR MLIRLLVMAVX512 MLIRLLVMIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) @@ -54,6 +56,7 @@ IRReader LINK_LIBS PUBLIC + MLIRLLVMToLLVMIRTranslation MLIRTargetLLVMIRModuleTranslation ) @@ -73,6 +76,7 @@ MLIRIR MLIRLLVMArmNeon MLIRLLVMIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) @@ -92,6 +96,7 @@ MLIRIR MLIRLLVMArmSVE MLIRLLVMIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) @@ -112,6 +117,7 @@ MLIRIR MLIRLLVMIR MLIRNVVMIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) @@ -132,5 +138,6 @@ MLIRIR MLIRLLVMIR MLIRROCDLIR + MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation ) diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" @@ -35,6 +36,24 @@ return llvmModule; } +void mlir::registerLLVMDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); +} + +void mlir::registerLLVMDialectTranslation(MLIRContext &context) { + auto *dialect = context.getLoadedDialect(); + if (!dialect || dialect->getRegisteredInterface< + LLVMDialectLLVMIRTranslationInterface>() == nullptr) { + DialectRegistry registry; + registry.insert(); + registry.addDialectInterface(); + context.appendDialectRegistry(registry); + } +} + namespace mlir { void registerToLLVMIRTranslation() { TranslateFromMLIRRegistration registration( @@ -50,7 +69,8 @@ return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" @@ -68,6 +69,11 @@ std::unique_ptr mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { + // Register the translation to LLVM IR if nobody else did before. This may + // happen if this translation is called inside a pass pipeline that converts + // GPU dialects to binary blobs without translating the rest of the code. + registerLLVMDialectTranslation(*m->getContext()); + auto llvmModule = LLVM::ModuleTranslation::translateModule( m, llvmContext, name); if (!llvmModule) @@ -111,7 +117,8 @@ return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" @@ -77,11 +78,16 @@ std::unique_ptr mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { - // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics) + // Register the translation to LLVM IR if nobody else did before. This may + // happen if this translation is called inside a pass pipeline that converts + // GPU dialects to binary blobs without translating the rest of the code. + registerLLVMDialectTranslation(*m->getContext()); + + // Lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics). auto llvmModule = LLVM::ModuleTranslation::translateModule( m, llvmContext, name); - // foreach GPU kernel + // Foreach GPU kernel: // 1. Insert AMDGPU_KERNEL calling convention. // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute. for (auto func : @@ -114,7 +120,8 @@ return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_translation_library(MLIRLLVMToLLVMIRTranslation + LLVMToLLVMIRTranslation.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRModuleTranslation + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -0,0 +1,405 @@ +//===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Operator.h" + +using namespace mlir; +using namespace mlir::LLVM; + +#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" + +/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. +static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) { + switch (p) { + case LLVM::ICmpPredicate::eq: + return llvm::CmpInst::Predicate::ICMP_EQ; + case LLVM::ICmpPredicate::ne: + return llvm::CmpInst::Predicate::ICMP_NE; + case LLVM::ICmpPredicate::slt: + return llvm::CmpInst::Predicate::ICMP_SLT; + case LLVM::ICmpPredicate::sle: + return llvm::CmpInst::Predicate::ICMP_SLE; + case LLVM::ICmpPredicate::sgt: + return llvm::CmpInst::Predicate::ICMP_SGT; + case LLVM::ICmpPredicate::sge: + return llvm::CmpInst::Predicate::ICMP_SGE; + case LLVM::ICmpPredicate::ult: + return llvm::CmpInst::Predicate::ICMP_ULT; + case LLVM::ICmpPredicate::ule: + return llvm::CmpInst::Predicate::ICMP_ULE; + case LLVM::ICmpPredicate::ugt: + return llvm::CmpInst::Predicate::ICMP_UGT; + case LLVM::ICmpPredicate::uge: + return llvm::CmpInst::Predicate::ICMP_UGE; + } + llvm_unreachable("incorrect comparison predicate"); +} + +static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) { + switch (p) { + case LLVM::FCmpPredicate::_false: + return llvm::CmpInst::Predicate::FCMP_FALSE; + case LLVM::FCmpPredicate::oeq: + return llvm::CmpInst::Predicate::FCMP_OEQ; + case LLVM::FCmpPredicate::ogt: + return llvm::CmpInst::Predicate::FCMP_OGT; + case LLVM::FCmpPredicate::oge: + return llvm::CmpInst::Predicate::FCMP_OGE; + case LLVM::FCmpPredicate::olt: + return llvm::CmpInst::Predicate::FCMP_OLT; + case LLVM::FCmpPredicate::ole: + return llvm::CmpInst::Predicate::FCMP_OLE; + case LLVM::FCmpPredicate::one: + return llvm::CmpInst::Predicate::FCMP_ONE; + case LLVM::FCmpPredicate::ord: + return llvm::CmpInst::Predicate::FCMP_ORD; + case LLVM::FCmpPredicate::ueq: + return llvm::CmpInst::Predicate::FCMP_UEQ; + case LLVM::FCmpPredicate::ugt: + return llvm::CmpInst::Predicate::FCMP_UGT; + case LLVM::FCmpPredicate::uge: + return llvm::CmpInst::Predicate::FCMP_UGE; + case LLVM::FCmpPredicate::ult: + return llvm::CmpInst::Predicate::FCMP_ULT; + case LLVM::FCmpPredicate::ule: + return llvm::CmpInst::Predicate::FCMP_ULE; + case LLVM::FCmpPredicate::une: + return llvm::CmpInst::Predicate::FCMP_UNE; + case LLVM::FCmpPredicate::uno: + return llvm::CmpInst::Predicate::FCMP_UNO; + case LLVM::FCmpPredicate::_true: + return llvm::CmpInst::Predicate::FCMP_TRUE; + } + llvm_unreachable("incorrect comparison predicate"); +} + +static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) { + switch (op) { + case LLVM::AtomicBinOp::xchg: + return llvm::AtomicRMWInst::BinOp::Xchg; + case LLVM::AtomicBinOp::add: + return llvm::AtomicRMWInst::BinOp::Add; + case LLVM::AtomicBinOp::sub: + return llvm::AtomicRMWInst::BinOp::Sub; + case LLVM::AtomicBinOp::_and: + return llvm::AtomicRMWInst::BinOp::And; + case LLVM::AtomicBinOp::nand: + return llvm::AtomicRMWInst::BinOp::Nand; + case LLVM::AtomicBinOp::_or: + return llvm::AtomicRMWInst::BinOp::Or; + case LLVM::AtomicBinOp::_xor: + return llvm::AtomicRMWInst::BinOp::Xor; + case LLVM::AtomicBinOp::max: + return llvm::AtomicRMWInst::BinOp::Max; + case LLVM::AtomicBinOp::min: + return llvm::AtomicRMWInst::BinOp::Min; + case LLVM::AtomicBinOp::umax: + return llvm::AtomicRMWInst::BinOp::UMax; + case LLVM::AtomicBinOp::umin: + return llvm::AtomicRMWInst::BinOp::UMin; + case LLVM::AtomicBinOp::fadd: + return llvm::AtomicRMWInst::BinOp::FAdd; + case LLVM::AtomicBinOp::fsub: + return llvm::AtomicRMWInst::BinOp::FSub; + } + llvm_unreachable("incorrect atomic binary operator"); +} + +static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) { + switch (ordering) { + case LLVM::AtomicOrdering::not_atomic: + return llvm::AtomicOrdering::NotAtomic; + case LLVM::AtomicOrdering::unordered: + return llvm::AtomicOrdering::Unordered; + case LLVM::AtomicOrdering::monotonic: + return llvm::AtomicOrdering::Monotonic; + case LLVM::AtomicOrdering::acquire: + return llvm::AtomicOrdering::Acquire; + case LLVM::AtomicOrdering::release: + return llvm::AtomicOrdering::Release; + case LLVM::AtomicOrdering::acq_rel: + return llvm::AtomicOrdering::AcquireRelease; + case LLVM::AtomicOrdering::seq_cst: + return llvm::AtomicOrdering::SequentiallyConsistent; + } + llvm_unreachable("incorrect atomic ordering"); +} + +static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { + using llvmFMF = llvm::FastMathFlags; + using FuncT = void (llvmFMF::*)(bool); + const std::pair handlers[] = { + // clang-format off + {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, + {FastmathFlags::ninf, &llvmFMF::setNoInfs}, + {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, + {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, + {FastmathFlags::contract, &llvmFMF::setAllowContract}, + {FastmathFlags::afn, &llvmFMF::setApproxFunc}, + {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, + {FastmathFlags::fast, &llvmFMF::setFast}, + // clang-format on + }; + llvm::FastMathFlags ret; + auto fmf = op.fastmathFlags(); + for (auto it : handlers) + if (bitEnumContains(fmf, it.first)) + (ret.*(it.second))(true); + return ret; +} + +namespace { +/// Dispatcher functional object targeting different overloads of +/// ModuleTranslation::mapValue. +// TODO: this is only necessary for compatibility with the code emitted from +// ODS, remove when ODS is updated (after all dialects have migrated to the new +// translation mechanism). +struct MapValueDispatcher { + explicit MapValueDispatcher(ModuleTranslation &mt) : moduleTranslation(mt) {} + + llvm::Value *&operator()(mlir::Value v) { + return moduleTranslation.mapValue(v); + } + + void operator()(mlir::Value m, llvm::Value *l) { + moduleTranslation.mapValue(m, l); + } + + LLVM::ModuleTranslation &moduleTranslation; +}; +} // end namespace + +static LogicalResult +convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto extractPosition = [](ArrayAttr attr) { + SmallVector position; + position.reserve(attr.size()); + for (Attribute v : attr) + position.push_back(v.cast().getValue().getZExtValue()); + return position; + }; + + llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); + if (auto fmf = dyn_cast(opInst)) + builder.setFastMathFlags(getFastmathFlags(fmf)); + + // TODO: these are necessary for compatibility with the code emitted from ODS, + // remove them when ODS is updated (after all dialects have migrated to the + // new translation mechanism). + MapValueDispatcher mapValue(moduleTranslation); + auto lookupValue = [&](mlir::Value v) { + return moduleTranslation.lookupValue(v); + }; + auto convertType = [&](Type ty) { return moduleTranslation.convertType(ty); }; + auto lookupValues = [&](ValueRange vs) { + return moduleTranslation.lookupValues(vs); + }; + auto getLLVMConstant = [&](llvm::Type *ty, Attribute attr, Location loc) { + return moduleTranslation.getLLVMConstant(ty, attr, loc); + }; + +#include "mlir/Dialect/LLVMIR/LLVMConversions.inc" + + // Emit function calls. If the "callee" attribute is present, this is a + // direct function call and we also need to look up the remapped function + // itself. Otherwise, this is an indirect call and the callee is the first + // operand, look it up as a normal value. Return the llvm::Value representing + // the function result, which may be of llvm::VoidTy type. + auto convertCall = [&](Operation &op) -> llvm::Value * { + auto operands = moduleTranslation.lookupValues(op.getOperands()); + ArrayRef operandsRef(operands); + if (auto attr = op.getAttrOfType("callee")) + return builder.CreateCall( + moduleTranslation.lookupFunction(attr.getValue()), operandsRef); + auto *calleePtrType = + cast(operandsRef.front()->getType()); + auto *calleeType = + cast(calleePtrType->getElementType()); + return builder.CreateCall(calleeType, operandsRef.front(), + operandsRef.drop_front()); + }; + + // Emit calls. If the called function has a result, remap the corresponding + // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. + if (isa(opInst)) { + llvm::Value *result = convertCall(opInst); + if (opInst.getNumResults() != 0) { + mapValue(opInst.getResult(0), result); + return success(); + } + // Check that LLVM call returns void for 0-result functions. + return success(result->getType()->isVoidTy()); + } + + if (auto inlineAsmOp = dyn_cast(opInst)) { + // TODO: refactor function type creation which usually occurs in std-LLVM + // conversion. + SmallVector operandTypes; + operandTypes.reserve(inlineAsmOp.operands().size()); + for (auto t : inlineAsmOp.operands().getTypes()) + operandTypes.push_back(t); + + Type resultType; + if (inlineAsmOp.getNumResults() == 0) { + resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext()); + } else { + assert(inlineAsmOp.getNumResults() == 1); + resultType = inlineAsmOp.getResultTypes()[0]; + } + auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); + llvm::InlineAsm *inlineAsmInst = + inlineAsmOp.asm_dialect().hasValue() + ? llvm::InlineAsm::get( + static_cast(convertType(ft)), + inlineAsmOp.asm_string(), inlineAsmOp.constraints(), + inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(), + convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect())) + : llvm::InlineAsm::get( + static_cast(convertType(ft)), + inlineAsmOp.asm_string(), inlineAsmOp.constraints(), + inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack()); + llvm::Value *result = + builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands())); + if (opInst.getNumResults() != 0) + mapValue(opInst.getResult(0), result); + return success(); + } + + if (auto invOp = dyn_cast(opInst)) { + auto operands = lookupValues(opInst.getOperands()); + ArrayRef operandsRef(operands); + if (auto attr = opInst.getAttrOfType("callee")) { + builder.CreateInvoke(moduleTranslation.lookupFunction(attr.getValue()), + moduleTranslation.lookupBlock(invOp.getSuccessor(0)), + moduleTranslation.lookupBlock(invOp.getSuccessor(1)), + operandsRef); + } else { + auto *calleePtrType = + cast(operandsRef.front()->getType()); + auto *calleeType = + cast(calleePtrType->getElementType()); + builder.CreateInvoke(calleeType, operandsRef.front(), + moduleTranslation.lookupBlock(invOp.getSuccessor(0)), + moduleTranslation.lookupBlock(invOp.getSuccessor(1)), + operandsRef.drop_front()); + } + return success(); + } + + if (auto lpOp = dyn_cast(opInst)) { + llvm::Type *ty = convertType(lpOp.getType()); + llvm::LandingPadInst *lpi = + builder.CreateLandingPad(ty, lpOp.getNumOperands()); + + // Add clauses + for (llvm::Value *operand : lookupValues(lpOp.getOperands())) { + // All operands should be constant - checked by verifier + if (auto *constOperand = dyn_cast(operand)) + lpi->addClause(constOperand); + } + mapValue(lpOp.getResult(), lpi); + return success(); + } + + // Emit branches. We need to look up the remapped blocks and ignore the block + // arguments that were transformed into PHI nodes. + if (auto brOp = dyn_cast(opInst)) { + llvm::BranchInst *branch = + builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor())); + moduleTranslation.mapBranch(&opInst, branch); + return success(); + } + if (auto condbrOp = dyn_cast(opInst)) { + auto weights = condbrOp.branch_weights(); + llvm::MDNode *branchWeights = nullptr; + if (weights) { + // Map weight attributes to LLVM metadata. + auto trueWeight = + weights.getValue().getValue(0).cast().getInt(); + auto falseWeight = + weights.getValue().getValue(1).cast().getInt(); + branchWeights = + llvm::MDBuilder(moduleTranslation.getLLVMContext()) + .createBranchWeights(static_cast(trueWeight), + static_cast(falseWeight)); + } + llvm::BranchInst *branch = builder.CreateCondBr( + moduleTranslation.lookupValue(condbrOp.getOperand(0)), + moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), + moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights); + moduleTranslation.mapBranch(&opInst, branch); + return success(); + } + if (auto switchOp = dyn_cast(opInst)) { + llvm::MDNode *branchWeights = nullptr; + if (auto weights = switchOp.branch_weights()) { + llvm::SmallVector weightValues; + weightValues.reserve(weights->size()); + for (llvm::APInt weight : weights->cast()) + weightValues.push_back(weight.getLimitedValue()); + branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext()) + .createBranchWeights(weightValues); + } + + llvm::SwitchInst *switchInst = builder.CreateSwitch( + moduleTranslation.lookupValue(switchOp.value()), + moduleTranslation.lookupBlock(switchOp.defaultDestination()), + switchOp.caseDestinations().size(), branchWeights); + + auto *ty = + llvm::cast(convertType(switchOp.value().getType())); + for (auto i : + llvm::zip(switchOp.case_values()->cast(), + switchOp.caseDestinations())) + switchInst->addCase( + llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), + moduleTranslation.lookupBlock(std::get<1>(i))); + + moduleTranslation.mapBranch(&opInst, switchInst); + return success(); + } + + // Emit addressof. We need to look up the global value referenced by the + // operation and store it in the MLIR-to-LLVM value mapping. This does not + // emit any LLVM instruction. + if (auto addressOfOp = dyn_cast(opInst)) { + LLVM::GlobalOp global = addressOfOp.getGlobal(); + LLVM::LLVMFuncOp function = addressOfOp.getFunction(); + + // The verifier should not have allowed this. + assert((global || function) && + "referencing an undefined global or function"); + + mapValue(addressOfOp.getResult(), + global ? moduleTranslation.lookupGlobal(global) + : moduleTranslation.lookupFunction(function.getName())); + return success(); + } + + return failure(); +} + +LogicalResult mlir::LLVMDialectLLVMIRTranslationInterface::convertOperation( + Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + return convertOperationImpl(*op, builder, moduleTranslation); +} diff --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" #include "llvm/IR/IntrinsicsX86.h" @@ -57,7 +58,8 @@ return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" #include "llvm/IR/IntrinsicsAArch64.h" @@ -57,7 +58,8 @@ return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" #include "llvm/IR/IntrinsicsAArch64.h" @@ -57,7 +58,8 @@ return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); + registerLLVMDialectTranslation(registry); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" #include "llvm/ADT/TypeSwitch.h" @@ -183,130 +184,14 @@ return nullptr; } -/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. -static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) { - switch (p) { - case LLVM::ICmpPredicate::eq: - return llvm::CmpInst::Predicate::ICMP_EQ; - case LLVM::ICmpPredicate::ne: - return llvm::CmpInst::Predicate::ICMP_NE; - case LLVM::ICmpPredicate::slt: - return llvm::CmpInst::Predicate::ICMP_SLT; - case LLVM::ICmpPredicate::sle: - return llvm::CmpInst::Predicate::ICMP_SLE; - case LLVM::ICmpPredicate::sgt: - return llvm::CmpInst::Predicate::ICMP_SGT; - case LLVM::ICmpPredicate::sge: - return llvm::CmpInst::Predicate::ICMP_SGE; - case LLVM::ICmpPredicate::ult: - return llvm::CmpInst::Predicate::ICMP_ULT; - case LLVM::ICmpPredicate::ule: - return llvm::CmpInst::Predicate::ICMP_ULE; - case LLVM::ICmpPredicate::ugt: - return llvm::CmpInst::Predicate::ICMP_UGT; - case LLVM::ICmpPredicate::uge: - return llvm::CmpInst::Predicate::ICMP_UGE; - } - llvm_unreachable("incorrect comparison predicate"); -} - -static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) { - switch (p) { - case LLVM::FCmpPredicate::_false: - return llvm::CmpInst::Predicate::FCMP_FALSE; - case LLVM::FCmpPredicate::oeq: - return llvm::CmpInst::Predicate::FCMP_OEQ; - case LLVM::FCmpPredicate::ogt: - return llvm::CmpInst::Predicate::FCMP_OGT; - case LLVM::FCmpPredicate::oge: - return llvm::CmpInst::Predicate::FCMP_OGE; - case LLVM::FCmpPredicate::olt: - return llvm::CmpInst::Predicate::FCMP_OLT; - case LLVM::FCmpPredicate::ole: - return llvm::CmpInst::Predicate::FCMP_OLE; - case LLVM::FCmpPredicate::one: - return llvm::CmpInst::Predicate::FCMP_ONE; - case LLVM::FCmpPredicate::ord: - return llvm::CmpInst::Predicate::FCMP_ORD; - case LLVM::FCmpPredicate::ueq: - return llvm::CmpInst::Predicate::FCMP_UEQ; - case LLVM::FCmpPredicate::ugt: - return llvm::CmpInst::Predicate::FCMP_UGT; - case LLVM::FCmpPredicate::uge: - return llvm::CmpInst::Predicate::FCMP_UGE; - case LLVM::FCmpPredicate::ult: - return llvm::CmpInst::Predicate::FCMP_ULT; - case LLVM::FCmpPredicate::ule: - return llvm::CmpInst::Predicate::FCMP_ULE; - case LLVM::FCmpPredicate::une: - return llvm::CmpInst::Predicate::FCMP_UNE; - case LLVM::FCmpPredicate::uno: - return llvm::CmpInst::Predicate::FCMP_UNO; - case LLVM::FCmpPredicate::_true: - return llvm::CmpInst::Predicate::FCMP_TRUE; - } - llvm_unreachable("incorrect comparison predicate"); -} - -static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) { - switch (op) { - case LLVM::AtomicBinOp::xchg: - return llvm::AtomicRMWInst::BinOp::Xchg; - case LLVM::AtomicBinOp::add: - return llvm::AtomicRMWInst::BinOp::Add; - case LLVM::AtomicBinOp::sub: - return llvm::AtomicRMWInst::BinOp::Sub; - case LLVM::AtomicBinOp::_and: - return llvm::AtomicRMWInst::BinOp::And; - case LLVM::AtomicBinOp::nand: - return llvm::AtomicRMWInst::BinOp::Nand; - case LLVM::AtomicBinOp::_or: - return llvm::AtomicRMWInst::BinOp::Or; - case LLVM::AtomicBinOp::_xor: - return llvm::AtomicRMWInst::BinOp::Xor; - case LLVM::AtomicBinOp::max: - return llvm::AtomicRMWInst::BinOp::Max; - case LLVM::AtomicBinOp::min: - return llvm::AtomicRMWInst::BinOp::Min; - case LLVM::AtomicBinOp::umax: - return llvm::AtomicRMWInst::BinOp::UMax; - case LLVM::AtomicBinOp::umin: - return llvm::AtomicRMWInst::BinOp::UMin; - case LLVM::AtomicBinOp::fadd: - return llvm::AtomicRMWInst::BinOp::FAdd; - case LLVM::AtomicBinOp::fsub: - return llvm::AtomicRMWInst::BinOp::FSub; - } - llvm_unreachable("incorrect atomic binary operator"); -} - -static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) { - switch (ordering) { - case LLVM::AtomicOrdering::not_atomic: - return llvm::AtomicOrdering::NotAtomic; - case LLVM::AtomicOrdering::unordered: - return llvm::AtomicOrdering::Unordered; - case LLVM::AtomicOrdering::monotonic: - return llvm::AtomicOrdering::Monotonic; - case LLVM::AtomicOrdering::acquire: - return llvm::AtomicOrdering::Acquire; - case LLVM::AtomicOrdering::release: - return llvm::AtomicOrdering::Release; - case LLVM::AtomicOrdering::acq_rel: - return llvm::AtomicOrdering::AcquireRelease; - case LLVM::AtomicOrdering::seq_cst: - return llvm::AtomicOrdering::SequentiallyConsistent; - } - llvm_unreachable("incorrect atomic ordering"); -} - ModuleTranslation::ModuleTranslation(Operation *module, std::unique_ptr llvmModule) : mlirModule(module), llvmModule(std::move(llvmModule)), debugTranslation( std::make_unique(module, *this->llvmModule)), ompDialect(module->getContext()->getLoadedDialect("omp")), - typeTranslator(this->llvmModule->getContext()) { + typeTranslator(this->llvmModule->getContext()), + iface(module->getContext()) { assert(satisfiesLLVMModule(mlirModule) && "mlirModule should honor LLVM's module semantics."); } @@ -658,221 +543,17 @@ }); } -static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { - using llvmFMF = llvm::FastMathFlags; - using FuncT = void (llvmFMF::*)(bool); - const std::pair handlers[] = { - // clang-format off - {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, - {FastmathFlags::ninf, &llvmFMF::setNoInfs}, - {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, - {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, - {FastmathFlags::contract, &llvmFMF::setAllowContract}, - {FastmathFlags::afn, &llvmFMF::setApproxFunc}, - {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, - {FastmathFlags::fast, &llvmFMF::setFast}, - // clang-format on - }; - llvm::FastMathFlags ret; - auto fmf = op.fastmathFlags(); - for (auto it : handlers) - if (bitEnumContains(fmf, it.first)) - (ret.*(it.second))(true); - return ret; -} - /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a /// different number of arguments. LogicalResult ModuleTranslation::convertOperation(Operation &opInst, llvm::IRBuilder<> &builder) { - auto extractPosition = [](ArrayAttr attr) { - SmallVector position; - position.reserve(attr.size()); - for (Attribute v : attr) - position.push_back(v.cast().getValue().getZExtValue()); - return position; - }; - - llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); - if (auto fmf = dyn_cast(opInst)) - builder.setFastMathFlags(getFastmathFlags(fmf)); - -#include "mlir/Dialect/LLVMIR/LLVMConversions.inc" - - // Emit function calls. If the "callee" attribute is present, this is a - // direct function call and we also need to look up the remapped function - // itself. Otherwise, this is an indirect call and the callee is the first - // operand, look it up as a normal value. Return the llvm::Value representing - // the function result, which may be of llvm::VoidTy type. - auto convertCall = [this, &builder](Operation &op) -> llvm::Value * { - auto operands = lookupValues(op.getOperands()); - ArrayRef operandsRef(operands); - if (auto attr = op.getAttrOfType("callee")) - return builder.CreateCall(lookupFunction(attr.getValue()), operandsRef); - auto *calleePtrType = - cast(operandsRef.front()->getType()); - auto *calleeType = - cast(calleePtrType->getElementType()); - return builder.CreateCall(calleeType, operandsRef.front(), - operandsRef.drop_front()); - }; - - // Emit calls. If the called function has a result, remap the corresponding - // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. - if (isa(opInst)) { - llvm::Value *result = convertCall(opInst); - if (opInst.getNumResults() != 0) { - mapValue(opInst.getResult(0), result); - return success(); - } - // Check that LLVM call returns void for 0-result functions. - return success(result->getType()->isVoidTy()); - } - if (auto inlineAsmOp = dyn_cast(opInst)) { - // TODO: refactor function type creation which usually occurs in std-LLVM - // conversion. - SmallVector operandTypes; - operandTypes.reserve(inlineAsmOp.operands().size()); - for (auto t : inlineAsmOp.operands().getTypes()) - operandTypes.push_back(t); - - Type resultType; - if (inlineAsmOp.getNumResults() == 0) { - resultType = LLVM::LLVMVoidType::get(mlirModule->getContext()); - } else { - assert(inlineAsmOp.getNumResults() == 1); - resultType = inlineAsmOp.getResultTypes()[0]; - } - auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); - llvm::InlineAsm *inlineAsmInst = - inlineAsmOp.asm_dialect().hasValue() - ? llvm::InlineAsm::get( - static_cast(convertType(ft)), - inlineAsmOp.asm_string(), inlineAsmOp.constraints(), - inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(), - convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect())) - : llvm::InlineAsm::get( - static_cast(convertType(ft)), - inlineAsmOp.asm_string(), inlineAsmOp.constraints(), - inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack()); - llvm::Value *result = - builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands())); - if (opInst.getNumResults() != 0) - mapValue(opInst.getResult(0), result); + // TODO(zinenko): this should be the "main" conversion here, remove the + // dispatch below. + if (succeeded(iface.convertOperation(&opInst, builder, *this))) return success(); - } - - if (auto invOp = dyn_cast(opInst)) { - auto operands = lookupValues(opInst.getOperands()); - ArrayRef operandsRef(operands); - if (auto attr = opInst.getAttrOfType("callee")) { - builder.CreateInvoke(lookupFunction(attr.getValue()), - lookupBlock(invOp.getSuccessor(0)), - lookupBlock(invOp.getSuccessor(1)), operandsRef); - } else { - auto *calleePtrType = - cast(operandsRef.front()->getType()); - auto *calleeType = - cast(calleePtrType->getElementType()); - builder.CreateInvoke( - calleeType, operandsRef.front(), lookupBlock(invOp.getSuccessor(0)), - lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); - } - return success(); - } - - if (auto lpOp = dyn_cast(opInst)) { - llvm::Type *ty = convertType(lpOp.getType()); - llvm::LandingPadInst *lpi = - builder.CreateLandingPad(ty, lpOp.getNumOperands()); - - // Add clauses - for (llvm::Value *operand : lookupValues(lpOp.getOperands())) { - // All operands should be constant - checked by verifier - if (auto *constOperand = dyn_cast(operand)) - lpi->addClause(constOperand); - } - mapValue(lpOp.getResult(), lpi); - return success(); - } - - // Emit branches. We need to look up the remapped blocks and ignore the block - // arguments that were transformed into PHI nodes. - if (auto brOp = dyn_cast(opInst)) { - llvm::BranchInst *branch = - builder.CreateBr(lookupBlock(brOp.getSuccessor())); - mapBranch(&opInst, branch); - return success(); - } - if (auto condbrOp = dyn_cast(opInst)) { - auto weights = condbrOp.branch_weights(); - llvm::MDNode *branchWeights = nullptr; - if (weights) { - // Map weight attributes to LLVM metadata. - auto trueWeight = - weights.getValue().getValue(0).cast().getInt(); - auto falseWeight = - weights.getValue().getValue(1).cast().getInt(); - branchWeights = - llvm::MDBuilder(llvmModule->getContext()) - .createBranchWeights(static_cast(trueWeight), - static_cast(falseWeight)); - } - llvm::BranchInst *branch = builder.CreateCondBr( - lookupValue(condbrOp.getOperand(0)), - lookupBlock(condbrOp.getSuccessor(0)), - lookupBlock(condbrOp.getSuccessor(1)), branchWeights); - mapBranch(&opInst, branch); - return success(); - } - if (auto switchOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = nullptr; - if (auto weights = switchOp.branch_weights()) { - llvm::SmallVector weightValues; - weightValues.reserve(weights->size()); - for (llvm::APInt weight : weights->cast()) - weightValues.push_back(weight.getLimitedValue()); - branchWeights = llvm::MDBuilder(llvmModule->getContext()) - .createBranchWeights(weightValues); - } - - llvm::SwitchInst *switchInst = - builder.CreateSwitch(lookupValue(switchOp.value()), - lookupBlock(switchOp.defaultDestination()), - switchOp.caseDestinations().size(), branchWeights); - - auto *ty = - llvm::cast(convertType(switchOp.value().getType())); - for (auto i : - llvm::zip(switchOp.case_values()->cast(), - switchOp.caseDestinations())) - switchInst->addCase( - llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), - lookupBlock(std::get<1>(i))); - - mapBranch(&opInst, switchInst); - return success(); - } - - // Emit addressof. We need to look up the global value referenced by the - // operation and store it in the MLIR-to-LLVM value mapping. This does not - // emit any LLVM instruction. - if (auto addressOfOp = dyn_cast(opInst)) { - LLVM::GlobalOp global = addressOfOp.getGlobal(); - LLVM::LLVMFuncOp function = addressOfOp.getFunction(); - - // The verifier should not have allowed this. - assert((global || function) && - "referencing an undefined global or function"); - - mapValue(addressOfOp.getResult(), global - ? globalsMapping.lookup(global) - : lookupFunction(function.getName())); - return success(); - } if (ompDialect && opInst.getDialect() == ompDialect) return convertOmpOperation(opInst, builder); diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -17,6 +17,8 @@ #include "mlir/ExecutionEngine/JitRunner.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/Dialect.h" +#include "mlir/Target/LLVMIR.h" + #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" @@ -29,5 +31,7 @@ mlir::DialectRegistry registry; registry.insert(); + mlir::registerLLVMDialectTranslation(registry); + return mlir::JitRunnerMain(argc, argv, registry); } diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -31,9 +31,11 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/NVVMIR.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" + #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" @@ -154,6 +156,7 @@ registry.insert(); + mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); } diff --git a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp --- a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp +++ b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp @@ -30,6 +30,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Target/LLVMIR.h" #include "mlir/Target/ROCDLIR.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" @@ -340,6 +341,7 @@ mlir::DialectRegistry registry; registry.insert(); + mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); } diff --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp --- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp +++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp @@ -96,6 +96,7 @@ mlir::DialectRegistry registry; registry.insert(); + mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); } diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -27,6 +27,7 @@ #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" @@ -67,6 +68,7 @@ mlir::DialectRegistry registry; registry.insert(); + mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); } diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -19,6 +19,7 @@ #include "mlir/InitAllDialects.h" #include "mlir/Parser.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" @@ -51,8 +52,10 @@ return %res : i32 } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); OwningModuleRef module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -74,8 +77,10 @@ return %res : f32 } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); OwningModuleRef module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -102,8 +107,10 @@ return } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); auto module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -135,8 +142,10 @@ return } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); auto module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -187,8 +196,10 @@ return } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); OwningModuleRef module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); @@ -227,8 +238,10 @@ return } )mlir"; - MLIRContext context; - registerAllDialects(context); + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); auto module = parseSourceString(moduleStr, &context); ASSERT_TRUE(!!module); ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));