Index: include/llvm/Transforms/Scalar/SROA.h =================================================================== --- include/llvm/Transforms/Scalar/SROA.h +++ include/llvm/Transforms/Scalar/SROA.h @@ -19,6 +19,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/Compiler.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include namespace llvm { @@ -66,6 +67,7 @@ LLVMContext *C = nullptr; DominatorTree *DT = nullptr; AssumptionCache *AC = nullptr; + TargetTransformInfo *TTI = nullptr; /// Worklist of alloca instructions to simplify. /// @@ -121,7 +123,8 @@ /// Helper used by both the public run method and by the legacy pass. PreservedAnalyses runImpl(Function &F, DominatorTree &RunDT, - AssumptionCache &RunAC); + AssumptionCache &RunAC, + TargetTransformInfo &TTI); bool presplitLoadsAndStores(AllocaInst &AI, sroa::AllocaSlices &AS); AllocaInst *rewritePartition(AllocaInst &AI, sroa::AllocaSlices &AS, Index: lib/Transforms/Scalar/SROA.cpp =================================================================== --- lib/Transforms/Scalar/SROA.cpp +++ lib/Transforms/Scalar/SROA.cpp @@ -297,6 +297,8 @@ void dump() const; #endif + bool shouldExpand(TargetTransformInfo *TTI) const; + private: template class BuilderBase; class SliceBuilder; @@ -1088,6 +1090,18 @@ print(OS, I); } +bool AllocaSlices::shouldExpand(TargetTransformInfo *TTI) const { + for (auto I = begin(), E = end(); I != E; ++I) { + const User *U = I->getUse()->getUser(); + if(TTI->getInstructionCost(dyn_cast(U), + TargetTransformInfo::TCK_CodeSize) >= + TargetTransformInfo::TCC_Expensive) + return false; + + } + return true; +} + LLVM_DUMP_METHOD void AllocaSlices::dump(const_iterator I) const { print(dbgs(), I); } @@ -4392,6 +4406,10 @@ // Build the slices using a recursive instruction-visiting builder. AllocaSlices AS(DL, AI); LLVM_DEBUG(AS.print(dbgs())); + + if (!AS.shouldExpand(TTI)) + return Changed; + if (AS.isEscaped()) return Changed; @@ -4490,11 +4508,13 @@ } PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, - AssumptionCache &RunAC) { + AssumptionCache &RunAC, + TargetTransformInfo &RunTTI) { LLVM_DEBUG(dbgs() << "SROA function: " << F.getName() << "\n"); C = &F.getContext(); DT = &RunDT; AC = &RunAC; + TTI = &RunTTI; BasicBlock &EntryBB = F.getEntryBlock(); for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end()); @@ -4542,7 +4562,8 @@ PreservedAnalyses SROA::run(Function &F, FunctionAnalysisManager &AM) { return runImpl(F, AM.getResult(F), - AM.getResult(F)); + AM.getResult(F), + AM.getResult(F)); } /// A legacy pass for the legacy pass manager that wraps the \c SROA pass. @@ -4566,11 +4587,13 @@ auto PA = Impl.runImpl( F, getAnalysis().getDomTree(), - getAnalysis().getAssumptionCache(F)); + getAnalysis().getAssumptionCache(F), + getAnalysis().getTTI(F)); return !PA.areAllPreserved(); } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addPreserved();