diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -18,7 +18,7 @@ #include "mlir/IR/Value.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" -#include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "llvm/ADT/SetVector.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" diff --git a/mlir/include/mlir/Target/LLVMIR/TypeFromLLVM.h b/mlir/include/mlir/Target/LLVMIR/TypeFromLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/TypeFromLLVM.h @@ -0,0 +1,55 @@ +//===- TypeFromLLVM.h - Translate types from LLVM to MLIR --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the type translation function going from MLIR LLVM dialect +// to LLVM IR and back. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_TYPEFROMLLVM_H +#define MLIR_TARGET_LLVMIR_TYPEFROMLLVM_H + +#include + +namespace llvm { +class DataLayout; +class LLVMContext; +class Type; +} // namespace llvm + +namespace mlir { + +class Type; +class MLIRContext; + +namespace LLVM { + +namespace detail { +class TypeFromLLVMIRTranslatorImpl; +} // namespace detail + +/// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores +/// the translation state, in particular any identified structure types that are +/// reused across translations. +class TypeFromLLVMIRTranslator { +public: + TypeFromLLVMIRTranslator(MLIRContext &context); + ~TypeFromLLVMIRTranslator(); + + /// Translates the given LLVM IR type to the MLIR LLVM dialect. + Type translateType(llvm::Type *type); + +private: + /// Private implementation. + std::unique_ptr impl; +}; + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_TYPEFROMLLVM_H diff --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeToLLVM.h rename from mlir/include/mlir/Target/LLVMIR/TypeTranslation.h rename to mlir/include/mlir/Target/LLVMIR/TypeToLLVM.h --- a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/TypeToLLVM.h @@ -1,4 +1,4 @@ -//===- TypeTranslation.h - Translate types between MLIR & LLVM --*- C++ -*-===// +//===- TypeToLLVM.h - Translate types from MLIR to LLVM --*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_TARGET_LLVMIR_TYPETRANSLATION_H -#define MLIR_TARGET_LLVMIR_TYPETRANSLATION_H +#ifndef MLIR_TARGET_LLVMIR_TYPETOLLVM_H +#define MLIR_TARGET_LLVMIR_TYPETOLLVM_H #include @@ -58,4 +58,4 @@ } // namespace LLVM } // namespace mlir -#endif // MLIR_TARGET_LLVMIR_TYPETRANSLATION_H +#endif // MLIR_TARGET_LLVMIR_TYPETOLLVM_H diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -17,7 +17,7 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/MathExtras.h" -#include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -5,14 +5,15 @@ ConvertToLLVMIR.cpp DebugTranslation.cpp ModuleTranslation.cpp - TypeTranslation.cpp + TypeToLLVM.cpp + TypeFromLLVM.cpp ) add_mlir_translation_library(MLIRTargetLLVMIRExport DebugTranslation.cpp ModuleTranslation.cpp - TypeTranslation.cpp + TypeToLLVM.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR @@ -35,6 +36,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration ConvertToLLVMIR.cpp + TypeToLLVM.cpp LINK_LIBS PUBLIC MLIRArmNeonToLLVMIRTranslation @@ -50,6 +52,7 @@ add_mlir_translation_library(MLIRTargetLLVMIRImport ConvertFromLLVMIR.cpp + TypeFromLLVM.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR @@ -61,4 +64,5 @@ LINK_LIBS PUBLIC MLIRLLVMIR MLIRTranslation + MLIRTargetLLVMIRExport ) diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Target/LLVMIR/Import.h" +#include "mlir/Target/LLVMIR/TypeFromLLVM.h" #include "mlir/Translation.h" #include "llvm/ADT/TypeSwitch.h" @@ -45,167 +46,6 @@ return os.str(); } -namespace mlir { -namespace LLVM { -namespace detail { -/// Support for translating LLVM IR types to MLIR LLVM dialect types. -class TypeFromLLVMIRTranslatorImpl { -public: - /// Constructs a class creating types in the given MLIR context. - TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} - - /// Translates the given type. - Type translateType(llvm::Type *type) { - if (knownTranslations.count(type)) - return knownTranslations.lookup(type); - - Type translated = - llvm::TypeSwitch(type) - .Case( - [this](auto *type) { return this->translate(type); }) - .Default([this](llvm::Type *type) { - return translatePrimitiveType(type); - }); - knownTranslations.try_emplace(type, translated); - return translated; - } - -private: - /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, - /// type. - Type translatePrimitiveType(llvm::Type *type) { - if (type->isVoidTy()) - return LLVM::LLVMVoidType::get(&context); - if (type->isHalfTy()) - return Float16Type::get(&context); - if (type->isBFloatTy()) - return BFloat16Type::get(&context); - if (type->isFloatTy()) - return Float32Type::get(&context); - if (type->isDoubleTy()) - return Float64Type::get(&context); - if (type->isFP128Ty()) - return Float128Type::get(&context); - if (type->isX86_FP80Ty()) - return Float80Type::get(&context); - if (type->isPPC_FP128Ty()) - return LLVM::LLVMPPCFP128Type::get(&context); - if (type->isX86_MMXTy()) - return LLVM::LLVMX86MMXType::get(&context); - if (type->isLabelTy()) - return LLVM::LLVMLabelType::get(&context); - if (type->isMetadataTy()) - return LLVM::LLVMMetadataType::get(&context); - llvm_unreachable("not a primitive type"); - } - - /// Translates the given array type. - Type translate(llvm::ArrayType *type) { - return LLVM::LLVMArrayType::get(translateType(type->getElementType()), - type->getNumElements()); - } - - /// Translates the given function type. - Type translate(llvm::FunctionType *type) { - SmallVector paramTypes; - translateTypes(type->params(), paramTypes); - return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), - paramTypes, type->isVarArg()); - } - - /// Translates the given integer type. - Type translate(llvm::IntegerType *type) { - return IntegerType::get(&context, type->getBitWidth()); - } - - /// Translates the given pointer type. - Type translate(llvm::PointerType *type) { - return LLVM::LLVMPointerType::get(translateType(type->getElementType()), - type->getAddressSpace()); - } - - /// Translates the given structure type. - Type translate(llvm::StructType *type) { - SmallVector subtypes; - if (type->isLiteral()) { - translateTypes(type->subtypes(), subtypes); - return LLVM::LLVMStructType::getLiteral(&context, subtypes, - type->isPacked()); - } - - if (type->isOpaque()) - return LLVM::LLVMStructType::getOpaque(type->getName(), &context); - - LLVM::LLVMStructType translated = - LLVM::LLVMStructType::getIdentified(&context, type->getName()); - knownTranslations.try_emplace(type, translated); - translateTypes(type->subtypes(), subtypes); - LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); - assert(succeeded(bodySet) && - "could not set the body of an identified struct"); - (void)bodySet; - return translated; - } - - /// Translates the given fixed-vector type. - Type translate(llvm::FixedVectorType *type) { - return LLVM::getFixedVectorType(translateType(type->getElementType()), - type->getNumElements()); - } - - /// Translates the given scalable-vector type. - Type translate(llvm::ScalableVectorType *type) { - return LLVM::LLVMScalableVectorType::get( - translateType(type->getElementType()), type->getMinNumElements()); - } - - /// Translates a list of types. - void translateTypes(ArrayRef types, - SmallVectorImpl &result) { - result.reserve(result.size() + types.size()); - for (llvm::Type *type : types) - result.push_back(translateType(type)); - } - - /// Map of known translations. Serves as a cache and as recursion stopper for - /// translating recursive structs. - llvm::DenseMap knownTranslations; - - /// The context in which MLIR types are created. - MLIRContext &context; -}; -} // end namespace detail - -/// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores -/// the translation state, in particular any identified structure types that are -/// reused across translations. -class TypeFromLLVMIRTranslator { -public: - TypeFromLLVMIRTranslator(MLIRContext &context); - ~TypeFromLLVMIRTranslator(); - - /// Translates the given LLVM IR type to the MLIR LLVM dialect. - Type translateType(llvm::Type *type); - -private: - /// Private implementation. - std::unique_ptr impl; -}; - -} // end namespace LLVM -} // end namespace mlir - -LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context) - : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {} - -LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {} - -Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { - return impl->translateType(type); -} - // Handles importing globals and functions from an LLVM module. namespace { class Importer { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -23,7 +23,7 @@ #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" -#include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/PostOrderIterator.h" diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp @@ -0,0 +1,164 @@ +//===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/TypeFromLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Type.h" + +using namespace mlir; + +namespace mlir { +namespace LLVM { +namespace detail { +/// Support for translating LLVM IR types to MLIR LLVM dialect types. +class TypeFromLLVMIRTranslatorImpl { +public: + /// Constructs a class creating types in the given MLIR context. + TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} + + /// Translates the given type. + Type translateType(llvm::Type *type) { + if (knownTranslations.count(type)) + return knownTranslations.lookup(type); + + Type translated = + llvm::TypeSwitch(type) + .Case( + [this](auto *type) { return this->translate(type); }) + .Default([this](llvm::Type *type) { + return translatePrimitiveType(type); + }); + knownTranslations.try_emplace(type, translated); + return translated; + } + +private: + /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, + /// type. + Type translatePrimitiveType(llvm::Type *type) { + if (type->isVoidTy()) + return LLVM::LLVMVoidType::get(&context); + if (type->isHalfTy()) + return Float16Type::get(&context); + if (type->isBFloatTy()) + return BFloat16Type::get(&context); + if (type->isFloatTy()) + return Float32Type::get(&context); + if (type->isDoubleTy()) + return Float64Type::get(&context); + if (type->isFP128Ty()) + return Float128Type::get(&context); + if (type->isX86_FP80Ty()) + return Float80Type::get(&context); + if (type->isPPC_FP128Ty()) + return LLVM::LLVMPPCFP128Type::get(&context); + if (type->isX86_MMXTy()) + return LLVM::LLVMX86MMXType::get(&context); + if (type->isLabelTy()) + return LLVM::LLVMLabelType::get(&context); + if (type->isMetadataTy()) + return LLVM::LLVMMetadataType::get(&context); + llvm_unreachable("not a primitive type"); + } + + /// Translates the given array type. + Type translate(llvm::ArrayType *type) { + return LLVM::LLVMArrayType::get(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given function type. + Type translate(llvm::FunctionType *type) { + SmallVector paramTypes; + translateTypes(type->params(), paramTypes); + return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), + paramTypes, type->isVarArg()); + } + + /// Translates the given integer type. + Type translate(llvm::IntegerType *type) { + return IntegerType::get(&context, type->getBitWidth()); + } + + /// Translates the given pointer type. + Type translate(llvm::PointerType *type) { + return LLVM::LLVMPointerType::get(translateType(type->getElementType()), + type->getAddressSpace()); + } + + /// Translates the given structure type. + Type translate(llvm::StructType *type) { + SmallVector subtypes; + if (type->isLiteral()) { + translateTypes(type->subtypes(), subtypes); + return LLVM::LLVMStructType::getLiteral(&context, subtypes, + type->isPacked()); + } + + if (type->isOpaque()) + return LLVM::LLVMStructType::getOpaque(type->getName(), &context); + + LLVM::LLVMStructType translated = + LLVM::LLVMStructType::getIdentified(&context, type->getName()); + knownTranslations.try_emplace(type, translated); + translateTypes(type->subtypes(), subtypes); + LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); + assert(succeeded(bodySet) && + "could not set the body of an identified struct"); + (void)bodySet; + return translated; + } + + /// Translates the given fixed-vector type. + Type translate(llvm::FixedVectorType *type) { + return LLVM::getFixedVectorType(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given scalable-vector type. + Type translate(llvm::ScalableVectorType *type) { + return LLVM::LLVMScalableVectorType::get( + translateType(type->getElementType()), type->getMinNumElements()); + } + + /// Translates a list of types. + void translateTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (llvm::Type *type : types) + result.push_back(translateType(type)); + } + + /// Map of known translations. Serves as a cache and as recursion stopper for + /// translating recursive structs. + llvm::DenseMap knownTranslations; + + /// The context in which MLIR types are created. + MLIRContext &context; +}; + +} // end namespace detail +} // end namespace LLVM +} // end namespace mlir + +LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context) + : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {} + +LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {} + +Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { + return impl->translateType(type); +} diff --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp rename from mlir/lib/Target/LLVMIR/TypeTranslation.cpp rename to mlir/lib/Target/LLVMIR/TypeToLLVM.cpp --- a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -1,4 +1,4 @@ -//===- TypeTranslation.cpp - type translation between MLIR LLVM & LLVM IR -===// +//===- TypeToLLVM.cpp - type translation from MLIR to LLVM IR -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -170,6 +170,7 @@ /// type instead of creating a new type. llvm::DenseMap knownTranslations; }; + } // end namespace detail } // end namespace LLVM } // end namespace mlir