Index: llvm/include/llvm/IR/Metadata.h =================================================================== --- llvm/include/llvm/IR/Metadata.h +++ llvm/include/llvm/IR/Metadata.h @@ -1319,6 +1319,15 @@ static MDNode *getMostGenericRange(MDNode *A, MDNode *B); static MDNode *getMostGenericAliasScope(MDNode *A, MDNode *B); static MDNode *getMostGenericAlignmentOrDereferenceable(MDNode *A, MDNode *B); + /// Returns nullptr if merge is not implementated for the given types + /// instructions and !prof metadata. Currently only implemented with direct + /// callsites with branch weights (used only in SamplePGO as documented in + /// https://llvm.org/docs/BranchWeightMetadata.html#callinst). Pass in both + /// instructions and nodes. Instruction information (e.g., instruction type) + /// helps interpret profiles and make implementation clearer. + static MDNode *getMergedProfMetadata(MDNode *A, MDNode *B, + const Instruction *AInstr, + const Instruction *BInstr); }; /// Tuple of metadata. Index: llvm/lib/IR/Metadata.cpp =================================================================== --- llvm/lib/IR/Metadata.cpp +++ llvm/lib/IR/Metadata.cpp @@ -1072,6 +1072,59 @@ return B; } +MDNode *MDNode::getMergedProfMetadata(MDNode *A, MDNode *B, + const Instruction *AInstr, + const Instruction *BInstr) { + if (!A) + return B; + if (!B) + return A; + + assert(AInstr->getMetadata(LLVMContext::MD_prof) == A && + "Caller should guarantee"); + assert(BInstr->getMetadata(LLVMContext::MD_prof) == B && + "Caller should guarantee"); + + const CallInst *ACall = dyn_cast(AInstr); + const CallInst *BCall = dyn_cast(BInstr); + // Proceed if Both ACall and BCall are direct callsites. + // The rest of the cases are not implemented but could be added + // when there are use cases. + if (!(ACall && BCall && ACall->getCalledFunction() && + BCall->getCalledFunction())) + return nullptr; + + auto &Ctx = AInstr->getContext(); + MDBuilder MDHelper(Ctx); + + // LLVM IR verifier verifies !prof metadata has at least 2 operands. + assert(A->getNumOperands() >= 2 && B->getNumOperands() >= 2 && + "!prof annotations should have no less than 2 operands"); + MDString *AMDS = dyn_cast(A->getOperand(0)); + MDString *BMDS = dyn_cast(B->getOperand(0)); + // LLVM IR verfier verifies first operand is MDString. + assert(AMDS != nullptr && BMDS != nullptr && + "first operand should be a non-null MDString"); + StringRef AProfName = AMDS->getString(); + StringRef BProfName = BMDS->getString(); + if (AProfName.equals("branch_weights") && + BProfName.equals("branch_weights")) { + ConstantInt *AInstrWeight = + mdconst::dyn_extract(A->getOperand(1)); + ConstantInt *BInstrWeight = + mdconst::dyn_extract(B->getOperand(1)); + if (AInstrWeight && BInstrWeight) { + return MDNode::get(Ctx, + {MDHelper.createString("branch_weights"), + MDHelper.createConstant(ConstantInt::get( + Type::getInt64Ty(Ctx), + SaturatingAdd(AInstrWeight->getZExtValue(), + BInstrWeight->getZExtValue())))}); + } + } + return nullptr; +} + static bool isContiguous(const ConstantRange &A, const ConstantRange &B) { return A.getUpper() == B.getLower() || A.getLower() == B.getUpper(); } Index: llvm/lib/Transforms/Utils/Local.cpp =================================================================== --- llvm/lib/Transforms/Utils/Local.cpp +++ llvm/lib/Transforms/Utils/Local.cpp @@ -2717,6 +2717,9 @@ // Preserve !nontemporal if it is present on both instructions. K->setMetadata(Kind, JMD); break; + case LLVMContext::MD_prof: + K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J)); + break; } } // Set !invariant.group from J if J has it. If both instructions have it @@ -2745,6 +2748,7 @@ LLVMContext::MD_dereferenceable_or_null, LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index, + LLVMContext::MD_prof, LLVMContext::MD_nontemporal, LLVMContext::MD_noundef}; combineMetadata(K, J, KnownIDs, KDominatesJ); Index: llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --version 2 +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=HOIST + +; Test case based on C++ code with manualy annotated !prof metadata. +; This is to test that when calls to 'func1' from 'if.then' block +; and 'if.else' block are hoisted, the branch_weights are merged and +; attached to merged call rather than dropped. +; +; int func1(int a, int b) ; +; int func2(int a, int b) ; + +; int func(int a, int b, bool c) { +; int sum= 0; +; if(c) { +; sum += func1(a, b); +; } else { +; sum += func1(a, b); +; sum -= func2(a, b); +; } +; return sum; +; } +define i32 @_Z4funciib(i32 %a, i32 %b, i1 %c) { +; HOIST-LABEL: define i32 @_Z4funciib +; HOIST-SAME: (i32 [[A:%.*]], i32 [[B:%.*]], i1 [[C:%.*]]) { +; HOIST-NEXT: entry: +; HOIST-NEXT: [[CALL:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[B]]), !prof [[PROF0:![0-9]+]] +; HOIST-NEXT: br i1 [[C]], label [[IF_END:%.*]], label [[IF_ELSE:%.*]] +; HOIST: if.else: +; HOIST-NEXT: [[CALL3:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]]) +; HOIST-NEXT: [[SUB:%.*]] = sub i32 [[CALL]], [[CALL3]] +; HOIST-NEXT: br label [[IF_END]] +; HOIST: if.end: +; HOIST-NEXT: [[SUM_0:%.*]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[CALL]], [[ENTRY:%.*]] ] +; HOIST-NEXT: ret i32 [[SUM_0]] +; +entry: + br i1 %c, label %if.then, label %if.else + +if.then: ; preds = %entry + %call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0 + br label %if.end + +if.else: ; preds = %entry + %call1 = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !1 + %call3 = tail call i32 @_Z5func2ii(i32 %a, i32 %b) + %sub = sub i32 %call1, %call3 + br label %if.end + +if.end: ; preds = %if.else, %if.then + %sum.0 = phi i32 [ %call, %if.then ], [ %sub, %if.else ] + ret i32 %sum.0 +} + +declare i32 @_Z5func1ii(i32, i32) + +declare i32 @_Z5func2ii(i32, i32) + +!0 = !{!"branch_weights", i32 10} +!1 = !{!"branch_weights", i32 90} +;. +; HOIST: [[PROF0]] = !{!"branch_weights", i64 100} +;. Index: llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --version 2 +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=SINK + + +; Test case based on the following C++ code with manualy annotated !prof metadata. +; This is to test that when calls to 'func1' from 'if.then' and 'if.else' are +; sinked, the branch weights are merged and attached to sinked call. +; +; int func1(int a, int b) ; +; int func2(int a, int b) ; + +; int func(int a, int b, bool c) { +; int sum = 0; +; if (c) { +; sum += func1(a,b); +; } else { +; b -= func2(a,b); +; sum += func1(a,b); +; } +; return sum; +; } + +define i32 @_Z4funciib(i32 %a, i32 %b, i1 %c) { +; SINK-LABEL: define i32 @_Z4funciib +; SINK-SAME: (i32 [[A:%.*]], i32 [[B:%.*]], i1 [[C:%.*]]) { +; SINK-NEXT: entry: +; SINK-NEXT: br i1 [[C]], label [[IF_END:%.*]], label [[IF_ELSE:%.*]] +; SINK: if.else: +; SINK-NEXT: [[CALL1:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]]) +; SINK-NEXT: [[SUB:%.*]] = sub i32 [[B]], [[CALL1]] +; SINK-NEXT: br label [[IF_END]] +; SINK: if.end: +; SINK-NEXT: [[SUB_SINK:%.*]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[B]], [[ENTRY:%.*]] ] +; SINK-NEXT: [[CALL2:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[SUB_SINK]]), !prof [[PROF0:![0-9]+]] +; SINK-NEXT: ret i32 [[CALL2]] +; +entry: + br i1 %c, label %if.then, label %if.else + +if.then: ; preds = %entry + %call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0 + br label %if.end + +if.else: ; preds = %entry + %call1 = tail call i32 @_Z5func2ii(i32 %a, i32 %b) + %sub = sub i32 %b, %call1 + %call2 = tail call i32 @_Z5func1ii(i32 %a, i32 %sub), !prof !1 + br label %if.end + +if.end: ; preds = %if.else, %if.then + %sum.0 = phi i32 [ %call, %if.then ], [ %call2, %if.else ] + ret i32 %sum.0 +} + +declare i32 @_Z5func1ii(i32, i32) + +declare i32 @_Z5func2ii(i32, i32) + +!0 = !{!"branch_weights", i32 10} +!1 = !{!"branch_weights", i32 90} +;. +; SINK: [[PROF0]] = !{!"branch_weights", i64 100} +;.