Index: mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -25,6 +25,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/ThreadLocalCache.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -56,6 +56,18 @@ /// Name of the target triple attribute. static StringRef getTargetTripleAttrName() { return "llvm.target_triple"; } + + /// Check if a compatible LLVM type has been cached. + bool isCompatibleTypeCached(::mlir::Type type); + + /// Memorize a compatible LLVM type. + void cacheCompatibleType(::mlir::Type type); + + private: + /// A cache storing compatible LLVM types that have been verified. This + /// can save us lots of verification time if there are many occurrences + /// of some deeply-nested aggregate types in the program. + ::mlir::ThreadLocalCache> compatibleTypes; }]; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2917,6 +2917,14 @@ op->hasTrait(); } +bool LLVMDialect::isCompatibleTypeCached(Type type) { + return compatibleTypes->count(type); +} + +void LLVMDialect::cacheCompatibleType(Type type) { + compatibleTypes->insert(type); +} + void FMFAttr::print(AsmPrinter &printer) const { printer << "<"; printer << stringifyFastmathFlags(this->getFlags()); Index: mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -762,23 +762,39 @@ } static bool isCompatibleImpl(Type type, SetVector &callstack) { - if (callstack.contains(type)) + if (!callstack.insert(type)) return true; - callstack.insert(type); auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); }); + // Right now we only cache compatible, aggregate LLVM types. Ideally we + // should memorize any type that can be deeply nested, but putting the + // cache storage in LLVM dialect is the easiest and cleanest way to ensure + // thread safety. And such storage is only reachable from LLVM types. + auto *llvmDialect = dyn_cast(&type.getDialect()); + if (llvmDialect && llvmDialect->isCompatibleTypeCached(type)) + return true; + auto isCompatible = [&](Type type) { return isCompatibleImpl(type, callstack); }; + auto cacheCompatibleType = [&](bool result, Type typeToCache) -> bool { + if (result && llvmDialect) + llvmDialect->cacheCompatibleType(typeToCache); + return result; + }; + return llvm::TypeSwitch(type) .Case([&](auto structType) { - return llvm::all_of(structType.getBody(), isCompatible); + return cacheCompatibleType( + llvm::all_of(structType.getBody(), isCompatible), structType); }) .Case([&](auto funcType) { - return isCompatible(funcType.getReturnType()) && - llvm::all_of(funcType.getParams(), isCompatible); + return cacheCompatibleType( + isCompatible(funcType.getReturnType()) && + llvm::all_of(funcType.getParams(), isCompatible), + funcType); }) .Case([](auto intType) { return intType.isSignless(); }) .Case([&](auto vecType) { @@ -787,7 +803,8 @@ .Case([&](auto pointerType) { if (pointerType.isOpaque()) return true; - return isCompatible(pointerType.getElementType()); + return cacheCompatibleType(isCompatible(pointerType.getElementType()), + pointerType); }) // clang-format off .Case< @@ -795,7 +812,8 @@ LLVMScalableVectorType, LLVMArrayType >([&](auto containerType) { - return isCompatible(containerType.getElementType()); + return cacheCompatibleType(isCompatible(containerType.getElementType()), + containerType); }) .Case< BFloat16Type,