Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -272,6 +272,7 @@ void initializeBBVectorizePass(PassRegistry&); void initializeMachineFunctionPrinterPassPass(PassRegistry&); void initializeStackMapLivenessPass(PassRegistry&); +void initializeLoadCombinePass(PassRegistry&); } #endif Index: include/llvm/Transforms/Scalar.h =================================================================== --- include/llvm/Transforms/Scalar.h +++ include/llvm/Transforms/Scalar.h @@ -19,6 +19,7 @@ namespace llvm { +class BasicBlockPass; class FunctionPass; class Pass; class GetElementPtrInst; @@ -383,6 +384,12 @@ // FunctionPass *createSeparateConstOffsetFromGEPPass(); +//===----------------------------------------------------------------------===// +// +// LoadCombine - Combine loads into bigger loads. +// +BasicBlockPass *createLoadCombinePass(); + } // End llvm namespace #endif Index: lib/Transforms/IPO/PassManagerBuilder.cpp =================================================================== --- lib/Transforms/IPO/PassManagerBuilder.cpp +++ lib/Transforms/IPO/PassManagerBuilder.cpp @@ -219,6 +219,8 @@ if (SLPVectorize) MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. + MPM.add(createLoadCombinePass()); + if (BBVectorize) { MPM.add(createBBVectorizePass()); MPM.add(createInstructionCombiningPass()); @@ -344,6 +346,8 @@ // More scalar chains could be vectorized due to more alias information PM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. + PM.add(createLoadCombinePass()); + // Cleanup and simplify the code after the scalar optimizations. PM.add(createInstructionCombiningPass()); Index: lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- lib/Transforms/Scalar/CMakeLists.txt +++ lib/Transforms/Scalar/CMakeLists.txt @@ -12,6 +12,7 @@ IndVarSimplify.cpp JumpThreading.cpp LICM.cpp + LoadCombine.cpp LoopDeletion.cpp LoopIdiomRecognize.cpp LoopInstSimplify.cpp Index: lib/Transforms/Scalar/LoadCombine.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/LoadCombine.cpp @@ -0,0 +1,272 @@ +//===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This transformation combines adjacent loads. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetFolder.h" +#include "llvm/Pass.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "load-combine" + +STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining"); +STATISTIC(NumLoadsCombined, "Number of loads combined"); + +namespace { +struct PointerOffsetPair { + const Value *Pointer; + uint64_t Offset; +}; + +struct LoadPOPPair { + LoadPOPPair(LoadInst * L, PointerOffsetPair P) : Load(L), POP(P) {} + LoadInst *Load; + PointerOffsetPair POP; +}; + +class LoadCombine : public BasicBlockPass { + LLVMContext *C; + const DataLayout *DL; + +public: + LoadCombine() + : BasicBlockPass(ID), + C(nullptr), DL(nullptr) { + initializeSROAPass(*PassRegistry::getPassRegistry()); + } + bool doInitialization(Function &) override; + bool runOnBasicBlock(BasicBlock &BB) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + + const char *getPassName() const override { return "LoadCombine"; } + static char ID; + + typedef IRBuilder BuilderTy; + +private: + BuilderTy *Builder; + + PointerOffsetPair getPointerOffsetPair(const LoadInst &); + bool combineLoads(DenseMap> &); + bool aggregateLoads(SmallVectorImpl &); + bool combineLoads(SmallVectorImpl &); +}; +} + +bool LoadCombine::doInitialization(Function &F) { + DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n"); + C = &F.getContext(); + DataLayoutPass *DLP = getAnalysisIfAvailable(); + if (!DLP) { + DEBUG(dbgs() << " Skipping LoadCombine -- no target data!\n"); + return false; + } + DL = &DLP->getDataLayout(); + return true; +} + +PointerOffsetPair LoadCombine::getPointerOffsetPair(const LoadInst &LI) { + PointerOffsetPair POP; + if (auto *GEP = dyn_cast(LI.getPointerOperand())) { + POP.Pointer = GEP->getPointerOperand(); + unsigned BitWidth = DL->getPointerTypeSizeInBits(GEP->getType()); + APInt Offset(BitWidth, 0); + if (GEP->accumulateConstantOffset(*DL, Offset)) { + POP.Offset = Offset.getZExtValue(); + return POP; + } + // Can't handle GEPs with variable indices. + POP.Pointer = nullptr; + return POP; + } + POP.Pointer = LI.getPointerOperand(); + POP.Offset = 0; + return POP; +} + +bool LoadCombine::combineLoads( + DenseMap> &LoadMap) { + bool Combined = false; + for (auto &Loads : LoadMap) { + if (Loads.second.size() < 2) + continue; + std::sort(Loads.second.begin(), Loads.second.end(), + [](const LoadPOPPair &A, const LoadPOPPair &B) { + return A.POP.Offset < B.POP.Offset; + }); + if (aggregateLoads(Loads.second)) + Combined = true; + } + return Combined; +} + +/// \brief Try to aggregate loads from a sorted list of loads to be combined. +/// +/// It is guaranteed that no writes occur between any of the loads. All loads +/// have the same base pointer. There are at least two loads. +bool LoadCombine::aggregateLoads(SmallVectorImpl &Loads) { + assert(Loads.size() >= 2 && "Insufficient loads!"); + LoadInst *BaseLoad = nullptr; + SmallVector AggregateLoads; + bool Combined = false; + uint64_t PrevOffset = -1ull; + uint64_t PrevSize = 0; + for (auto &L : Loads) { + if (PrevOffset == -1ull) { + BaseLoad = L.Load; + PrevOffset = L.POP.Offset; + PrevSize = L.Load->getType()->getPrimitiveSizeInBits() / 8; + AggregateLoads.push_back(L); + continue; + } + if (L.Load->getAlignment() != BaseLoad->getAlignment()) + continue; + if (L.POP.Offset > PrevOffset + PrevSize) { + // No other load will be combinable + if (combineLoads(AggregateLoads)) + Combined = true; + AggregateLoads.clear(); + PrevOffset = -1; + continue; + } + if (L.POP.Offset != PrevOffset + PrevSize) + // This load is offset less than the size of the last load. + // FIXME: We may want to handle this case. + continue; + PrevOffset = L.POP.Offset; + PrevSize = L.Load->getType()->getPrimitiveSizeInBits() / 8; + AggregateLoads.push_back(L); + } + if (combineLoads(AggregateLoads)) + Combined = true; + return Combined; +} + +static Value *extractInteger(const DataLayout &DL, LoadCombine::BuilderTy &IRB, + Value *V, IntegerType *Ty, uint64_t Offset, + const Twine &Name) { + DEBUG(dbgs() << " start: " << *V << "\n"); + IntegerType *IntTy = cast(V->getType()); + assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && + "Element extends past full value"); + uint64_t ShAmt = 8*Offset; + if (DL.isBigEndian()) + ShAmt = 8*(DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); + if (ShAmt) { + V = IRB.CreateLShr(V, ShAmt, Name + ".shift"); + DEBUG(dbgs() << " shifted: " << *V << "\n"); + } + assert(Ty->getBitWidth() <= IntTy->getBitWidth() && + "Cannot extract to a larger integer!"); + if (Ty != IntTy) { + V = IRB.CreateTrunc(V, Ty, Name + ".trunc"); + DEBUG(dbgs() << " trunced: " << *V << "\n"); + } + return V; +} + +/// \brief Given a list of combinable load. Combine the maximum number of them. +bool LoadCombine::combineLoads(SmallVectorImpl &Loads) { + // Remove loads from the end while the size is not a power of 2. + unsigned TotalSize = 0; + for (const auto &L : Loads) + TotalSize += L.Load->getType()->getPrimitiveSizeInBits(); + while (TotalSize != 0 && !isPowerOf2_32(TotalSize)) + TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits(); + if (Loads.size() < 2) + return false; + + DEBUG({ + dbgs() << "***** Combining Loads ******\n"; + for (const auto &L : Loads) { + dbgs() << L.POP.Offset << ": " << *L.Load << "\n"; + } + }); + + Value *Ptr = Loads[0].Load->getPointerOperand(); + LoadInst *NewLoad = new LoadInst( + CastInst::CreatePointerCast( + Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize), + Ptr->getType()->getPointerAddressSpace()), + "", Loads[0].Load), + Twine(Loads[0].Load->getName()) + ".combined", false, + Loads[0].Load->getAlignment(), Loads[0].Load); + + for (const auto &L : Loads) { + Builder->SetInsertPoint(L.Load); + Value *V = extractInteger( + *DL, *Builder, NewLoad, cast(Loads[0].Load->getType()), + L.POP.Offset - Loads[0].POP.Offset, "combine.extract"); + L.Load->replaceAllUsesWith(V); + } + + NumLoadsCombined = NumLoadsCombined + Loads.size(); + return true; +} + +bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { + if (skipOptnoneFunction(BB) || !DL) + return false; + + IRBuilder + TheBuilder(BB.getContext(), TargetFolder(DL)); + Builder = &TheBuilder; + + DenseMap> LoadMap; + + bool Combined = false; + for (auto &I : BB) { + if (I.mayWriteToMemory()) { + if (combineLoads(LoadMap)) + Combined = true; + LoadMap.clear(); + continue; + } + LoadInst *LI = dyn_cast(&I); + if (!LI) + continue; + ++NumLoadsAnalyzed; + if (!LI->isSimple() || !LI->getType()->isIntegerTy()) + continue; + auto POP = getPointerOffsetPair(*LI); + if (!POP.Pointer) + continue; + LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP)); + } + if (combineLoads(LoadMap)) + Combined = true; + return Combined; +} + +void LoadCombine::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); +} + +char LoadCombine::ID = 0; + +BasicBlockPass *llvm::createLoadCombinePass() { + return new LoadCombine(); +} + +INITIALIZE_PASS(LoadCombine, "load-combine", "Combine Adjacent Loads", false, + false) Index: lib/Transforms/Scalar/Scalar.cpp =================================================================== --- lib/Transforms/Scalar/Scalar.cpp +++ lib/Transforms/Scalar/Scalar.cpp @@ -65,6 +65,7 @@ initializeSinkingPass(Registry); initializeTailCallElimPass(Registry); initializeSeparateConstOffsetFromGEPPass(Registry); + initializeLoadCombinePass(Registry); } void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { Index: test/Transforms/LoadCombine/load-combine.ll =================================================================== --- /dev/null +++ test/Transforms/LoadCombine/load-combine.ll @@ -0,0 +1,161 @@ +; RUN: opt < %s -load-combine -instcombine -S | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +; Combine read from char* idiom. +define i64 @LoadU64_x64_0(i64* %pData) { + %1 = bitcast i64* %pData to i8* + %2 = load i8* %1, align 1 + %3 = zext i8 %2 to i64 + %4 = shl nuw i64 %3, 56 + %5 = getelementptr inbounds i8* %1, i64 1 + %6 = load i8* %5, align 1 + %7 = zext i8 %6 to i64 + %8 = shl nuw nsw i64 %7, 48 + %9 = or i64 %8, %4 + %10 = getelementptr inbounds i8* %1, i64 2 + %11 = load i8* %10, align 1 + %12 = zext i8 %11 to i64 + %13 = shl nuw nsw i64 %12, 40 + %14 = or i64 %9, %13 + %15 = getelementptr inbounds i8* %1, i64 3 + %16 = load i8* %15, align 1 + %17 = zext i8 %16 to i64 + %18 = shl nuw nsw i64 %17, 32 + %19 = or i64 %14, %18 + %20 = getelementptr inbounds i8* %1, i64 4 + %21 = load i8* %20, align 1 + %22 = zext i8 %21 to i64 + %23 = shl nuw nsw i64 %22, 24 + %24 = or i64 %19, %23 + %25 = getelementptr inbounds i8* %1, i64 5 + %26 = load i8* %25, align 1 + %27 = zext i8 %26 to i64 + %28 = shl nuw nsw i64 %27, 16 + %29 = or i64 %24, %28 + %30 = getelementptr inbounds i8* %1, i64 6 + %31 = load i8* %30, align 1 + %32 = zext i8 %31 to i64 + %33 = shl nuw nsw i64 %32, 8 + %34 = or i64 %29, %33 + %35 = getelementptr inbounds i8* %1, i64 7 + %36 = load i8* %35, align 1 + %37 = zext i8 %36 to i64 + %38 = or i64 %34, %37 + ret i64 %38 +; CHECK: load i64* +; CHECK-NOT: load +} + +; Combine simple adjacent loads. +define i32 @"2xi16_i32"(i16* %x) { + %1 = load i16* %x, align 2 + %2 = getelementptr inbounds i16* %x, i64 1 + %3 = load i16* %2, align 2 + %4 = zext i16 %3 to i32 + %5 = shl nuw i32 %4, 16 + %6 = zext i16 %1 to i32 + %7 = or i32 %5, %6 + ret i32 %7 +; CHECK: load i32* +; CHECK-NOT: load +} + +; Don't combine loads across stores. +define i32 @"2xi16_i32_store"(i16* %x, i16* %y) { + %1 = load i16* %x, align 2 + store i16 0, i16* %y, align 2 + %2 = getelementptr inbounds i16* %x, i64 1 + %3 = load i16* %2, align 2 + %4 = zext i16 %3 to i32 + %5 = shl nuw i32 %4, 16 + %6 = zext i16 %1 to i32 + %7 = or i32 %5, %6 + ret i32 %7 +; CHECK: load i16* +; CHECK: store +; CHECK: load i16* +} + +; Don't combine loads with a gap. +define i32 @"2xi16_i32_gap"(i16* %x) { + %1 = load i16* %x, align 2 + %2 = getelementptr inbounds i16* %x, i64 2 + %3 = load i16* %2, align 2 + %4 = zext i16 %3 to i32 + %5 = shl nuw i32 %4, 16 + %6 = zext i16 %1 to i32 + %7 = or i32 %5, %6 + ret i32 %7 +; CHECK: load i16* +; CHECK: load i16* +} + +; Combine out of order loads. +define i32 @"2xi16_i32_order"(i16* %x) { + %1 = getelementptr inbounds i16* %x, i64 1 + %2 = load i16* %1, align 2 + %3 = load i16* %x, align 2 + %4 = zext i16 %2 to i32 + %5 = shl nuw i32 %4, 16 + %6 = zext i16 %3 to i32 + %7 = or i32 %5, %6 + ret i32 %7 +; CHECK: load i32* +; CHECK-NOT: load +} + +; Overlapping loads. +define i32 @"2xi16_i32_overlap"(i8* %x) { + %1 = bitcast i8* %x to i16* + %2 = load i16* %1, align 2 + %3 = getelementptr inbounds i8* %x, i64 1 + %4 = bitcast i8* %3 to i16* + %5 = load i16* %4, align 2 + %6 = zext i16 %5 to i32 + %7 = shl nuw i32 %6, 16 + %8 = zext i16 %2 to i32 + %9 = or i32 %7, %8 + ret i32 %9 +; CHECK: load i16* +; CHECK: load i16* +} + +; Non power of two. +define i64 @"2xi16_i32_npo2"(i8* %x) { + %1 = load i8* %x, align 1 + %2 = zext i8 %1 to i64 + %3 = getelementptr inbounds i8* %x, i64 1 + %4 = load i8* %3, align 1 + %5 = zext i8 %4 to i64 + %6 = shl nuw nsw i64 %5, 8 + %7 = or i64 %6, %2 + %8 = getelementptr inbounds i8* %x, i64 2 + %9 = load i8* %8, align 1 + %10 = zext i8 %9 to i64 + %11 = shl nuw nsw i64 %10, 16 + %12 = or i64 %11, %7 + %13 = getelementptr inbounds i8* %x, i64 3 + %14 = load i8* %13, align 1 + %15 = zext i8 %14 to i64 + %16 = shl nuw nsw i64 %15, 24 + %17 = or i64 %16, %12 + %18 = getelementptr inbounds i8* %x, i64 4 + %19 = load i8* %18, align 1 + %20 = zext i8 %19 to i64 + %21 = shl nuw nsw i64 %20, 32 + %22 = or i64 %21, %17 + %23 = getelementptr inbounds i8* %x, i64 5 + %24 = load i8* %23, align 1 + %25 = zext i8 %24 to i64 + %26 = shl nuw nsw i64 %25, 40 + %27 = or i64 %26, %22 + %28 = getelementptr inbounds i8* %x, i64 6 + %29 = load i8* %28, align 1 + %30 = zext i8 %29 to i64 + %31 = shl nuw nsw i64 %30, 48 + %32 = or i64 %31, %27 + ret i64 %32 +; CHECK: load i32* +}