Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -294,6 +294,7 @@ void initializePlaceBackedgeSafepointsImplPass(PassRegistry&); void initializePlaceSafepointsPass(PassRegistry&); void initializeDwarfEHPreparePass(PassRegistry&); +void initializeAggregateMemAccessRemovalPass(PassRegistry&); } #endif Index: include/llvm/Transforms/Scalar.h =================================================================== --- include/llvm/Transforms/Scalar.h +++ include/llvm/Transforms/Scalar.h @@ -440,6 +440,12 @@ // FunctionPass *createRewriteStatepointsForGCPass(); +//===----------------------------------------------------------------------===// +// +// AggregateMemAccessRemoval - Convert aggregate load/store into +// scalar load/store +FunctionPass *createAggregateMemAccessRemovalPass(); + } // End llvm namespace #endif Index: lib/Transforms/IPO/PassManagerBuilder.cpp =================================================================== --- lib/Transforms/IPO/PassManagerBuilder.cpp +++ lib/Transforms/IPO/PassManagerBuilder.cpp @@ -186,6 +186,9 @@ // Add LibraryInfo if we have some. if (LibraryInfo) MPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); + + // Remove aggregate load/store + MPM.add(createAggregateMemAccessRemovalPass()); addInitialAliasAnalysisPasses(MPM); Index: lib/Transforms/Scalar/AggregateMemAccessRemoval.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/AggregateMemAccessRemoval.cpp @@ -0,0 +1,170 @@ +//===- AggregateMemAccessRemoval.cpp - Remove aggregate loads/stores ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This transforms removes load ans stores of aggregate type by replacing them +/// with equivalent scalar loads and stores. This will allow subsequent passes +/// to optimize them properly. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" + +#include "llvm/ADT/Statistic.h" +#include "llvm/Pass.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "aggregate-removal" + +namespace { + +class AggregateMemAccessRemoval : public FunctionPass { +public: + AggregateMemAccessRemoval(): FunctionPass(ID) { } + + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + + const char *getPassName() const override { return "AggregateMemAccessRemoval"; } + static char ID; +private: + void runOnLoad(LoadInst* LI, const DataLayout *DL, + SmallVector &InstrsToErase); + void runOnStore(StoreInst* SI, const DataLayout *DL, + SmallVector &InstrsToErase); +}; +} + +bool AggregateMemAccessRemoval::runOnFunction(Function &F) { + if (skipOptnoneFunction(F)) + return false; + + DEBUG(dbgs() << "AggregateMemAccessRemoval function: " << F.getName() << "\n"); + + const DataLayout *DL = &F.getParent()->getDataLayout(); + if (!DL) { + DEBUG(dbgs() << "\tSkipping AggregateMemAccessRemoval -- no data layout!\n"); + return false; + } + + SmallVector InstrsToErase; + + for (auto &I : inst_range(F)) { + LoadInst* LI = dyn_cast(&I); + if (LI && LI->isUnordered()) { + runOnLoad(LI, DL, InstrsToErase); + continue; + } + + StoreInst* SI = dyn_cast(&I); + if (SI && SI->isUnordered()) { + runOnStore(SI, DL, InstrsToErase); + continue; + } + } + + for (auto *&I : InstrsToErase) { + I->eraseFromParent(); + } + + return !InstrsToErase.empty(); +} + +void AggregateMemAccessRemoval::runOnLoad(LoadInst* LI, const DataLayout *DL, + SmallVector &InstrsToErase) { + Type* T = LI->getType(); + if (!T->isAggregateType()) + return; + + DEBUG(dbgs() << "\tload : " << *LI << "\n"); + + IRBuilder<> Builder(LI->getParent(), LI); + if (StructType *ST = dyn_cast(T)) { + if (ST->isOpaque()) + return; + + // If the struct only have one element, we unpack. + if (ST->getNumElements() == 1) { + unsigned Align = LI->getAlignment(); + Value* Addr = Builder.CreateStructGEP(LI->getPointerOperand(), 0); + + // TODO: preserve metadatas. + LoadInst* NewLI = Builder.CreateAlignedLoad(Addr, Align); + + assert(DL->getTypeStoreSize(ST) == DL->getTypeStoreSize(NewLI->getType())); + + Value* V = UndefValue::get(T); + V = Builder.CreateInsertValue(V, NewLI, 0); + LI->replaceAllUsesWith(V); + + InstrsToErase.push_back(LI); + + // Recursion is limited by how nested the type structure is. + // We remove one level of the aggregate each time we recurse. + runOnLoad(NewLI, DL, InstrsToErase); + } + } +} + +void AggregateMemAccessRemoval::runOnStore(StoreInst* SI, const DataLayout *DL, + SmallVector &InstrsToErase) { + Value* V = SI->getValueOperand(); + Type* T = V->getType(); + if (!T->isAggregateType()) + return; + + DEBUG(dbgs() << "\tstore : " << *SI << "\n"); + + IRBuilder<> Builder(SI->getParent(), SI); + if (StructType* ST = dyn_cast(T)) { + if (ST->isOpaque()) + return; + + // If the struct only have one element, we unpack. + if (ST->getNumElements() == 1) { + unsigned Align = SI->getAlignment(); + Value* Addr = Builder.CreateStructGEP(SI->getPointerOperand(), 0); + Value* NewV = Builder.CreateExtractValue(V, 0); + + assert(DL->getTypeStoreSize(ST) == DL->getTypeStoreSize(NewV->getType())); + + // TODO: preserve metadatas. + StoreInst* NewSI = Builder.CreateAlignedStore(NewV, Addr, Align); + + InstrsToErase.push_back(SI); + + // Recursion is limited by how nested the type structure is. + // We remove one level of the aggregate each time we recurse. + runOnStore(NewSI, DL, InstrsToErase); + } + } +} + +void AggregateMemAccessRemoval::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); +} + +char AggregateMemAccessRemoval::ID = 0; + +FunctionPass *llvm::createAggregateMemAccessRemovalPass() { + return new AggregateMemAccessRemoval(); +} + +INITIALIZE_PASS(AggregateMemAccessRemoval, "aggregate-removal", + "Transform aggregate store/load into integrals.", + false, false) + Index: lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- lib/Transforms/Scalar/CMakeLists.txt +++ lib/Transforms/Scalar/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(LLVMScalarOpts ADCE.cpp + AggregateMemAccessRemoval.cpp AlignmentFromAssumptions.cpp BDCE.cpp ConstantHoisting.cpp Index: lib/Transforms/Scalar/Scalar.cpp =================================================================== --- lib/Transforms/Scalar/Scalar.cpp +++ lib/Transforms/Scalar/Scalar.cpp @@ -76,6 +76,7 @@ initializeLoadCombinePass(Registry); initializePlaceBackedgeSafepointsImplPass(Registry); initializePlaceSafepointsPass(Registry); + initializeAggregateMemAccessRemovalPass(Registry); } void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { Index: test/Transforms/AggregateMemAccessRemoval/load.ll =================================================================== --- /dev/null +++ test/Transforms/AggregateMemAccessRemoval/load.ll @@ -0,0 +1,19 @@ +; RUN: opt -aggregate-removal -S < %s | FileCheck %s + +target datalayout = "e-i64:64-f80:128-n8:16:32:64" +target triple = "x86_64-unknown-linux-gnu" + +%A__vtbl = type { i8*, i32 (%A*)* } +%A = type { %A__vtbl* } + +declare i8* @allocmemory(i64) + +define %A @structA() { +body: + %0 = tail call i8* @allocmemory(i64 32) + %1 = bitcast i8* %0 to %A* +; CHECK: load %A__vtbl*, +; CHECK: insertvalue %A undef, %A__vtbl* {{.*}}, 0 + %2 = load %A, %A* %1, align 8 + ret %A %2 +} Index: test/Transforms/AggregateMemAccessRemoval/store.ll =================================================================== --- /dev/null +++ test/Transforms/AggregateMemAccessRemoval/store.ll @@ -0,0 +1,22 @@ +; RUN: opt -aggregate-removal -S < %s | FileCheck %s + +target datalayout = "e-i64:64-f80:128-n8:16:32:64" +target triple = "x86_64-unknown-linux-gnu" + +%A__vtbl = type { i8*, i32 (%A*)* } +%A = type { %A__vtbl* } + +@A__vtblZ = constant %A__vtbl { i8* null, i32 (%A*)* @A.foo } + +declare i32 @A.foo(%A* nocapture %this) + +declare i8* @allocmemory(i64) + +define void @structs() { +body: + %0 = tail call i8* @allocmemory(i64 32) + %1 = bitcast i8* %0 to %A* +; CHECK: store %A__vtbl* @A__vtblZ + store %A { %A__vtbl* @A__vtblZ }, %A* %1, align 8 + ret void +}