Index: llvm/lib/Transforms/Scalar/SROA.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SROA.cpp +++ llvm/lib/Transforms/Scalar/SROA.cpp @@ -3205,50 +3205,123 @@ /// This pass aggressively rewrites all aggregate loads and stores on /// a particular pointer (or any pointer derived from it which we can identify) /// with scalar loads and stores. -class AggLoadStoreRewriter : public InstVisitor { +/// +/// It also rewrites integer loads and stores which covers multiple fields in +/// original aggregate. If we have following structure type +/// +/// {i32, {i32, i32}} +/// +/// %ptrval = ptrtoint %struct.ptr* %ptr to i64 +/// %ptrval2 = add i64 %ptrval, 4 +/// %ptr1 = inttoptr i64 %ptrval to i64* +/// %ptr2 = inttoptr i64 %ptrval2 to i64* +/// %val1 = load i64, i64* ptr1, align 4 +/// %val2 = load i64, i64* ptr2, align 4 +/// +/// The first 64-bit load will be rewritten to 2 32-bit loads because it +/// actually access 2 fields in the original aggregate, and the two fields +/// don't belong to the same inner structure. +/// +/// The second load won't be rewritten because all fields accessed by the load +/// belong to the same inner structure, it's a common case in LLVM IR. +class AggLoadStoreRewriter : public PtrUseVisitor { // Befriend the base class so it can delegate to private visit methods. - friend class InstVisitor; + friend class PtrUseVisitor; + using Base = PtrUseVisitor; - /// Queue of pointer uses to analyze and potentially rewrite. - SmallVector Queue; + /// Fields offset of the aggregate. + SmallVector FieldOffset; - /// Set to prevent us from cycling with phi nodes and loops. - SmallPtrSet Visited; + /// Offset to field number in the aggregate. + SmallDenseMap Offset2Number; - /// The current pointer use being rewritten. This is used to dig up the used - /// value (as opposed to the user). - Use *U = nullptr; + /// The integer types with the same width as corresponding inner fields. + SmallVector InnerTypes; - /// Used to calculate offsets, and hence alignment, of subobjects. - const DataLayout &DL; + /// This structure is used to uniquely identify the parent structure of an + /// inner field. + struct ParentInfo { + Type *ParentTy; + uint64_t Offset; // The offset of parent structure in the aggregate. + ParentInfo(Type *P, uint64_t O): ParentTy(P), Offset(O) {} + }; + + SmallVector Parents; + + /// We try to split integer load/store if it covers multiple fields in + /// original struct. But if the aggregate is too large we will stop splitting. + bool SplitInteger; + + /// IR is changed. + bool Changed; public: - AggLoadStoreRewriter(const DataLayout &DL) : DL(DL) {} + AggLoadStoreRewriter(const DataLayout &DL) : PtrUseVisitor(DL) {} /// Rewrite loads and stores through a pointer and all pointers derived from /// it. bool rewrite(Instruction &I) { LLVM_DEBUG(dbgs() << " Rewriting FCA loads and stores...\n"); - enqueueUsers(I); - bool Changed = false; - while (!Queue.empty()) { - U = Queue.pop_back_val(); - Changed |= visit(cast(U->getUser())); - } + SplitInteger = true; + extractAllocaFields(I); + Changed = false; + visitPtr(I); return Changed; } private: - /// Enqueue all the users of the given instruction for further processing. - /// This uses a set to de-duplicate users. - void enqueueUsers(Instruction &I) { - for (Use &U : I.uses()) - if (Visited.insert(U.getUser()).second) - Queue.push_back(&U); + /// Helper function of extractAllocaFields. Recursively extracts field + /// information of the specified type. + void extractTypeFields(Type *Ty, uint64_t BaseOffset, + Type *ParentTy, uint64_t ParentOffset) { + if (ArrayType *ATy = dyn_cast(Ty)) { + if (DL.getTypeStoreSizeInBits(Ty) <= 128) { + uint64_t ElementSize = DL.getTypeAllocSize(ATy->getElementType()); + for (unsigned Idx = 0, Size = ATy->getNumElements(); Idx != Size; ++Idx) + extractTypeFields(ATy->getElementType(), + BaseOffset + Idx * ElementSize, ATy, BaseOffset); + return; + } + } + + if (StructType *STy = dyn_cast(Ty)) { + const StructLayout *SL = DL.getStructLayout(STy); + for (unsigned Idx = 0, Size = STy->getNumElements(); Idx != Size; ++Idx) + extractTypeFields(STy->getElementType(Idx), + BaseOffset + SL->getElementOffset(Idx), + STy, BaseOffset); + return; + } + + uint64_t BitSize = DL.getTypeStoreSizeInBits(Ty); + if (BitSize > IntegerType::MAX_INT_BITS) { + SplitInteger = false; + return; + } + + Offset2Number.insert(std::make_pair(BaseOffset, FieldOffset.size())); + FieldOffset.push_back(BaseOffset); + Type* IntTy = Type::getIntNTy(Ty->getContext(), BitSize); + InnerTypes.push_back(IntTy); + Parents.push_back(ParentInfo(ParentTy, ParentOffset)); + return; } - // Conservative default is to not rewrite anything. - bool visitInstruction(Instruction &I) { return false; } + /// Extract fields information of an alloca object. + void extractAllocaFields(Instruction &I) { + FieldOffset.clear(); + InnerTypes.clear(); + Offset2Number.clear(); + Parents.clear(); + + Type *AllocaTy = dyn_cast(&I)->getAllocatedType(); + extractTypeFields(AllocaTy, 0, nullptr, 0); + + uint64_t AllocSize = DL.getTypeAllocSize(AllocaTy); + Offset2Number.insert(std::make_pair(AllocSize, FieldOffset.size())); + FieldOffset.push_back(AllocSize); + InnerTypes.push_back(nullptr); + } /// Generic recursive split emission class. template class OpSplitter { @@ -3278,12 +3351,15 @@ /// alignments. const DataLayout &DL; + const AggLoadStoreRewriter &Rewriter; + /// Initialize the splitter with an insertion point, Ptr and start with a /// single zero GEP index. OpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy, - Align BaseAlign, const DataLayout &DL) + Align BaseAlign, const DataLayout &DL, + const AggLoadStoreRewriter &Rewriter) : IRB(InsertionPoint), GEPIndices(1, IRB.getInt32(0)), Ptr(Ptr), - BaseTy(BaseTy), BaseAlign(BaseAlign), DL(DL) {} + BaseTy(BaseTy), BaseAlign(BaseAlign), DL(DL), Rewriter(Rewriter) {} public: /// Generic recursive split emission routine. @@ -3301,9 +3377,15 @@ /// whether this is splitting a load or a store respectively. void emitSplitOps(Type *Ty, Value *&Agg, const Twine &Name) { if (Ty->isSingleValueType()) { - unsigned Offset = DL.getIndexedOffsetInType(BaseTy, GEPIndices); - return static_cast(this)->emitFunc( - Ty, Agg, commonAlignment(BaseAlign, Offset), Name); + if (Indices.size()) { + unsigned Offset = DL.getIndexedOffsetInType(BaseTy, GEPIndices); + return static_cast(this)->emitFunc( + Ty, Agg, commonAlignment(BaseAlign, Offset), Name); + } + + // Integer type covers multiple fields in original structure. + assert(Ty->isIntegerTy() && "Ty should be integer type"); + return static_cast(this)->emitIntegerFunc(Agg, Name); } if (ArrayType *ATy = dyn_cast(Ty)) { @@ -3344,9 +3426,10 @@ AAMDNodes AATags; LoadOpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy, - AAMDNodes AATags, Align BaseAlign, const DataLayout &DL) + AAMDNodes AATags, Align BaseAlign, const DataLayout &DL, + const AggLoadStoreRewriter &Rewriter) : OpSplitter(InsertionPoint, Ptr, BaseTy, BaseAlign, - DL), + DL, Rewriter), AATags(AATags) {} /// Emit a leaf load of a single value. This is called at the leaves of the @@ -3363,31 +3446,69 @@ Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert"); LLVM_DEBUG(dbgs() << " to: " << *Load << "\n"); } + + /// Split an integer load into multiple smaller loads according to + /// underlying aggregate fields. + void emitIntegerFunc(Value *&V, const Twine &Name) { + uint64_t BeginOffset = Rewriter.Offset.getZExtValue(); + uint64_t EndOffset = BeginOffset + DL.getTypeStoreSize(V->getType()); + int BeginIdx = Rewriter.Offset2Number.find(BeginOffset)->second; + int EndIdx = Rewriter.Offset2Number.find(EndOffset)->second; + auto It = Rewriter.InnerTypes.begin(); + Type *AggTy = StructType::get(V->getType()->getContext(), + makeArrayRef(It+BeginIdx, It+EndIdx)); + Type *PtrTy = AggTy->getPointerTo( + Ptr->getType()->getPointerAddressSpace()); + Ptr = IRB.CreateBitCast(Ptr, PtrTy, Name + ".bitcast"); + + for (int i=BeginIdx; i < EndIdx; i++) { + unsigned InsideOffset = Rewriter.FieldOffset[i] - BeginOffset; + GEPIndices.push_back(IRB.getInt32(i - BeginIdx)); + Value *InBoundsGEP = + IRB.CreateInBoundsGEP(AggTy, Ptr, GEPIndices, Name + ".gep"); + IntegerType *IntTy = cast(*(It+i)); + LoadInst *Load = + IRB.CreateAlignedLoad(IntTy, InBoundsGEP, + commonAlignment(BaseAlign, InsideOffset), + Name + ".load"); + if (AATags) + Load->setAAMetadata(AATags); + V = insertInteger(DL, IRB, V, Load, InsideOffset, Name); + GEPIndices.pop_back(); + LLVM_DEBUG(dbgs() << " to: " << *Load << "\n"); + } + } }; - bool visitLoadInst(LoadInst &LI) { +public: + void visitLoadInst(LoadInst &LI) { assert(LI.getPointerOperand() == *U); - if (!LI.isSimple() || LI.getType()->isSingleValueType()) - return false; + if (!LI.isSimple()) + return; + if (LI.getType()->isSingleValueType() && !isSplittableType(LI.getType())) + return; // We have an aggregate being loaded, split it apart. LLVM_DEBUG(dbgs() << " original: " << LI << "\n"); AAMDNodes AATags; LI.getAAMetadata(AATags); LoadOpSplitter Splitter(&LI, *U, LI.getType(), AATags, - getAdjustedAlignment(&LI, 0, DL), DL); + getAdjustedAlignment(&LI, 0, DL), DL, *this); Value *V = UndefValue::get(LI.getType()); Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca"); LI.replaceAllUsesWith(V); LI.eraseFromParent(); - return true; + Changed = true; + return; } +private: struct StoreOpSplitter : public OpSplitter { StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy, - AAMDNodes AATags, Align BaseAlign, const DataLayout &DL) + AAMDNodes AATags, Align BaseAlign, const DataLayout &DL, + const AggLoadStoreRewriter &Rewriter) : OpSplitter(InsertionPoint, Ptr, BaseTy, BaseAlign, - DL), + DL, Rewriter), AATags(AATags) {} AAMDNodes AATags; /// Emit a leaf store of a single value. This is called at the leaves of the @@ -3408,48 +3529,86 @@ Store->setAAMetadata(AATags); LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); } + + /// Split an integer store into multiple smaller stores according to + /// underlying aggregate fields. + void emitIntegerFunc(Value *&V, const Twine &Name) { + uint64_t BeginOffset = Rewriter.Offset.getZExtValue(); + uint64_t EndOffset = BeginOffset + DL.getTypeStoreSize(V->getType()); + int BeginIdx = Rewriter.Offset2Number.find(BeginOffset)->second; + int EndIdx = Rewriter.Offset2Number.find(EndOffset)->second; + auto It = Rewriter.InnerTypes.begin(); + Type *AggTy = StructType::get(V->getType()->getContext(), + makeArrayRef(It+BeginIdx, It+EndIdx)); + Type *PtrTy = AggTy->getPointerTo( + Ptr->getType()->getPointerAddressSpace()); + Ptr = IRB.CreateBitCast(Ptr, PtrTy, Name + ".bitcast"); + + for (int i=BeginIdx; i < EndIdx; i++) { + unsigned InsideOffset = Rewriter.FieldOffset[i] - BeginOffset; + IntegerType *IntTy = cast(*(It+i)); + Value *SubV = extractInteger(DL, IRB, V, IntTy, InsideOffset, Name); + GEPIndices.push_back(IRB.getInt32(i - BeginIdx)); + Value *InBoundsGEP = + IRB.CreateInBoundsGEP(AggTy, Ptr, GEPIndices, Name + ".gep"); + StoreInst *Store = + IRB.CreateAlignedStore(SubV, InBoundsGEP, + commonAlignment(BaseAlign, InsideOffset)); + if (AATags) + Store->setAAMetadata(AATags); + LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); + GEPIndices.pop_back(); + } + } }; - bool visitStoreInst(StoreInst &SI) { +public: + void visitStoreInst(StoreInst &SI) { + Base::visitStoreInst(SI); if (!SI.isSimple() || SI.getPointerOperand() != *U) - return false; + return; Value *V = SI.getValueOperand(); - if (V->getType()->isSingleValueType()) - return false; + if (V->getType()->isSingleValueType() && !isSplittableType(V->getType())) + return; // We have an aggregate being stored, split it apart. LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); AAMDNodes AATags; SI.getAAMetadata(AATags); StoreOpSplitter Splitter(&SI, *U, V->getType(), AATags, - getAdjustedAlignment(&SI, 0, DL), DL); + getAdjustedAlignment(&SI, 0, DL), DL, *this); Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); SI.eraseFromParent(); - return true; - } - - bool visitBitCastInst(BitCastInst &BC) { - enqueueUsers(BC); - return false; - } - - bool visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) { - enqueueUsers(ASC); - return false; - } - - bool visitGetElementPtrInst(GetElementPtrInst &GEPI) { - enqueueUsers(GEPI); - return false; + Changed = true; + return; } - bool visitPHINode(PHINode &PN) { + void visitPHINode(PHINode &PN) { enqueueUsers(PN); - return false; + return; } - bool visitSelectInst(SelectInst &SI) { + void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); + return; + } + + /// If an integer type covers multiple original strucutre fields, we can still + /// split it. + bool isSplittableType(Type *Ty) { + if (!Ty->isIntegerTy() || !IsOffsetKnown || !SplitInteger) + return false; + uint64_t BeginOffset = Offset.getZExtValue(); + uint64_t EndOffset = BeginOffset + DL.getTypeStoreSize(Ty); + if (Offset2Number.count(BeginOffset) && Offset2Number.count(EndOffset)) { + int BeginIdx = Offset2Number[BeginOffset]; + int EndIdx = Offset2Number[EndOffset]; + ParentInfo BeginP = Parents[BeginIdx]; + ParentInfo EndP = Parents[EndIdx - 1]; + if ((BeginP.ParentTy == EndP.ParentTy) && (BeginP.Offset == EndP.Offset)) + return false; + return EndIdx > (BeginIdx + 1); + } return false; } }; Index: llvm/test/Transforms/SROA/split-integer.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SROA/split-integer.ll @@ -0,0 +1,45 @@ +; RUN: opt < %s -sroa -S | FileCheck %s + + +%inner = type { i32, i32 } +%outer = type { i8, %inner } + +; CHECK-LABEL: @foo +; CHECK-NOT: alloca +; CHECK-NOT: store +; CHECK-NOT: load + +define i64 @foo() { +entry: + %tmpstruct1 = alloca %outer, align 8 + %tmpstruct2 = alloca %outer, align 8 + %ptr1 = getelementptr inbounds %outer, %outer* %tmpstruct2, i64 0, i32 0 + store i8 0, i8* %ptr1, align 8 + %innerptr = getelementptr inbounds %outer, %outer* %tmpstruct2, i64 0, i32 1 + %ptr2 = bitcast %inner* %innerptr to i64* + store i64 4, i64* %ptr2, align 4 + %altptr = bitcast %outer* %tmpstruct2 to i64* + %split = load i64, i64* %altptr, align 8 + %construct1 = insertvalue { i64, i32 } undef, i64 %split, 0 + %construct2 = insertvalue { i64, i32 } %construct1, i32 0, 1 + %first64 = extractvalue { i64, i32 } %construct2, 0 + %last32 = extractvalue { i64, i32 } %construct2, 1 + %tmpptr = bitcast %outer* %tmpstruct1 to i64* + store i64 %first64, i64* %tmpptr + %lastptr = getelementptr inbounds %outer, %outer* %tmpstruct1, i64 0, i32 1, i32 1 + store i32 %last32, i32* %lastptr, align 8 + %flagptr = getelementptr inbounds %outer, %outer* %tmpstruct1, i64 0, i32 0 + %flag = load i8, i8* %flagptr, align 8 + %structptr = getelementptr inbounds %outer, %outer* %tmpstruct1, i64 0, i32 1 + %valptr = bitcast %inner* %structptr to i64* + %value = load i64, i64* %valptr, align 4 + %cond = icmp eq i8 %flag, 0 + br i1 %cond, label %true, label %exit + +exit: + %retv = phi i64 [ 4, %true ], [ %value, %entry ] + ret i64 %retv + +true: + br label %exit +}