Index: llvm/include/llvm/Transforms/Utils/FunctionComparator.h =================================================================== --- llvm/include/llvm/Transforms/Utils/FunctionComparator.h +++ llvm/include/llvm/Transforms/Utils/FunctionComparator.h @@ -332,7 +332,8 @@ int cmpOrderings(AtomicOrdering L, AtomicOrdering R) const; int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const; int cmpAttrs(const AttributeList L, const AttributeList R) const; - int cmpRangeMetadata(const MDNode *L, const MDNode *R) const; + int cmpMetadata(const MDNode *L, const MDNode *R) const; + int cmpLoadInstMetadata(LoadInst const *L, LoadInst const *R) const; int cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const; /// Compare two GEPs for equivalent pointer arithmetic. Index: llvm/lib/Transforms/Utils/FunctionComparator.cpp =================================================================== --- llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -157,8 +157,7 @@ return 0; } -int FunctionComparator::cmpRangeMetadata(const MDNode *L, - const MDNode *R) const { +int FunctionComparator::cmpMetadata(const MDNode *L, const MDNode *R) const { if (L == R) return 0; if (!L) @@ -178,12 +177,52 @@ for (size_t I = 0; I < L->getNumOperands(); ++I) { ConstantInt *LLow = mdconst::extract(L->getOperand(I)); ConstantInt *RLow = mdconst::extract(R->getOperand(I)); + if (LLow == RLow) + continue; + if (!LLow) + return -1; + if (!RLow) + return 1; if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue())) return Res; } return 0; } +int FunctionComparator::cmpLoadInstMetadata(LoadInst const *L, + LoadInst const *R) const { + if (L == R) + return 0; + if (!L) + return -1; + if (!R) + return 1; + + /// Compare metadata with empty nodes + /// These metadata affects the other optimization passes by making assertions + /// or constraints. + /// Values that carries different expectations should be considered different. + auto MDL = SmallVector>(); + auto MDR = SmallVector>(); + L->getAllMetadata(MDL); + R->getAllMetadata(MDR); + if (MDL.size() > MDR.size()) + return 1; + else if (MDL.size() < MDR.size()) + return -1; + for (size_t I = 0, N = MDL.size(); I < N; ++I) { + auto const &[KeyL, ML] = MDL[I]; + auto const &[KeyR, MR] = MDR[I]; + if (int Res = cmpNumbers(KeyL, KeyR)) + return Res; + if (KeyL == LLVMContext::MD_dbg) + continue; + if (int Res = cmpMetadata(ML, MR)) + return Res; + } + return 0; +} + int FunctionComparator::cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const { assert(LCS.getOpcode() == RCS.getOpcode() && "Can't compare otherwise!"); @@ -586,9 +625,7 @@ if (int Res = cmpNumbers(LI->getSyncScopeID(), cast(R)->getSyncScopeID())) return Res; - return cmpRangeMetadata( - LI->getMetadata(LLVMContext::MD_range), - cast(R)->getMetadata(LLVMContext::MD_range)); + return cmpLoadInstMetadata(LI, cast(R)); } if (const StoreInst *SI = dyn_cast(L)) { if (int Res = @@ -616,8 +653,8 @@ if (int Res = cmpNumbers(CI->getTailCallKind(), cast(R)->getTailCallKind())) return Res; - return cmpRangeMetadata(L->getMetadata(LLVMContext::MD_range), - R->getMetadata(LLVMContext::MD_range)); + return cmpMetadata(L->getMetadata(LLVMContext::MD_range), + R->getMetadata(LLVMContext::MD_range)); } if (const InsertValueInst *IVI = dyn_cast(L)) { ArrayRef LIndices = IVI->getIndices(); Index: llvm/test/Transforms/MergeFunc/mergefunc-preserve-nonnull.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/MergeFunc/mergefunc-preserve-nonnull.ll @@ -0,0 +1,42 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes=mergefunc -S < %s | FileCheck %s + +; This test makes sure that the mergefunc pass does not merge functions +; that have different nonnull assertions. + +%1 = type ptr + +define void @f1(ptr sret(%1) dereferenceable(8) %0, ptr dereferenceable(24) %1) { +; CHECK-LABEL: @f1( +; CHECK-NEXT: [[TMP3:%.*]] = load ptr, ptr [[TMP1:%.*]], align 8, !nonnull !0 +; CHECK-NEXT: store ptr [[TMP3]], ptr [[TMP0:%.*]], align 8 +; CHECK-NEXT: ret void +; + %3 = load ptr, ptr %1, align 8, !nonnull !0 + store ptr %3, ptr %0, align 8 + ret void +} + +define void @f2(ptr sret(%1) dereferenceable(8) %0, ptr dereferenceable(24) %1) { +; CHECK-LABEL: @f2( +; CHECK-NEXT: [[TMP3:%.*]] = load ptr, ptr [[TMP1:%.*]], align 8 +; CHECK-NEXT: store ptr [[TMP3]], ptr [[TMP0:%.*]], align 8 +; CHECK-NEXT: ret void +; + %3 = load ptr, ptr %1, align 8 + store ptr %3, ptr %0, align 8 + ret void +} + +define void @f3(ptr sret(%1) dereferenceable(8) %0, ptr dereferenceable(24) %1) { +; CHECK-LABEL: @f3( +; CHECK-NEXT: [[TMP3:%.*]] = load ptr, ptr [[TMP1:%.*]], align 8, !noundef !0 +; CHECK-NEXT: store ptr [[TMP3]], ptr [[TMP0:%.*]], align 8 +; CHECK-NEXT: ret void +; + %3 = load ptr, ptr %1, align 8, !noundef !0 + store ptr %3, ptr %0, align 8 + ret void +} + +!0 = !{}