Index: lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -12,14 +12,17 @@ //===----------------------------------------------------------------------===// #include "InstCombine.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" @@ -345,6 +348,129 @@ return NewLoad; } +/// \brief Combine integer loads to vector stores when the integers bits are +/// just a concatenation of non-integer (and non-vector) types. +/// +/// This specifically matches the pattern of loading an integer, right-shifting, +/// trucating, and casting it to a non-integer type. When the shift is an exact +/// multiple of the result non-integer type's size, this is more naturally +/// expressed as a load of a vector and an extractelement. This shows up largely +/// because large integers are sometimes used to represent a "generic" load or +/// store, and only later optimization may uncover that there is a more natural +/// type to represent the load with. +static Instruction *combineIntegerLoadToVectorLoad(InstCombiner &IC, LoadInst &LI) { + // FIXME: This is probably a reasonable transform to make for atomic stores. + assert(LI.isSimple() && "Do not call for non-simple stores!"); + + const DataLayout &DL = *IC.getDataLayout(); + unsigned BaseBits = LI.getType()->getIntegerBitWidth(); + Type *ElementTy = nullptr; + int ElementSize; + + // We match any number of element extractions from the loaded integer. Each of + // these should be RAUW'ed with an actual extract element instruction at the + // given index of a loaded vector. + struct ExtractedElement { + Instruction *Element; + int Index; + }; + SmallVector Elements; + + // Lambda to match the bit cast in the extracted element (which is the root + // pattern matched). Accepts the instruction and shifted bits, returns false + // if at any point we failed to match a suitable bitcast for element + // extraction. + auto MatchBitCast = [&](Instruction *I, unsigned ShiftBits) { + // The truncate must be bitcasted to some element type. + auto *BC = dyn_cast(I); + if (!BC) + return false; + + // Check that the element type matches. + if (!ElementTy) { + ElementTy = BC->getType(); + // We don't handle integers, sub-vectors, or any aggregate types. + if (!ElementTy->isSingleValueType() || ElementTy->isIntegerTy() || + ElementTy->isVectorTy()) + return false; + + ElementSize = DL.getTypeSizeInBits(ElementTy); + // The base integer size and the shift need to be multiples of the element + // size in bits. + if (BaseBits % ElementSize || ShiftBits % ElementSize) + return false; + } else if (ElementTy != BC->getType()) { + return false; + } + + // Compute the vector index and store the element with it. + int Index = (DL.isLittleEndian() ? BaseBits - ElementSize - ShiftBits + : ShiftBits) / + ElementSize; + ExtractedElement E = {BC, Index}; + Elements.push_back(std::move(E)); + return true; + }; + + // Lambda to match the truncate in the extracted element. Accepts the + // instruction and shifted bits. Returns false if at any point we failed to + // match a suitable truncate for element extraction. + auto MatchTruncate = [&](Instruction *I, unsigned ShiftBits) { + // Handle the truncate to the bit size of the element. + auto *T = dyn_cast(I); + if (!T) + return false; + + // Walk all the users of the truncate, whuch must all be bitcasts. + for (User *TU : T->users()) + if (!MatchBitCast(cast(TU), ShiftBits)) + return false; + return true; + }; + + for (User *U : LI.users()) { + Instruction *I = cast(U); + + // Strip off a logical shift right and retain the shifted amount. + ConstantInt *ShiftC; + if (!match(I, m_LShr(m_Value(), m_ConstantInt(ShiftC)))) { + // This must be a direct truncate. + if (!MatchTruncate(I, 0)) + return nullptr; + continue; + } + + unsigned ShiftBits = ShiftC->getLimitedValue(BaseBits); + // We can't handle shifts of more than the number of bits in the integer. + if (ShiftBits == BaseBits) + return nullptr; + + // Match all the element extraction users of the shift. + for (User *IU : I->users()) + if (!MatchTruncate(cast(IU), ShiftBits)) + return nullptr; + } + + // If didn't find any extracted elements, there is nothing to do here. + if (Elements.empty()) + return nullptr; + + // Create a vector load and rewrite all of the elements extracted as + // extractelement instructions. + VectorType *VTy = VectorType::get(ElementTy, BaseBits / ElementSize); + LoadInst *NewLI = combineLoadToNewType(IC, LI, VTy); + + for (const auto &E : Elements) { + IC.Builder->SetInsertPoint(E.Element); + E.Element->replaceAllUsesWith( + IC.Builder->CreateExtractElement(NewLI, IC.Builder->getInt32(E.Index))); + IC.EraseInstFromFunction(*E.Element); + } + + // Return the original load to indicate it has been combined away. + return &LI; +} + /// \brief Combine loads to match the type of value their uses after looking /// through intervening bitcasts. /// @@ -373,6 +499,8 @@ // Fold away bit casts of the loaded value by loading the desired type. + // FIXME: We should also canonicalize loads of vectors when their elements are + // cast to other types. if (LI.hasOneUse()) if (auto *BC = dyn_cast(LI.user_back())) { LoadInst *NewLoad = combineLoadToNewType(IC, LI, BC->getDestTy()); @@ -381,8 +509,12 @@ return &LI; } - // FIXME: We should also canonicalize loads of vectors when their elements are - // cast to other types. + // Try to combine integer loads into vector loads when the integer is just + // loading a bag of bits that are casted into vector element chunks. + if (LI.getType()->isIntegerTy()) + if (Instruction *R = combineIntegerLoadToVectorLoad(IC, LI)) + return R; + return nullptr; } @@ -491,6 +623,187 @@ return nullptr; } +/// \brief Helper to combine a store to use a new value. +/// +/// This just does the work of combining a store to use a new value, potentially +/// of a different type. It handles metadata, etc., and returns the new +/// instruction. The new value is stored to a bitcast of the pointer argument to +/// the original store. +/// +/// Note that this will create the instructions with whatever insert point the +/// \c InstCombiner currently is using. +static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &OldSI, + Value *V) { + Value *Ptr = OldSI.getPointerOperand(); + unsigned AS = OldSI.getPointerAddressSpace(); + SmallVector, 8> MD; + OldSI.getAllMetadata(MD); + + StoreInst *NewSI = IC.Builder->CreateAlignedStore( + V, IC.Builder->CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), + OldSI.getAlignment()); + for (const auto &MDPair : MD) { + unsigned ID = MDPair.first; + MDNode *N = MDPair.second; + // Note, essentially every kind of metadata should be preserved here! This + // routine is supposed to clone a store instruction changing *only its + // type*. The only metadata it makes sense to drop is metadata which is + // invalidated when the pointer type changes. This should essentially + // never be the case in LLVM, but we explicitly switch over only known + // metadata to be conservatively correct. If you are adding metadata to + // LLVM which pertains to stores, you almost certainly want to add it + // here. + switch (ID) { + case LLVMContext::MD_dbg: + case LLVMContext::MD_tbaa: + case LLVMContext::MD_prof: + case LLVMContext::MD_fpmath: + case LLVMContext::MD_tbaa_struct: + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + case LLVMContext::MD_nontemporal: + case LLVMContext::MD_mem_parallel_loop_access: + case LLVMContext::MD_nonnull: + // All of these directly apply. + NewSI->setMetadata(ID, N); + break; + + case LLVMContext::MD_invariant_load: + case LLVMContext::MD_range: + break; + } + } + return NewSI; +} + + +/// \brief Combine integer stores to vector stores when the integers bits are +/// just a concatenation of non-integer (and non-vector) types. +/// +/// This specifically matches the pattern of taking a sequence of non-integer +/// types, casting them to integers, extending, shifting, and or-ing them +/// together to make a concatenation, and then storing the result. This shows up +/// because large integers are sometimes used to represent a "generic" load or +/// store, and only later optimization may uncover that there is a more natural +/// type to represent the store with. +/// +/// \returns true if the store was successfully combined away. This indicates +/// the caller must erase the store instruction. We have to let the caller erase +/// the store instruction sas otherwise there is no way to signal whether it was +/// combined or not: IC.EraseInstFromFunction returns a null pointer. +static bool combineIntegerStoreToVectorStore(InstCombiner &IC, StoreInst &SI) { + // FIXME: This is probably a reasonable transform to make for atomic stores. + assert(SI.isSimple() && "Do not call for non-simple stores!"); + + Instruction *OrigV = dyn_cast(SI.getValueOperand()); + if (!OrigV) + return false; + + // We only handle values which are used entirely to store to memory. If the + // value is used directly as an SSA value, then even if there are matching + // element insertion and element extraction, we rely on basic integer + // combining to forward the bits and delete the intermediate math. Here we + // just need to clean up the places where it actually reaches memory. + SmallVector Stores; + for (User *U : OrigV->users()) + if (auto *SIU = dyn_cast(U)) + Stores.push_back(SIU); + else + return false; + + const DataLayout &DL = *IC.getDataLayout(); + unsigned BaseBits = OrigV->getType()->getIntegerBitWidth(); + Type *ElementTy = nullptr; + int ElementSize; + + // We need to match some number of element insertions into an integer. Each + // insertion takes the form of an element value (and type), index (multiple of + // the bitwidth of the type) of insertion, and the base it was inserted into. + struct InsertedElement { + Value *Base; + Value *Element; + int Index; + }; + auto MatchInsertedElement = [&](Value *V) -> Optional { + // Handle a null input to make it easy to loop over bases. + if (!V) + return Optional(); + + assert(!V->getType()->isVectorTy() && "Must not be a vector."); + assert(V->getType()->isIntegerTy() && "Must be an integer value."); + + Value *Base = nullptr, *Element; + ConstantInt *ShiftC = nullptr; + auto InsertPattern = m_CombineOr( + m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(Element))))), + m_ConstantInt(ShiftC)), + m_ZExt(m_OneUse(m_BitCast(m_Value(Element))))); + if (!match(V, + m_CombineOr( + m_CombineOr( + m_Or(m_OneUse(m_Value(Base)), m_OneUse(InsertPattern)), + m_Or(m_OneUse(InsertPattern), m_OneUse(m_Value(Base)))), + InsertPattern))) + return Optional(); + + // We can't handle shifts wider than the number of bits in the integer. + unsigned ShiftBits = ShiftC ? ShiftC->getLimitedValue(BaseBits) : 0; + if (ShiftBits == BaseBits) + return Optional(); + + // All of the elements inserted need to be the same type. Either capture the + // first element type or check this element type against the previous + // element types. + if (!ElementTy) { + ElementTy = Element->getType(); + // The base integer size and the shift need to be multiples of the element + // size in bits. + ElementSize = DL.getTypeSizeInBits(ElementTy); + if (BaseBits % ElementSize || ShiftBits % ElementSize) + return Optional(); + } else if (ElementTy != Element->getType()) { + return Optional(); + } + + // We don't handle integers, sub-vectors, or any aggregate types. + if (!ElementTy->isSingleValueType() || ElementTy->isIntegerTy() || + ElementTy->isVectorTy()) + return Optional(); + + int Index = (DL.isLittleEndian() ? BaseBits - ElementSize - ShiftBits + : ShiftBits) / + ElementSize; + InsertedElement Result = {Base, Element, Index}; + return Result; + }; + + SmallVector Elements; + Value *V = OrigV; + while (Optional E = MatchInsertedElement(V)) { + V = E->Base; + Elements.push_back(std::move(*E)); + } + // If searching for elements found none, or didn't terminate in either an + // undef or a direct zext, we can't form a vector. + if (Elements.empty() || (V && !isa(V))) + return false; + + // Build a storable vector by looping over the inserted elements. + VectorType *VTy = VectorType::get(ElementTy, BaseBits / ElementSize); + V = UndefValue::get(VTy); + IC.Builder->SetInsertPoint(OrigV); + for (const auto &E : Elements) + V = IC.Builder->CreateInsertElement(V, E.Element, IC.Builder->getInt32(E.Index)); + + for (StoreInst *OldSI : Stores) { + IC.Builder->SetInsertPoint(OldSI); + combineStoreToNewValue(IC, *OldSI, V); + if (OldSI != &SI) + IC.EraseInstFromFunction(*OldSI); + } + return true; +} + /// \brief Combine stores to match the type of value being stored. /// /// The core idea here is that the memory does not have any intrinsic type and @@ -517,52 +830,20 @@ if (!SI.isSimple()) return false; - Value *Ptr = SI.getPointerOperand(); Value *V = SI.getValueOperand(); - unsigned AS = SI.getPointerAddressSpace(); - SmallVector, 8> MD; - SI.getAllMetadata(MD); // Fold away bit casts of the stored value by storing the original type. if (auto *BC = dyn_cast(V)) { - V = BC->getOperand(0); - StoreInst *NewStore = IC.Builder->CreateAlignedStore( - V, IC.Builder->CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), - SI.getAlignment()); - for (const auto &MDPair : MD) { - unsigned ID = MDPair.first; - MDNode *N = MDPair.second; - // Note, essentially every kind of metadata should be preserved here! This - // routine is supposed to clone a store instruction changing *only its - // type*. The only metadata it makes sense to drop is metadata which is - // invalidated when the pointer type changes. This should essentially - // never be the case in LLVM, but we explicitly switch over only known - // metadata to be conservatively correct. If you are adding metadata to - // LLVM which pertains to stores, you almost certainly want to add it - // here. - switch (ID) { - case LLVMContext::MD_dbg: - case LLVMContext::MD_tbaa: - case LLVMContext::MD_prof: - case LLVMContext::MD_fpmath: - case LLVMContext::MD_tbaa_struct: - case LLVMContext::MD_alias_scope: - case LLVMContext::MD_noalias: - case LLVMContext::MD_nontemporal: - case LLVMContext::MD_mem_parallel_loop_access: - case LLVMContext::MD_nonnull: - // All of these directly apply. - NewStore->setMetadata(ID, N); - break; - - case LLVMContext::MD_invariant_load: - case LLVMContext::MD_range: - break; - } - } + combineStoreToNewValue(IC, SI, BC->getOperand(0)); return true; } + // If this is an integer store and we have data layout, look for a pattern of + // storing a vector as an integer (modeled as a bag of bits). + if (V->getType()->isIntegerTy() && IC.getDataLayout() && + combineIntegerStoreToVectorStore(IC, SI)) + return true; + // FIXME: We should also canonicalize loads of vectors when their elements are // cast to other types. return false; Index: test/Transforms/InstCombine/loadstore-vector.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/loadstore-vector.ll @@ -0,0 +1,159 @@ +; RUN: opt -instcombine -S < %s | FileCheck %s +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +; Basic test for turning element extraction from integer loads and element +; insertion into integer stores into extraction and insertion with vectors. +define void @test1({ float, float }* %x, float %a, float %b, { float, float }* %out) { +; CHECK-LABEL: @test1( +entry: + %x.cast = bitcast { float, float }* %x to i64* + %x.load = load i64* %x.cast, align 4 +; CHECK-NOT: load i64* +; CHECK: load <2 x float>* + + %l1.sroa.0.0.extract.trunc = trunc i64 %x.load to i32 + %l1.sroa.4.0.extract.shift = lshr i64 %x.load, 32 + %l1.sroa.4.0.extract.trunc = trunc i64 %l1.sroa.4.0.extract.shift to i32 + %hi.cast = bitcast i32 %l1.sroa.4.0.extract.trunc to float + %lo.cast = bitcast i32 %l1.sroa.0.0.extract.trunc to float +; CHECK-NOT: trunc +; CHECK-NOT: lshr +; CHECK: extractelement <2 x float> +; CHECK-NEXT: extractelement <2 x float> + + %add.i.i = fadd float %lo.cast, %a + %add5.i.i = fadd float %hi.cast, %b + + %add.lo.cast = bitcast float %add.i.i to i32 + %add.hi.cast = bitcast float %add5.i.i to i32 + %l1.sroa.4.0.insert.ext = zext i32 %add.hi.cast to i64 + %l1.sroa.4.0.insert.shift = shl nuw i64 %l1.sroa.4.0.insert.ext, 32 + %l1.sroa.0.0.insert.ext = zext i32 %add.lo.cast to i64 + %l1.sroa.0.0.insert.insert = or i64 %l1.sroa.4.0.insert.shift, %l1.sroa.0.0.insert.ext +; CHECK-NOT: zext i32 +; CHECK-NOT: shl {{.*}} i64 +; CHECK-NOT: or i64 +; CHECK: insertelement <2 x float> +; CHECK-NEXT: insertelement <2 x float> + + %out.cast = bitcast { float, float }* %out to i64* + store i64 %l1.sroa.0.0.insert.insert, i64* %out.cast, align 4 +; CHECK-NOT: store i64 +; CHECK: store <2 x float> + + ret void +} + +define void @test2({ float, float }* %x, float %a, float %b, { float, float }* %out1, { float, float }* %out2) { +; CHECK-LABEL: @test2( +entry: + %x.cast = bitcast { float, float }* %x to i64* + %x.load = load i64* %x.cast, align 4 +; CHECK-NOT: load i64* +; CHECK: load <2 x float>* + + %l1.sroa.0.0.extract.trunc = trunc i64 %x.load to i32 + %l1.sroa.4.0.extract.shift = lshr i64 %x.load, 32 + %l1.sroa.4.0.extract.trunc = trunc i64 %l1.sroa.4.0.extract.shift to i32 + %hi.cast = bitcast i32 %l1.sroa.4.0.extract.trunc to float + %lo.cast = bitcast i32 %l1.sroa.0.0.extract.trunc to float +; CHECK-NOT: trunc +; CHECK-NOT: lshr +; CHECK: extractelement <2 x float> +; CHECK-NEXT: extractelement <2 x float> + + %add.i.i = fadd float %lo.cast, %a + %add5.i.i = fadd float %hi.cast, %b + + %add.lo.cast = bitcast float %add.i.i to i32 + %add.hi.cast = bitcast float %add5.i.i to i32 + %l1.sroa.4.0.insert.ext = zext i32 %add.hi.cast to i64 + %l1.sroa.4.0.insert.shift = shl nuw i64 %l1.sroa.4.0.insert.ext, 32 + %l1.sroa.0.0.insert.ext = zext i32 %add.lo.cast to i64 + %l1.sroa.0.0.insert.insert = or i64 %l1.sroa.4.0.insert.shift, %l1.sroa.0.0.insert.ext +; CHECK-NOT: zext i32 +; CHECK-NOT: shl {{.*}} i64 +; CHECK-NOT: or i64 +; CHECK: insertelement <2 x float> +; CHECK-NEXT: insertelement <2 x float> + + %out1.cast = bitcast { float, float }* %out1 to i64* + store i64 %l1.sroa.0.0.insert.insert, i64* %out1.cast, align 4 + %out2.cast = bitcast { float, float }* %out2 to i64* + store i64 %l1.sroa.0.0.insert.insert, i64* %out2.cast, align 4 +; CHECK-NOT: store i64 +; CHECK: store <2 x float> +; CHECK-NOT: store i64 +; CHECK: store <2 x float> + + ret void +} + +; We handle some cases where there is partial CSE but not complete CSE of +; repeated insertion and extraction. Currently, we don't catch the store side +; yet because it would require extreme heroics to match this reliably. +define void @test3({ float, float, float }* %x, float %a, float %b, { float, float, float }* %out1, { float, float, float }* %out2) { +; CHECK-LABEL: @test3( +entry: + %x.cast = bitcast { float, float, float }* %x to i96* + %x.load = load i96* %x.cast, align 4 +; CHECK-NOT: load i96* +; CHECK: load <3 x float>* + + %lo.trunc = trunc i96 %x.load to i32 + %lo.cast = bitcast i32 %lo.trunc to float + %mid.shift = lshr i96 %x.load, 32 + %mid.trunc = trunc i96 %mid.shift to i32 + %mid.cast = bitcast i32 %mid.trunc to float + %mid.trunc2 = trunc i96 %mid.shift to i32 + %mid.cast2 = bitcast i32 %mid.trunc2 to float + %hi.shift = lshr i96 %mid.shift, 32 + %hi.trunc = trunc i96 %hi.shift to i32 + %hi.cast = bitcast i32 %hi.trunc to float +; CHECK-NOT: trunc +; CHECK-NOT: lshr +; CHECK: extractelement <3 x float> {{.*}}, i32 2 +; CHECK-NEXT: extractelement <3 x float> {{.*}}, i32 1 +; CHECK-NEXT: extractelement <3 x float> {{.*}}, i32 1 +; CHECK-NEXT: extractelement <3 x float> {{.*}}, i32 0 + + %add.lo = fadd float %lo.cast, %a + %add.mid = fadd float %mid.cast, %b + %add.hi = fadd float %hi.cast, %mid.cast2 + + %add.lo.cast = bitcast float %add.lo to i32 + %add.mid.cast = bitcast float %add.mid to i32 + %add.hi.cast = bitcast float %add.hi to i32 + %result.hi.ext = zext i32 %add.hi.cast to i96 + %result.hi.shift = shl nuw i96 %result.hi.ext, 32 + %result.mid.ext = zext i32 %add.mid.cast to i96 + %result.mid.or = or i96 %result.hi.shift, %result.mid.ext + %result.mid.shift = shl nuw i96 %result.mid.or, 32 + %result.lo.ext = zext i32 %add.lo.cast to i96 + %result.lo.or = or i96 %result.mid.shift, %result.lo.ext +; FIXME-NOT: zext i32 +; FIXME-NOT: shl {{.*}} i64 +; FIXME-NOT: or i64 +; FIXME: insertelement <3 x float> {{.*}}, i32 2 +; FIXME-NEXT: insertelement <3 x float> {{.*}}, i32 1 +; FIXME-NEXT: insertelement <3 x float> {{.*}}, i32 0 + + %out1.cast = bitcast { float, float, float }* %out1 to i96* + store i96 %result.lo.or, i96* %out1.cast, align 4 +; FIXME-NOT: store i96 +; FIXME: store <3 x float> + + %result2.lo.ext = zext i32 %add.lo.cast to i96 + %result2.lo.or = or i96 %result.mid.shift, %result2.lo.ext +; FIXME-NOT: zext i32 +; FIXME-NOT: shl {{.*}} i64 +; FIXME-NOT: or i64 +; FIXME: insertelement <3 x float> {{.*}}, i32 0 + + %out2.cast = bitcast { float, float, float }* %out2 to i96* + store i96 %result2.lo.or, i96* %out2.cast, align 4 +; FIXME-NOT: store i96 +; FIXME: store <3 x float> + + ret void +}