Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -463,23 +463,25 @@ SDValue getMergeStoreChains(SmallVectorImpl &StoreNodes, unsigned NumStores); - /// This is a helper function for MergeConsecutiveStores. When the source - /// elements of the consecutive stores are all constants or all extracted - /// vector elements, try to merge them into one larger store. + /// This is a helper function for MergeConsecutiveStores. When the + /// source elements of the consecutive stores are all constants or + /// all extracted vector elements, try to merge them into one + /// larger store introducing bitcasts if necessary. /// \return True if a merged store was created. bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl &StoreNodes, EVT MemVT, unsigned NumStores, bool IsConstantSrc, bool UseVector, bool UseTrunc); - /// This is a helper function for MergeConsecutiveStores. - /// Stores that may be merged are placed in StoreNodes. + /// This is a helper function for MergeConsecutiveStores. Stores + /// that may potentially merged with St are placed in + /// StoreNodes. void getStoreMergeCandidates(StoreSDNode *St, SmallVectorImpl &StoreNodes); /// Helper function for MergeConsecutiveStores. Checks if /// Candidate stores have indirect dependency through their - /// operands. \return True if safe to merge + /// operands. \return True if safe to merge. bool checkMergeStoreCandidatesForDependencies( SmallVectorImpl &StoreNodes, unsigned NumStores); @@ -12436,56 +12438,102 @@ if (NumStores < 2) return false; - int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8; - // The latest Node in the DAG. SDLoc DL(StoreNodes[0].MemNode); - SDValue StoredVal; + int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8; + unsigned SizeInBits = NumStores * ElementSizeBytes * 8; + + EVT StoreTy; if (UseVector) { - bool IsVec = MemVT.isVector(); unsigned Elts = NumStores; - if (IsVec) { - // When merging vector stores, get the total number of elements. + // When merging vector stores, get the total number of elements. + if (MemVT.isVector()) Elts *= MemVT.getVectorNumElements(); - } // Get the type for the merged vector store. - EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); - assert(TLI.isTypeLegal(Ty) && "Illegal vector store"); + StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); + } else + StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits); + SDValue StoredVal; + if (UseVector) { if (IsConstantSrc) { SmallVector BuildVector; - for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) { + for (unsigned I = 0; I != NumStores; ++I) { StoreSDNode *St = cast(StoreNodes[I].MemNode); SDValue Val = St->getValue(); - if (MemVT.getScalarType().isInteger()) - if (auto *CFP = dyn_cast(St->getValue())) - Val = DAG.getConstant( - (uint32_t)CFP->getValueAPF().bitcastToAPInt().getZExtValue(), - SDLoc(CFP), MemVT); + // Make sure constant is correct bit-size. + if (MemVT != Val.getValueType()) { + while (Val.getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); + if (ElementSizeBytes * 8 != Val.getValueType().getSizeInBits()) { + EVT IntMemVT = + EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()); + if (auto *CFP = dyn_cast(Val)) + Val = DAG.getConstant( + (uint32_t)CFP->getValueAPF().bitcastToAPInt().getZExtValue(), + SDLoc(CFP), IntMemVT); + if (auto *C = dyn_cast(Val)) + Val = DAG.getConstant( + C->getAPIntValue().zextOrTrunc(8 * ElementSizeBytes), + SDLoc(C), IntMemVT); + } + // make sure it's the correct type + Val = DAG.getBitcast(MemVT, Val); + } BuildVector.push_back(Val); } - StoredVal = DAG.getBuildVector(Ty, DL, BuildVector); + StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS + : ISD::BUILD_VECTOR, + DL, StoreTy, BuildVector); } else { SmallVector Ops; for (unsigned i = 0; i < NumStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); SDValue Val = St->getValue(); - // All operands of BUILD_VECTOR / CONCAT_VECTOR must have the same type. - if (Val.getValueType() != MemVT) - return false; + // Peek through bitcasts. + while (Val.getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); + // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of + // type MemVT. If the underlying value is not the correct + // type, but it is an extraction of an appropriate vector we + // can recast Val to be of the correct type. This may require + // converying between EXTRACT_VECTOR_ELT and + // EXTRACT_SUBVECTOR. + if (MemVT != Val.getValueType()) { + if (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || + Val.getOpcode() == ISD::EXTRACT_SUBVECTOR) { + SDValue Vec = Val.getOperand(0); + // We may need to add a bitcast here to get types to line up. + if (MemVT != Val.getValueType()) { + EVT NewVecScalarTy = MemVT.getScalarType(); + unsigned Elts = Vec.getValueType().getSizeInBits() / + NewVecScalarTy.getSizeInBits(); + EVT NewVecTy = + EVT::getVectorVT(*DAG.getContext(), NewVecScalarTy, Elts); + Vec = DAG.getBitcast(NewVecTy, Vec); + } + if (MemVT.isVector()) + Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(Val), MemVT, Vec, + Val.getOperand(1)); + else + Val = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(Val), MemVT, Vec, + Val.getOperand(1)); + } + } Ops.push_back(Val); } // Build the extracted vector elements back into a vector. - StoredVal = DAG.getNode(IsVec ? ISD::CONCAT_VECTORS : ISD::BUILD_VECTOR, - DL, Ty, Ops); } + StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS + : ISD::BUILD_VECTOR, + DL, StoreTy, Ops); + } } else { // We should always use a vector store when merging extracted vector // elements, so this path implies a store of constants. assert(IsConstantSrc && "Merged vector elements should use vector store"); - unsigned SizeInBits = NumStores * ElementSizeBytes * 8; APInt StoreInt(SizeInBits, 0); // Construct a single integer constant which is made of the smaller @@ -12507,7 +12555,6 @@ } // Create the new Load and Store operations. - EVT StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits); StoredVal = DAG.getConstant(StoreInt, DL, StoreTy); } @@ -12516,7 +12563,7 @@ // make sure we use trunc store if it's necessary to be legal. SDValue NewStore; - if (UseVector || !UseTrunc) { + if (!UseTrunc) { NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), FirstInChain->getAlignment()); @@ -12550,6 +12597,11 @@ BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr(), DAG); EVT MemVT = St->getMemoryVT(); + SDValue Val = St->getValue(); + // Only peek through bitcasts of non-truncstores + while (Val->getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); + // We must have a base and an offset. if (!BasePtr.getBase().getNode()) return; @@ -12558,44 +12610,61 @@ if (BasePtr.getBase().isUndef()) return; - bool IsConstantSrc = isa(St->getValue()) || - isa(St->getValue()); - bool IsExtractVecSrc = - (St->getValue().getOpcode() == ISD::EXTRACT_VECTOR_ELT || - St->getValue().getOpcode() == ISD::EXTRACT_SUBVECTOR); - bool IsLoadSrc = isa(St->getValue()); + bool IsConstantSrc = isa(Val) || isa(Val); + bool IsExtractVecSrc = (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || + Val.getOpcode() == ISD::EXTRACT_SUBVECTOR); + bool IsLoadSrc = isa(Val); BaseIndexOffset LBasePtr; // Match on loadbaseptr if relevant. - if (IsLoadSrc) - LBasePtr = BaseIndexOffset::match( - cast(St->getValue())->getBasePtr(), DAG); - + EVT LoadVT; + if (IsLoadSrc) { + auto *Ld = cast(Val); + LBasePtr = BaseIndexOffset::match(Ld->getBasePtr(), DAG); + LoadVT = Ld->getMemoryVT(); + // Don't allow hidden truncations / extensions in merge candidates. + if (!MemVT.bitsEq(LoadVT)) + return; + } auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr, int64_t &Offset) -> bool { if (Other->isVolatile() || Other->isIndexed()) return false; - // We can merge constant floats to equivalent integers - if (Other->getMemoryVT() != MemVT) - if (!(MemVT.isInteger() && MemVT.bitsEq(Other->getMemoryVT()) && - isa(Other->getValue()))) - return false; + // Peek through bitcasts. + SDValue Val = Other->getValue(); + while (Val.getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); if (IsLoadSrc) { + // Allow loads of different types to merge as integers. + if (MemVT.isInteger() ? !MemVT.bitsEq(Other->getMemoryVT()) + : Other->getMemoryVT() != MemVT) + return false; // The Load's Base Ptr must also match - if (LoadSDNode *OtherLd = dyn_cast(Other->getValue())) { + if (LoadSDNode *OtherLd = dyn_cast(Val)) { auto LPtr = BaseIndexOffset::match(OtherLd->getBasePtr(), DAG); + if (!LoadVT.bitsEq(OtherLd->getMemoryVT())) + return false; if (!(LBasePtr.equalBaseIndex(LPtr, DAG))) return false; } else return false; } - if (IsConstantSrc) - if (!(isa(Other->getValue()) || - isa(Other->getValue()))) + if (IsConstantSrc) { + // Allow merging constants of different types as integers. + if (MemVT.isInteger() ? !MemVT.bitsEq(Other->getMemoryVT()) + : Other->getMemoryVT() != MemVT) + return false; + if (!(isa(Val) || isa(Val))) + return false; + } + if (IsExtractVecSrc) { + // Do not merge truncated stores here. + if (Other->isTruncatingStore()) return false; - if (IsExtractVecSrc) - if (!(Other->getValue().getOpcode() == ISD::EXTRACT_VECTOR_ELT || - Other->getValue().getOpcode() == ISD::EXTRACT_SUBVECTOR)) + if (!MemVT.bitsEq(Val.getValueType()) || + !(Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || + Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) return false; + } Ptr = BaseIndexOffset::match(Other->getBasePtr(), DAG); return (BasePtr.equalBaseIndex(Ptr, DAG, Offset)); }; @@ -12687,6 +12756,9 @@ // Perform an early exit check. Do not bother looking at stored values that // are not constants, loads, or extracted vector elements. SDValue StoredVal = St->getValue(); + while (StoredVal->getOpcode() == ISD::BITCAST) + StoredVal = StoredVal.getOperand(0); + bool IsLoadSrc = isa(StoredVal); bool IsConstantSrc = isa(StoredVal) || isa(StoredVal); @@ -12696,12 +12768,6 @@ if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc) return false; - // Don't merge vectors into wider vectors if the source data comes from loads. - // TODO: This restriction can be lifted by using logic similar to the - // ExtractVecSrc case. - if (MemVT.isVector() && IsLoadSrc) - return false; - SmallVector StoreNodes; // Find potential store merge candidates by searching through chain sub-DAG getStoreMergeCandidates(St, StoreNodes); @@ -12888,14 +12954,18 @@ bool IsVec = MemVT.isVector(); for (unsigned i = 0; i < NumConsecutiveStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); - unsigned StoreValOpcode = St->getValue().getOpcode(); + SDValue StVal = St->getValue(); + // Peek through bitcasts + while (StVal->getOpcode() == ISD::BITCAST) + StVal = StVal.getOperand(0); + // This restriction could be loosened. // Bail out if any stored values are not elements extracted from a // vector. It should be possible to handle mixed sources, but load // sources need more careful handling (see the block of code below that // handles consecutive loads). - if (StoreValOpcode != ISD::EXTRACT_VECTOR_ELT && - StoreValOpcode != ISD::EXTRACT_SUBVECTOR) + if (StVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT && + StVal.getOpcode() != ISD::EXTRACT_SUBVECTOR) return RV; // Find a legal type for the vector store. @@ -12958,7 +13028,11 @@ BaseIndexOffset LdBasePtr; for (unsigned i = 0; i < NumConsecutiveStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); - LoadSDNode *Ld = dyn_cast(St->getValue()); + SDValue Val = St->getValue(); + while (Val.getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); + + LoadSDNode *Ld = dyn_cast(Val); if (!Ld) break; @@ -12970,12 +13044,8 @@ if (Ld->isVolatile() || Ld->isIndexed()) break; - // We do not accept ext loads. - if (Ld->getExtensionType() != ISD::NON_EXTLOAD) - break; - - // The stored memory type must be the same. - if (Ld->getMemoryVT() != MemVT) + // The stored memory type must be the same size as MemVT. + if (!MemVT.bitsEq(Ld->getMemoryVT())) break; BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld->getBasePtr(), DAG); @@ -13039,7 +13109,10 @@ isDereferenceable = false; // Find a legal type for the vector store. - EVT StoreTy = EVT::getVectorVT(Context, MemVT, i + 1); + unsigned Elts = + (i + 1) * ((MemVT.isVector()) ? MemVT.getVectorNumElements() : 1); + EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); + bool IsFastSt, IsFastLd; if (TLI.isTypeLegal(StoreTy) && TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && @@ -13108,7 +13181,11 @@ // to memory. EVT JointMemOpVT; if (UseVectorTy) { - JointMemOpVT = EVT::getVectorVT(Context, MemVT, NumElem); + // Find a legal type for the vector store. + unsigned Elts = NumElem; + if (MemVT.isVector()) + Elts *= MemVT.getVectorNumElements(); + JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); } else { unsigned SizeInBits = NumElem * ElementSizeBytes * 8; JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits); @@ -13153,6 +13230,7 @@ // Transfer chain users from old loads to the new load. for (unsigned i = 0; i < NumElem; ++i) { LoadSDNode *Ld = cast(LoadNodes[i].MemNode); + AddToWorklist(Ld); DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1)); } Index: test/CodeGen/X86/MergeConsecutiveStores.ll =================================================================== --- test/CodeGen/X86/MergeConsecutiveStores.ll +++ test/CodeGen/X86/MergeConsecutiveStores.ll @@ -522,7 +522,7 @@ ; CHECK-NEXT: retq } -; Merging vector stores when sourced from vector loads is not currently handled. +; Merging vector stores when sourced from vector loads. define void @merge_vec_stores_from_loads(<4 x float>* %v, <4 x float>* %ptr) { %load_idx0 = getelementptr inbounds <4 x float>, <4 x float>* %v, i64 0 %load_idx1 = getelementptr inbounds <4 x float>, <4 x float>* %v, i64 1 @@ -535,10 +535,9 @@ ret void ; CHECK-LABEL: merge_vec_stores_from_loads -; CHECK: vmovaps -; CHECK-NEXT: vmovaps -; CHECK-NEXT: vmovaps -; CHECK-NEXT: vmovaps +; CHECK: vmovups +; CHECK-NEXT: vmovups +; CHECK-NEXT: vzeroupper ; CHECK-NEXT: retq } @@ -622,9 +621,6 @@ ret void ; CHECK-LABEL: merge_bitcast -; CHECK: vmovd %xmm0, (%rdi) -; CHECK-NEXT: vpextrd $1, %xmm0, 4(%rdi) -; CHECK-NEXT: vpextrd $2, %xmm0, 8(%rdi) -; CHECK-NEXT: vpextrd $3, %xmm0, 12(%rdi) +; CHECK: vmovups %xmm0, (%rdi) ; CHECK-NEXT: retq }