Index: mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -761,63 +761,79 @@ return false; } -static bool isCompatibleImpl(Type type, SetVector &callstack) { - if (callstack.contains(type)) +static bool isCompatibleImpl(Type type, SetVector &callstack, + DenseSet &validTypes) { + if (validTypes.count(type) || !callstack.insert(type)) return true; - callstack.insert(type); auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); }); auto isCompatible = [&](Type type) { - return isCompatibleImpl(type, callstack); + return isCompatibleImpl(type, callstack, validTypes); }; - return llvm::TypeSwitch(type) - .Case([&](auto structType) { - return llvm::all_of(structType.getBody(), isCompatible); - }) - .Case([&](auto funcType) { - return isCompatible(funcType.getReturnType()) && - llvm::all_of(funcType.getParams(), isCompatible); - }) - .Case([](auto intType) { return intType.isSignless(); }) - .Case([&](auto vecType) { - return vecType.getRank() == 1 && isCompatible(vecType.getElementType()); - }) - .Case([&](auto pointerType) { - if (pointerType.isOpaque()) - return true; - return isCompatible(pointerType.getElementType()); - }) - // clang-format off - .Case< - LLVMFixedVectorType, - LLVMScalableVectorType, - LLVMArrayType - >([&](auto containerType) { - return isCompatible(containerType.getElementType()); - }) - .Case< - BFloat16Type, - Float16Type, - Float32Type, - Float64Type, - Float80Type, - Float128Type, - LLVMLabelType, - LLVMMetadataType, - LLVMPPCFP128Type, - LLVMTokenType, - LLVMVoidType, - LLVMX86MMXType - >([](Type) { return true; }) - // clang-format on - .Default([](Type) { return false; }); + bool isValid = + llvm::TypeSwitch(type) + .Case([&](auto structType) { + return llvm::all_of(structType.getBody(), isCompatible); + }) + .Case([&](auto funcType) { + return isCompatible(funcType.getReturnType()) && + llvm::all_of(funcType.getParams(), isCompatible); + }) + .Case([](auto intType) { return intType.isSignless(); }) + .Case([&](auto vecType) { + return vecType.getRank() == 1 && + isCompatible(vecType.getElementType()); + }) + .Case([&](auto pointerType) { + if (pointerType.isOpaque()) + return true; + return isCompatible(pointerType.getElementType()); + }) + // clang-format off + .Case< + LLVMFixedVectorType, + LLVMScalableVectorType, + LLVMArrayType + >([&](auto containerType) { + return isCompatible(containerType.getElementType()); + }) + .Case< + BFloat16Type, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + Float128Type, + LLVMLabelType, + LLVMMetadataType, + LLVMPPCFP128Type, + LLVMTokenType, + LLVMVoidType, + LLVMX86MMXType + >([](Type) { return true; }) + // clang-format on + .Default([](Type) { return false; }); + + if (isValid) + validTypes.insert(type); + + return isValid; } bool mlir::LLVM::isCompatibleType(Type type) { + // Ideally, we can use a single set to memorize both the visited types and + // valid types -- we simply don't erase the type from set after its visiting + // and mark the set thread_local. Because if the type is invalid, most of the + // current usages of `isCompatibleType` simply bails out. However, there are + // still some places dipatch logics based on the result of this function and + // in the cases of mixing LLVM dialect with other dialects, we might end up + // memorizing invalid types. + // Thus, we're using two separate sets here. SetVector callstack; - return isCompatibleImpl(type, callstack); + thread_local DenseSet validTypes; + return isCompatibleImpl(type, callstack, validTypes); } bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {