Index: mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -25,6 +25,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/ThreadLocalCache.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -56,6 +56,15 @@ /// Name of the target triple attribute. static StringRef getTargetTripleAttrName() { return "llvm.target_triple"; } + + /// Returns `true` if the given type is compatible with the LLVM dialect. + static bool isCompatibleType(Type); + + private: + /// A cache storing compatible LLVM types that have been verified. This + /// can save us lots of verification time if there are many occurrences + /// of some deeply-nested aggregate types in the program. + ThreadLocalCache> compatibleTypes; }]; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; Index: mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -476,7 +476,8 @@ // Utility functions. //===----------------------------------------------------------------------===// -/// Returns `true` if the given type is compatible with the LLVM dialect. +/// Returns `true` if the given type is compatible with the LLVM dialect. This +/// is an alias to `LLVMDialect::isCompatibleType`. bool isCompatibleType(Type type); /// Returns `true` if the given outer type is compatible with the LLVM dialect Index: mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -761,33 +761,48 @@ return false; } -static bool isCompatibleImpl(Type type, SetVector &callstack) { - if (callstack.contains(type)) +static bool isCompatibleImpl(Type type, SetVector &callstack, + DenseSet *compatibleTypes) { + if (!callstack.insert(type)) return true; - callstack.insert(type); auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); }); + if (compatibleTypes && compatibleTypes->count(type)) + return true; + auto isCompatible = [&](Type type) { - return isCompatibleImpl(type, callstack); + return isCompatibleImpl(type, callstack, compatibleTypes); + }; + + auto cacheCompatibleType = [&](bool result, Type typeToCache) -> bool { + if (result && compatibleTypes) + compatibleTypes->insert(typeToCache); + return result; }; return llvm::TypeSwitch(type) .Case([&](auto structType) { - return llvm::all_of(structType.getBody(), isCompatible); + return cacheCompatibleType( + llvm::all_of(structType.getBody(), isCompatible), structType); }) .Case([&](auto funcType) { - return isCompatible(funcType.getReturnType()) && - llvm::all_of(funcType.getParams(), isCompatible); + return cacheCompatibleType( + isCompatible(funcType.getReturnType()) && + llvm::all_of(funcType.getParams(), isCompatible), + funcType); }) .Case([](auto intType) { return intType.isSignless(); }) .Case([&](auto vecType) { - return vecType.getRank() == 1 && isCompatible(vecType.getElementType()); + return cacheCompatibleType(vecType.getRank() == 1 && + isCompatible(vecType.getElementType()), + vecType); }) .Case([&](auto pointerType) { if (pointerType.isOpaque()) return true; - return isCompatible(pointerType.getElementType()); + return cacheCompatibleType(isCompatible(pointerType.getElementType()), + pointerType); }) // clang-format off .Case< @@ -795,7 +810,8 @@ LLVMScalableVectorType, LLVMArrayType >([&](auto containerType) { - return isCompatible(containerType.getElementType()); + return cacheCompatibleType(isCompatible(containerType.getElementType()), + containerType); }) .Case< BFloat16Type, @@ -815,9 +831,17 @@ .Default([](Type) { return false; }); } -bool mlir::LLVM::isCompatibleType(Type type) { +bool LLVMDialect::isCompatibleType(Type type) { SetVector callstack; - return isCompatibleImpl(type, callstack); + DenseSet *compatibleTypes = nullptr; + if (auto *llvmDialect = + type.getContext()->getLoadedDialect()) + compatibleTypes = &llvmDialect->compatibleTypes.get(); + return isCompatibleImpl(type, callstack, compatibleTypes); +} + +bool mlir::LLVM::isCompatibleType(Type type) { + return LLVMDialect::isCompatibleType(type); } bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {