Index: llvm/include/llvm/IR/Metadata.h =================================================================== --- llvm/include/llvm/IR/Metadata.h +++ llvm/include/llvm/IR/Metadata.h @@ -1319,6 +1319,14 @@ static MDNode *getMostGenericRange(MDNode *A, MDNode *B); static MDNode *getMostGenericAliasScope(MDNode *A, MDNode *B); static MDNode *getMostGenericAlignmentOrDereferenceable(MDNode *A, MDNode *B); + /// Returns nullptr if merge is not implementated for the given types instructions + /// and !prof metadata. Currently only implemented with direct callsites with + /// branch weights (used only in SamplePGO as documented in https://llvm.org/docs/BranchWeightMetadata.html#callinst). + /// Pass in both instructions and nodes. Instruction information (e.g., instruction type) + /// helps interpret profiles and make implementation clearer. + static MDNode *getMergedProfMetadata(MDNode *A, MDNode *B, + const Instruction *AInstr, + const Instruction *BInstr); }; /// Tuple of metadata. Index: llvm/lib/IR/Metadata.cpp =================================================================== --- llvm/lib/IR/Metadata.cpp +++ llvm/lib/IR/Metadata.cpp @@ -1072,6 +1072,55 @@ return B; } +MDNode *MDNode::getMergedProfMetadata(MDNode *A, MDNode *B, + const Instruction *AInstr, + const Instruction *BInstr) { + if (!A) + return B; + if (!B) + return A; + + assert(AInstr->getMetadata(LLVMContext::MD_prof) == A && "Caller should guarantee"); + assert(BInstr->getMetadata(LLVMContext::MD_prof) == B && "Caller should guarantee"); + + const CallInst *ACall = dyn_cast(AInstr); + const CallInst *BCall = dyn_cast(BInstr); + // Proceed if Both ACall and BCall are direct callsites. + // The rest of the cases are not implemented but could be added + // when there are use cases. + if (!(ACall && BCall && ACall->getCalledFunction() && + BCall->getCalledFunction())) + return nullptr; + + auto &Ctx = AInstr->getContext(); + MDBuilder MDHelper(Ctx); + + // LLVM IR verifier verifies !prof metadata has at least 2 operands. + assert(A->getNumOperands() >= 2 && B->getNumOperands() >= 2 && + "!prof annotations should have no less than 2 operands"); + MDString *AMDS = dyn_cast(A->getOperand(0)); + MDString *BMDS = dyn_cast(B->getOperand(0)); + // LLVM IR verfier verifies first operand is MDString. + assert(AMDS != nullptr && BMDS != nullptr && + "first operand should be a non-null MDString"); + StringRef AProfName = AMDS->getString(); + StringRef BProfName = BMDS->getString(); + if (AProfName.equals("branch_weights") && + BProfName.equals("branch_weights")) { + ConstantInt *AInstrWeight = + mdconst::dyn_extract(A->getOperand(1)); + ConstantInt *BInstrWeight = + mdconst::dyn_extract(B->getOperand(1)); + if (AInstrWeight && BInstrWeight) { + return MDNode::get( + Ctx, {MDHelper.createString("branch_weights"), + MDHelper.createConstant(ConstantInt::get( + Type::getInt64Ty(Ctx), SaturatingAdd(AInstrWeight->getZExtValue(), BInstrWeight->getZExtValue())))}); + } + } + return nullptr; +} + static bool isContiguous(const ConstantRange &A, const ConstantRange &B) { return A.getUpper() == B.getLower() || A.getLower() == B.getUpper(); } Index: llvm/lib/Transforms/Utils/Local.cpp =================================================================== --- llvm/lib/Transforms/Utils/Local.cpp +++ llvm/lib/Transforms/Utils/Local.cpp @@ -2717,6 +2717,9 @@ // Preserve !nontemporal if it is present on both instructions. K->setMetadata(Kind, JMD); break; + case LLVMContext::MD_prof: + K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J)); + break; } } // Set !invariant.group from J if J has it. If both instructions have it @@ -2745,6 +2748,7 @@ LLVMContext::MD_dereferenceable_or_null, LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index, + LLVMContext::MD_prof, LLVMContext::MD_nontemporal, LLVMContext::MD_noundef}; combineMetadata(K, J, KnownIDs, KDominatesJ); Index: llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll @@ -0,0 +1,70 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=HOIST + +define dso_local i32 @_Z4funciiii(i32 %a, i32 %b){ +; HOIST-LABEL: define dso_local i32 @_Z4funciiii +; HOIST-SAME: (i32 [[A:%.*]], i32 [[B:%.*]]) { +; HOIST-NEXT: entry: +; HOIST-NEXT: br label [[FOR_BODY:%.*]] +; HOIST: for.cond.cleanup: +; HOIST-NEXT: [[SUM_1_LCSSA:%.*]] = phi i32 [ [[SUM_1:%.*]], [[FOR_INC:%.*]] ] +; HOIST-NEXT: ret i32 [[SUM_1_LCSSA]] +; HOIST: for.body: +; HOIST-NEXT: [[I_016:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC:%.*]], [[FOR_INC]] ] +; HOIST-NEXT: [[SUM_015:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[SUM_1]], [[FOR_INC]] ] +; HOIST-NEXT: [[REM_LHS_TRUNC:%.*]] = trunc i32 [[I_016]] to i16 +; HOIST-NEXT: [[REM14:%.*]] = urem i16 [[REM_LHS_TRUNC]], 100 +; HOIST-NEXT: [[CMP1:%.*]] = icmp eq i16 [[REM14]], 0 +; HOIST-NEXT: [[CALL:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[B]]), !prof [[PROF0:![0-9]+]] +; HOIST-NEXT: [[ADD:%.*]] = add nsw i32 [[CALL]], [[SUM_015]] +; HOIST-NEXT: br i1 [[CMP1]], label [[FOR_INC]], label [[IF_ELSE:%.*]] +; HOIST: if.else: +; HOIST-NEXT: [[CALL4:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]]) +; HOIST-NEXT: [[SUB:%.*]] = sub i32 [[ADD]], [[CALL4]] +; HOIST-NEXT: br label [[FOR_INC]] +; HOIST: for.inc: +; HOIST-NEXT: [[SUM_1]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[ADD]], [[FOR_BODY]] ] +; HOIST-NEXT: [[INC]] = add nuw nsw i32 [[I_016]], 1 +; HOIST-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC]], 1000 +; HOIST-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY]] +; +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.inc + %sum.1.lcssa = phi i32 [ %sum.1, %for.inc ] + ret i32 %sum.1.lcssa + +for.body: ; preds = %entry, %for.inc + %i.016 = phi i32 [ 0, %entry ], [ %inc, %for.inc ] + %sum.015 = phi i32 [ 0, %entry ], [ %sum.1, %for.inc ] + %rem.lhs.trunc = trunc i32 %i.016 to i16 + %rem14 = urem i16 %rem.lhs.trunc, 100 + %cmp1 = icmp eq i16 %rem14, 0 + br i1 %cmp1, label %if.then, label %if.else + +if.then: ; preds = %for.body + %call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0 + %add = add nsw i32 %call, %sum.015 + br label %for.inc + +if.else: ; preds = %for.body + %call2 = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !1 + %add3 = add nsw i32 %call2, %sum.015 + %call4 = tail call i32 @_Z5func2ii(i32 %a, i32 %b) + %sub = sub i32 %add3, %call4 + br label %for.inc + +for.inc: ; preds = %if.then, %if.else + %sum.1 = phi i32 [ %add, %if.then ], [ %sub, %if.else ] + %inc = add nuw nsw i32 %i.016, 1 + %exitcond.not = icmp eq i32 %inc, 1000 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body +} + +declare i32 @_Z5func1ii(i32, i32) + +declare i32 @_Z5func2ii(i32, i32) + +!0 = !{!"branch_weights", i32 10} +!1 = !{!"branch_weights", i32 990} Index: llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll @@ -0,0 +1,72 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=HOIST + +define i32 @_Z4funcii(i32 %a, i32 %b) { +; HOIST-LABEL: define i32 @_Z4funcii +; HOIST-SAME: (i32 [[A:%.*]], i32 [[B:%.*]]) { +; HOIST-NEXT: entry: +; HOIST-NEXT: br label [[FOR_BODY:%.*]] +; HOIST: for.cond.cleanup: +; HOIST-NEXT: [[SUM_1_LCSSA:%.*]] = phi i32 [ [[SUM_1:%.*]], [[FOR_INC:%.*]] ] +; HOIST-NEXT: ret i32 [[SUM_1_LCSSA]] +; HOIST: for.body: +; HOIST-NEXT: [[I_017:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC:%.*]], [[FOR_INC]] ] +; HOIST-NEXT: [[SUM_016:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[SUM_1]], [[FOR_INC]] ] +; HOIST-NEXT: [[B_ADDR_015:%.*]] = phi i32 [ [[B]], [[ENTRY]] ], [ [[SUB_SINK:%.*]], [[FOR_INC]] ] +; HOIST-NEXT: [[REM_LHS_TRUNC:%.*]] = trunc i32 [[I_017]] to i16 +; HOIST-NEXT: [[REM14:%.*]] = urem i16 [[REM_LHS_TRUNC]], 100 +; HOIST-NEXT: [[CMP1:%.*]] = icmp eq i16 [[REM14]], 0 +; HOIST-NEXT: br i1 [[CMP1]], label [[FOR_INC]], label [[IF_ELSE:%.*]] +; HOIST: if.else: +; HOIST-NEXT: [[CALL2:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B_ADDR_015]]) +; HOIST-NEXT: [[SUB:%.*]] = sub nsw i32 [[B_ADDR_015]], [[CALL2]] +; HOIST-NEXT: br label [[FOR_INC]] +; HOIST: for.inc: +; HOIST-NEXT: [[SUB_SINK]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[B_ADDR_015]], [[FOR_BODY]] ] +; HOIST-NEXT: [[CALL3:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[SUB_SINK]]), !prof [[PROF0:![0-9]+]] +; HOIST-NEXT: [[SUM_1]] = add nsw i32 [[CALL3]], [[SUM_016]] +; HOIST-NEXT: [[INC]] = add nuw nsw i32 [[I_017]], 1 +; HOIST-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC]], 1000 +; HOIST-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY]] +; +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.inc + %sum.1.lcssa = phi i32 [ %sum.1, %for.inc ] + ret i32 %sum.1.lcssa + +for.body: ; preds = %entry, %for.inc + %i.017 = phi i32 [ 0, %entry ], [ %inc, %for.inc ] + %sum.016 = phi i32 [ 0, %entry ], [ %sum.1, %for.inc ] + %b.addr.015 = phi i32 [ %b, %entry ], [ %b.addr.1, %for.inc ] + %rem.lhs.trunc = trunc i32 %i.017 to i16 + %rem14 = urem i16 %rem.lhs.trunc, 100 + %cmp1 = icmp eq i16 %rem14, 0 + br i1 %cmp1, label %if.then, label %if.else + +if.then: ; preds = %for.body + %call = tail call i32 @_Z5func1ii(i32 %a, i32 %b.addr.015), !prof !0 + br label %for.inc + +if.else: ; preds = %for.body + %call2 = tail call i32 @_Z5func2ii(i32 %a, i32 %b.addr.015) + %sub = sub nsw i32 %b.addr.015, %call2 + %call3 = tail call i32 @_Z5func1ii(i32 %a, i32 %sub), !prof !1 + br label %for.inc + +for.inc: ; preds = %if.then, %if.else + %b.addr.1 = phi i32 [ %b.addr.015, %if.then ], [ %sub, %if.else ] + %call.pn = phi i32 [ %call, %if.then ], [ %call3, %if.else ] + %sum.1 = add nsw i32 %call.pn, %sum.016 + %inc = add nuw nsw i32 %i.017, 1 + %exitcond.not = icmp eq i32 %inc, 1000 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body +} + +declare i32 @_Z5func1ii(i32, i32) + +declare i32 @_Z5func2ii(i32, i32) + +!0 = !{!"branch_weights", i32 10} +!1 = !{!"branch_weights", i32 990}