Index: llvm/include/llvm/Transforms/Utils/FunctionComparator.h =================================================================== --- llvm/include/llvm/Transforms/Utils/FunctionComparator.h +++ llvm/include/llvm/Transforms/Utils/FunctionComparator.h @@ -333,6 +333,7 @@ int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const; int cmpAttrs(const AttributeList L, const AttributeList R) const; int cmpRangeMetadata(const MDNode *L, const MDNode *R) const; + int cmpLoadInstMetadata(LoadInst const *L, LoadInst const *R) const; int cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const; /// Compare two GEPs for equivalent pointer arithmetic. Index: llvm/lib/Transforms/Utils/FunctionComparator.cpp =================================================================== --- llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -184,6 +184,44 @@ return 0; } +int FunctionComparator::cmpLoadInstMetadata(LoadInst const *L, + LoadInst const *R) const { + if (L == R) + return 0; + if (!L) + return -1; + if (!R) + return 1; + + /// Compare metadata with empty nodes + /// These metadata affects the other optimization passes by making assertions + /// or constraints. + /// Values that carries different expectations should be considered different. + /// First, check metadata with empty nodes: + for (auto Kind : {LLVMContext::MD_nonnull, LLVMContext::MD_invariant_group, + LLVMContext::MD_invariant_load, LLVMContext::MD_noundef}) { + auto ML = bool(L->getMetadata(Kind)); + auto MR = bool(R->getMetadata(Kind)); + if (ML != MR) + return ML - MR; + } + /// We are left with `dereferenceable` and `dereferenceable_or_null`, + /// which carries a single parameter for the size of dereferenceable + /// data behind the result: + for (auto Kind : {LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null}) { + auto ML = L->getMetadata(Kind); + auto MR = R->getMetadata(Kind); + if (ML == MR) + continue; + auto VL = mdconst::extract(ML->getOperand(0)); + auto VR = mdconst::extract(MR->getOperand(0)); + if (int Res = cmpAPInts(VL->getValue(), VR->getValue())) + return Res; + } + return 0; +} + int FunctionComparator::cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const { assert(LCS.getOpcode() == RCS.getOpcode() && "Can't compare otherwise!"); @@ -586,9 +624,11 @@ if (int Res = cmpNumbers(LI->getSyncScopeID(), cast(R)->getSyncScopeID())) return Res; - return cmpRangeMetadata( - LI->getMetadata(LLVMContext::MD_range), - cast(R)->getMetadata(LLVMContext::MD_range)); + if (int Res = cmpRangeMetadata( + LI->getMetadata(LLVMContext::MD_range), + cast(R)->getMetadata(LLVMContext::MD_range))) + return Res; + return cmpLoadInstMetadata(LI, cast(R)); } if (const StoreInst *SI = dyn_cast(L)) { if (int Res = Index: llvm/test/Transforms/MergeFunc/mergefunc-preserve-nonnull.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/MergeFunc/mergefunc-preserve-nonnull.ll @@ -0,0 +1,36 @@ +; RUN: opt -passes=mergefunc -S < %s | FileCheck %s + +; This test makes sure that the mergefunc pass does not merge functions +; that have different nonnull assertions. + +target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" + +%1 = type ptr + +define void @f1(ptr noalias nocapture noundef sret(%1) dereferenceable(24) %0, ptr noalias nocapture noundef dereferenceable(24) %1) { + ; CHECK-LABEL: @f1( + ; CHECK: %3 = load ptr, ptr %1, align 8, !nonnull !0 + %3 = load ptr, ptr %1, align 8, !nonnull !0 + store ptr %3, ptr %0, align 8 + ret void +} + +define void @f2(ptr noalias nocapture noundef sret(%1) dereferenceable(24) %0, ptr noalias nocapture noundef dereferenceable(24) %1) { + ; CHECK-LABEL: @f2( + ; CHECK: %3 = load ptr, ptr %1, align 8 + %3 = load ptr, ptr %1, align 8 + store ptr %3, ptr %0, align 8 + ret void +} + +define void @f3(ptr noalias nocapture noundef sret(%1) dereferenceable(24) %0, ptr noalias nocapture noundef dereferenceable(24) %1) { + ; CHECK-LABEL: @f3( + ; CHECK: %3 = load ptr, ptr %1, align 8, !noundef !0 + %3 = load ptr, ptr %1, align 8, !noundef !0 + store ptr %3, ptr %0, align 8 + ret void +} + +!0 = !{} +!1 = !{i64 1} +!2 = !{i64 8}