diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -820,6 +820,97 @@ /// If \p LivenessAA is not provided it is queried. bool isAssumedDead(const AbstractAttribute &AA, const AAIsDead *LivenessAA); + /// Helper struct used in the communication between an abstract attribute (AA) + /// that wants to change the signature of a function and the Attributor which + /// applies the changes. The struct is partially initialized with the + /// information from the AA (see the constructor). All other members are + /// provided by the Attributor prior to invoking any callbacks. + struct ArgumentReplacementInfo { + /// Callee repair callback type + /// + /// The function repair callback is invoked once to rewire the replacement + /// arguments in the body of the new function. The argument replacement info + /// is passed, as build from the registerFunctionSignatureRewrite call, as + /// well as the replacement function and an iteratore to the first + /// replacement argument. + using CalleeRepairCBTy = std::function; + + /// Abstract call site (ACS) repair callback type + /// + /// The abstract call site repair callback is invoked once on every abstract + /// call site of the replaced function (\see ReplacedFn). The callback needs + /// to provide the operands for the call to the new replacement function. + /// The number and type of the operands appended to the provided vector + /// (second argument) is defined by the number and types determined through + /// the replacement type vector (\see ReplacementTypes). The first argument + /// is the ArgumentReplacementInfo object registered with the Attributor + /// through the registerFunctionSignatureRewrite call. + using ACSRepairCBTy = + std::function &)>; + + /// Simple getters, see the corresponding members for details. + ///{ + + Attributor &getAttributor() const { return A; } + const Function &getReplacedFn() const { return ReplacedFn; } + const Argument &getReplacedArg() const { return ReplacedArg; } + unsigned getNumReplacementArgs() const { return ReplacementTypes.size(); } + const SmallVectorImpl &getReplacementTypes() const { + return ReplacementTypes; + } + + ///} + + private: + /// Constructor that takes the argument to be replaced, the types of + /// the replacement arguments, as well as callbacks to repair the call sites + /// and new function after the replacement happened. + ArgumentReplacementInfo(Attributor &A, Argument &Arg, + ArrayRef ReplacementTypes, + CalleeRepairCBTy &&CalleeRepairCB, + ACSRepairCBTy &&ACSRepairCB) + : A(A), ReplacedFn(*Arg.getParent()), ReplacedArg(Arg), + ReplacementTypes(ReplacementTypes.begin(), ReplacementTypes.end()), + CalleeRepairCB(std::move(CalleeRepairCB)), + ACSRepairCB(std::move(ACSRepairCB)) {} + + /// Reference to the attributor to allow access from the callbacks. + Attributor &A; + + /// The "old" function replaced by ReplacementFn. + const Function &ReplacedFn; + + /// The "old" argument replaced by new ones defined via ReplacementTypes. + const Argument &ReplacedArg; + + /// The types of the arguments replacing ReplacedArg. + const SmallVector ReplacementTypes; + + /// Callee repair callback, see CalleeRepairCBTy. + const CalleeRepairCBTy CalleeRepairCB; + + /// Abstract call site (ACS) repair callback, see ACSRepairCBTy. + const ACSRepairCBTy ACSRepairCB; + + /// Allow access to the private members from the Attributor. + friend class Attributor; + }; + + /// Register a rewrite for a function signature. + /// + /// The argument \p Arg is replaced with new ones defined by the number, + /// order, and types in \p ReplacementTypes. The rewiring at the call sites is + /// done through \p ACSRepairCB and at the callee site through + /// \p CalleeRepairCB. + /// + /// \returns True, if the replacement was registered, false otherwise. + bool registerFunctionSignatureRewrite( + Argument &Arg, ArrayRef ReplacementTypes, + ArgumentReplacementInfo::CalleeRepairCBTy &&CalleeRepairCB, + ArgumentReplacementInfo::ACSRepairCBTy &&ACSRepairCB); + /// Check \p Pred on all function call sites. /// /// This method will evaluate \p Pred on call sites and return @@ -942,6 +1033,11 @@ return nullptr; } + /// Apply all requested function signature rewrites + /// (\see registerFunctionSignatureRewrite) and return Changed if the module + /// was altered. + ChangeStatus rewriteFunctionSignatures(); + /// The set of all abstract attributes. ///{ using AAVector = SmallVector; @@ -965,6 +1061,10 @@ QueryMapTy QueryMap; ///} + /// Map to remember all requested signature changes (= argument replacements). + DenseMap> + ArgumentReplacementMap; + /// The information cache that holds pre-processed (LLVM-IR) information. InformationCache &InfoCache; diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -3238,8 +3238,17 @@ // step. The AAIsDead impl. will also depend on the simplified value // result as a simplified value means we replace the uses. if (Changed == ChangeStatus::CHANGED || Arg.getNumUses() == 0) { - // We can replace the call site arguments with undef because the value is - // not used. + // Check if the argument can be deleted. If so, register a rewrite without + // replacement types. + if (Arg.getParent()->hasLocalLinkage()) + if (A.registerFunctionSignatureRewrite( + Arg, /* ReplacementTypes */ {}, + Attributor::ArgumentReplacementInfo::CalleeRepairCBTy{}, + Attributor::ArgumentReplacementInfo::ACSRepairCBTy{})) + return ChangeStatus::CHANGED; + + // Argument cannot be deleted but we can replace the call site arguments + // with undef because the value is not used. SmallVector, 4> CallSiteOpReplacements; auto RplArgOpWithUndef = [&](AbstractCallSite ACS) { // Check if we have an associated argument or not (which can happen @@ -4490,6 +4499,7 @@ continue; STATS_TRACK(AAIsDead, Function); + ToBeDeletedFunctions.insert(F); F->replaceAllUsesWith(UndefValue::get(F->getType())); F->eraseFromParent(); InternalFns[u] = nullptr; @@ -4498,6 +4508,9 @@ } } + // Rewrite the functions as requested during manifest. + ManifestChange = ManifestChange | rewriteFunctionSignatures(); + if (VerifyMaxFixpointIterations && IterationCounter != MaxFixpointIterations) { errs() << "\n[Attributor] Fixpoint iteration done after: " @@ -4510,6 +4523,254 @@ return ManifestChange; } +bool Attributor::registerFunctionSignatureRewrite( + Argument &Arg, ArrayRef ReplacementTypes, + ArgumentReplacementInfo::CalleeRepairCBTy &&CalleeRepairCB, + ArgumentReplacementInfo::ACSRepairCBTy &&ACSRepairCB) { + + auto CallSiteCanBeChanged = [](AbstractCallSite ACS) { + return !ACS.isCallbackCall(); + }; + + Function *Fn = Arg.getParent(); + // Avoid var-arg functions for now. + if (Fn->isVarArg()) { + LLVM_DEBUG(dbgs() << "[Attributor] Cannot rewrite var-args functions\n"); + return false; + } + + // Avoid functions with complicated argument passing semantics. + AttributeList FnAttributeList = Fn->getAttributes(); + if (FnAttributeList.hasAttrSomewhere(Attribute::Nest) || + FnAttributeList.hasAttrSomewhere(Attribute::StructRet) || + FnAttributeList.hasAttrSomewhere(Attribute::InAlloca)) { + LLVM_DEBUG( + dbgs() << "[Attributor] Cannot rewrite due to complex attribute\n"); + return false; + } + + // Avoid callbacks for now. + if (!checkForAllCallSites(CallSiteCanBeChanged, *Fn, true, nullptr)) { + LLVM_DEBUG(dbgs() << "[Attributor] Cannot rewrite all call sites\n"); + return false; + } + + auto InstPred = [](Instruction &I) { + if (auto *CI = dyn_cast(&I)) + return !CI->isMustTailCall(); + return true; + }; + + // Forbid must-tail calls for now. + // TODO: + bool AnyDead; + auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(*Fn); + if (!checkForAllInstructionsImpl(OpcodeInstMap, InstPred, nullptr, AnyDead, + {Instruction::Call})) { + LLVM_DEBUG(dbgs() << "[Attributor] Cannot rewrite due to instructions\n"); + return false; + } + + SmallVectorImpl &ARIs = ArgumentReplacementMap[Fn]; + if (ARIs.size() == 0) + ARIs.resize(Fn->arg_size()); + + // If we have a replacement already with less than or equal new arguments, + // ignore this request. + ArgumentReplacementInfo *&ARI = ARIs[Arg.getArgNo()]; + if (ARI && ARI->getNumReplacementArgs() <= ReplacementTypes.size()) { + LLVM_DEBUG(dbgs() << "[Attributor] Existing rewrite is preferred\n"); + return false; + } + + // If we have a replacement already but we like the new one better, delete + // the old. + if (ARI) + delete ARI; + + // Remember the replacement. + ARI = new ArgumentReplacementInfo(*this, Arg, ReplacementTypes, + std::move(CalleeRepairCB), + std::move(ACSRepairCB)); + + return true; +} + +ChangeStatus Attributor::rewriteFunctionSignatures() { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + for (auto &It : ArgumentReplacementMap) { + Function *ReplacedFn = It.getFirst(); + + // Deleted functions do not require rewrites. + if (ToBeDeletedFunctions.count(ReplacedFn)) + continue; + + const SmallVectorImpl &ARIs = It.getSecond(); + assert(ARIs.size() == ReplacedFn->arg_size() && "Inconsistent state!"); + + SmallVector ReplacementArgumentTypes; + SmallVector ReplacementArgumentAttributes; + + // Collect replacement argument types and copy over existing attributes. + AttributeList ReplacedFnAttributeList = ReplacedFn->getAttributes(); + for (Argument &Arg : ReplacedFn->args()) { + if (ArgumentReplacementInfo *ARI = ARIs[Arg.getArgNo()]) { + ReplacementArgumentTypes.append(ARI->ReplacementTypes.begin(), + ARI->ReplacementTypes.end()); + ReplacementArgumentAttributes.append(ARI->getNumReplacementArgs(), + AttributeSet()); + } else { + ReplacementArgumentTypes.push_back(Arg.getType()); + ReplacementArgumentAttributes.push_back( + ReplacedFnAttributeList.getParamAttributes(Arg.getArgNo())); + } + } + + FunctionType *ReplacedFnTy = ReplacedFn->getFunctionType(); + Type *RetTy = ReplacedFnTy->getReturnType(); + + // Construct the new function type using the new arguments types. + FunctionType *ReplacementFnTy = FunctionType::get( + RetTy, ReplacementArgumentTypes, ReplacedFnTy->isVarArg()); + + LLVM_DEBUG(dbgs() << "[Attributor] Function rewrite '" + << ReplacedFn->getName() << "' from " + << *ReplacedFn->getFunctionType() << " to " + << *ReplacementFnTy << "\n"); + + // Create the new function body and insert it into the module. + Function *ReplacementFn = + Function::Create(ReplacementFnTy, ReplacedFn->getLinkage(), + ReplacedFn->getAddressSpace(), ""); + ReplacedFn->getParent()->getFunctionList().insert(ReplacedFn->getIterator(), + ReplacementFn); + ReplacementFn->takeName(ReplacedFn); + ReplacementFn->copyAttributesFrom(ReplacedFn); + + // Patch the pointer to LLVM function in debug info descriptor. + ReplacementFn->setSubprogram(ReplacedFn->getSubprogram()); + ReplacedFn->setSubprogram(nullptr); + + // Recompute the parameter attributes list based on the new arguments for + // the function. + LLVMContext &Ctx = ReplacedFn->getContext(); + ReplacementFn->setAttributes( + AttributeList::get(Ctx, ReplacedFnAttributeList.getFnAttributes(), + ReplacedFnAttributeList.getRetAttributes(), + ReplacementArgumentAttributes)); + + // Since we have now created the new function, splice the body of the old + // function right into the new function, leaving the old rotting hulk of the + // function empty. + ReplacementFn->getBasicBlockList().splice(ReplacementFn->begin(), + ReplacedFn->getBasicBlockList()); + + // Set of all "call-like" instructions that invoke the replaced function. + SmallPtrSet ReplacedCallSites; + + // Callback to create a new "call-like" instruction for a given one. + auto CallSiteReplacementCreator = [&](AbstractCallSite ACS) { + CallBase *ReplacedCB = cast(ACS.getInstruction()); + const AttributeList &ReplacedCallAttributeList = + ReplacedCB->getAttributes(); + + // Collect the new argument operands for the replacement call site. + SmallVector NewArgOperands; + SmallVector NewArgOperandAttributes; + for (unsigned OldArgNum = 0; OldArgNum < ARIs.size(); ++OldArgNum) { + unsigned NewFirstArgNum = NewArgOperands.size(); + if (ArgumentReplacementInfo *ARI = ARIs[OldArgNum]) { + if (ARI->ACSRepairCB) + ARI->ACSRepairCB(*ARI, ACS, NewArgOperands); + assert(ARI->getNumReplacementArgs() + NewFirstArgNum == + NewArgOperands.size() && + "ACS repair callback did not provide as many operand as new " + "types were registered!"); + // TODO: Exose the attribute set to the ACS repair callback + NewArgOperandAttributes.append(ARI->ReplacementTypes.size(), + AttributeSet()); + } else { + NewArgOperands.push_back(ACS.getCallArgOperand(OldArgNum)); + NewArgOperandAttributes.push_back( + ReplacedCallAttributeList.getParamAttributes(OldArgNum)); + } + } + + assert(NewArgOperands.size() == NewArgOperandAttributes.size() && + "Mismatch # argument operands vs. # argument operand attributes!"); + assert(NewArgOperands.size() == ReplacementFn->arg_size() && + "Mismatch # argument operands vs. # function arguments!"); + + SmallVector OperandBundleDefs; + ReplacedCB->getOperandBundlesAsDefs(OperandBundleDefs); + + // Create a new call or invoke instruction to replace the old one. + CallBase *ReplacementCB; + if (InvokeInst *II = dyn_cast(ReplacedCB)) { + ReplacementCB = InvokeInst::Create(ReplacementFn, II->getNormalDest(), + II->getUnwindDest(), NewArgOperands, + OperandBundleDefs, "", ReplacedCB); + } else { + auto *NewCI = CallInst::Create(ReplacementFn, NewArgOperands, + OperandBundleDefs, "", ReplacedCB); + NewCI->setTailCallKind(cast(ReplacedCB)->getTailCallKind()); + ReplacementCB = NewCI; + } + + // Copy over various properties and the new attributes. + ReplacedCB->replaceAllUsesWith(ReplacementCB); + uint64_t W; + if (ReplacedCB->extractProfTotalWeight(W)) + ReplacementCB->setProfWeight(W); + ReplacementCB->setCallingConv(ReplacedCB->getCallingConv()); + ReplacementCB->setDebugLoc(ReplacedCB->getDebugLoc()); + ReplacementCB->takeName(ReplacedCB); + ReplacementCB->setAttributes( + AttributeList::get(Ctx, ReplacedCallAttributeList.getFnAttributes(), + ReplacedCallAttributeList.getRetAttributes(), + NewArgOperandAttributes)); + + bool Inserted = ReplacedCallSites.insert(ReplacedCB).second; + assert(Inserted && "Call site was replaced twice!"); + (void)Inserted; + + return true; + }; + + // Use the CallSiteReplacementCreator to create replacement call sites. + bool Success = checkForAllCallSites(CallSiteReplacementCreator, *ReplacedFn, + true, nullptr); + assert(Success && "Assumed call site replacement to succeed!"); + + // Rewire the arguments. + auto ReplacedFnArgIt = ReplacedFn->arg_begin(); + auto ReplacementFnArgIt = ReplacementFn->arg_begin(); + for (unsigned OldArgNum = 0; OldArgNum < ARIs.size(); + ++OldArgNum, ++ReplacedFnArgIt) { + if (ArgumentReplacementInfo *ARI = ARIs[OldArgNum]) { + if (ARI->CalleeRepairCB) + ARI->CalleeRepairCB(*ARI, *ReplacementFn, ReplacementFnArgIt); + ReplacementFnArgIt += ARI->ReplacementTypes.size(); + } else { + ReplacementFnArgIt->takeName(&*ReplacedFnArgIt); + ReplacedFnArgIt->replaceAllUsesWith(&*ReplacementFnArgIt); + ++ReplacementFnArgIt; + } + } + + // Eliminate the instructions *after* we visited all of them. + for (Instruction *ReplacedCallSite : ReplacedCallSites) + ReplacedCallSite->eraseFromParent(); + + assert(ReplacedFn->getNumUses() == 0 && "Unexpected leftover uses!"); + ReplacedFn->eraseFromParent(); + Changed = ChangeStatus::CHANGED; + } + + return Changed; +} + void Attributor::initializeInformationCache(Function &F) { // Walk all instructions to find interesting instructions that might be diff --git a/llvm/test/Transforms/FunctionAttrs/align.ll b/llvm/test/Transforms/FunctionAttrs/align.ll --- a/llvm/test/Transforms/FunctionAttrs/align.ll +++ b/llvm/test/Transforms/FunctionAttrs/align.ll @@ -88,7 +88,7 @@ br i1 %2, label %3, label %5 ;