Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -463,23 +463,25 @@ SDValue getMergeStoreChains(SmallVectorImpl &StoreNodes, unsigned NumStores); - /// This is a helper function for MergeConsecutiveStores. When the source - /// elements of the consecutive stores are all constants or all extracted - /// vector elements, try to merge them into one larger store. + /// This is a helper function for MergeConsecutiveStores. When the + /// source elements of the consecutive stores are all constants or + /// all extracted vector elements, try to merge them into one + /// larger store introducing bitcasts if necessary. /// \return True if a merged store was created. bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl &StoreNodes, EVT MemVT, unsigned NumStores, bool IsConstantSrc, bool UseVector, bool UseTrunc); - /// This is a helper function for MergeConsecutiveStores. - /// Stores that may be merged are placed in StoreNodes. + /// This is a helper function for MergeConsecutiveStores. Stores + /// that may potentially merged with St are placed in + /// StoreNodes. void getStoreMergeCandidates(StoreSDNode *St, SmallVectorImpl &StoreNodes); /// Helper function for MergeConsecutiveStores. Checks if /// Candidate stores have indirect dependency through their - /// operands. \return True if safe to merge + /// operands. \return True if safe to merge. bool checkMergeStoreCandidatesForDependencies( SmallVectorImpl &StoreNodes, unsigned NumStores); @@ -12467,22 +12469,67 @@ for (unsigned I = 0; I != NumStores; ++I) { StoreSDNode *St = cast(StoreNodes[I].MemNode); SDValue Val = St->getValue(); - if (MemVT.getScalarType().isInteger()) - if (auto *CFP = dyn_cast(Val)) - Val = DAG.getConstant( - (uint32_t)CFP->getValueAPF().bitcastToAPInt().getZExtValue(), - SDLoc(CFP), MemVT); + // if constant is of the wrong type, convert it now. + if (MemVT != Val.getValueType()) { + // Peek through bitcasts. + while (Val.getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); + // Deal with constants of wrong size. + if (ElementSizeBytes * 8 != Val.getValueType().getSizeInBits()) { + EVT IntMemVT = + EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()); + if (auto *CFP = dyn_cast(Val)) + Val = DAG.getConstant( + (uint32_t)CFP->getValueAPF().bitcastToAPInt().getZExtValue(), + SDLoc(CFP), IntMemVT); + else if (auto *C = dyn_cast(Val)) + Val = DAG.getConstant( + C->getAPIntValue().zextOrTrunc(8 * ElementSizeBytes), + SDLoc(C), IntMemVT); + } + // Make sure correctly size type is the correct type. + Val = DAG.getBitcast(MemVT, Val); + } BuildVector.push_back(Val); } - StoredVal = DAG.getBuildVector(StoreTy, DL, BuildVector); + StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS + : ISD::BUILD_VECTOR, + DL, StoreTy, BuildVector); } else { SmallVector Ops; for (unsigned i = 0; i < NumStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); SDValue Val = St->getValue(); - // All operands of BUILD_VECTOR / CONCAT_VECTOR must have the same type. - if (Val.getValueType() != MemVT) - return false; + // Peek through bitcasts. + while (Val.getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); + // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of + // type MemVT. If the underlying value is not the correct + // type, but it is an extraction of an appropriate vector we + // can recast Val to be of the correct type. This may require + // converting between EXTRACT_VECTOR_ELT and + // EXTRACT_SUBVECTOR. + if (MemVT != Val.getValueType()) { + if (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || + Val.getOpcode() == ISD::EXTRACT_SUBVECTOR) { + SDValue Vec = Val.getOperand(0); + EVT MemVTScalarTy = MemVT.getScalarType(); + // We may need to add a bitcast here to get types to line up. + if (MemVTScalarTy != Vec.getValueType()) { + unsigned Elts = Vec.getValueType().getSizeInBits() / + MemVTScalarTy.getSizeInBits(); + EVT NewVecTy = + EVT::getVectorVT(*DAG.getContext(), MemVTScalarTy, Elts); + Vec = DAG.getBitcast(NewVecTy, Vec); + } + if (MemVT.isVector()) + Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(Val), MemVT, Vec, + Val.getOperand(1)); + else + Val = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(Val), MemVT, Vec, + Val.getOperand(1)); + } + } Ops.push_back(Val); } @@ -12525,7 +12572,7 @@ // make sure we use trunc store if it's necessary to be legal. SDValue NewStore; - if (UseVector || !UseTrunc) { + if (!UseTrunc) { NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), FirstInChain->getAlignment()); @@ -12560,6 +12607,10 @@ EVT MemVT = St->getMemoryVT(); SDValue Val = St->getValue(); + // Only peek through bitcasts of non-truncstores + while (Val->getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); + // We must have a base and an offset. if (!BasePtr.getBase().getNode()) return; @@ -12588,9 +12639,13 @@ if (Other->isVolatile() || Other->isIndexed()) return false; SDValue Val = Other->getValue(); + // Peek through bitcasts. + while (Val.getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); if (IsLoadSrc) { - // Loads must match type. - if (Other->getMemoryVT() != MemVT) + // Allow loads of different types to merge as integers. + if (MemVT.isInteger() ? !MemVT.bitsEq(Other->getMemoryVT()) + : Other->getMemoryVT() != MemVT) return false; // The Load's Base Ptr must also match if (LoadSDNode *OtherLd = dyn_cast(Val)) { @@ -12611,8 +12666,10 @@ return false; } if (IsExtractVecSrc) { - // Must match type. - if (Other->getMemoryVT() != MemVT) + // Do not merge truncated stores here. + if (Other->isTruncatingStore()) + return false; + if (!MemVT.bitsEq(Val.getValueType())) return false; if (!(Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) @@ -12710,6 +12767,9 @@ // Perform an early exit check. Do not bother looking at stored values that // are not constants, loads, or extracted vector elements. SDValue StoredVal = St->getValue(); + while (StoredVal->getOpcode() == ISD::BITCAST) + StoredVal = StoredVal.getOperand(0); + bool IsLoadSrc = isa(StoredVal); bool IsConstantSrc = isa(StoredVal) || isa(StoredVal); @@ -12901,6 +12961,10 @@ for (unsigned i = 0; i < NumConsecutiveStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); SDValue StVal = St->getValue(); + // Peek through bitcasts + while (StVal->getOpcode() == ISD::BITCAST) + StVal = StVal.getOperand(0); + // This restriction could be loosened. // Bail out if any stored values are not elements extracted from a // vector. It should be possible to handle mixed sources, but load @@ -12967,6 +13031,10 @@ for (unsigned i = 0; i < NumConsecutiveStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); SDValue Val = St->getValue(); + // Peek through Bitcasts. + while (Val.getOpcode() == ISD::BITCAST) + Val = Val.getOperand(0); + LoadSDNode *Ld = dyn_cast(Val); if (!Ld) break; Index: test/CodeGen/X86/MergeConsecutiveStores.ll =================================================================== --- test/CodeGen/X86/MergeConsecutiveStores.ll +++ test/CodeGen/X86/MergeConsecutiveStores.ll @@ -522,7 +522,7 @@ ; CHECK-NEXT: retq } -; Merging vector stores when sourced from vector loads is not currently handled. +; Merging vector stores when sourced from vector loads. define void @merge_vec_stores_from_loads(<4 x float>* %v, <4 x float>* %ptr) { %load_idx0 = getelementptr inbounds <4 x float>, <4 x float>* %v, i64 0 %load_idx1 = getelementptr inbounds <4 x float>, <4 x float>* %v, i64 1 @@ -621,9 +621,6 @@ ret void ; CHECK-LABEL: merge_bitcast -; CHECK: vmovd %xmm0, (%rdi) -; CHECK-NEXT: vpextrd $1, %xmm0, 4(%rdi) -; CHECK-NEXT: vpextrd $2, %xmm0, 8(%rdi) -; CHECK-NEXT: vpextrd $3, %xmm0, 12(%rdi) +; CHECK: vmovups %xmm0, (%rdi) ; CHECK-NEXT: retq }