diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h --- a/llvm/include/llvm/AsmParser/LLToken.h +++ b/llvm/include/llvm/AsmParser/LLToken.h @@ -407,6 +407,7 @@ kw_noUnwind, kw_mayThrow, kw_hasUnknownCall, + kw_mustBeUnreachable, kw_calls, kw_callee, kw_params, diff --git a/llvm/include/llvm/IR/ModuleSummaryIndex.h b/llvm/include/llvm/IR/ModuleSummaryIndex.h --- a/llvm/include/llvm/IR/ModuleSummaryIndex.h +++ b/llvm/include/llvm/IR/ModuleSummaryIndex.h @@ -581,6 +581,13 @@ // If there are calls to unknown targets (e.g. indirect) unsigned HasUnknownCall : 1; + // Indicate if a function must be an unreachable function. + // + // This bit is sufficient but not necessary; + // If this bit is on, the function must be unreachable; + // if this bit is off, the function might be reachable or unreachable. + unsigned MustBeUnreachable : 1; + FFlags &operator&=(const FFlags &RHS) { this->ReadNone &= RHS.ReadNone; this->ReadOnly &= RHS.ReadOnly; @@ -591,13 +598,15 @@ this->NoUnwind &= RHS.NoUnwind; this->MayThrow &= RHS.MayThrow; this->HasUnknownCall &= RHS.HasUnknownCall; + this->MustBeUnreachable &= RHS.MustBeUnreachable; return *this; } bool anyFlagSet() { return this->ReadNone | this->ReadOnly | this->NoRecurse | this->ReturnDoesNotAlias | this->NoInline | this->AlwaysInline | - this->NoUnwind | this->MayThrow | this->HasUnknownCall; + this->NoUnwind | this->MayThrow | this->HasUnknownCall | + this->MustBeUnreachable; } operator std::string() { @@ -613,6 +622,7 @@ OS << ", noUnwind: " << this->NoUnwind; OS << ", mayThrow: " << this->MayThrow; OS << ", hasUnknownCall: " << this->HasUnknownCall; + OS << ", mustBeUnreachable: " << this->MustBeUnreachable; OS << ")"; return OS.str(); } diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp --- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp +++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -234,6 +234,26 @@ return false; } +// Returns true if `F` must be an unreachable function. +// +// Note if this helper function returns true, `F` is guaranteed +// to be unreachable; if it returns false, `F` might still +// be unreachble but not detected by this helper function. +static bool mustBeUnreachableFunction(const Function &F) { + if (!F.empty()) { + const BasicBlock &entryBlock = F.getEntryBlock(); + // A function must be unreachable if its basic block + // ends with an 'unreachable'. + if (!entryBlock.empty()) { + const Instruction *inst = &(*entryBlock.rbegin()); + if (inst->getOpcode() == Instruction::Unreachable) { + return true; + } + } + } + return false; +} + static void computeFunctionSummary( ModuleSummaryIndex &Index, const Module &M, const Function &F, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, DominatorTree &DT, @@ -488,7 +508,8 @@ // Don't try to import functions with noinline attribute. F.getAttributes().hasFnAttr(Attribute::NoInline), F.hasFnAttribute(Attribute::AlwaysInline), - F.hasFnAttribute(Attribute::NoUnwind), MayThrow, HasUnknownCall}; + F.hasFnAttribute(Attribute::NoUnwind), MayThrow, HasUnknownCall, + mustBeUnreachableFunction(F)}; std::vector ParamAccesses; if (auto *SSI = GetSSICallback(F)) ParamAccesses = SSI->getParamAccesses(Index); @@ -500,6 +521,9 @@ TypeCheckedLoadConstVCalls.takeVector(), std::move(ParamAccesses)); if (NonRenamableLocal) CantBePromoted.insert(F.getGUID()); + // errs() << "MM Going to add F " << F.getName() << " with global identifier + // as " << F.getGlobalIdentifier() + // << " to global value summary\n"; Index.addGlobalValueSummary(F, std::move(FuncSummary)); } @@ -737,7 +761,8 @@ F->hasFnAttribute(Attribute::AlwaysInline), F->hasFnAttribute(Attribute::NoUnwind), /* MayThrow */ true, - /* HasUnknownCall */ true}, + /* HasUnknownCall */ true, + /* MustBeUnreachable */ false}, /*EntryCount=*/0, ArrayRef{}, ArrayRef{}, ArrayRef{}, diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -773,6 +773,7 @@ KEYWORD(noUnwind); KEYWORD(mayThrow); KEYWORD(hasUnknownCall); + KEYWORD(mustBeUnreachable); KEYWORD(calls); KEYWORD(callee); KEYWORD(params); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -8349,6 +8349,8 @@ std::vector Refs; // Default is all-zeros (conservative values). FunctionSummary::FFlags FFlags = {}; + // TODO(mingming): + // This is changed due to unreachable. if (parseToken(lltok::colon, "expected ':' here") || parseToken(lltok::lparen, "expected '(' here") || parseModuleReference(ModulePath) || @@ -8533,6 +8535,7 @@ /// [',' 'noUnwind' ':' Flag]? ')' /// [',' 'mayThrow' ':' Flag]? ')' /// [',' 'hasUnknownCall' ':' Flag]? ')' +/// [',' 'mustBeUnreachable' ':' Flag]? ')' bool LLParser::parseOptionalFFlags(FunctionSummary::FFlags &FFlags) { assert(Lex.getKind() == lltok::kw_funcFlags); @@ -8599,6 +8602,13 @@ return true; FFlags.HasUnknownCall = Val; break; + case lltok::kw_mustBeUnreachable: + Lex.Lex(); + if (parseToken(lltok::colon, "expected ':'") || parseFlag(Val)) + return true; + FFlags.MustBeUnreachable = Val; + break; + // TODO(handle unreachable here and other places). default: return error(Lex.getLoc(), "expected function flag type"); } diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -932,6 +932,7 @@ Flags.NoUnwind = (RawFlags >> 6) & 0x1; Flags.MayThrow = (RawFlags >> 7) & 0x1; Flags.HasUnknownCall = (RawFlags >> 8) & 0x1; + Flags.MustBeUnreachable = (RawFlags >> 9) & 0x1; return Flags; } @@ -6308,6 +6309,8 @@ // numrefs, numrefs x valueid, // n x (valueid, offset)] case bitc::FS_PERMODULE_VTABLE_GLOBALVAR_INIT_REFS: { + dbgs() << "Mingming FS_PERMODULE_VTABLE_GLOBALVAR_INIT_REFS" + << "\n"; unsigned ValueID = Record[0]; uint64_t RawFlags = Record[1]; GlobalVarSummary::GVarFlags GVF = getDecodedGVarFlags(Record[2]); @@ -6348,6 +6351,7 @@ unsigned NumRORefs = 0, NumWORefs = 0; int RefListStartIndex = 5; + dbgs() << "Mingming Version is " << Version << "\n"; if (Version >= 4) { RawFunFlags = Record[4]; RefListStartIndex = 6; diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -63,6 +63,7 @@ #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/Endian.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" @@ -1067,6 +1068,7 @@ RawFlags |= (Flags.NoUnwind << 6); RawFlags |= (Flags.MayThrow << 7); RawFlags |= (Flags.HasUnknownCall << 8); + RawFlags |= (Flags.MustBeUnreachable << 9); return RawFlags; } diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -3153,6 +3153,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { Out << ", insts: " << FS->instCount(); + // Out << ", unreachable: " << FS->isUnreachableFunction(); if (FS->fflags().anyFlagSet()) Out << ", " << FS->fflags(); diff --git a/llvm/lib/IR/ModuleSummaryIndex.cpp b/llvm/lib/IR/ModuleSummaryIndex.cpp --- a/llvm/lib/IR/ModuleSummaryIndex.cpp +++ b/llvm/lib/IR/ModuleSummaryIndex.cpp @@ -447,11 +447,17 @@ static std::string fflagsToString(FunctionSummary::FFlags F) { auto FlagValue = [](unsigned V) { return V ? '1' : '0'; }; - char FlagRep[] = {FlagValue(F.ReadNone), FlagValue(F.ReadOnly), - FlagValue(F.NoRecurse), FlagValue(F.ReturnDoesNotAlias), - FlagValue(F.NoInline), FlagValue(F.AlwaysInline), - FlagValue(F.NoUnwind), FlagValue(F.MayThrow), - FlagValue(F.HasUnknownCall), 0}; + char FlagRep[] = {FlagValue(F.ReadNone), + FlagValue(F.ReadOnly), + FlagValue(F.NoRecurse), + FlagValue(F.ReturnDoesNotAlias), + FlagValue(F.NoInline), + FlagValue(F.AlwaysInline), + FlagValue(F.NoUnwind), + FlagValue(F.MayThrow), + FlagValue(F.HasUnknownCall), + FlagValue(F.MustBeUnreachable), + 0}; return FlagRep; } diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -562,17 +562,20 @@ void buildTypeIdentifierMap( std::vector &Bits, DenseMap> &TypeIdMap); - bool - tryFindVirtualCallTargets(std::vector &TargetsForSlot, - const std::set &TypeMemberInfos, - uint64_t ByteOffset); + + // Caller doesn't guarantee that `ExportSummary` must be not nullptr. + bool tryFindVirtualCallTargets( + std::vector &TargetsForSlot, + const std::set &TypeMemberInfos, uint64_t ByteOffset, + ModuleSummaryIndex *ExportSummary, const bool IsExported); void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, bool &IsExported); bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary, MutableArrayRef TargetsForSlot, VTableSlotInfo &SlotInfo, - WholeProgramDevirtResolution *Res); + WholeProgramDevirtResolution *Res, + const bool preComputedIsExported); void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT, bool &IsExported); @@ -646,6 +649,15 @@ runForTesting(Module &M, function_ref AARGetter, function_ref OREGetter, function_ref LookupDomTree); + + // Returns true if the function is unreachable from all summaries. + // + // In particular, identifies a function as unreachable if and only if + // 1) All summaries are function summaries. + // 2) All function summaries indicate it's unreachable. + static bool mustBeUnreachableFunction(Function *TheFn, + ModuleSummaryIndex *ExportSummary, + bool IsExported); }; struct DevirtIndex { @@ -933,6 +945,50 @@ return Changed; } +bool DevirtModule::mustBeUnreachableFunction(Function *TheFn, + ModuleSummaryIndex *ExportSummary, + bool IsExported) { + // ExportSummary is absent. No sufficient information to determine + // if the function is reachable. + if (ExportSummary == nullptr) { + return false; + } + assert((TheFn != nullptr) && "Caller guarantees that TheFn is not nullptr"); + + // If a function will be exported, use external linkage to get its + // global identifier. + const std::string rewrittenFuncGlobalIdentifier = + GlobalValue::getGlobalIdentifier(TheFn->getName(), + IsExported ? GlobalValue::ExternalLinkage + : TheFn->getLinkage(), + TheFn->getParent()->getSourceFileName()); + + if (ValueInfo TheFnVI = ExportSummary->getValueInfo( + GlobalValue::getGUID(rewrittenFuncGlobalIdentifier))) { + bool AllSummariesAreFunctionSummary = true; + bool AllFunctionSummariesIndicateUnreachable = true; + for (auto &Summary : TheFnVI.getSummaryList()) { + if (auto *FS = dyn_cast(Summary.get())) { + if (!FS->fflags().MustBeUnreachable) { + AllFunctionSummariesIndicateUnreachable = false; + break; + } + } else { + AllSummariesAreFunctionSummary = false; + break; + } + } + // Identifies a function as unreachable if and only if + // 1) All summaries are function summaries. + // 2) All function summaries indicate it's unreachable. + if (AllSummariesAreFunctionSummary && + AllFunctionSummariesIndicateUnreachable) { + return true; + } + } + return false; +} + void DevirtModule::buildTypeIdentifierMap( std::vector &Bits, DenseMap> &TypeIdMap) { @@ -969,7 +1025,8 @@ bool DevirtModule::tryFindVirtualCallTargets( std::vector &TargetsForSlot, - const std::set &TypeMemberInfos, uint64_t ByteOffset) { + const std::set &TypeMemberInfos, uint64_t ByteOffset, + ModuleSummaryIndex *ExportSummary, const bool IsExported) { for (const TypeMemberInfo &TM : TypeMemberInfos) { if (!TM.Bits->GV->isConstant()) return false; @@ -997,6 +1054,10 @@ if (Fn->getName() == "__cxa_pure_virtual") continue; + if (mustBeUnreachableFunction(Fn, ExportSummary, IsExported)) { + continue; + } + TargetsForSlot.push_back({Fn, &TM}); } @@ -1141,13 +1202,15 @@ bool DevirtModule::trySingleImplDevirt( ModuleSummaryIndex *ExportSummary, MutableArrayRef TargetsForSlot, VTableSlotInfo &SlotInfo, - WholeProgramDevirtResolution *Res) { + WholeProgramDevirtResolution *Res, const bool preComputedIsExported) { // See if the program contains a single implementation of this virtual // function. Function *TheFn = TargetsForSlot[0].Fn; - for (auto &&Target : TargetsForSlot) - if (TheFn != Target.Fn) + for (auto &&Target : TargetsForSlot) { + if (TheFn != Target.Fn) { return false; + } + } // If so, update each call site to call that implementation directly. if (RemarksEnabled) @@ -2118,6 +2181,19 @@ bool DidVirtualConstProp = false; std::map DevirtTargets; for (auto &S : CallSlots) { + + // `precomputedIsExported` + bool precomputedIsExported = false; + VTableSlotInfo &vtableSlotInfo = S.second; + if (vtableSlotInfo.CSInfo.isExported()) { + precomputedIsExported = true; + } + for (auto &P : vtableSlotInfo.ConstCSInfo) { + if (P.second.isExported()) { + precomputedIsExported = true; + } + } + // Search each of the members of the type identifier for the virtual // function implementation at offset S.first.ByteOffset, and add to // TargetsForSlot. @@ -2137,9 +2213,11 @@ cast(S.first.TypeID)->getString()) .WPDRes[S.first.ByteOffset]; if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos, - S.first.ByteOffset)) { + S.first.ByteOffset, ExportSummary, + precomputedIsExported)) { - if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) { + if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res, + precomputedIsExported)) { DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); diff --git a/llvm/test/Bitcode/thinlto-function-summary-refgraph.ll b/llvm/test/Bitcode/thinlto-function-summary-refgraph.ll --- a/llvm/test/Bitcode/thinlto-function-summary-refgraph.ll +++ b/llvm/test/Bitcode/thinlto-function-summary-refgraph.ll @@ -158,7 +158,7 @@ ; DIS-DAG: = gv: (name: "globalvar", summaries: (variable: (module: ^0, flags: (linkage: external, visibility: default, notEligibleToImport: 0, live: 0, dsoLocal: 0, canAutoHide: 0), varFlags: (readonly: 1, writeonly: 0, constant: 1)))) ; guid = 12887606300320728018 ; DIS-DAG: = gv: (name: "func2") ; guid = 14069196320850861797 ; DIS-DAG: = gv: (name: "llvm.ctpop.i8") ; guid = 15254915475081819833 -; DIS-DAG: = gv: (name: "main", summaries: (function: (module: ^0, flags: (linkage: external, visibility: default, notEligibleToImport: 0, live: 0, dsoLocal: 0, canAutoHide: 0), insts: 9, funcFlags: (readNone: 0, readOnly: 0, noRecurse: 0, returnDoesNotAlias: 0, noInline: 0, alwaysInline: 0, noUnwind: 0, mayThrow: 0, hasUnknownCall: 1), calls: ((callee: ^{{.*}})), refs: (^{{.*}})))) ; guid = 15822663052811949562 +; DIS-DAG: = gv: (name: "main", summaries: (function: (module: ^0, flags: (linkage: external, visibility: default, notEligibleToImport: 0, live: 0, dsoLocal: 0, canAutoHide: 0), insts: 9, funcFlags: (readNone: 0, readOnly: 0, noRecurse: 0, returnDoesNotAlias: 0, noInline: 0, alwaysInline: 0, noUnwind: 0, mayThrow: 0, hasUnknownCall: 1, mustBeUnreachable: 0), calls: ((callee: ^{{.*}})), refs: (^{{.*}})))) ; guid = 15822663052811949562 ; DIS-DAG: = gv: (name: "bar", summaries: (variable: (module: ^0, flags: (linkage: external, visibility: default, notEligibleToImport: 0, live: 0, dsoLocal: 0, canAutoHide: 0), varFlags: (readonly: 1, writeonly: 1, constant: 0), refs: (^{{.*}})))) ; guid = 16434608426314478903 ; Don't try to match the exact GUID. Since it is private, the file path ; will get hashed, and that will be test dependent.