diff --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h @@ -0,0 +1,36 @@ +//===- TypeTranslation.h - Translate types between MLIR & LLVM --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the type translation function going from MLIR LLVM dialect +// to LLVM IR and back. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_TYPETRANSLATION_H +#define MLIR_TARGET_LLVMIR_TYPETRANSLATION_H + +namespace llvm { +class LLVMContext; +class Type; +} // namespace llvm + +namespace mlir { + +class MLIRContext; + +namespace LLVM { + +class LLVMTypeNew; + +llvm::Type *translateTypeToLLVMIR(LLVMTypeNew type, llvm::LLVMContext &context); +LLVMTypeNew translateTypeFromLLVMIR(llvm::Type *type, MLIRContext &context); + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_TYPETRANSLATION_H diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation LLVMIR/DebugTranslation.cpp LLVMIR/ModuleTranslation.cpp + LLVMIR/TypeTranslation.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR diff --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp @@ -0,0 +1,309 @@ +//===- TypeTranslation.cpp - type translation between MLIR LLVM & LLVM IR -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/MLIRContext.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Type.h" + +using namespace mlir; + +namespace { +/// Support for translating MLIR LLVM dialect types to LLVM IR. +class TypeToLLVMIRTranslator { +public: + /// Constructs a class creating types in the given LLVM context. + TypeToLLVMIRTranslator(llvm::LLVMContext &context) : context(context) {} + + /// Translates a single type. + llvm::Type *translateType(LLVM::LLVMTypeNew type) { + // If the conversion is already known, just return it. + if (knownTranslations.count(type)) + return knownTranslations.lookup(type); + + // Dispatch to an appropriate function. + llvm::Type *translated = + llvm::TypeSwitch(type) + .Case([this](LLVM::LLVMVoidType) { + return llvm::Type::getVoidTy(context); + }) + .Case([this](LLVM::LLVMHalfType) { + return llvm::Type::getHalfTy(context); + }) + .Case([this](LLVM::LLVMBFloatType) { + return llvm::Type::getBFloatTy(context); + }) + .Case([this](LLVM::LLVMFloatType) { + return llvm::Type::getFloatTy(context); + }) + .Case([this](LLVM::LLVMDoubleType) { + return llvm::Type::getDoubleTy(context); + }) + .Case([this](LLVM::LLVMFP128Type) { + return llvm::Type::getFP128Ty(context); + }) + .Case([this](LLVM::LLVMX86FP80Type) { + return llvm::Type::getX86_FP80Ty(context); + }) + .Case([this](LLVM::LLVMPPCFP128Type) { + return llvm::Type::getPPC_FP128Ty(context); + }) + .Case([this](LLVM::LLVMX86MMXType) { + return llvm::Type::getX86_MMXTy(context); + }) + .Case([this](LLVM::LLVMTokenType) { + return llvm::Type::getTokenTy(context); + }) + .Case([this](LLVM::LLVMLabelType) { + return llvm::Type::getLabelTy(context); + }) + .Case([this](LLVM::LLVMMetadataType) { + return llvm::Type::getMetadataTy(context); + }) + .Case( + [this](auto array) { return translate(array); }) + .Default([](LLVM::LLVMTypeNew t) -> llvm::Type * { + llvm_unreachable("unknown LLVM dialect type"); + }); + + // Cache the result of the conversion and return. + knownTranslations.try_emplace(type, translated); + return translated; + } + +private: + /// Translates the given array type. + llvm::Type *translate(LLVM::LLVMArrayType type) { + return llvm::ArrayType::get(translateType(type.getElementType()), + type.getNumElements()); + } + + /// Translates the given function type. + llvm::Type *translate(LLVM::LLVMFunctionType type) { + SmallVector paramTypes; + translateTypes(type.getParams(), paramTypes); + return llvm::FunctionType::get(translateType(type.getReturnType()), + paramTypes, type.isVarArg()); + } + + /// Translates the given integer type. + llvm::Type *translate(LLVM::LLVMIntegerType type) { + return llvm::IntegerType::get(context, type.getBitWidth()); + } + + /// Translates the given pointer type. + llvm::Type *translate(LLVM::LLVMPointerType type) { + return llvm::PointerType::get(translateType(type.getElementType()), + type.getAddressSpace()); + } + + /// Translates the given structure type, supports both identified and literal + /// structs. This will _create_ a new identified structure every time, use + /// `convertType` if a structure with the same name must be looked up instead. + llvm::Type *translate(LLVM::LLVMStructType type) { + SmallVector subtypes; + if (!type.isIdentified()) { + translateTypes(type.getBody(), subtypes); + return llvm::StructType::get(context, subtypes, type.isPacked()); + } + + llvm::StructType *structType = + llvm::StructType::create(context, type.getName()); + // Mark the type we just created as known so that recursive calls can pick + // it up and use directly. + knownTranslations.try_emplace(type, structType); + if (type.isOpaque()) + return structType; + + translateTypes(type.getBody(), subtypes); + structType->setBody(subtypes, type.isPacked()); + return structType; + } + + /// Translates the given fixed-vector type. + llvm::Type *translate(LLVM::LLVMFixedVectorType type) { + return llvm::FixedVectorType::get(translateType(type.getElementType()), + type.getNumElements()); + } + + /// Translates the given scalable-vector type. + llvm::Type *translate(LLVM::LLVMScalableVectorType type) { + return llvm::ScalableVectorType::get(translateType(type.getElementType()), + type.getMinNumElements()); + } + + /// Translates a list of types. + void translateTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (auto type : types) + result.push_back(translateType(type)); + } + + /// Reference to the context in which the LLVM IR types are created. + llvm::LLVMContext &context; + + /// Map of known translation. This serves a double purpose: caches translation + /// results to avoid repeated recursive calls and makes sure identified + /// structs with the same name (that is, equal) are resolved to an existing + /// type instead of creating a new type. + llvm::DenseMap knownTranslations; +}; +} // end namespace + +/// Translates a type from MLIR LLVM dialect to LLVM IR. This does not maintain +/// the mapping for identified structs so new structs will be created with +/// auto-renaming on each call. This is intended exclusively for testing. +llvm::Type *mlir::LLVM::translateTypeToLLVMIR(LLVM::LLVMTypeNew type, + llvm::LLVMContext &context) { + return TypeToLLVMIRTranslator(context).translateType(type); +} + +namespace { +/// Support for translating LLVM IR types to MLIR LLVM dialect types. +class TypeFromLLVMIRTranslator { +public: + /// Constructs a class creating types in the given MLIR context. + TypeFromLLVMIRTranslator(MLIRContext &context) : context(context) {} + + /// Translates the given type. + LLVM::LLVMTypeNew translateType(llvm::Type *type) { + if (knownTranslations.count(type)) + return knownTranslations.lookup(type); + + LLVM::LLVMTypeNew translated = + llvm::TypeSwitch(type) + .Case( + [this](auto *type) { return translate(type); }) + .Default([this](llvm::Type *type) { + return translatePrimitiveType(type); + }); + knownTranslations.try_emplace(type, translated); + return translated; + } + +private: + /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, + /// type. + LLVM::LLVMTypeNew translatePrimitiveType(llvm::Type *type) { + if (type->isVoidTy()) + return LLVM::LLVMVoidType::get(&context); + if (type->isHalfTy()) + return LLVM::LLVMHalfType::get(&context); + if (type->isBFloatTy()) + return LLVM::LLVMBFloatType::get(&context); + if (type->isFloatTy()) + return LLVM::LLVMFloatType::get(&context); + if (type->isDoubleTy()) + return LLVM::LLVMDoubleType::get(&context); + if (type->isFP128Ty()) + return LLVM::LLVMFP128Type::get(&context); + if (type->isX86_FP80Ty()) + return LLVM::LLVMX86FP80Type::get(&context); + if (type->isPPC_FP128Ty()) + return LLVM::LLVMPPCFP128Type::get(&context); + if (type->isX86_MMXTy()) + return LLVM::LLVMX86MMXType::get(&context); + if (type->isLabelTy()) + return LLVM::LLVMLabelType::get(&context); + if (type->isMetadataTy()) + return LLVM::LLVMMetadataType::get(&context); + llvm_unreachable("not a primitive type"); + } + + /// Translates the given array type. + LLVM::LLVMTypeNew translate(llvm::ArrayType *type) { + return LLVM::LLVMArrayType::get(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given function type. + LLVM::LLVMTypeNew translate(llvm::FunctionType *type) { + SmallVector paramTypes; + translateTypes(type->params(), paramTypes); + return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), + paramTypes, type->isVarArg()); + } + + /// Translates the given integer type. + LLVM::LLVMTypeNew translate(llvm::IntegerType *type) { + return LLVM::LLVMIntegerType::get(&context, type->getBitWidth()); + } + + /// Translates the given pointer type. + LLVM::LLVMTypeNew translate(llvm::PointerType *type) { + return LLVM::LLVMPointerType::get(translateType(type->getElementType()), + type->getAddressSpace()); + } + + /// Translates the given structure type. + LLVM::LLVMTypeNew translate(llvm::StructType *type) { + SmallVector subtypes; + if (type->isLiteral()) { + translateTypes(type->subtypes(), subtypes); + return LLVM::LLVMStructType::getLiteral(&context, subtypes, + type->isPacked()); + } + + if (type->isOpaque()) + return LLVM::LLVMStructType::getOpaque(type->getName(), &context); + + LLVM::LLVMStructType translated = + LLVM::LLVMStructType::getIdentified(&context, type->getName()); + knownTranslations.try_emplace(type, translated); + translateTypes(type->subtypes(), subtypes); + LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); + assert(succeeded(bodySet) && + "could not set the body of an identified struct"); + (void)bodySet; + return translated; + } + + /// Translates the given fixed-vector type. + LLVM::LLVMTypeNew translate(llvm::FixedVectorType *type) { + return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given scalable-vector type. + LLVM::LLVMTypeNew translate(llvm::ScalableVectorType *type) { + return LLVM::LLVMScalableVectorType::get( + translateType(type->getElementType()), type->getMinNumElements()); + } + + /// Translates a list of types. + void translateTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (llvm::Type *type : types) + result.push_back(translateType(type)); + } + + /// Map of known translations. Serves as a cache and as recursion stopper for + /// translating recursive structs. + llvm::DenseMap knownTranslations; + + /// The context in which MLIR types are created. + MLIRContext &context; +}; +} // end namespace + +/// Translates a type from LLVM IR to MLIR LLVM dialect. This is intended +/// exclusively for testing. +LLVM::LLVMTypeNew mlir::LLVM::translateTypeFromLLVMIR(llvm::Type *type, + MLIRContext &context) { + return TypeFromLLVMIRTranslator(context).translateType(type); +} diff --git a/mlir/test/Target/llvmir-types.mlir b/mlir/test/Target/llvmir-types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/llvmir-types.mlir @@ -0,0 +1,228 @@ +// RUN: mlir-translate -test-mlir-to-llvmir -split-input-file %s | FileCheck %s + +llvm.func @primitives() { + // CHECK: declare void @return_void() + // CHECK: declare void @return_void_round() + "llvm.test_introduce_func"() { name = "return_void", type = !llvm2.void } : () -> () + // CHECK: declare half @return_half() + // CHECK: declare half @return_half_round() + "llvm.test_introduce_func"() { name = "return_half", type = !llvm2.half } : () -> () + // CHECK: declare bfloat @return_bfloat() + // CHECK: declare bfloat @return_bfloat_round() + "llvm.test_introduce_func"() { name = "return_bfloat", type = !llvm2.bfloat } : () -> () + // CHECK: declare float @return_float() + // CHECK: declare float @return_float_round() + "llvm.test_introduce_func"() { name = "return_float", type = !llvm2.float } : () -> () + // CHECK: declare double @return_double() + // CHECK: declare double @return_double_round() + "llvm.test_introduce_func"() { name = "return_double", type = !llvm2.double } : () -> () + // CHECK: declare fp128 @return_fp128() + // CHECK: declare fp128 @return_fp128_round() + "llvm.test_introduce_func"() { name = "return_fp128", type = !llvm2.fp128 } : () -> () + // CHECK: declare x86_fp80 @return_x86_fp80() + // CHECK: declare x86_fp80 @return_x86_fp80_round() + "llvm.test_introduce_func"() { name = "return_x86_fp80", type = !llvm2.x86_fp80 } : () -> () + // CHECK: declare ppc_fp128 @return_ppc_fp128() + // CHECK: declare ppc_fp128 @return_ppc_fp128_round() + "llvm.test_introduce_func"() { name = "return_ppc_fp128", type = !llvm2.ppc_fp128 } : () -> () + // CHECK: declare x86_mmx @return_x86_mmx() + // CHECK: declare x86_mmx @return_x86_mmx_round() + "llvm.test_introduce_func"() { name = "return_x86_mmx", type = !llvm2.x86_mmx } : () -> () + llvm.return +} + +llvm.func @funcs() { + // CHECK: declare void @f_void_i32(i32) + // CHECK: declare void @f_void_i32_round(i32) + "llvm.test_introduce_func"() { name ="f_void_i32", type = !llvm2.func } : () -> () + // CHECK: declare i32 @f_i32_empty() + // CHECK: declare i32 @f_i32_empty_round() + "llvm.test_introduce_func"() { name ="f_i32_empty", type = !llvm2.func } : () -> () + // CHECK: declare i32 @f_i32_half_bfloat_float_double(half, bfloat, float, double) + // CHECK: declare i32 @f_i32_half_bfloat_float_double_round(half, bfloat, float, double) + "llvm.test_introduce_func"() { name ="f_i32_half_bfloat_float_double", type = !llvm2.func } : () -> () + // CHECK: declare i32 @f_i32_i32_i32(i32, i32) + // CHECK: declare i32 @f_i32_i32_i32_round(i32, i32) + "llvm.test_introduce_func"() { name ="f_i32_i32_i32", type = !llvm2.func } : () -> () + // CHECK: declare void @f_void_variadic(...) + // CHECK: declare void @f_void_variadic_round(...) + "llvm.test_introduce_func"() { name ="f_void_variadic", type = !llvm2.func } : () -> () + // CHECK: declare void @f_void_i32_i32_variadic(i32, i32, ...) + // CHECK: declare void @f_void_i32_i32_variadic_round(i32, i32, ...) + "llvm.test_introduce_func"() { name ="f_void_i32_i32_variadic", type = !llvm2.func } : () -> () + llvm.return +} + +llvm.func @ints() { + // CHECK: declare i1 @return_i1() + // CHECK: declare i1 @return_i1_round() + "llvm.test_introduce_func"() { name = "return_i1", type = !llvm2.i1 } : () -> () + // CHECK: declare i8 @return_i8() + // CHECK: declare i8 @return_i8_round() + "llvm.test_introduce_func"() { name = "return_i8", type = !llvm2.i8 } : () -> () + // CHECK: declare i16 @return_i16() + // CHECK: declare i16 @return_i16_round() + "llvm.test_introduce_func"() { name = "return_i16", type = !llvm2.i16 } : () -> () + // CHECK: declare i32 @return_i32() + // CHECK: declare i32 @return_i32_round() + "llvm.test_introduce_func"() { name = "return_i32", type = !llvm2.i32 } : () -> () + // CHECK: declare i64 @return_i64() + // CHECK: declare i64 @return_i64_round() + "llvm.test_introduce_func"() { name = "return_i64", type = !llvm2.i64 } : () -> () + // CHECK: declare i57 @return_i57() + // CHECK: declare i57 @return_i57_round() + "llvm.test_introduce_func"() { name = "return_i57", type = !llvm2.i57 } : () -> () + // CHECK: declare i129 @return_i129() + // CHECK: declare i129 @return_i129_round() + "llvm.test_introduce_func"() { name = "return_i129", type = !llvm2.i129 } : () -> () + llvm.return +} + +llvm.func @pointers() { + // CHECK: declare i8* @return_pi8() + // CHECK: declare i8* @return_pi8_round() + "llvm.test_introduce_func"() { name = "return_pi8", type = !llvm2.ptr } : () -> () + // CHECK: declare float* @return_pfloat() + // CHECK: declare float* @return_pfloat_round() + "llvm.test_introduce_func"() { name = "return_pfloat", type = !llvm2.ptr } : () -> () + // CHECK: declare i8** @return_ppi8() + // CHECK: declare i8** @return_ppi8_round() + "llvm.test_introduce_func"() { name = "return_ppi8", type = !llvm2.ptr> } : () -> () + // CHECK: declare i8***** @return_pppppi8() + // CHECK: declare i8***** @return_pppppi8_round() + "llvm.test_introduce_func"() { name = "return_pppppi8", type = !llvm2.ptr>>>> } : () -> () + // CHECK: declare i8* @return_pi8_0() + // CHECK: declare i8* @return_pi8_0_round() + "llvm.test_introduce_func"() { name = "return_pi8_0", type = !llvm2.ptr } : () -> () + // CHECK: declare i8 addrspace(1)* @return_pi8_1() + // CHECK: declare i8 addrspace(1)* @return_pi8_1_round() + "llvm.test_introduce_func"() { name = "return_pi8_1", type = !llvm2.ptr } : () -> () + // CHECK: declare i8 addrspace(42)* @return_pi8_42() + // CHECK: declare i8 addrspace(42)* @return_pi8_42_round() + "llvm.test_introduce_func"() { name = "return_pi8_42", type = !llvm2.ptr } : () -> () + // CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9() + // CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9_round() + "llvm.test_introduce_func"() { name = "return_ppi8_42_9", type = !llvm2.ptr, 9> } : () -> () + llvm.return +} + +llvm.func @vectors() { + // CHECK: declare <4 x i32> @return_v4_i32() + // CHECK: declare <4 x i32> @return_v4_i32_round() + "llvm.test_introduce_func"() { name = "return_v4_i32", type = !llvm2.vec<4 x i32> } : () -> () + // CHECK: declare <4 x float> @return_v4_float() + // CHECK: declare <4 x float> @return_v4_float_round() + "llvm.test_introduce_func"() { name = "return_v4_float", type = !llvm2.vec<4 x float> } : () -> () + // CHECK: declare @return_vs_4_i32() + // CHECK: declare @return_vs_4_i32_round() + "llvm.test_introduce_func"() { name = "return_vs_4_i32", type = !llvm2.vec } : () -> () + // CHECK: declare @return_vs_8_half() + // CHECK: declare @return_vs_8_half_round() + "llvm.test_introduce_func"() { name = "return_vs_8_half", type = !llvm2.vec } : () -> () + // CHECK: declare <4 x i8*> @return_v_4_pi8() + // CHECK: declare <4 x i8*> @return_v_4_pi8_round() + "llvm.test_introduce_func"() { name = "return_v_4_pi8", type = !llvm2.vec<4 x ptr> } : () -> () + llvm.return +} + +llvm.func @arrays() { + // CHECK: declare [10 x i32] @return_a10_i32() + // CHECK: declare [10 x i32] @return_a10_i32_round() + "llvm.test_introduce_func"() { name = "return_a10_i32", type = !llvm2.array<10 x i32> } : () -> () + // CHECK: declare [8 x float] @return_a8_float() + // CHECK: declare [8 x float] @return_a8_float_round() + "llvm.test_introduce_func"() { name = "return_a8_float", type = !llvm2.array<8 x float> } : () -> () + // CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4() + // CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4_round() + "llvm.test_introduce_func"() { name = "return_a10_pi32_4", type = !llvm2.array<10 x ptr> } : () -> () + // CHECK: declare [10 x [4 x float]] @return_a10_a4_float() + // CHECK: declare [10 x [4 x float]] @return_a10_a4_float_round() + "llvm.test_introduce_func"() { name = "return_a10_a4_float", type = !llvm2.array<10 x array<4 x float>> } : () -> () + llvm.return +} + +llvm.func @literal_structs() { + // CHECK: declare {} @return_struct_empty() + // CHECK: declare {} @return_struct_empty_round() + "llvm.test_introduce_func"() { name = "return_struct_empty", type = !llvm2.struct<()> } : () -> () + // CHECK: declare { i32 } @return_s_i32() + // CHECK: declare { i32 } @return_s_i32_round() + "llvm.test_introduce_func"() { name = "return_s_i32", type = !llvm2.struct<(i32)> } : () -> () + // CHECK: declare { float, i32 } @return_s_float_i32() + // CHECK: declare { float, i32 } @return_s_float_i32_round() + "llvm.test_introduce_func"() { name = "return_s_float_i32", type = !llvm2.struct<(float, i32)> } : () -> () + // CHECK: declare { { i32 } } @return_s_s_i32() + // CHECK: declare { { i32 } } @return_s_s_i32_round() + "llvm.test_introduce_func"() { name = "return_s_s_i32", type = !llvm2.struct<(struct<(i32)>)> } : () -> () + // CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float() + // CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float_round() + "llvm.test_introduce_func"() { name = "return_s_i32_s_i32_float", type = !llvm2.struct<(i32, struct<(i32)>, float)> } : () -> () + + // CHECK: declare <{}> @return_sp_empty() + // CHECK: declare <{}> @return_sp_empty_round() + "llvm.test_introduce_func"() { name = "return_sp_empty", type = !llvm2.struct } : () -> () + // CHECK: declare <{ i32 }> @return_sp_i32() + // CHECK: declare <{ i32 }> @return_sp_i32_round() + "llvm.test_introduce_func"() { name = "return_sp_i32", type = !llvm2.struct } : () -> () + // CHECK: declare <{ float, i32 }> @return_sp_float_i32() + // CHECK: declare <{ float, i32 }> @return_sp_float_i32_round() + "llvm.test_introduce_func"() { name = "return_sp_float_i32", type = !llvm2.struct } : () -> () + // CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float() + // CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float_round() + "llvm.test_introduce_func"() { name = "return_sp_i32_s_i31_1_float", type = !llvm2.struct, float)> } : () -> () + + // CHECK: declare { <{ i32 }> } @return_s_sp_i32() + // CHECK: declare { <{ i32 }> } @return_s_sp_i32_round() + "llvm.test_introduce_func"() { name = "return_s_sp_i32", type = !llvm2.struct<(struct)> } : () -> () + // CHECK: declare <{ { i32 } }> @return_sp_s_i32() + // CHECK: declare <{ { i32 } }> @return_sp_s_i32_round() + "llvm.test_introduce_func"() { name = "return_sp_s_i32", type = !llvm2.struct)> } : () -> () + llvm.return +} + +// ----- +// Put structs into a separate split so that we can match their declarations +// locally. + +// CHECK: %empty = type {} +// CHECK: %opaque = type opaque +// CHECK: %long = type { i32, { i32, i1 }, float, void ()* } +// CHECK: %self-recursive = type { %self-recursive* } +// CHECK: %unpacked = type { i32 } +// CHECK: %packed = type <{ i32 }> +// CHECK: %"name with spaces and !^$@$#" = type <{ i32 }> +// CHECK: %mutually-a = type { %mutually-b* } +// CHECK: %mutually-b = type { %mutually-a addrspace(3)* } +// CHECK: %struct-of-arrays = type { [10 x i32] } +// CHECK: %array-of-structs = type { i32 } +// CHECK: %ptr-to-struct = type { i8 } + +llvm.func @identified_structs() { + // CHECK: declare %empty + "llvm.test_introduce_func"() { name = "return_s_empty", type = !llvm2.struct<"empty", ()> } : () -> () + // CHECK: declare %opaque + "llvm.test_introduce_func"() { name = "return_s_opaque", type = !llvm2.struct<"opaque", opaque> } : () -> () + // CHECK: declare %long + "llvm.test_introduce_func"() { name = "return_s_long", type = !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr>)> } : () -> () + // CHECK: declare %self-recursive + "llvm.test_introduce_func"() { name = "return_s_self_recurisve", type = !llvm2.struct<"self-recursive", (ptr>)> } : () -> () + // CHECK: declare %unpacked + "llvm.test_introduce_func"() { name = "return_s_unpacked", type = !llvm2.struct<"unpacked", (i32)> } : () -> () + // CHECK: declare %packed + "llvm.test_introduce_func"() { name = "return_s_packed", type = !llvm2.struct<"packed", packed (i32)> } : () -> () + // CHECK: declare %"name with spaces and !^$@$#" + "llvm.test_introduce_func"() { name = "return_s_symbols", type = !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> } : () -> () + + // CHECK: declare %mutually-a + "llvm.test_introduce_func"() { name = "return_s_mutually_a", type = !llvm2.struct<"mutually-a", (ptr, 3>)>>)> } : () -> () + // CHECK: declare %mutually-b + "llvm.test_introduce_func"() { name = "return_s_mutually_b", type = !llvm2.struct<"mutually-b", (ptr>)>, 3>)> } : () -> () + + // CHECK: declare %struct-of-arrays + "llvm.test_introduce_func"() { name = "return_s_struct_of_arrays", type = !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> } : () -> () + // CHECK: declare [10 x %array-of-structs] + "llvm.test_introduce_func"() { name = "return_s_array_of_structs", type = !llvm2.array<10 x struct<"array-of-structs", (i32)>> } : () -> () + // CHECK: declare %ptr-to-struct* + "llvm.test_introduce_func"() { name = "return_s_ptr_to_struct", type = !llvm2.ptr> } : () -> () + llvm.return +} diff --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt --- a/mlir/test/lib/CMakeLists.txt +++ b/mlir/test/lib/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(Reducer) +add_subdirectory(Target) add_subdirectory(Transforms) diff --git a/mlir/test/lib/Target/CMakeLists.txt b/mlir/test/lib/Target/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Target/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_translation_library(MLIRTestLLVMTypeTranslation + TestLLVMTypeTranslation.cpp + + LINK_COMPONENTS + Core + TransformUtils + + LINK_LIBS PUBLIC + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + MLIRTestIR + MLIRTranslation + ) diff --git a/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp b/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp @@ -0,0 +1,79 @@ +//===- TestLLVMTypeTranslation.cpp - Test MLIR/LLVM IR type translation ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Translation.h" + +using namespace mlir; + +namespace { +class TestLLVMTypeTranslation : public LLVM::ModuleTranslation { + // Allow access to the constructors under MSVC. + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + /// Simple test facility for translating types from MLIR LLVM dialect to LLVM + /// IR. This converts the "llvm.test_introduce_func" operation into an LLVM IR + /// function with the name extracted from the `name` attribute that returns + /// the type contained in the `type` attribute if it is a non-function type or + /// that has the signature obtained by converting `type` if it is a function + /// type. This is a temporary check before type translation is substituted + /// into the main translation flow and exercised here. + LogicalResult convertOperation(Operation &op, + llvm::IRBuilder<> &builder) override { + if (op.getName().getStringRef() == "llvm.test_introduce_func") { + auto attr = op.getAttrOfType("type"); + assert(attr && "expected 'type' attribute"); + auto type = attr.getValue().cast(); + + auto nameAttr = op.getAttrOfType("name"); + assert(nameAttr && "expected 'name' attributes"); + + llvm::Type *translated = + LLVM::translateTypeToLLVMIR(type, builder.getContext()); + + llvm::Module *module = builder.GetInsertBlock()->getModule(); + if (auto *funcType = dyn_cast(translated)) + module->getOrInsertFunction(nameAttr.getValue(), funcType); + else + module->getOrInsertFunction(nameAttr.getValue(), translated); + + std::string roundtripName = (Twine(nameAttr.getValue()) + "_round").str(); + LLVM::LLVMTypeNew translatedBack = + LLVM::translateTypeFromLLVMIR(translated, *op.getContext()); + llvm::Type *translatedBackAndForth = + LLVM::translateTypeToLLVMIR(translatedBack, builder.getContext()); + if (auto *funcType = dyn_cast(translatedBackAndForth)) + module->getOrInsertFunction(roundtripName, funcType); + else + module->getOrInsertFunction(roundtripName, translatedBackAndForth); + return success(); + } + + return LLVM::ModuleTranslation::convertOperation(op, builder); + } +}; +} // namespace + +namespace mlir { +void registerTestLLVMTypeTranslation() { + TranslateFromMLIRRegistration reg( + "test-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { + std::unique_ptr llvmModule = + LLVM::ModuleTranslation::translateModule( + module.getOperation()); + llvmModule->print(output, nullptr); + return success(); + }); +} +} // namespace mlir diff --git a/mlir/tools/mlir-translate/CMakeLists.txt b/mlir/tools/mlir-translate/CMakeLists.txt --- a/mlir/tools/mlir-translate/CMakeLists.txt +++ b/mlir/tools/mlir-translate/CMakeLists.txt @@ -13,7 +13,11 @@ PRIVATE ${dialect_libs} ${translation_libs} + ${test_libs} MLIRIR + # TODO: remove after LLVM dialect transition is complete; translation uses a + # registration function defined in this library unconditionally. + MLIRLLVMTypeTestDialect MLIRParser MLIRPass MLIRSPIRV diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -49,17 +49,21 @@ namespace mlir { // Defined in the test directory, no public header. +void registerLLVMTypeTestDialect(); +void registerTestLLVMTypeTranslation(); void registerTestRoundtripSPIRV(); void registerTestRoundtripDebugSPIRV(); } // namespace mlir static void registerTestTranslations() { + registerTestLLVMTypeTranslation(); registerTestRoundtripSPIRV(); registerTestRoundtripDebugSPIRV(); } int main(int argc, char **argv) { registerAllDialects(); + registerLLVMTypeTestDialect(); registerAllTranslations(); registerTestTranslations(); llvm::InitLLVM y(argc, argv);