Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9738,6 +9738,123 @@ } } + bool NewLoad = false; + bool BCNumEltsChanged = false; + EVT ExtVT = VT.getVectorElementType(); + EVT LVT = ExtVT; + + // If the result of load has to be truncated, then it's not necessarily + // profitable. + if (NVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, NVT)) + return SDValue(); + + if (InVec.getOpcode() == ISD::BITCAST) { + // Don't duplicate a load with other uses. + if (!InVec.hasOneUse()) + return SDValue(); + + EVT BCVT = InVec.getOperand(0).getValueType(); + if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType())) + return SDValue(); + if (VT.getVectorNumElements() != BCVT.getVectorNumElements()) + BCNumEltsChanged = true; + InVec = InVec.getOperand(0); + ExtVT = BCVT.getVectorElementType(); + NewLoad = true; + } + + auto ReplaceLoad = [&](EVT LoadVT, LoadSDNode *OriginalLoad, + SDValue EltIndex) { + unsigned Align = OriginalLoad->getAlignment(); + if (NewLoad) { + // Check the resultant load doesn't need a higher alignment than the + // original load. + unsigned NewAlign = + TLI.getDataLayout() + ->getABITypeAlignment(LoadVT.getTypeForEVT(*DAG.getContext())); + + if (NewAlign > Align || !TLI.isOperationLegalOrCustom(ISD::LOAD, LoadVT)) + return SDValue(); + + Align = NewAlign; + } + + SDValue NewPtr = OriginalLoad->getBasePtr(); + SDValue Offset; + EVT PtrType = NewPtr.getValueType(); + MachinePointerInfo MPI; + if (ConstEltNo) { + int Elt = cast(EltNo)->getZExtValue(); + unsigned PtrOff = LoadVT.getSizeInBits() * Elt / 8; + if (TLI.isBigEndian()) + PtrOff = VT.getSizeInBits() / 8 - PtrOff; + Offset = DAG.getConstant(PtrOff, PtrType); + MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff); + } else { + Offset = DAG.getNode( + ISD::MUL, SDLoc(N), EltNo.getValueType(), EltNo, + DAG.getConstant(LoadVT.getStoreSize(), EltNo.getValueType())); + if (TLI.isBigEndian()) + Offset = DAG.getNode(ISD::SUB, SDLoc(N), EltNo.getValueType(), + DAG.getConstant(VT.getStoreSize(), + EltNo.getValueType()), + Offset); + MPI = OriginalLoad->getPointerInfo(); + } + NewPtr = DAG.getNode(ISD::ADD, SDLoc(N), PtrType, NewPtr, Offset); + + // The replacement we need to do here is a little tricky: we need to + // replace an extractelement of a load with a load. + // Use ReplaceAllUsesOfValuesWith to do the replacement. + // Note that this replacement assumes that the extractvalue is the only + // use of the load; that's okay because we don't want to perform this + // transformation in other cases anyway. + SDValue Load; + SDValue Chain; + if (NVT.bitsGT(LoadVT)) { + // If the result type of vextract is wider than the load, then issue an + // extending load instead. + ISD::LoadExtType ExtType = TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadVT) + ? ISD::ZEXTLOAD + : ISD::EXTLOAD; + Load = DAG.getExtLoad(ExtType, SDLoc(N), NVT, OriginalLoad->getChain(), + NewPtr, MPI, LoadVT, OriginalLoad->isVolatile(), + OriginalLoad->isNonTemporal(), Align, + OriginalLoad->getTBAAInfo()); + Chain = Load.getValue(1); + } else { + Load = DAG.getLoad( + LoadVT, SDLoc(N), OriginalLoad->getChain(), NewPtr, MPI, + OriginalLoad->isVolatile(), OriginalLoad->isNonTemporal(), + OriginalLoad->isInvariant(), Align, OriginalLoad->getTBAAInfo()); + Chain = Load.getValue(1); + if (NVT.bitsLT(LoadVT)) + Load = DAG.getNode(ISD::TRUNCATE, SDLoc(N), NVT, Load); + else + Load = DAG.getNode(ISD::BITCAST, SDLoc(N), NVT, Load); + } + WorkListRemover DeadNodes(*this); + SDValue From[] = { SDValue(N, 0), SDValue(OriginalLoad, 1) }; + SDValue To[] = { Load, Chain }; + DAG.ReplaceAllUsesOfValuesWith(From, To, 2); + // Since we're explicitly calling ReplaceAllUses, add the new node to the + // worklist explicitly as well. + AddToWorkList(Load.getNode()); + AddUsersToWorkList(Load.getNode()); // Add users too + // Make sure to revisit this node to clean it up; it will usually be dead. + AddToWorkList(N); + ++OpsNarrowed; + return SDValue(N, 0); + }; + + // (vextract (vN[if]M load $addr), i) -> ([if]M load $addr + i * size) + if (!LegalOperations && !ConstEltNo && InVec.hasOneUse() && + ISD::isNormalLoad(InVec.getNode())) { + SDValue Index = N->getOperand(1); + if (LoadSDNode *OrigLoad = dyn_cast(InVec)) + return ReplaceLoad(LVT, OrigLoad, Index); + } + // Perform only after legalization to ensure build_vector / vector_shuffle // optimizations have already been done. if (!LegalOperations) return SDValue(); @@ -9748,30 +9865,6 @@ if (ConstEltNo) { int Elt = cast(EltNo)->getZExtValue(); - bool NewLoad = false; - bool BCNumEltsChanged = false; - EVT ExtVT = VT.getVectorElementType(); - EVT LVT = ExtVT; - - // If the result of load has to be truncated, then it's not necessarily - // profitable. - if (NVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, NVT)) - return SDValue(); - - if (InVec.getOpcode() == ISD::BITCAST) { - // Don't duplicate a load with other uses. - if (!InVec.hasOneUse()) - return SDValue(); - - EVT BCVT = InVec.getOperand(0).getValueType(); - if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType())) - return SDValue(); - if (VT.getVectorNumElements() != BCVT.getVectorNumElements()) - BCNumEltsChanged = true; - InVec = InVec.getOperand(0); - ExtVT = BCVT.getVectorElementType(); - NewLoad = true; - } LoadSDNode *LN0 = nullptr; const ShuffleVectorSDNode *SVN = nullptr; @@ -9814,6 +9907,7 @@ if (ISD::isNormalLoad(InVec.getNode())) { LN0 = cast(InVec); Elt = (Idx < (int)NumElems) ? Idx : Idx - (int)NumElems; + EltNo = DAG.getConstant(Elt, EltNo.getValueType()); } } @@ -9826,72 +9920,7 @@ if (Elt == -1) return DAG.getUNDEF(LVT); - unsigned Align = LN0->getAlignment(); - if (NewLoad) { - // Check the resultant load doesn't need a higher alignment than the - // original load. - unsigned NewAlign = - TLI.getDataLayout() - ->getABITypeAlignment(LVT.getTypeForEVT(*DAG.getContext())); - - if (NewAlign > Align || !TLI.isOperationLegalOrCustom(ISD::LOAD, LVT)) - return SDValue(); - - Align = NewAlign; - } - - SDValue NewPtr = LN0->getBasePtr(); - unsigned PtrOff = 0; - - if (Elt) { - PtrOff = LVT.getSizeInBits() * Elt / 8; - EVT PtrType = NewPtr.getValueType(); - if (TLI.isBigEndian()) - PtrOff = VT.getSizeInBits() / 8 - PtrOff; - NewPtr = DAG.getNode(ISD::ADD, SDLoc(N), PtrType, NewPtr, - DAG.getConstant(PtrOff, PtrType)); - } - - // The replacement we need to do here is a little tricky: we need to - // replace an extractelement of a load with a load. - // Use ReplaceAllUsesOfValuesWith to do the replacement. - // Note that this replacement assumes that the extractvalue is the only - // use of the load; that's okay because we don't want to perform this - // transformation in other cases anyway. - SDValue Load; - SDValue Chain; - if (NVT.bitsGT(LVT)) { - // If the result type of vextract is wider than the load, then issue an - // extending load instead. - ISD::LoadExtType ExtType = TLI.isLoadExtLegal(ISD::ZEXTLOAD, LVT) - ? ISD::ZEXTLOAD : ISD::EXTLOAD; - Load = DAG.getExtLoad(ExtType, SDLoc(N), NVT, LN0->getChain(), - NewPtr, LN0->getPointerInfo().getWithOffset(PtrOff), - LVT, LN0->isVolatile(), LN0->isNonTemporal(), - Align, LN0->getTBAAInfo()); - Chain = Load.getValue(1); - } else { - Load = DAG.getLoad(LVT, SDLoc(N), LN0->getChain(), NewPtr, - LN0->getPointerInfo().getWithOffset(PtrOff), - LN0->isVolatile(), LN0->isNonTemporal(), - LN0->isInvariant(), Align, LN0->getTBAAInfo()); - Chain = Load.getValue(1); - if (NVT.bitsLT(LVT)) - Load = DAG.getNode(ISD::TRUNCATE, SDLoc(N), NVT, Load); - else - Load = DAG.getNode(ISD::BITCAST, SDLoc(N), NVT, Load); - } - WorkListRemover DeadNodes(*this); - SDValue From[] = { SDValue(N, 0), SDValue(LN0,1) }; - SDValue To[] = { Load, Chain }; - DAG.ReplaceAllUsesOfValuesWith(From, To, 2); - // Since we're explcitly calling ReplaceAllUses, add the new node to the - // worklist explicitly as well. - AddToWorkList(Load.getNode()); - AddUsersToWorkList(Load.getNode()); // Add users too - // Make sure to revisit this node to clean it up; it will usually be dead. - AddToWorkList(N); - return SDValue(N, 0); + return ReplaceLoad(LVT, LN0, EltNo); } return SDValue(); Index: test/CodeGen/X86/vec_splat.ll =================================================================== --- test/CodeGen/X86/vec_splat.ll +++ test/CodeGen/X86/vec_splat.ll @@ -1,5 +1,6 @@ ; RUN: llc < %s -march=x86 -mcpu=pentium4 -mattr=+sse2 | FileCheck %s -check-prefix=SSE2 ; RUN: llc < %s -march=x86 -mcpu=pentium4 -mattr=+sse3 | FileCheck %s -check-prefix=SSE3 +; RUN: llc < %s -march=x86-64 -mattr=+avx | FileCheck %s -check-prefix=AVX define void @test_v4sf(<4 x float>* %P, <4 x float>* %Q, float %X) nounwind { %tmp = insertelement <4 x float> zeroinitializer, float %X, i32 0 ; <<4 x float>> [#uses=1] @@ -32,3 +33,20 @@ ; SSE3-LABEL: test_v2sd: ; SSE3: movddup } + +; Fold extract of a load into the load's address computation. This avoids spilling to the stack. +define <4 x float> @load_extract_splat(<4 x float>* nocapture readonly %ptr, i64 %i, i64 %j) nounwind { + %1 = getelementptr inbounds <4 x float>* %ptr, i64 %i + %2 = load <4 x float>* %1, align 16 + %3 = trunc i64 %j to i32 + %4 = extractelement <4 x float> %2, i32 %3 + %5 = insertelement <4 x float> undef, float %4, i32 0 + %6 = insertelement <4 x float> %5, float %4, i32 1 + %7 = insertelement <4 x float> %6, float %4, i32 2 + %8 = insertelement <4 x float> %7, float %4, i32 3 + ret <4 x float> %8 + +; AVX-LABEL: load_extract_splat +; AVX-NOT: rsp +; AVX: vbroadcastss +}