Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -300,6 +300,7 @@ void initializeLoopDistributePass(PassRegistry&); void initializeSjLjEHPreparePass(PassRegistry&); void initializeDemandedBitsPass(PassRegistry&); +void initializeAggregateLifterPass(PassRegistry&); } #endif Index: include/llvm/Transforms/Scalar.h =================================================================== --- include/llvm/Transforms/Scalar.h +++ include/llvm/Transforms/Scalar.h @@ -487,6 +487,12 @@ // FunctionPass *createLoopDistributePass(); +//===----------------------------------------------------------------------===// +// +// AggregateLifter - Convert aggregate load/store into +// optimizable operations in alloca. +FunctionPass *createAggregateLifterPass(); + } // End llvm namespace #endif Index: lib/Transforms/IPO/PassManagerBuilder.cpp =================================================================== --- lib/Transforms/IPO/PassManagerBuilder.cpp +++ lib/Transforms/IPO/PassManagerBuilder.cpp @@ -208,6 +208,9 @@ if (LibraryInfo) MPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); + // Lift aggregate load/store into alloca + MPM.add(createAggregateLifterPass()); + addInitialAliasAnalysisPasses(MPM); if (!DisableUnitAtATime) { Index: lib/Transforms/Scalar/AggregateLifter.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/AggregateLifter.cpp @@ -0,0 +1,216 @@ +//===- AggregateLifter.cpp - Lift aggregate loads/stores into alloca ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This transforms removes load ans stores of aggregate type by replacing them +/// by a memcpy into an alloc and rewrting IR to use it. This will allow +/// subsequent passes, notably SROA, to optimize properly. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" + +#include "llvm/ADT/Statistic.h" +#include "llvm/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "aggregate-lifter" + +namespace { + +typedef std::pair FixupPair; +typedef DenseMap FixupMap; + +class AggregateLifter : public FunctionPass { + FixupMap FixupWorklist; + SmallVector InstrsToErase; + +public: + AggregateLifter(): FunctionPass(ID) { } + + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + + const char *getPassName() const override { return "AggregateLifter"; } + static char ID; + +private: + bool liftExtractValue(ExtractValueInst *EVI, Value *A); + + void runOnLoad(LoadInst *LI); + void liftAggregate(Instruction *I, Value *A); +}; + +} // namespace llvm + +bool AggregateLifter::runOnFunction(Function &F) { + if (skipOptnoneFunction(F)) + return false; + + DEBUG(dbgs() << "AggregateLifter function: " << F.getName() << "\n"); + + for (auto &I : instructions(F)) { + LoadInst* LI = dyn_cast(&I); + if (!LI) { + continue; + } + + runOnLoad(LI); + } + + while (!FixupWorklist.empty()) { + FixupMap CurrentWorklist; + std::swap(CurrentWorklist, FixupWorklist); + + for (auto &I : CurrentWorklist) { + liftAggregate(I.first, I.second); + } + } + + if (InstrsToErase.empty()) { + return false; + } + + unsigned Count = InstrsToErase.size(); + for (unsigned i = Count - 1; i < Count; --i) { + InstrsToErase[i]->eraseFromParent(); + } + + InstrsToErase.clear(); + return true; +} + +void AggregateLifter::runOnLoad(LoadInst *LI) { + if (!LI->isSimple()) { + return; + } + + Type *T = LI->getType(); + if (!T->isAggregateType()) { + return; + } + + Function *F = LI->getParent()->getParent(); + + const DataLayout &DL = F->getParent()->getDataLayout(); + + unsigned Size = DL.getTypeStoreSize(T); // getTypeAllocSize ? + unsigned Align = LI->getAlignment(); + + IRBuilder<> AtEntry(F->getEntryBlock().begin()); + AllocaInst* A = AtEntry.CreateAlloca(T, nullptr, + LI->getName() + ".lifted"); + A->setAlignment(Align); + + IRBuilder<> Builder(LI); + Builder.CreateMemCpy(A, LI->getPointerOperand(), Size, Align); + + InstrsToErase.push_back(LI); + FixupWorklist.insert(FixupPair(LI, A)); +} + +static bool liftStore(StoreInst *SI, Value *A) { + if (!SI->isSimple()) { + return false; + } + + Function *F = SI->getParent()->getParent(); + const DataLayout &DL = F->getParent()->getDataLayout(); + + Type *T = SI->getValueOperand()->getType(); + unsigned Size = DL.getTypeStoreSize(T); // getTypeAllocSize ? + // What if this is smaller than alloca's align ? + unsigned Align = SI->getAlignment(); + + IRBuilder<> Builder(SI); + Builder.CreateMemCpy(SI->getPointerOperand(), A, Size, Align); + + return true; +} + +bool AggregateLifter::liftExtractValue(ExtractValueInst *EVI, Value *A) { + IRBuilder<> Builder(EVI); + + SmallVector IdxList; + IdxList.push_back(Builder.getInt32(0)); + + for (auto Idx : EVI->getIndices()) { + IdxList.push_back(Builder.getInt32(Idx)); + } + + Type *AT = EVI->getAggregateOperand()->getType(); + Value *GEP = Builder.CreateInBoundsGEP(AT, A, IdxList, + EVI->getName() + ".lifted"); + Type *T = EVI->getType(); + if (T->isAggregateType()) { + FixupWorklist.insert(FixupPair(EVI, GEP)); + return true; + } + + LoadInst *LI = Builder.CreateLoad(GEP, EVI->getName()); + EVI->replaceAllUsesWith(LI); + + return true; +} + +void AggregateLifter::liftAggregate(Instruction *I, Value *A) { + Type *T = I->getType(); + + DEBUG(dbgs() << "T: " << *T << "\n"); + assert(T->isAggregateType() && "T is expected to be an aggregate"); + + DEBUG(dbgs() << "Lifting " << *I << " into " << *A << "\n"); + + LoadInst *LI = nullptr; + + for (User *U : I->users()) { + StoreInst* SI = dyn_cast(U); + if (SI && liftStore(SI, A)) { + InstrsToErase.push_back(SI); + continue; + } + + ExtractValueInst *EVI = dyn_cast(U); + if (EVI && liftExtractValue(EVI, A)) { + InstrsToErase.push_back(EVI); + continue; + } + + // No luck, fallback on loading the aggregate and hope SROA knows better. + if (LI == nullptr) { + IRBuilder<> Builder(I); + LI = Builder.CreateLoad(A, I->getName()); + } + + U->replaceUsesOfWith(I, LI); + } +} + +void AggregateLifter::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); +} + +char AggregateLifter::ID = 0; + +FunctionPass *llvm::createAggregateLifterPass() { + return new AggregateLifter(); +} + +INITIALIZE_PASS(AggregateLifter, "aggregate-lifter", + "Lift aggregate store/load into alloca.", + false, false) + Index: lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- lib/Transforms/Scalar/CMakeLists.txt +++ lib/Transforms/Scalar/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(LLVMScalarOpts ADCE.cpp + AggregateLifter.cpp AlignmentFromAssumptions.cpp BDCE.cpp ConstantHoisting.cpp Index: lib/Transforms/Scalar/Scalar.cpp =================================================================== --- lib/Transforms/Scalar/Scalar.cpp +++ lib/Transforms/Scalar/Scalar.cpp @@ -84,6 +84,7 @@ initializePlaceSafepointsPass(Registry); initializeFloat2IntPass(Registry); initializeLoopDistributePass(Registry); + initializeAggregateLifterPass(Registry); } void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { Index: test/Transforms/AggregateLifter/extractvalue.ll =================================================================== --- /dev/null +++ test/Transforms/AggregateLifter/extractvalue.ll @@ -0,0 +1,42 @@ +; RUN: opt -aggregate-lifter -S < %s | FileCheck %s +; RUN: opt -O3 -S < %s | FileCheck %s -check-prefix=CHECK-OPT + +target datalayout = "e-i64:64-f80:128-n8:16:32:64" +target triple = "x86_64-unknown-linux-gnu" + +%A = type { i8*, i64 } +%B = type { %A, i64 } + +@a0 = global %A zeroinitializer +@a1 = global %A zeroinitializer +@b0 = global %B zeroinitializer + +; CHECK-LABEL: @return( +; CHECK-OPT-LABEL: @return( +define i64 @return() { + %1 = load %A, %A* @a0 + %2 = extractvalue %A %1, 1 + ret i64 %2 +; CHECK: alloca %A +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-NEXT: getelementptr inbounds %A, %A* %.lifted, i32 0, i32 1 +; CHECK-NEXT: load i64, i64* +; CHECK-NEXT: ret i64 +; CHECK-OPT-NEXT: load i64, i64* getelementptr inbounds +; CHECK-OPT-NEXT: ret i64 +} + +; CHECK-LABEL: @nested( +; CHECK-OPT-LABEL: @nested( +define i64 @nested() { + %1 = load %B, %B* @b0 + %2 = extractvalue %B %1, 0 + %3 = extractvalue %A %2, 1 + ret i64 %3 +; CHECK: alloca %B +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-NEXT: getelementptr inbounds %B, %B* +; CHECK-NEXT: getelementptr inbounds %A, %A* +; CHECK-OPT-NEXT: load i64, i64* getelementptr inbounds +; CHECK-OPT-NEXT: ret i64 +} Index: test/Transforms/AggregateLifter/store.ll =================================================================== --- /dev/null +++ test/Transforms/AggregateLifter/store.ll @@ -0,0 +1,42 @@ +; RUN: opt -aggregate-lifter -S < %s | FileCheck %s +; RUN: opt -O3 -S < %s | FileCheck %s -check-prefix=CHECK-OPT + +target datalayout = "e-i64:64-f80:128-n8:16:32:64" +target triple = "x86_64-unknown-linux-gnu" + +%A = type { i8*, i64 } +%B = type { %A, i64 } + +@a0 = global %A zeroinitializer +@a1 = global %A zeroinitializer +@b0 = global %B zeroinitializer + +; CHECK-LABEL: @forward( +; CHECK-OPT-LABEL: @forward( +define void @forward() { + %1 = load %A, %A* @a0 + store %A %1, %A* @a1 + ret void +; CHECK: alloca %A +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-NEXT: ret void +; CHECK-OPT-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-OPT-NEXT: ret void +} + +; CHECK-LABEL: @forwardelement( +; CHECK-OPT-LABEL: @forwardelement( +define void @forwardelement() { + %1 = load %B, %B* @b0 + %2 = extractvalue %B %1, 0 + store %A %2, %A* @a0 + ret void +; CHECK: alloca %B +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-NEXT: getelementptr inbounds %B, %B* +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-NEXT: ret void +; CHECK-OPT-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64 +; CHECK-OPT-NEXT: ret void +}