diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -429,6 +429,10 @@ /// Returns `true` if the given type is compatible with the LLVM dialect. bool isCompatibleType(Type type); +/// Returns `true` if the given outer type is compatible with the LLVM dialect +/// without checking its potential nested types such as struct elements. +bool isCompatibleOuterType(Type type); + /// Returns `true` if the given type is a floating-point type compatible with /// the LLVM dialect. bool isCompatibleFloatingPointType(Type type); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -55,6 +55,12 @@ }); addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl &results, ArrayRef callStack) -> llvm::Optional { + // Fastpath for types that won't be converted by this callback anyway. + if (LLVM::isCompatibleType(type)) { + results.push_back(type); + return success(); + } + if (type.isIdentified()) { auto convertedType = LLVM::LLVMStructType::getIdentified( type.getContext(), ("_Converted_" + type.getName()).str()); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -468,7 +468,7 @@ Type type = dispatchParse(parser, /*allowAny=*/false); if (!type) return type; - if (!isCompatibleType(type)) { + if (!isCompatibleOuterType(type)) { parser.emitError(loc) << "unexpected type, expected keyword"; return nullptr; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -1,4 +1,3 @@ -//===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -19,6 +18,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/TypeSupport.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/TypeSize.h" @@ -120,9 +120,10 @@ //===----------------------------------------------------------------------===// bool LLVMPointerType::isValidElementType(Type type) { - return isCompatibleType(type) ? !type.isa() - : type.isa(); + return isCompatibleOuterType(type) + ? !type.isa() + : type.isa(); } LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) { @@ -483,17 +484,9 @@ // Utility functions. //===----------------------------------------------------------------------===// -bool mlir::LLVM::isCompatibleType(Type type) { - // Only signless integers are compatible. - if (auto intType = type.dyn_cast()) - return intType.isSignless(); - - // 1D vector types are compatible if their element types are. - if (auto vecType = type.dyn_cast()) - return vecType.getRank() == 1 && isCompatibleType(vecType.getElementType()); - +bool mlir::LLVM::isCompatibleOuterType(Type type) { // clang-format off - return type.isa< + if (type.isa< BFloat16Type, Float16Type, Float32Type, @@ -512,8 +505,75 @@ LLVMScalableVectorType, LLVMVoidType, LLVMX86MMXType - >(); - // clang-format on + >()) { + // clang-format on + return true; + } + + // Only signless integers are compatible. + if (auto intType = type.dyn_cast()) + return intType.isSignless(); + + // 1D vector types are compatible. + if (auto vecType = type.dyn_cast()) + return vecType.getRank() == 1; + + return false; +} + +static bool isCompatibleImpl(Type type, SetVector &callstack) { + if (callstack.contains(type)) + return true; + + callstack.insert(type); + auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); }); + + auto isCompatible = [&](Type type) { + return isCompatibleImpl(type, callstack); + }; + + return llvm::TypeSwitch(type) + .Case([&](auto structType) { + return llvm::all_of(structType.getBody(), isCompatible); + }) + .Case([&](auto funcType) { + return isCompatible(funcType.getReturnType()) && + llvm::all_of(funcType.getParams(), isCompatible); + }) + .Case([](auto intType) { return intType.isSignless(); }) + .Case([&](auto vecType) { + return vecType.getRank() == 1 && isCompatible(vecType.getElementType()); + }) + // clang-format off + .Case< + LLVMPointerType, + LLVMFixedVectorType, + LLVMScalableVectorType, + LLVMArrayType + >([&](auto containerType) { + return isCompatible(containerType.getElementType()); + }) + .Case< + BFloat16Type, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + Float128Type, + LLVMLabelType, + LLVMMetadataType, + LLVMPPCFP128Type, + LLVMTokenType, + LLVMVoidType, + LLVMX86MMXType + >([](Type) { return true; }) + // clang-format on + .Default([](Type) { return false; }); +} + +bool mlir::LLVM::isCompatibleType(Type type) { + SetVector callstack; + return isCompatibleImpl(type, callstack); } bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { diff --git a/mlir/test/Conversion/StandardToLLVM/convert-types.mlir b/mlir/test/Conversion/StandardToLLVM/convert-types.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-types.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-types.mlir @@ -16,6 +16,10 @@ // CHECK: !llvm.struct<"_Converted_named", (ptr)> func private @named_struct_ptr() -> !llvm.struct<"named", (ptr)> +// CHECK-LABEL: @named_no_convert +// CHECK: !llvm.struct<"no_convert", (ptr>)> +func private @named_no_convert() -> !llvm.struct<"no_convert", (ptr>)> + // CHECK-LABEL: @array_ptr() // CHECK: !llvm.array<10 x ptr> func private @array_ptr() -> !llvm.array<10 x ptr>