diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -9887,8 +9887,7 @@ bool IsPermutedLoad = false; const SDValue *InputLoad = getNormalLoadInput(V1, IsPermutedLoad); if (InputLoad && Subtarget.hasVSX() && V2.isUndef() && - (PPC::isSplatShuffleMask(SVOp, 4) || PPC::isSplatShuffleMask(SVOp, 8)) && - InputLoad->hasOneUse()) { + (PPC::isSplatShuffleMask(SVOp, 4) || PPC::isSplatShuffleMask(SVOp, 8))) { bool IsFourByte = PPC::isSplatShuffleMask(SVOp, 4); int SplatIdx = PPC::getSplatIdxForPPCMnemonics(SVOp, IsFourByte ? 4 : 8, DAG); @@ -9904,9 +9903,100 @@ "Splat of a value outside of the loaded memory"); } + // In case of multiple uses, we can do load-and-splat if all uses only touch + // the splat region. To identify all uses of load, need to look through + // bitcast until vector_shuffle node is reached, or if any other node + // reached we give up. + bool LSCandidate = true; + bool MultiUsePresent = false; + bool DisableMultiUse = false; + SmallVector WorkList; + SmallPtrSet Visited; + + for (SDNode *U : (*InputLoad)->uses()) { + if (Visited.count(U)) + continue; + Visited.insert(U); + WorkList.push_back(U); + } + while (!WorkList.empty()) { + auto *N = WorkList.pop_back_val(); + if (N->getOpcode() == ISD::BITCAST) { + // Lookthrough bitcast + for (SDNode *U : N->uses()) { + if (Visited.count(U)) + continue; + Visited.insert(U); + WorkList.push_back(U); + } + } else if (auto *SV = dyn_cast(N)) { + if (SV == SVOp) + continue; + + // Check if load-and-splat can cover the data referenced in other + // vector_shuffle + auto IsDataFromSplat = [](auto *SV, bool IsFourByte, + bool isLittleEndian, unsigned SplatIdx, + unsigned Offset) -> bool { + EVT VT = SV->getValueType(0); + unsigned ElemSize = VT.getScalarSizeInBits() / 8; + unsigned NumElem = VT.getVectorNumElements(); + unsigned StartByteIdxForPPC = + SplatIdx * (IsFourByte ? 4 : 8) + Offset; + unsigned StopByteIdxForPPC = + (SplatIdx + 1) * (IsFourByte ? 4 : 8) + Offset; + for (unsigned i = 0; i < NumElem; ++i) { + unsigned ElemStartByteIdxForPPC; + unsigned ElemStopByteIdxForPPC; + unsigned ByteSeq = SV->getMaskElt(i) * ElemSize; + if (isLittleEndian) { + if (ByteSeq < 16) { + ElemStopByteIdxForPPC = 16 - ByteSeq; + } else { + ElemStopByteIdxForPPC = 16 + (16 - (ByteSeq - 16)); + } + ElemStartByteIdxForPPC = ElemStopByteIdxForPPC - ElemSize; + } else { + ElemStartByteIdxForPPC = ByteSeq; + ElemStopByteIdxForPPC = ElemStartByteIdxForPPC + ElemSize; + } + if ((Offset <= ElemStartByteIdxForPPC && + ElemStopByteIdxForPPC < (Offset + 16)) && + !(StartByteIdxForPPC <= ElemStartByteIdxForPPC && + ElemStopByteIdxForPPC <= StopByteIdxForPPC)) + return false; + } + return true; + }; + + MultiUsePresent = true; + SDValue V1 = SV->getOperand(0); + SDValue V2 = SV->getOperand(1); + bool IsPermuted = false; + const SDValue *V1InputLoad = getNormalLoadInput(V1, IsPermuted); + const SDValue *V2InputLoad = getNormalLoadInput(V2, IsPermuted); + if (V1InputLoad && *V1InputLoad == *InputLoad) + if (!IsDataFromSplat(SV, IsFourByte, isLittleEndian, SplatIdx, 0)) + LSCandidate = false; + if (V2InputLoad && *V2InputLoad == *InputLoad) + if (!IsDataFromSplat(SV, IsFourByte, isLittleEndian, SplatIdx, 16)) + LSCandidate = false; + } else + DisableMultiUse = true; + } + + if (DisableMultiUse) { + if (InputLoad->hasOneUse()) + LSCandidate = true; + else + LSCandidate = false; + MultiUsePresent = false; + } + LoadSDNode *LD = cast(*InputLoad); // For 4-byte load-and-splat, we need Power9. - if ((IsFourByte && Subtarget.hasP9Vector()) || !IsFourByte) { + if (LSCandidate && + ((IsFourByte && Subtarget.hasP9Vector()) || !IsFourByte)) { uint64_t Offset = 0; if (IsFourByte) Offset = isLittleEndian ? (3 - SplatIdx) * 4 : SplatIdx * 4; @@ -9935,6 +10025,9 @@ DAG.ReplaceAllUsesOfValueWith(InputLoad->getValue(1), LdSplt.getValue(1)); if (LdSplt.getValueType() != SVOp->getValueType(0)) LdSplt = DAG.getBitcast(SVOp->getValueType(0), LdSplt); + if (MultiUsePresent) + DAG.ReplaceAllUsesOfValueWith(InputLoad->getValue(0), + LdSplt.getValue(0)); return LdSplt; } } diff --git a/llvm/test/CodeGen/PowerPC/load-and-splat.ll b/llvm/test/CodeGen/PowerPC/load-and-splat.ll --- a/llvm/test/CodeGen/PowerPC/load-and-splat.ll +++ b/llvm/test/CodeGen/PowerPC/load-and-splat.ll @@ -1401,3 +1401,130 @@ %2 = bitcast<8 x i16> %1 to <4 x i32> ret <4 x i32> %2 } + +define <8 x float> @test_splat_multiuseW(<8 x float>* %vp) { +; P9-LABEL: test_splat_multiuseW: +; P9: # %bb.0: # %entry +; P9-NEXT: lxv v3, 16(r3) +; P9-NEXT: lxvwsx v2, 0, r3 +; P9-NEXT: addis r3, r2, .LCPI26_0@toc@ha +; P9-NEXT: addi r3, r3, .LCPI26_0@toc@l +; P9-NEXT: lxv v4, 0(r3) +; P9-NEXT: vperm v3, v3, v2, v4 +; P9-NEXT: blr +; +; P8-LABEL: test_splat_multiuseW: +; P8: # %bb.0: # %entry +; P8-NEXT: li r4, 16 +; P8-NEXT: addis r5, r2, .LCPI26_0@toc@ha +; P8-NEXT: lxvd2x vs1, 0, r3 +; P8-NEXT: lxvd2x vs0, r3, r4 +; P8-NEXT: addi r4, r5, .LCPI26_0@toc@l +; P8-NEXT: lxvd2x vs2, 0, r4 +; P8-NEXT: xxswapd v4, vs1 +; P8-NEXT: xxswapd v2, vs0 +; P8-NEXT: xxswapd v3, vs2 +; P8-NEXT: vperm v3, v4, v2, v3 +; P8-NEXT: xxspltw v2, v4, 3 +; P8-NEXT: blr +; +; P7-LABEL: test_splat_multiuseW: +; P7: # %bb.0: # %entry +; P7-NEXT: li r4, 16 +; P7-NEXT: addis r5, r2, .LCPI26_0@toc@ha +; P7-NEXT: lxvw4x v4, 0, r3 +; P7-NEXT: lxvw4x v2, r3, r4 +; P7-NEXT: addi r4, r5, .LCPI26_0@toc@l +; P7-NEXT: lxvw4x v3, 0, r4 +; P7-NEXT: vperm v3, v4, v2, v3 +; P7-NEXT: xxspltw v2, v4, 0 +; P7-NEXT: blr +; +; P9-AIX32-LABEL: test_splat_multiuseW: +; P9-AIX32: # %bb.0: # %entry +; P9-AIX32-NEXT: lwz r4, L..C3(r2) # %const.0 +; P9-AIX32-NEXT: lxv v3, 16(r3) +; P9-AIX32-NEXT: lxvwsx v2, 0, r3 +; P9-AIX32-NEXT: lxv v4, 0(r4) +; P9-AIX32-NEXT: vperm v3, v3, v2, v4 +; P9-AIX32-NEXT: blr +; +; P8-AIX32-LABEL: test_splat_multiuseW: +; P8-AIX32: # %bb.0: # %entry +; P8-AIX32-NEXT: lwz r4, L..C3(r2) # %const.0 +; P8-AIX32-NEXT: li r5, 16 +; P8-AIX32-NEXT: lxvw4x v4, 0, r3 +; P8-AIX32-NEXT: lxvw4x v2, r3, r5 +; P8-AIX32-NEXT: lxvw4x v3, 0, r4 +; P8-AIX32-NEXT: vperm v3, v2, v4, v3 +; P8-AIX32-NEXT: xxspltw v2, v4, 0 +; P8-AIX32-NEXT: blr +; +; P7-AIX32-LABEL: test_splat_multiuseW: +; P7-AIX32: # %bb.0: # %entry +; P7-AIX32-NEXT: lwz r4, L..C3(r2) # %const.0 +; P7-AIX32-NEXT: li r5, 16 +; P7-AIX32-NEXT: lxvw4x v4, 0, r3 +; P7-AIX32-NEXT: lxvw4x v2, r3, r5 +; P7-AIX32-NEXT: lxvw4x v3, 0, r4 +; P7-AIX32-NEXT: vperm v3, v4, v2, v3 +; P7-AIX32-NEXT: xxspltw v2, v4, 0 +; P7-AIX32-NEXT: blr +entry: + %vec = load <8 x float>, <8 x float>* %vp + %res = shufflevector <8 x float> %vec, <8 x float> undef, <8 x i32> + ret <8 x float> %res +} + +define <4 x double> @test_splat_multiuseD(<4 x double>* %vp) { +; P9-LABEL: test_splat_multiuseD: +; P9: # %bb.0: # %entry +; P9-NEXT: lxvdsx v2, 0, r3 +; P9-NEXT: lxv vs0, 16(r3) +; P9-NEXT: xxmrghd v3, vs0, v2 +; P9-NEXT: blr +; +; P8-LABEL: test_splat_multiuseD: +; P8: # %bb.0: # %entry +; P8-NEXT: li r4, 16 +; P8-NEXT: lxvdsx v2, 0, r3 +; P8-NEXT: lxvd2x vs0, r3, r4 +; P8-NEXT: xxswapd vs0, vs0 +; P8-NEXT: xxmrgld v3, v2, vs0 +; P8-NEXT: blr +; +; P7-LABEL: test_splat_multiuseD: +; P7: # %bb.0: # %entry +; P7-NEXT: li r4, 16 +; P7-NEXT: lxvdsx v2, 0, r3 +; P7-NEXT: lxvd2x vs0, r3, r4 +; P7-NEXT: xxmrghd v3, vs0, v2 +; P7-NEXT: blr +; +; P9-AIX32-LABEL: test_splat_multiuseD: +; P9-AIX32: # %bb.0: # %entry +; P9-AIX32-NEXT: lxvdsx v2, 0, r3 +; P9-AIX32-NEXT: lxv vs0, 16(r3) +; P9-AIX32-NEXT: xxmrghd v3, vs0, v2 +; P9-AIX32-NEXT: blr +; +; P8-AIX32-LABEL: test_splat_multiuseD: +; P8-AIX32: # %bb.0: # %entry +; P8-AIX32-NEXT: li r4, 16 +; P8-AIX32-NEXT: lxvdsx v2, 0, r3 +; P8-AIX32-NEXT: lxvd2x vs0, r3, r4 +; P8-AIX32-NEXT: xxmrghd v3, vs0, v2 +; P8-AIX32-NEXT: blr +; +; P7-AIX32-LABEL: test_splat_multiuseD: +; P7-AIX32: # %bb.0: # %entry +; P7-AIX32-NEXT: li r4, 16 +; P7-AIX32-NEXT: lxvdsx v2, 0, r3 +; P7-AIX32-NEXT: lxvd2x vs0, r3, r4 +; P7-AIX32-NEXT: xxmrghd v3, vs0, v2 +; P7-AIX32-NEXT: blr +entry: + %vec = load <4 x double>, <4 x double>* %vp + %res = shufflevector <4 x double> %vec, <4 x double> undef, <4 x i32> + ret <4 x double> %res +}