Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -161,6 +161,7 @@ void initializeGlobalSplitPass(PassRegistry&); void initializeGlobalsAAWrapperPassPass(PassRegistry&); void initializeGuardWideningLegacyPassPass(PassRegistry&); +void initializeHeapToStackLegacyPassPass(PassRegistry&); void initializeHotColdSplittingLegacyPassPass(PassRegistry&); void initializeHWAddressSanitizerPass(PassRegistry&); void initializeIPCPPass(PassRegistry&); Index: include/llvm/LinkAllPasses.h =================================================================== --- include/llvm/LinkAllPasses.h +++ include/llvm/LinkAllPasses.h @@ -217,6 +217,7 @@ (void) llvm::createMemDerefPrinter(); (void) llvm::createMustExecutePrinter(); (void) llvm::createFloat2IntPass(); + (void) llvm::createHeapToStackPass(); (void) llvm::createEliminateAvailableExternallyPass(); (void) llvm::createScalarizeMaskedMemIntrinPass(); Index: include/llvm/Transforms/Scalar.h =================================================================== --- include/llvm/Transforms/Scalar.h +++ include/llvm/Transforms/Scalar.h @@ -446,6 +446,12 @@ //===----------------------------------------------------------------------===// // +// HeapToStack - Convert heap allocations to stack allocations where possible. +// +FunctionPass *createHeapToStackPass(); + +//===----------------------------------------------------------------------===// +// // NaryReassociate - Simplify n-ary operations by reassociation. // FunctionPass *createNaryReassociatePass(); Index: include/llvm/Transforms/Scalar/HeapToStack.h =================================================================== --- /dev/null +++ include/llvm/Transforms/Scalar/HeapToStack.h @@ -0,0 +1,29 @@ +//===-- HeapToStack.h - Heap-to-Stack Conversion ----------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass moves small allocations from the heap to the stack. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_SCALAR_HEAPTOSTACK_H +#define LLVM_TRANSFORMS_SCALAR_HEAPTOSTACK_H + +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +/// Move instructions into successor blocks when possible. +class HeapToStackPass : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} + +#endif // LLVM_TRANSFORMS_SCALAR_HEAPTOSTACK_H Index: lib/Passes/PassBuilder.cpp =================================================================== --- lib/Passes/PassBuilder.cpp +++ lib/Passes/PassBuilder.cpp @@ -106,6 +106,7 @@ #include "llvm/Transforms/Scalar/Float2Int.h" #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Scalar/GuardWidening.h" +#include "llvm/Transforms/Scalar/HeapToStack.h" #include "llvm/Transforms/Scalar/IVUsersPrinter.h" #include "llvm/Transforms/Scalar/IndVarSimplify.h" #include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h" @@ -364,6 +365,8 @@ // Catch trivial redundancies FPM.addPass(EarlyCSEPass(EnableEarlyCSEMemSSA)); + FPM.addPass(HeapToStackPass()); + // Hoisting of scalars and load expressions. if (EnableGVNHoist) FPM.addPass(GVNHoistPass()); Index: lib/Passes/PassRegistry.def =================================================================== --- lib/Passes/PassRegistry.def +++ lib/Passes/PassRegistry.def @@ -170,6 +170,7 @@ FUNCTION_PASS("instsimplify", InstSimplifyPass()) FUNCTION_PASS("invalidate", InvalidateAllAnalysesPass()) FUNCTION_PASS("float2int", Float2IntPass()) +FUNCTION_PASS("heap-to-stack", HeapToStackPass()) FUNCTION_PASS("no-op-function", NoOpFunctionPass()) FUNCTION_PASS("libcalls-shrinkwrap", LibCallsShrinkWrapPass()) FUNCTION_PASS("loweratomic", LowerAtomicPass()) Index: lib/Transforms/IPO/PassManagerBuilder.cpp =================================================================== --- lib/Transforms/IPO/PassManagerBuilder.cpp +++ lib/Transforms/IPO/PassManagerBuilder.cpp @@ -321,6 +321,7 @@ // Break up aggregate allocas, using SSAUpdater. MPM.add(createSROAPass()); MPM.add(createEarlyCSEPass(EnableEarlyCSEMemSSA)); // Catch trivial redundancies + MPM.add(createHeapToStackPass()); if (EnableGVNHoist) MPM.add(createGVNHoistPass()); if (EnableGVNSink) { Index: lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- lib/Transforms/Scalar/CMakeLists.txt +++ lib/Transforms/Scalar/CMakeLists.txt @@ -16,6 +16,7 @@ GVN.cpp GVNHoist.cpp GVNSink.cpp + HeapToStack.cpp IVUsersPrinter.cpp InductiveRangeCheckElimination.cpp IndVarSimplify.cpp Index: lib/Transforms/Scalar/HeapToStack.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/HeapToStack.cpp @@ -0,0 +1,412 @@ +//===-- HeapToStack.cpp - Code HeapToStack -------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass moves instructions into successor blocks, when possible, so that +// they aren't executed on paths where their results aren't needed. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/HeapToStack.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include +using namespace llvm; + +#define DEBUG_TYPE "heap-to-stack" + +STATISTIC(NumAConverted, "Number of converted allocations"); +STATISTIC(NumDConverted, "Number of converted deallocations"); + +static cl::opt EnableHeapToStack("enable-heap-to-stack-conversion", + cl::init(true), cl::Hidden); +static cl::opt BeAggressive("aggressive-heap-to-stack-conversion", + cl::init(false), cl::Hidden); +static cl::opt BeVeryAggressive("very-aggressive-heap-to-stack-conversion", + cl::init(false), cl::Hidden); +static cl::opt MaxHeapToStackSize("max-heap-to-stack-size", + cl::init(1024), cl::Hidden); + +static bool convertAllocations(Function &F, DominatorTree &DT, + LoopInfo &LI, const TargetLibraryInfo &TLI, + OptimizationRemarkEmitter &ORE, + bool &ChangedCFG) { + auto &DL = F.getParent()->getDataLayout(); + bool MadeChange = false; + ChangedCFG = false; + + if (!EnableHeapToStack) + return MadeChange; + + // First, find allocation and deallocation functions. If we don't have at + // least one of both of these, then we're done. + SmallPtrSet MallocCalls, FreeCalls; + + for (auto &BB : F) + for (auto &I : BB) { + if (isMallocLikeFn(&I, &TLI)) + if (auto *CSize = dyn_cast(I.getOperand(0))) + if (CSize->getValue().sle(MaxHeapToStackSize)) { + MallocCalls.insert(&I); + LLVM_DEBUG(dbgs() << "H2S: Initial malloc call: " << I << "\n"); + } + + if (isFreeCall(&I, &TLI)) { + FreeCalls.insert(&I); + LLVM_DEBUG(dbgs() << "H2S: Initial free call: " << I << "\n"); + } + } + + if (MallocCalls.empty() || FreeCalls.empty()) + return MadeChange; + + // A map of each malloc call to the set of associated free calls. + DenseMap> FreesForMalloc; + + bool ShouldRefine; + do { + ShouldRefine = false; + // For each call to free, the underlying objects of its arguments must be + // the already-collected set of calls to malloc. If not, remove the free. + + FreesForMalloc.clear(); + SmallVector BadFreeCalls; + for (auto *FreeCall : FreeCalls) { + SmallVector FreeUOs; + GetUnderlyingObjects(FreeCall->getOperand(0), FreeUOs, DL, &LI); + + // Note that the call to GetUnderlyingObjects looks through pointer + // arithmetic, and that's okay, but we're relying on the fact that it is + // UB to pass an offset-adjusted pointer to free (thus, and offset + // applied must be reverted by the time the value gets to the call to + // free). + + for (auto *UO : FreeUOs) { + if (!isa(UO) || + !MallocCalls.count(cast(UO))) { + BadFreeCalls.push_back(FreeCall); + LLVM_DEBUG(dbgs() << "H2S: Bad free call: " << *FreeCall << " UO: " << + *UO << "\n"); + break; + } + + FreesForMalloc[cast(UO)].insert(FreeCall); + } + } + + for (auto *BadFreeCall : BadFreeCalls) + FreeCalls.erase(BadFreeCall); + + if (FreeCalls.empty()) + return MadeChange; + + // For each malloc call, we need to know if there's a path to the exit + // which does not contain one of the associated calls to free. If there is, + // then we can't convert this malloc. + SmallVector BadMallocCalls; + for (auto *MallocCall : MallocCalls) { + if (FreesForMalloc[MallocCall].empty()) { + BadMallocCalls.push_back(MallocCall); + LLVM_DEBUG(dbgs() << "H2S: Bad malloc call (no frees): " << + *MallocCall << "\n"); + continue; + } + + // Note: We need to be careful about the case where an infinite loop is + // followed by a free. In this case, the free is dead, and the infinite + // loop might allow some called function, or another thread, to free the + // data elsewhere. Currently, things that cause an infinite loop to be + // well defined (e.g., atomics, I/O) will also cause + // isGuaranteedToTransferExecutionToSuccessor to return false, + // so we're okay on that front. + + struct WLEntry { + BasicBlock::iterator I; + BasicBlock *BB; + std::set Ptrs; + + WLEntry(BasicBlock::iterator I, BasicBlock *BB, Instruction *Ptr) + : I(I), BB(BB) { + Ptrs.insert(Ptr); + } + + WLEntry(BasicBlock::iterator I, BasicBlock *BB, + std::set &Ptrs) + : I(I), BB(BB), Ptrs(Ptrs) {} + + bool operator == (const WLEntry &WLE) const { + return I == WLE.I && BB == WLE.BB && Ptrs == WLE.Ptrs; + } + + bool operator < (const WLEntry &WLE) const { + if (std::less()(BB, WLE.BB)) + return true; + else if (std::less()(WLE.BB, BB)) + return false; + + if (std::less()((Instruction *) I.getNodePtr(), + (Instruction *) WLE.I.getNodePtr())) + return true; + else if (std::less()((Instruction *) WLE.I.getNodePtr(), + (Instruction *) I.getNodePtr())) + return false; + + return Ptrs < WLE.Ptrs; + } + }; + + SmallVector Worklist; + Worklist.push_back(WLEntry(std::next(BasicBlock::iterator(MallocCall)), + MallocCall->getParent(), MallocCall)); + + SmallSet Visited; + + while (!Worklist.empty()) { + bool FoundExit = false, FoundFree = false; + WLEntry WLE = Worklist.pop_back_val(); + if (!Visited.insert(WLE).second) + continue; + + auto Ptrs = WLE.Ptrs; + for (auto I = WLE.I, IE = WLE.BB->end(); I != IE; ++I) { + if (isFreeCall(&*I, &TLI) && + Ptrs.count((Instruction *) I->getOperand(0))) { + FoundFree = true; + LLVM_DEBUG(dbgs() << "H2S: Paired malloc call: " << *MallocCall << + " with free at: " << *I << "\n"); + break; + } + + if (auto *GEPI = dyn_cast(I)) { + if (Ptrs.count((Instruction *) GEPI->getPointerOperand())) + Ptrs.insert(&*I); + } else if (I->getOpcode() == Instruction::BitCast || + I->getOpcode() == Instruction::AddrSpaceCast) { + if (Ptrs.count((Instruction *) I->getOperand(0))) + Ptrs.insert(&*I); + } else if (auto CS = CallSite(&*I)) { + if (Ptrs.count((Instruction *) CS.getReturnedArgOperand())) + Ptrs.insert(&*I); + } else if (auto *SI = dyn_cast(I)) { + if (Ptrs.count((Instruction *) SI->getTrueValue()) || + Ptrs.count((Instruction *) SI->getFalseValue())) + Ptrs.insert(&*I); + } + + if (BeAggressive) { + auto NotAnExit = [&](const Instruction *I) { + if (const auto *CRI = dyn_cast(I)) + return !CRI->unwindsToCaller(); + if (const auto *CatchSwitch = dyn_cast(I)) + return !CatchSwitch->unwindsToCaller(); + if (isa(I)) + return false; + if (isa(I)) + return false; + if (isa(I)) + return false; + + if (auto CS = ImmutableCallSite(I)) + if (!CS.doesNotThrow() && !BeVeryAggressive) + return false; + + return true; + }; + + if (!NotAnExit(&*I)) + FoundExit = true; + } else { + if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) { + // Might return, etc. before we found something that might free the + // memory. + FoundExit = true; + } + } + + if (FoundExit) { + // The exit is a problem only if the pointer is captured... + bool AnyCaptured = false; + for (auto *Ptr : Ptrs) + // TODO: Add an OBB cache to the capturing query. + if (PointerMayBeCapturedBefore(Ptr, /* ReturnCaptures */ true, + /* StoreCaptures */ true, &*I, + &DT, /* IncludeI */ true)) { + LLVM_DEBUG(dbgs() << "H2S: Pointer: " << *Ptr << + " from malloc call: " << *MallocCall << + " captured before, or at, " + "the potential exit at: " << *I << "\n"); + AnyCaptured = true; + break; + } + + if (!AnyCaptured) { + FoundExit = false; + } else { + LLVM_DEBUG(dbgs() << "H2S: Bad malloc call: " << *MallocCall << + " found potential exit at: " << *I << "\n"); + break; + } + } + } + + if (FoundExit) { + BadMallocCalls.push_back(MallocCall); + break; + } else if (!FoundFree) { + for (auto *SBB : successors(WLE.BB)) { + auto SPtrs = Ptrs; + for (PHINode &PN : SBB->phis()) + if (Ptrs.count((Instruction *) + PN.getIncomingValueForBlock(WLE.BB))) + SPtrs.insert(&PN); + + Worklist.push_back(WLEntry(SBB->begin(), SBB, SPtrs)); + } + } + } + } + + if (!BadMallocCalls.empty()) + ShouldRefine = true; + + for (auto *BadMallocCall : BadMallocCalls) + MallocCalls.erase(BadMallocCall); + + if (MallocCalls.empty()) + return MadeChange; + } while (ShouldRefine); + + MadeChange = true; + + for (auto *MallocCall : MallocCalls) { + { + using namespace ore; + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "removed-malloc", MallocCall) + << "removed memory allocation (placed on stack)"; + }); + } + + LLVM_DEBUG(dbgs() << "H2S: Removing malloc call: " << *MallocCall << "\n"); + + unsigned AS = cast(MallocCall->getType())->getAddressSpace(); + Instruction *AI = new AllocaInst(Type::getInt8Ty(F.getContext()), AS, + MallocCall->getOperand(0)); + F.begin()->getInstList().insert(F.begin()->begin(), AI); + + if (AI->getType() != MallocCall->getType()) { + auto *BC = new BitCastInst(AI, MallocCall->getType()); + BC->insertAfter(AI); + AI = BC; + } + + MallocCall->replaceAllUsesWith(AI); + + if (auto *II = dyn_cast(MallocCall)) { + auto *NBB = II->getNormalDest(); + BranchInst::Create(NBB, MallocCall->getParent()); + MallocCall->eraseFromParent(); + ChangedCFG = true; + } else { + MallocCall->eraseFromParent(); + } + + ++NumAConverted; + } + + for (auto *FreeCall : FreeCalls) { + { + using namespace ore; + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "removed-free", FreeCall) + << "removed memory deallocation (placed on stack)"; + }); + } + + LLVM_DEBUG(dbgs() << "H2S: Removing free call: " << *FreeCall << "\n"); + + FreeCall->eraseFromParent(); + ++NumDConverted; + } + + return MadeChange; +} + +PreservedAnalyses HeapToStackPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &DT = AM.getResult(F); + auto &LI = AM.getResult(F); + auto &TLI = AM.getResult(F); + auto &ORE = AM.getResult(F); + + bool ChangedCFG; + if (!convertAllocations(F, DT, LI, TLI, ORE, ChangedCFG)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + if (!ChangedCFG) + PA.preserveSet(); + return PA; +} + +namespace { + class HeapToStackLegacyPass : public FunctionPass { + public: + static char ID; // Pass identification + HeapToStackLegacyPass() : FunctionPass(ID) { + initializeHeapToStackLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &DT = getAnalysis().getDomTree(); + auto &LI = getAnalysis().getLoopInfo(); + auto &TLI = getAnalysis().getTLI(); + auto &ORE = getAnalysis().getORE(); + + bool ChangedCFG; + return convertAllocations(F, DT, LI, TLI, ORE, ChangedCFG); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + FunctionPass::getAnalysisUsage(AU); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + AU.addPreserved(); + } + }; +} // end anonymous namespace + +char HeapToStackLegacyPass::ID = 0; +static const char H2SName[] = "Heap-to-Stack Conversion"; + +INITIALIZE_PASS_BEGIN(HeapToStackLegacyPass, DEBUG_TYPE, H2SName, false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(HeapToStackLegacyPass, DEBUG_TYPE, H2SName, false, false) + +FunctionPass *llvm::createHeapToStackPass() { return new HeapToStackLegacyPass(); } Index: lib/Transforms/Scalar/Scalar.cpp =================================================================== --- lib/Transforms/Scalar/Scalar.cpp +++ lib/Transforms/Scalar/Scalar.cpp @@ -99,6 +99,7 @@ initializePlaceBackedgeSafepointsImplPass(Registry); initializePlaceSafepointsPass(Registry); initializeFloat2IntLegacyPassPass(Registry); + initializeHeapToStackLegacyPassPass(Registry); initializeLoopDistributeLegacyPass(Registry); initializeLoopLoadEliminationPass(Registry); initializeLoopSimplifyCFGLegacyPassPass(Registry);