diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -7271,10 +7271,42 @@ // the smaller value for free with a truncate. SDValue Value = MemSetValue; if (VT.bitsLT(LargestVT)) { + // Helper function to query least cost index for CombineStoreAndExtract. + const auto QueryIndex = [](const TargetLowering &TLI, SelectionDAG &DAG, + EVT Type, unsigned NumElts, unsigned &Index) { + Index = -1U; + unsigned Cost = std::numeric_limits::max(); + bool Ret = false; + for (unsigned i = 0; i < NumElts; ++i) { + unsigned TmpC; + if (TLI.canCombineStoreAndExtract( + Type.getTypeForEVT(*DAG.getContext()), + ConstantInt::get(*DAG.getContext(), APInt(8, i)), TmpC) && + TmpC < Cost) { + Cost = TmpC; + Index = i; + Ret = true; + } + } + return Ret; + }; + + unsigned Index; + unsigned NElts = LargestVT.getSizeInBits() / VT.getSizeInBits(); + EVT SVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), NElts); if (!LargestVT.isVector() && !VT.isVector() && TLI.isTruncateFree(LargestVT, VT)) Value = DAG.getNode(ISD::TRUNCATE, dl, VT, MemSetValue); - else + else if (LargestVT.isVector() && !VT.isVector() && + QueryIndex(TLI, DAG, SVT, NElts, Index) && + TLI.isTypeLegal(SVT) && + LargestVT.getSizeInBits() == SVT.getSizeInBits()) { + // Target which can combine store(extractelement VectorTy, Idx) can get + // the smaller value for free. + SDValue TailValue = DAG.getNode(ISD::BITCAST, dl, SVT, MemSetValue); + Value = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, TailValue, + DAG.getVectorIdxConstant(Index, dl)); + } else Value = getMemsetValue(Src, VT, DAG, dl); } assert(Value.getValueType() == VT && "Value with wrong type."); diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.h b/llvm/lib/Target/PowerPC/PPCISelLowering.h --- a/llvm/lib/Target/PowerPC/PPCISelLowering.h +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.h @@ -804,6 +804,9 @@ return true; } + bool canCombineStoreAndExtract(Type *VectorTy, Value *Idx, + unsigned &Cost) const override; + bool isCtlzFast() const override { return true; } diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -1608,6 +1608,36 @@ return VT.isScalarInteger(); } +bool PPCTargetLowering::canCombineStoreAndExtract(Type *VectorTy, Value *Idx, + unsigned &Cost) const { + if (!Subtarget.isPPC64() || !Subtarget.hasVSX()) + return false; + + if (!isa(Idx)) + return false; + + if (auto *VTy = dyn_cast(VectorTy)) { + if (VTy->getScalarType()->isIntegerTy()) { + unsigned BitWidth = VTy->getScalarSizeInBits(); + unsigned ElemIdx; + // Accept the combine only if the element index matches the one that can + // be directly move-from VSR. + if (BitWidth == 32) { + ElemIdx = Subtarget.isLittleEndian() ? 2 : 1; + } else if (BitWidth == 64) { + ElemIdx = Subtarget.isLittleEndian() ? 1 : 0; + } else { + return false; + } + if (cast(Idx)->getZExtValue() == ElemIdx) { + Cost = 1; + return true; + } + } + } + return false; +} + const char *PPCTargetLowering::getTargetNodeName(unsigned Opcode) const { switch ((PPCISD::NodeType)Opcode) { case PPCISD::FIRST_NUMBER: break; @@ -16789,10 +16819,20 @@ if (getTargetMachine().getOptLevel() != CodeGenOpt::None) { // We should use Altivec/VSX loads and stores when available. For unaligned // addresses, unaligned VSX loads are only fast starting with the P8. - if (Subtarget.hasAltivec() && Op.size() >= 16 && - (Op.isAligned(Align(16)) || - ((Op.isMemset() && Subtarget.hasVSX()) || Subtarget.hasP8Vector()))) - return MVT::v4i32; + if (Subtarget.hasAltivec() && Op.size() >= 16) { + if (Op.isMemset() && Subtarget.hasVSX()) { + uint64_t TailSize = Op.size() % 16; + // For memset lowering, tail size need be different from vector element + // size to allow borrow tail from vector, otherwise constant tail will + // be generated. + if (TailSize > 2 && TailSize <= 4) { + return MVT::v8i16; + } + return MVT::v4i32; + } + if (Op.isAligned(Align(16)) || Subtarget.hasP8Vector()) + return MVT::v4i32; + } } if (Subtarget.isPPC64()) { diff --git a/llvm/test/CodeGen/ARM/memset-align.ll b/llvm/test/CodeGen/ARM/memset-align.ll --- a/llvm/test/CodeGen/ARM/memset-align.ll +++ b/llvm/test/CodeGen/ARM/memset-align.ll @@ -18,9 +18,9 @@ ; CHECK-NEXT: strd r1, r1, [sp] ; CHECK-NEXT: vst1.64 {d16, d17}, [r2]! ; CHECK-NEXT: str r1, [r2] +; CHECK-NEXT: add.w r2, r0, #15 +; CHECK-NEXT: vst1.32 {d16[0]}, [r2] ; CHECK-NEXT: str r1, [sp, #20] -; CHECK-NEXT: movs r1, #0 -; CHECK-NEXT: str.w r1, [sp, #15] ; CHECK-NEXT: bl callee ; CHECK-NEXT: add sp, #24 ; CHECK-NEXT: pop {r7, pc} diff --git a/llvm/test/CodeGen/PowerPC/p10-fi-elim.ll b/llvm/test/CodeGen/PowerPC/p10-fi-elim.ll --- a/llvm/test/CodeGen/PowerPC/p10-fi-elim.ll +++ b/llvm/test/CodeGen/PowerPC/p10-fi-elim.ll @@ -26,34 +26,32 @@ ; CHECK-NEXT: stdu r1, -80(r1) ; CHECK-NEXT: .cfi_def_cfa_offset 80 ; CHECK-NEXT: .cfi_offset lr, 16 -; CHECK-NEXT: lxv v2, 0(r3) ; CHECK-NEXT: mr r9, r6 ; CHECK-NEXT: mr r6, r5 -; CHECK-NEXT: li r0, 4 -; CHECK-NEXT: li r11, 3 -; CHECK-NEXT: std r0, 0(r3) -; CHECK-NEXT: stb r11, 0(0) -; CHECK-NEXT: li r12, -127 -; CHECK-NEXT: stb r12, 0(r3) -; CHECK-NEXT: li r2, 1 -; CHECK-NEXT: stb r11, 0(r3) -; CHECK-NEXT: stb r12, 0(r3) -; CHECK-NEXT: stw r2, 0(r3) -; CHECK-NEXT: mfvsrd r5, v2 -; CHECK-NEXT: vaddudm v3, v2, v2 -; CHECK-NEXT: pstxv v2, 64(r1), 0 -; CHECK-NEXT: neg r5, r5 -; CHECK-NEXT: mfvsrd r10, v3 -; CHECK-NEXT: std r5, 0(r3) +; CHECK-NEXT: li r5, 3 +; CHECK-NEXT: li r10, -127 +; CHECK-NEXT: lxv v2, 0(r3) +; CHECK-NEXT: stb r5, 0(0) +; CHECK-NEXT: stb r10, 0(r3) +; CHECK-NEXT: stb r5, 0(r3) ; CHECK-NEXT: lbz r5, 2(r7) ; CHECK-NEXT: mr r7, r9 -; CHECK-NEXT: neg r10, r10 -; CHECK-NEXT: std r2, 0(r3) -; CHECK-NEXT: std r0, 0(r3) -; CHECK-NEXT: std r10, 0(r3) +; CHECK-NEXT: li r12, 1 +; CHECK-NEXT: stb r10, 0(r3) +; CHECK-NEXT: stw r12, 0(r3) +; CHECK-NEXT: li r11, 4 +; CHECK-NEXT: std r11, 0(r3) +; CHECK-NEXT: vaddudm v4, v2, v2 +; CHECK-NEXT: vnegd v3, v2 +; CHECK-NEXT: pstxv v2, 64(r1), 0 ; CHECK-NEXT: rlwinm r5, r5, 0, 27, 27 +; CHECK-NEXT: vnegd v2, v4 +; CHECK-NEXT: stxsd v3, 0(r3) +; CHECK-NEXT: std r12, 0(r3) +; CHECK-NEXT: std r11, 0(r3) ; CHECK-NEXT: stb r5, 0(0) ; CHECK-NEXT: lbz r5, 2(r8) +; CHECK-NEXT: stxsd v2, 0(r3) ; CHECK-NEXT: rlwinm r5, r5, 0, 27, 27 ; CHECK-NEXT: stb r5, 0(r3) ; CHECK-NEXT: li r5, 2