Index: llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h =================================================================== --- llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h +++ llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h @@ -28,6 +28,7 @@ bool NeedCanonicalLoop = true; bool HoistCommonInsts = false; bool SinkCommonInsts = false; + bool PreserveIndirectCallInstWithProfInHoistAndSink = false; bool SimplifyCondBranch = true; bool FoldTwoEntryPHINode = true; @@ -62,6 +63,10 @@ SinkCommonInsts = B; return *this; } + SimplifyCFGOptions &preserveIndirectCallInstWithProfInHoistAndSink(bool B) { + PreserveIndirectCallInstWithProfInHoistAndSink = B; + return *this; + } SimplifyCFGOptions &setAssumptionCache(AssumptionCache *Cache) { AC = Cache; return *this; Index: llvm/lib/Passes/PassBuilder.cpp =================================================================== --- llvm/lib/Passes/PassBuilder.cpp +++ llvm/lib/Passes/PassBuilder.cpp @@ -790,6 +790,9 @@ Result.hoistCommonInsts(Enable); } else if (ParamName == "sink-common-insts") { Result.sinkCommonInsts(Enable); + } else if (ParamName == + "preserve-indirect-call-inst-with-prof-in-hoist-and-sink") { + Result.preserveIndirectCallInstWithProfInHoistAndSink(Enable); } else if (Enable && ParamName.consume_front("bonus-inst-threshold=")) { APInt BonusInstThreshold; if (ParamName.getAsInteger(0, BonusInstThreshold)) Index: llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -77,6 +77,12 @@ "sink-common-insts", cl::Hidden, cl::init(false), cl::desc("Sink common instructions (default = false)")); +static cl::opt UserPreserveIndirectCallInstWithProfInHoistAndSink( + "preserve-indirect-call-inst-with-prof-in-hoist-and-sink", cl::Hidden, + cl::init(false), + cl::desc("Preserve indirect call instructions with !prof metadata in " + "instruction " + "hoist and sink (default = false)")); STATISTIC(NumSimpl, "Number of blocks simplified"); @@ -323,6 +329,9 @@ Options.HoistCommonInsts = UserHoistCommonInsts; if (UserSinkCommonInsts.getNumOccurrences()) Options.SinkCommonInsts = UserSinkCommonInsts; + if (UserPreserveIndirectCallInstWithProfInHoistAndSink.getNumOccurrences()) + Options.PreserveIndirectCallInstWithProfInHoistAndSink = + UserPreserveIndirectCallInstWithProfInHoistAndSink; } SimplifyCFGPass::SimplifyCFGPass() { Index: llvm/lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1580,17 +1580,36 @@ !isSafeToHoistInstr(I2, SkipFlagsBB2)) return Changed; - // If we're going to hoist a call, make sure that the two instructions - // we're commoning/hoisting are both marked with musttail, or neither of - // them is marked as such. Otherwise, we might end up in a situation where - // we hoist from a block where the terminator is a `ret` to a block where - // the terminator is a `br`, and `musttail` calls expect to be followed by - // a return. auto *C1 = dyn_cast(I1); auto *C2 = dyn_cast(I2); - if (C1 && C2) + if (C1 && C2) { + // If we're going to hoist a call, make sure that the two instructions + // we're commoning/hoisting are both marked with musttail, or neither of + // them is marked as such. Otherwise, we might end up in a situation + // where we hoist from a block where the terminator is a `ret` to a + // block where the terminator is a `br`, and `musttail` calls expect to + // be followed by a return. if (C1->isMustTailCall() != C2->isMustTailCall()) return Changed; + // If C1 and C2 are indirect calls with different !prof metadata (i.e., + // "VP" for indirect calls), conservatively not hoisting them since + // merge by a naive concatenation might make profile data less precise. + // Here two !prof metadatas are considered different if the pointers are + // different. If they turn out be have the same content by a deep + // comparison, hoisting would still make sense. + // FIXME: + // 1. Probably add a 'eq' comparator for "VP" of indirect calls and use + // it here, so that hoisting happens for icalls with essentially the + // same value profile metadata. + // 2. Two direct calls with different "branch_weights" would be + // simplified to one (and get both !prof dropped). The branch + // weights should probably be preserved with a sum of weights. + if (Options.PreserveIndirectCallInstWithProfInHoistAndSink && + C1->isIndirectCall() && C2->isIndirectCall() && + C1->getMetadata(LLVMContext::MD_prof) != + C2->getMetadata(LLVMContext::MD_prof)) + return Changed; + } if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2)) return Changed; @@ -1778,7 +1797,8 @@ // PHI node (because an operand varies in each input block), add to PHIOperands. static bool canSinkInstructions( ArrayRef Insts, - DenseMap> &PHIOperands) { + DenseMap> &PHIOperands, + bool PreserveCallInstWithProf) { // Prune out obviously bad instructions to move. Each instruction must have // exactly zero or one use, and we check later that use is by a single, common // PHI instruction in the successor. @@ -1794,13 +1814,14 @@ if (I->getParent()->getSingleSuccessor() == I->getParent()) return false; - // Conservatively return false if I is an inline-asm instruction. Sinking - // and merging inline-asm instructions can potentially create arguments - // that cannot satisfy the inline-asm constraints. - // If the instruction has nomerge or convergent attribute, return false. - if (const auto *C = dyn_cast(I)) + if (const auto *C = dyn_cast(I)) { + // Conservatively return false if I is an inline-asm instruction. Sinking + // and merging inline-asm instructions can potentially create arguments + // that cannot satisfy the inline-asm constraints. + // If the instruction has nomerge or convergent attribute, return false. if (C->isInlineAsm() || C->cannotMerge() || C->isConvergent()) return false; + } // Each instruction must have zero or one use. if (HasUse && !I->hasOneUse()) @@ -1866,6 +1887,20 @@ if (HaveIndirectCalls) { if (!AllCallsAreIndirect) return false; + + if (PreserveCallInstWithProf) { + std::optional ProfMD = std::nullopt; + // If not all indirect calls have the same !prof (i.e., value profile) + // metadata, conservatively not sinking them given that merging value + // profiles with a naive concatenation may make profiles inaccurate and + // counterproductive for indirect-call-promotion. + for (const Instruction *I : Insts) { + if (ProfMD == std::nullopt) + ProfMD = std::make_optional(I->getMetadata(LLVMContext::MD_prof)); + else if (I->getMetadata(LLVMContext::MD_prof) != *ProfMD) + return false; + } + } } else { // All callees must be identical. Value *Callee = nullptr; @@ -2082,8 +2117,8 @@ /// Check whether BB's predecessors end with unconditional branches. If it is /// true, sink any common code from the predecessors to BB. -static bool SinkCommonCodeFromPredecessors(BasicBlock *BB, - DomTreeUpdater *DTU) { +static bool SinkCommonCodeFromPredecessors(BasicBlock *BB, DomTreeUpdater *DTU, + bool PreserveCallInstWithProf) { // We support two situations: // (1) all incoming arcs are unconditional // (2) there are non-unconditional incoming arcs @@ -2148,7 +2183,7 @@ DenseMap> PHIOperands; LockstepReverseIterator LRI(UnconditionalPreds); while (LRI.isValid() && - canSinkInstructions(*LRI, PHIOperands)) { + canSinkInstructions(*LRI, PHIOperands, PreserveCallInstWithProf)) { LLVM_DEBUG(dbgs() << "SINK: instruction can be sunk: " << *(*LRI)[0] << "\n"); InstructionsToSink.insert((*LRI).begin(), (*LRI).end()); @@ -7263,7 +7298,8 @@ return true; if (SinkCommon && Options.SinkCommonInsts) - if (SinkCommonCodeFromPredecessors(BB, DTU) || + if (SinkCommonCodeFromPredecessors( + BB, DTU, Options.PreserveIndirectCallInstWithProfInHoistAndSink) || MergeCompatibleInvokes(BB, DTU)) { // SinkCommonCodeFromPredecessors() does not automatically CSE PHI's, // so we may now how duplicate PHI's. Index: llvm/test/Transforms/SimplifyCFG/preserve-call-metadata-in-hoist.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/preserve-call-metadata-in-hoist.ll @@ -0,0 +1,110 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=HOIST +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -preserve-indirect-call-inst-with-prof-in-hoist-and-sink -S | FileCheck %s --check-prefix=HOIST_P +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=HOIST_P + +; IR @call_not_hoist is generated based on following C++ with manually-annotated !prof +; Without preserving call instructions, `d->func1` is hoisted while it may not make sense +; to do so. For example, the candidate calls are different based on derived type. +; Note merging meta data will make the profile data less precise and not desirable either. +; class Base { +; public: +; virtual int gettype() = 0; +; virtual int func1(int a, int b) = 0; +;}; +; +; int func2(int x, int y); +; +; Base* createptr(int c); +; +; int func(int x, int a, int b, int c) { +; Base* d = createptr(c); +; if (d->gettype() % 5 == 0) { +; auto ret = d->func1(a, b); +; return ret + func2(b, a); +; } +; return d->func1(a, b); +; } +define i32 @call_not_hoist(i32 %x, i32 %a, i32 %b, i32 %c) { +; HOIST-LABEL: define i32 @call_not_hoist +; HOIST-SAME: (i32 [[X:%.*]], i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) { +; HOIST-NEXT: entry: +; HOIST-NEXT: [[CALL:%.*]] = tail call ptr @createptr(i32 [[C]]) +; HOIST-NEXT: [[VTABLE:%.*]] = load ptr, ptr [[CALL]], align 8 +; HOIST-NEXT: [[TMP0:%.*]] = load ptr, ptr [[VTABLE]], align 8 +; HOIST-NEXT: [[CALL1:%.*]] = tail call i32 [[TMP0]](ptr [[CALL]]) +; HOIST-NEXT: [[REM:%.*]] = srem i32 [[CALL1]], 5 +; HOIST-NEXT: [[CMP:%.*]] = icmp eq i32 [[REM]], 0 +; HOIST-NEXT: [[VTABLE2:%.*]] = load ptr, ptr [[CALL]], align 8 +; HOIST-NEXT: [[VFN3:%.*]] = getelementptr inbounds ptr, ptr [[VTABLE2]], i64 1 +; HOIST-NEXT: [[TMP1:%.*]] = load ptr, ptr [[VFN3]], align 8 +; HOIST-NEXT: [[CALL4:%.*]] = tail call i32 [[TMP1]](ptr [[CALL]], i32 [[A]], i32 [[B]]) +; HOIST-NEXT: br i1 [[CMP]], label [[IF_THEN:%.*]], label [[CLEANUP:%.*]] +; HOIST: if.then: +; HOIST-NEXT: [[CALL5:%.*]] = tail call i32 @func2(i32 [[B]], i32 [[A]]) +; HOIST-NEXT: [[ADD:%.*]] = add nsw i32 [[CALL5]], [[CALL4]] +; HOIST-NEXT: br label [[CLEANUP]] +; HOIST: cleanup: +; HOIST-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[ADD]], [[IF_THEN]] ], [ [[CALL4]], [[ENTRY:%.*]] ] +; HOIST-NEXT: ret i32 [[RETVAL_0]] +; +; HOIST_P-LABEL: define i32 @call_not_hoist +; HOIST_P-SAME: (i32 [[X:%.*]], i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) { +; HOIST_P-NEXT: entry: +; HOIST_P-NEXT: [[CALL:%.*]] = tail call ptr @createptr(i32 [[C]]) +; HOIST_P-NEXT: [[VTABLE:%.*]] = load ptr, ptr [[CALL]], align 8 +; HOIST_P-NEXT: [[TMP0:%.*]] = load ptr, ptr [[VTABLE]], align 8 +; HOIST_P-NEXT: [[CALL1:%.*]] = tail call i32 [[TMP0]](ptr [[CALL]]) +; HOIST_P-NEXT: [[REM:%.*]] = srem i32 [[CALL1]], 5 +; HOIST_P-NEXT: [[CMP:%.*]] = icmp eq i32 [[REM]], 0 +; HOIST_P-NEXT: [[VTABLE2:%.*]] = load ptr, ptr [[CALL]], align 8 +; HOIST_P-NEXT: [[VFN3:%.*]] = getelementptr inbounds ptr, ptr [[VTABLE2]], i64 1 +; HOIST_P-NEXT: [[TMP1:%.*]] = load ptr, ptr [[VFN3]], align 8 +; HOIST_P-NEXT: br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_END:%.*]] +; HOIST_P: if.then: +; HOIST_P-NEXT: [[CALL4:%.*]] = tail call i32 [[TMP1]](ptr [[CALL]], i32 [[A]], i32 [[B]]), !prof [[PROF0:![0-9]+]] +; HOIST_P-NEXT: [[CALL5:%.*]] = tail call i32 @func2(i32 [[B]], i32 [[A]]) +; HOIST_P-NEXT: [[ADD:%.*]] = add nsw i32 [[CALL5]], [[CALL4]] +; HOIST_P-NEXT: br label [[CLEANUP:%.*]] +; HOIST_P: if.end: +; HOIST_P-NEXT: [[CALL8:%.*]] = tail call i32 [[TMP1]](ptr [[CALL]], i32 [[A]], i32 [[B]]), !prof [[PROF1:![0-9]+]] +; HOIST_P-NEXT: br label [[CLEANUP]] +; HOIST_P: cleanup: +; HOIST_P-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[ADD]], [[IF_THEN]] ], [ [[CALL8]], [[IF_END]] ] +; HOIST_P-NEXT: ret i32 [[RETVAL_0]] +; +entry: + %call = tail call ptr @createptr(i32 %c) + %vtable = load ptr, ptr %call, align 8 + %0 = load ptr, ptr %vtable, align 8 + %call1 = tail call i32 %0(ptr %call) + %rem = srem i32 %call1, 5 + %cmp = icmp eq i32 %rem, 0 + br i1 %cmp, label %if.then, label %if.end + +if.then: ; preds = %entry + %vtable2 = load ptr, ptr %call, align 8 + %vfn3 = getelementptr inbounds ptr, ptr %vtable2, i64 1 + %1 = load ptr, ptr %vfn3, align 8 + %call4 = tail call i32 %1(ptr %call, i32 %a, i32 %b), !prof !0 + %call5 = tail call i32 @func2(i32 %b, i32 %a) + %add = add nsw i32 %call5, %call4 + br label %cleanup + +if.end: ; preds = %entry + %vtable6 = load ptr, ptr %call, align 8 + %vfn7 = getelementptr inbounds ptr, ptr %vtable6, i64 1 + %2 = load ptr, ptr %vfn7, align 8 + %call8 = tail call i32 %2(ptr %call, i32 %a, i32 %b), !prof !1 + br label %cleanup + +cleanup: ; preds = %if.end, %if.then + %retval.0 = phi i32 [ %add, %if.then ], [ %call8, %if.end ] + ret i32 %retval.0 +} + +declare ptr @createptr(i32) +declare i32 @func2(i32, i32) + +!0 =!{!"VP", i32 0, i64 1600, i64 12345, i64 1030, i64 678, i64 410} +!1 =!{!"VP", i32 0, i64 1601, i64 54321, i64 1030, i64 678, i64 410} Index: llvm/test/Transforms/SimplifyCFG/preserve-call-metadata-in-sink.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/preserve-call-metadata-in-sink.ll @@ -0,0 +1,88 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=SINK +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -preserve-indirect-call-inst-with-prof-in-hoist-and-sink -S | FileCheck %s --check-prefix=SINK_P +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=SINK_P + +; IR @call_not_sinked is generated based on the following C++ code, with manually annotated !prof +; Note 'd->func1' and 'd->func2' are two virtual functions with different offsets so could have different value profiles. +; Without preserving call instructions with !prof, call instructions are simplified to one with selected offsets. +; class Base { +; public: +; virtual int func() = 0; +; virtual int func1(int a, int b) = 0; +; virtual int func2(int a, int b) = 0; +; }; +; +; Base* createptr(int c); +; +; int func(int x, int a, int b, int c) { +; Base* d = createptr(c); +; if (x % 1000 == 0) +; return d->func1(a, b); +; return d->func2(a, b); +; } +define i32 @call_not_sinked(i32 %x, i32 %a, i32 %b, i32 %c) { +; SINK-LABEL: define i32 @call_not_sinked +; SINK-SAME: (i32 [[X:%.*]], i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) { +; SINK-NEXT: entry: +; SINK-NEXT: [[CALL:%.*]] = tail call ptr @createptr(i32 [[C]]) +; SINK-NEXT: [[REM:%.*]] = srem i32 [[X]], 1000 +; SINK-NEXT: [[CMP:%.*]] = icmp eq i32 [[REM]], 0 +; SINK-NEXT: [[DOT:%.*]] = select i1 [[CMP]], i64 1, i64 2 +; SINK-NEXT: [[VTABLE2:%.*]] = load ptr, ptr [[CALL]], align 8 +; SINK-NEXT: [[VFN3:%.*]] = getelementptr inbounds ptr, ptr [[VTABLE2]], i64 [[DOT]] +; SINK-NEXT: [[TMP0:%.*]] = load ptr, ptr [[VFN3]], align 8 +; SINK-NEXT: [[CALL4:%.*]] = tail call i32 [[TMP0]](ptr [[CALL]], i32 [[A]], i32 [[B]]) +; SINK-NEXT: ret i32 [[CALL4]] +; +; SINK_P-LABEL: define i32 @call_not_sinked +; SINK_P-SAME: (i32 [[X:%.*]], i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]]) { +; SINK_P-NEXT: entry: +; SINK_P-NEXT: [[CALL:%.*]] = tail call ptr @createptr(i32 [[C]]) +; SINK_P-NEXT: [[REM:%.*]] = srem i32 [[X]], 1000 +; SINK_P-NEXT: [[CMP:%.*]] = icmp eq i32 [[REM]], 0 +; SINK_P-NEXT: br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_END:%.*]] +; SINK_P: if.then: +; SINK_P-NEXT: [[VTABLE:%.*]] = load ptr, ptr [[CALL]], align 8 +; SINK_P-NEXT: [[VFN:%.*]] = getelementptr inbounds ptr, ptr [[VTABLE]], i64 1 +; SINK_P-NEXT: [[TMP0:%.*]] = load ptr, ptr [[VFN]], align 8 +; SINK_P-NEXT: [[CALL1:%.*]] = tail call i32 [[TMP0]](ptr [[CALL]], i32 [[A]], i32 [[B]]) +; SINK_P-NEXT: br label [[CLEANUP:%.*]] +; SINK_P: if.end: +; SINK_P-NEXT: [[VTABLE2:%.*]] = load ptr, ptr [[CALL]], align 8 +; SINK_P-NEXT: [[VFN3:%.*]] = getelementptr inbounds ptr, ptr [[VTABLE2]], i64 2 +; SINK_P-NEXT: [[TMP1:%.*]] = load ptr, ptr [[VFN3]], align 8 +; SINK_P-NEXT: [[CALL4:%.*]] = tail call i32 [[TMP1]](ptr [[CALL]], i32 [[A]], i32 [[B]]), !prof [[PROF0:![0-9]+]] +; SINK_P-NEXT: br label [[CLEANUP]] +; SINK_P: cleanup: +; SINK_P-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[CALL1]], [[IF_THEN]] ], [ [[CALL4]], [[IF_END]] ] +; SINK_P-NEXT: ret i32 [[RETVAL_0]] +; +entry: + %call = tail call ptr @createptr(i32 %c) + %rem = srem i32 %x, 1000 + %cmp = icmp eq i32 %rem, 0 + br i1 %cmp, label %if.then, label %if.end + +if.then: ; preds = %entry + %vtable = load ptr, ptr %call, align 8 + %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1 + %0 = load ptr, ptr %vfn + %call1 = tail call i32 %0(ptr %call, i32 %a, i32 %b) + br label %cleanup + +if.end: ; preds = %entry + %vtable2 = load ptr, ptr %call, align 8 + %vfn3 = getelementptr inbounds ptr, ptr %vtable2, i64 2 + %1 = load ptr, ptr %vfn3 + %call4 = tail call i32 %1(ptr %call, i32 %a, i32 %b), !prof !0 + br label %cleanup + +cleanup: ; preds = %if.end, %if.then + %retval.0 = phi i32 [ %call1, %if.then ], [ %call4, %if.end ] + ret i32 %retval.0 +} + +declare ptr @createptr(i32) + +!0 =!{!"VP", i32 0, i64 1600, i64 12345, i64 1030, i64 678, i64 410}