diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -25,6 +25,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/ThreadLocalCache.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -61,6 +61,15 @@ static StringRef getEmitCWrapperAttrName() { return "llvm.emit_c_interface"; } + + /// Returns `true` if the given type is compatible with the LLVM dialect. + static bool isCompatibleType(Type); + + private: + /// A cache storing compatible LLVM types that have been verified. This + /// can save us lots of verification time if there are many occurrences + /// of some deeply-nested aggregate types in the program. + ThreadLocalCache> compatibleTypes; }]; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -451,7 +451,8 @@ // Utility functions. //===----------------------------------------------------------------------===// -/// Returns `true` if the given type is compatible with the LLVM dialect. +/// Returns `true` if the given type is compatible with the LLVM dialect. This +/// is an alias to `LLVMDialect::isCompatibleType`. bool isCompatibleType(Type type); /// Returns `true` if the given outer type is compatible with the LLVM dialect diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -721,63 +721,75 @@ return false; } -static bool isCompatibleImpl(Type type, SetVector &callstack) { - if (callstack.contains(type)) +static bool isCompatibleImpl(Type type, DenseSet &compatibleTypes) { + if (!compatibleTypes.insert(type).second) return true; - callstack.insert(type); - auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); }); - auto isCompatible = [&](Type type) { - return isCompatibleImpl(type, callstack); + return isCompatibleImpl(type, compatibleTypes); }; - return llvm::TypeSwitch(type) - .Case([&](auto structType) { - return llvm::all_of(structType.getBody(), isCompatible); - }) - .Case([&](auto funcType) { - return isCompatible(funcType.getReturnType()) && - llvm::all_of(funcType.getParams(), isCompatible); - }) - .Case([](auto intType) { return intType.isSignless(); }) - .Case([&](auto vecType) { - return vecType.getRank() == 1 && isCompatible(vecType.getElementType()); - }) - .Case([&](auto pointerType) { - if (pointerType.isOpaque()) - return true; - return isCompatible(pointerType.getElementType()); - }) - // clang-format off - .Case< - LLVMFixedVectorType, - LLVMScalableVectorType, - LLVMArrayType - >([&](auto containerType) { - return isCompatible(containerType.getElementType()); - }) - .Case< - BFloat16Type, - Float16Type, - Float32Type, - Float64Type, - Float80Type, - Float128Type, - LLVMLabelType, - LLVMMetadataType, - LLVMPPCFP128Type, - LLVMTokenType, - LLVMVoidType, - LLVMX86MMXType - >([](Type) { return true; }) - // clang-format on - .Default([](Type) { return false; }); + bool result = + llvm::TypeSwitch(type) + .Case([&](auto structType) { + return llvm::all_of(structType.getBody(), isCompatible); + }) + .Case([&](auto funcType) { + return isCompatible(funcType.getReturnType()) && + llvm::all_of(funcType.getParams(), isCompatible); + }) + .Case([](auto intType) { return intType.isSignless(); }) + .Case([&](auto vecType) { + return vecType.getRank() == 1 && + isCompatible(vecType.getElementType()); + }) + .Case([&](auto pointerType) { + if (pointerType.isOpaque()) + return true; + return isCompatible(pointerType.getElementType()); + }) + // clang-format off + .Case< + LLVMFixedVectorType, + LLVMScalableVectorType, + LLVMArrayType + >([&](auto containerType) { + return isCompatible(containerType.getElementType()); + }) + .Case< + BFloat16Type, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + Float128Type, + LLVMLabelType, + LLVMMetadataType, + LLVMPPCFP128Type, + LLVMTokenType, + LLVMVoidType, + LLVMX86MMXType + >([](Type) { return true; }) + // clang-format on + .Default([](Type) { return false; }); + + if (!result) + compatibleTypes.erase(type); + + return result; +} + +bool LLVMDialect::isCompatibleType(Type type) { + if (auto *llvmDialect = + type.getContext()->getLoadedDialect()) + return isCompatibleImpl(type, llvmDialect->compatibleTypes.get()); + + DenseSet localCompatibleTypes; + return isCompatibleImpl(type, localCompatibleTypes); } bool mlir::LLVM::isCompatibleType(Type type) { - SetVector callstack; - return isCompatibleImpl(type, callstack); + return LLVMDialect::isCompatibleType(type); } bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {