Index: include/llvm/Transforms/IPO/ArgumentPromotion.h =================================================================== --- include/llvm/Transforms/IPO/ArgumentPromotion.h +++ include/llvm/Transforms/IPO/ArgumentPromotion.h @@ -13,6 +13,7 @@ #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/LazyCallGraph.h" #include "llvm/IR/PassManager.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" namespace llvm { @@ -23,6 +24,10 @@ /// direct (by-value) arguments. class ArgumentPromotionPass : public PassInfoMixin { unsigned MaxElements; + // Tracks globals with metadata pointing at functions. + // If we end up re-creating a function, we point + // the metadata at the new function. + AssociatedGlobalsMap AssociatedGlobals; public: ArgumentPromotionPass(unsigned MaxElements = 3u) : MaxElements(MaxElements) {} Index: include/llvm/Transforms/IPO/DeadArgumentElimination.h =================================================================== --- include/llvm/Transforms/IPO/DeadArgumentElimination.h +++ include/llvm/Transforms/IPO/DeadArgumentElimination.h @@ -24,6 +24,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/IR/Function.h" #include "llvm/IR/PassManager.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" #include #include #include @@ -107,6 +108,12 @@ /// directly to F. UseMap Uses; + /// Maps Functions to GlobalObjects which reference them via MDNodes + /// (e.g. MD_associated metadata). When we modify a function, we need + /// to change any metadata which point at the old function + /// to point at the new function (with dead args/return values removed) + AssociatedGlobalsMap AssociatedGlobals; + using LiveSet = std::set; using LiveFuncSet = std::set; Index: include/llvm/Transforms/Utils/ModuleUtils.h =================================================================== --- include/llvm/Transforms/Utils/ModuleUtils.h +++ include/llvm/Transforms/Utils/ModuleUtils.h @@ -15,6 +15,11 @@ #define LLVM_TRANSFORMS_UTILS_MODULEUTILS_H #include "llvm/ADT/StringRef.h" +#include +#include +#include +#include +#include #include // for std::pair namespace llvm { @@ -29,6 +34,8 @@ class Value; class Type; +using AssociatedGlobalsMap = std::multimap; + /// Append F to the list of global ctors of module M with the given Priority. /// This wraps the function in the appropriate structure and stores it along /// side other global constructors. For details see @@ -96,6 +103,22 @@ /// unique identifier for this module, so we return the empty string. std::string getUniqueModuleId(Module *M); +/// Creates a map from Functions to GlobalObjects which reference them via +/// MDNodes (e.g. MD_associated metadata). When we modify a function, we need to +/// change any metadata which point at the old function to point at the new +/// function (with dead args/return values removed) +void collectAssociatedGlobals(Module *M, + AssociatedGlobalsMap *AssociatedGlobals); + +/// Changes GlobalObject metadata that points at OldFn to point at NewFn, using +/// the data collected from collectAssociatedGlobals. +/// This should be used by transformations that 'modify' functions by deleting +/// the original and replacing it with a modified version (e.g. +/// DeadArgumentElimination, ArgumentPromotion). This should only be used when +/// the new function should be conisdered to be 'the same' as the old one. +void FixupMetadataReferences(AssociatedGlobalsMap *AssociatedGlobals, + Function *OldFn, Function *NewFn); + } // End llvm namespace #endif // LLVM_TRANSFORMS_UTILS_MODULEUTILS_H Index: lib/Transforms/IPO/ArgumentPromotion.cpp =================================================================== --- lib/Transforms/IPO/ArgumentPromotion.cpp +++ lib/Transforms/IPO/ArgumentPromotion.cpp @@ -963,6 +963,9 @@ CGSCCUpdateResult &UR) { bool Changed = false, LocalChange; + collectAssociatedGlobals(C.begin()->getFunction().getParent(), + &AssociatedGlobals); + // Iterate until we stop promoting from this SCC. do { LocalChange = false; @@ -990,12 +993,15 @@ // swaps out the particular function mapped to a particular node in the // graph. C.getOuterRefSCC().replaceNodeFunction(N, *NewF); + FixupMetadataReferences(&AssociatedGlobals, &OldF, NewF); OldF.eraseFromParent(); } Changed |= LocalChange; } while (LocalChange); + AssociatedGlobals.clear(); + if (!Changed) return PreservedAnalyses::all(); @@ -1008,6 +1014,7 @@ struct ArgPromotion : public CallGraphSCCPass { // Pass identification, replacement for typeid static char ID; + AssociatedGlobalsMap AssociatedGlobals; explicit ArgPromotion(unsigned MaxElements = 3) : CallGraphSCCPass(ID), MaxElements(MaxElements) { @@ -1053,6 +1060,8 @@ if (skipSCC(SCC)) return false; + collectAssociatedGlobals(&SCC.getCallGraph().getModule(), &AssociatedGlobals); + // Get the callgraph information that we need to update to reflect our // changes. CallGraph &CG = getAnalysis().getCallGraph(); @@ -1090,6 +1099,8 @@ else OldF->setLinkage(Function::ExternalLinkage); + FixupMetadataReferences(&AssociatedGlobals, OldF, NewF); + // And updat ethe SCC we're iterating as well. SCC.ReplaceNode(OldNode, NewNode); } Index: lib/Transforms/IPO/DeadArgumentElimination.cpp =================================================================== --- lib/Transforms/IPO/DeadArgumentElimination.cpp +++ lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -246,6 +246,8 @@ for (auto MD : MDs) NF->addMetadata(MD.first, *MD.second); + FixupMetadataReferences(&AssociatedGlobals, &Fn, NF); + // Fix up any BlockAddresses that refer to the function. Fn.replaceAllUsesWith(ConstantExpr::getBitCast(NF, Fn.getType())); // Delete the bitcast that we just created, so that NF does not @@ -1077,6 +1079,8 @@ for (auto MD : MDs) NF->addMetadata(MD.first, *MD.second); + FixupMetadataReferences(&AssociatedGlobals, F, NF); + // Now that the old function is dead, delete it. F->eraseFromParent(); @@ -1087,6 +1091,10 @@ ModuleAnalysisManager &) { bool Changed = false; + // Save information about MD_Associated metadata so that we can + // modify it to point at our re-created functions + collectAssociatedGlobals(&M, &AssociatedGlobals); + // First pass: Do a simple check to see if any functions can have their "..." // removed. We can do this if they never call va_start. This loop cannot be // fused with the next loop, because deleting a function invalidates @@ -1120,6 +1128,8 @@ for (auto &F : M) Changed |= RemoveDeadArgumentsFromCallers(F); + AssociatedGlobals.clear(); + if (!Changed) return PreservedAnalyses::all(); return PreservedAnalyses::none(); Index: lib/Transforms/Utils/ModuleUtils.cpp =================================================================== --- lib/Transforms/Utils/ModuleUtils.cpp +++ lib/Transforms/Utils/ModuleUtils.cpp @@ -269,3 +269,36 @@ MD5::stringifyResult(R, Str); return ("$" + Str).str(); } + +void llvm::collectAssociatedGlobals(Module *M, + AssociatedGlobalsMap *AssociatedGlobals) { + for (GlobalObject &GO : M->global_objects()) { + MDNode *MD = GO.getMetadata(LLVMContext::MD_associated); + if (!MD) + continue; + + const MDOperand &Op = MD->getOperand(0); + if (!Op.get()) + continue; + + auto *VM = dyn_cast(Op); + if (!VM) + report_fatal_error("MD_associated operand is not ValueAsMetadata"); + + Function *OtherFunc = dyn_cast(VM->getValue()); + if (OtherFunc) { + AssociatedGlobals->insert(std::make_pair(OtherFunc, &GO)); + } + } +} + +void llvm::FixupMetadataReferences(AssociatedGlobalsMap *AssociatedGlobals, + Function *OldFn, Function *NewFn) { + std::pair + result = AssociatedGlobals->equal_range(OldFn); + for (AssociatedGlobalsMap::iterator I = result.first; I != result.second; + I++) { + MDNode *MD = MDNode::get(NewFn->getContext(), ValueAsMetadata::get(NewFn)); + I->second->setMetadata(LLVMContext::MD_associated, MD); + } +} Index: test/Transforms/ArgumentPromotion/md-associated.ll =================================================================== --- /dev/null +++ test/Transforms/ArgumentPromotion/md-associated.ll @@ -0,0 +1,21 @@ +; RUN: opt < %s -argpromotion -S | FileCheck %s +; RUN: opt < %s -passes=argpromotion -S | FileCheck %s + +; CHECK: @b = internal global i32 2, !associated !0 +@b = internal global i32 2, !associated !0 + +; CHECK: define internal i32 @test(i32 %A.val) +define internal i32 @test(i32* %A) { + %X = load i32, i32* %A + ret i32 %X +} + +define void @caller() { + %A = alloca i32 + store i32 1, i32* %A + %C = call i32 @test(i32* %A) + ret void +} + +; CHECK: !0 = !{i32 (i32)* @test} +!0 = !{i32 (i32*)* @test} \ No newline at end of file Index: test/Transforms/DeadArgElim/md-associated.ll =================================================================== --- /dev/null +++ test/Transforms/DeadArgElim/md-associated.ll @@ -0,0 +1,12 @@ +; RUN: opt -S -deadargelim %s | FileCheck %s + +; CHECK: @b = internal global i32 2, !associated !0 +@b = internal global i32 2, !associated !0 + +; CHECK: define internal void @test() { +define internal void @test(i32 %dead) { + ret void +} + +; CHECK: !0 = !{void ()* @test} +!0 = !{void (i32)* @test} \ No newline at end of file