Index: include/llvm/Transforms/Scalar/SROA.h =================================================================== --- include/llvm/Transforms/Scalar/SROA.h +++ include/llvm/Transforms/Scalar/SROA.h @@ -27,6 +27,7 @@ /// A private "module" namespace for types and utilities used by SROA. These /// are implementation details and should not be used by clients. namespace sroa { +class AggregateLifter; class AllocaSliceRewriter; class AllocaSlices; class Partition; @@ -107,6 +108,7 @@ PreservedAnalyses run(Function &F, AnalysisManager *AM); private: + friend class sroa::AggregateLifter; friend class sroa::AllocaSliceRewriter; friend class sroa::SROALegacyPass; Index: lib/Transforms/Scalar/SROA.cpp =================================================================== --- lib/Transforms/Scalar/SROA.cpp +++ lib/Transforms/Scalar/SROA.cpp @@ -39,6 +39,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" @@ -2152,6 +2153,151 @@ return V; } +/// \brief Visitor to rewrite load and store of aggregate types into memcpy +/// to allocas, as to get them optimized. +/// +/// The memory access aren't optimized directly as to keep the knowledge of +/// the whole range accessed, which is important if there is some padding in +/// the aggregate accessed. This is important, especially when moving +/// aggregates around. +class llvm::sroa::AggregateLifter + : public InstVisitor { + // Befriend the base class so it can delegate to private visit methods. + friend class llvm::InstVisitor; + typedef llvm::InstVisitor Base; + + typedef std::pair FixupPair; + typedef SmallVector FixupMap; + + SROA &Pass; + + const DataLayout &DL; + IRBuilder<> AtEntry; + Value *LiftedStorage; + + FixupMap FixupWorklist; + +public: + AggregateLifter(SROA &Pass, Function &F): Pass(Pass), + DL(F.getParent()->getDataLayout()), + AtEntry(F.getEntryBlock().begin()), + LiftedStorage(nullptr) {} + + bool run(Function &F) { + for (auto &I : instructions(F)) { + if (auto *LI = dyn_cast(&I)) + registerLoad(LI); + } + + if (FixupWorklist.empty()) + return false; + + while (!FixupWorklist.empty()) { + FixupMap CurrentWorklist; + std::swap(CurrentWorklist, FixupWorklist); + for (auto &P : CurrentWorklist) { + auto *I = P.first; + lift(I, P.second); + Pass.DeadInsts.insert(I); + } + } + + return true; + } + +private: + void registerLoad(LoadInst *LI) { + if (!LI->isSimple()) + return; + + Type *T = LI->getType(); + if (!T->isAggregateType()) + return; + + unsigned Size = DL.getTypeStoreSize(T); // getTypeAllocSize ? + unsigned Align = LI->getAlignment(); + + AllocaInst* A = AtEntry.CreateAlloca(T, nullptr, + LI->getName() + ".lifted"); + A->setAlignment(Align); + Pass.Worklist.insert(A); + + IRBuilder<> Builder(LI); + Builder.CreateMemCpy(A, LI->getPointerOperand(), Size, Align, LI->isVolatile()); + FixupWorklist.push_back(FixupPair(LI, A)); + } + + void lift(Instruction *I, Value *S) { + assert(LiftedStorage == nullptr && "Expected null storage"); + LiftedStorage = S; + + Type *T = I->getType(); + assert(T->isAggregateType() && "T is expected to be an aggregate"); + + DEBUG(dbgs() << "Lifting " << *I << " into " << *S << "\n"); + + LoadInst *LI = nullptr; + for (User *U : I->users()) { + Instruction *UI = dyn_cast(U); + if (visit(UI)) { + Pass.DeadInsts.insert(UI); + continue; + } + + // No luck, let SROA deaggregate the alloca. + if (LI == nullptr) { + IRBuilder<> Builder(I); + LI = Builder.CreateLoad(S, I->getName()); + } + + U->replaceUsesOfWith(I, LI); + } + + LiftedStorage = nullptr; + } + + bool visitInstruction(Instruction &I) { + return false; + } + + bool visitStoreInst(StoreInst &SI) { + if (!SI.isSimple()) + return false; + + Type *T = SI.getValueOperand()->getType(); + unsigned Size = DL.getTypeStoreSize(T); // getTypeAllocSize ? + + IRBuilder<> Builder(&SI); + Builder.CreateMemCpy(SI.getPointerOperand(), LiftedStorage, Size, + SI.getAlignment(), SI.isVolatile()); + return true; + } + + bool visitExtractValueInst(ExtractValueInst &EVI) { + IRBuilder<> Builder(&EVI); + + SmallVector IdxList; + IdxList.push_back(Builder.getInt32(0)); + for (auto Idx : EVI.getIndices()) { + IdxList.push_back(Builder.getInt32(Idx)); + } + + Type *AT = EVI.getAggregateOperand()->getType(); + Value *GEP = Builder.CreateInBoundsGEP(AT, LiftedStorage, IdxList, + EVI.getName() + ".lifted"); + Type *T = EVI.getType(); + if (T->isAggregateType()) { + FixupWorklist.push_back(FixupPair(&EVI, GEP)); + return true; + } + + LoadInst *LI = Builder.CreateLoad(GEP, EVI.getName()); + EVI.replaceAllUsesWith(LI); + + return true; + } +}; + /// \brief Visitor to rewrite instructions using p particular slice of an alloca /// to use a new alloca. /// @@ -4208,7 +4354,8 @@ Worklist.insert(AI); } - bool Changed = false; + AggregateLifter Lifter(*this, F); + bool Changed = Lifter.run(F); // A set of deleted alloca instruction pointers which should be removed from // the list of promotable allocas. SmallPtrSet DeletedAllocas; Index: test/Transforms/SROA/lifter.ll =================================================================== --- /dev/null +++ test/Transforms/SROA/lifter.ll @@ -0,0 +1,92 @@ +; RUN: opt -sroa -S < %s | FileCheck %s +; RUN: opt -O3 -S < %s | FileCheck %s -check-prefix=CHECK-OPT + +target datalayout = "e-i64:64-f80:128-n8:16:32:64" +target triple = "x86_64-unknown-linux-gnu" + +%A = type { i8*, i64 } +%B = type { %A, i64 } + +@a0 = global %A zeroinitializer +@a1 = global %A zeroinitializer +@b0 = global %B zeroinitializer + +; CHECK-LABEL: @forward( +; CHECK-OPT-LABEL: @forward( +define void @forward() { + %1 = load %A, %A* @a0 + store %A %1, %A* @a1 + ret void +; CHECK: alloca %A +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-NEXT: ret void +; CHECK-OPT-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-OPT-NEXT: ret void +} + +; CHECK-LABEL: @forwardelement( +; CHECK-OPT-LABEL: @forwardelement( +define void @forwardelement() { + %1 = load %B, %B* @b0 + %2 = extractvalue %B %1, 0 + store %A %2, %A* @a0 + ret void +; CHECK: alloca %B +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-NEXT: getelementptr inbounds %B, %B* +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-NEXT: ret void +; CHECK-OPT-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-OPT-NEXT: ret void +} + +; CHECK-LABEL: @return( +; CHECK-OPT-LABEL: @return( +define i64 @return() { + %1 = load %A, %A* @a0 + %2 = extractvalue %A %1, 1 + ret i64 %2 +; CHECK-NOT: alloca %A +; CHECK-NOT: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK: load i64, i64* +; CHECK-NEXT: ret i64 +; CHECK-OPT-NEXT: load i64, i64* getelementptr inbounds +; CHECK-OPT-NEXT: ret i64 +} + +; CHECK-LABEL: @nested( +; CHECK-OPT-LABEL: @nested( +define i64 @nested() { + %1 = load %B, %B* @b0 + %2 = extractvalue %B %1, 0 + %3 = extractvalue %A %2, 1 + ret i64 %3 +; CHECK-NOT: alloca %B +; CHECK-NOT: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK: load i64, i64* +; CHECK: ret i64 +; CHECK-OPT-NEXT: load i64, i64* getelementptr inbounds +; CHECK-OPT-NEXT: ret i64 +} + +; CHECK-LABEL: @multibb( +; CHECK-OPT-LABEL: @multibb( +define %A @multibb() { + br label %body +body: + %1 = load %A, %A* @a0 + ret %A %1 +; CHECK-NOT: alloca %A +; CHECK-NOT: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK: load i8*, i8** +; CHECK-NEXT: load i64, i64* +; CHECK-NEXT: insertvalue %A +; CHECK-NEXT: insertvalue %A +; CHECK-NEXT: ret %A +; CHECK-OPT: load i8*, i8** +; CHECK-OPT-NEXT: load i64, i64* +; CHECK-OPT-NEXT: insertvalue %A +; CHECK-OPT-NEXT: insertvalue %A +; CHECK-OPT-NEXT: ret %A +}