diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -899,8 +899,7 @@ ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR}); - if (Subtarget->supportsAddressTopByteIgnored()) - setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::LOAD); setTargetDAGCombine(ISD::MSTORE); @@ -18080,6 +18079,87 @@ return SDValue(); } +// Perform TBI simplification if supported by the target and try to break up nontemporal loads larger than 256-bits loads for odd types so LDNPQ 256-bit load instructions can be selected. +static SDValue performLOADCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + if (Subtarget->supportsAddressTopByteIgnored()) + performTBISimplification(N->getOperand(1), DCI, DAG); + + LoadSDNode *LD = cast(N); + EVT MemVT = LD->getMemoryVT(); + if (LD->isVolatile() || !LD->isNonTemporal() || !Subtarget->isLittleEndian()) + return SDValue(N, 0); + + if (MemVT.isScalableVector() || MemVT.getSizeInBits() <= 256 || + MemVT.getSizeInBits() % 256 == 0 || + 256 % MemVT.getScalarSizeInBits() != 0) + return SDValue(N, 0); + + SDLoc DL(LD); + SDValue Chain = LD->getChain(); + SDValue BasePtr = LD->getBasePtr(); + SDNodeFlags Flags = LD->getFlags(); + SmallVector LoadOps; + SmallVector LoadOpsChain; + // Replace any non temporal load over 256-bit with a series of 256 bit loads + // and a scalar/vector load less than 256. This way we can utilize 256-bit + // loads and reduce the amount of load instructions generated. + MVT NewVT = + MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(), + 256 / MemVT.getVectorElementType().getSizeInBits()); + unsigned Num256Loads = MemVT.getSizeInBits() / 256; + // Create all 256-bit loads starting from offset 0 and up to Num256Loads-1*32. + for (unsigned I = 0; I < Num256Loads; I++) { + unsigned PtrOffset = I * 32; + SDValue NewPtr = DAG.getMemBasePlusOffset( + BasePtr, TypeSize::Fixed(PtrOffset), DL, Flags); + Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset); + SDValue NewLoad = DAG.getLoad( + NewVT, DL, Chain, NewPtr, LD->getPointerInfo().getWithOffset(PtrOffset), + NewAlign, LD->getMemOperand()->getFlags(), LD->getAAInfo()); + LoadOps.push_back(NewLoad); + LoadOpsChain.push_back(SDValue(cast(NewLoad), 1)); + } + + // Process remaining bits of the load operation. + // This is done by creating an UNDEF vector to match the size of the + // 256-bit loads and inserting the remaining load to it. We extract the + // original load type at the end using EXTRACT_SUBVECTOR instruction. + unsigned BitsRemaining = MemVT.getSizeInBits() % 256; + unsigned PtrOffset = (MemVT.getSizeInBits() - BitsRemaining) / 8; + MVT RemainingVT = MVT::getVectorVT( + MemVT.getVectorElementType().getSimpleVT(), + BitsRemaining / MemVT.getVectorElementType().getSizeInBits()); + SDValue NewPtr = + DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(PtrOffset), DL, Flags); + Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset); + SDValue RemainingLoad = + DAG.getLoad(RemainingVT, DL, Chain, NewPtr, + LD->getPointerInfo().getWithOffset(PtrOffset), NewAlign, + LD->getMemOperand()->getFlags(), LD->getAAInfo()); + SDValue UndefVector = DAG.getUNDEF(NewVT); + SDValue InsertIdx = DAG.getVectorIdxConstant(0, DL); + SDValue ExtendedReminingLoad = + DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, + {UndefVector, RemainingLoad, InsertIdx}); + LoadOps.push_back(ExtendedReminingLoad); + LoadOpsChain.push_back(SDValue(cast(RemainingLoad), 1)); + EVT ConcatVT = + EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), + LoadOps.size() * NewVT.getVectorNumElements()); + SDValue ConcatVectors = + DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, LoadOps); + // Extract the original vector type size. + SDValue ExtractSubVector = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT, + {ConcatVectors, DAG.getVectorIdxConstant(0, DL)}); + SDValue TokenFactor = + DAG.getNode(ISD::TokenFactor, DL, MVT::Other, LoadOpsChain); + return DAG.getMergeValues({ExtractSubVector, TokenFactor}, DL); +} + static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -20129,9 +20209,7 @@ case ISD::SETCC: return performSETCCCombine(N, DCI, DAG); case ISD::LOAD: - if (performTBISimplification(N->getOperand(1), DCI, DAG)) - return SDValue(N, 0); - break; + return performLOADCombine(N, DCI, DAG, Subtarget); case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); case ISD::MSTORE: diff --git a/llvm/test/CodeGen/AArch64/nontemporal-load.ll b/llvm/test/CodeGen/AArch64/nontemporal-load.ll --- a/llvm/test/CodeGen/AArch64/nontemporal-load.ll +++ b/llvm/test/CodeGen/AArch64/nontemporal-load.ll @@ -320,12 +320,12 @@ define <17 x float> @test_ldnp_v17f32(<17 x float>* %A) { ; CHECK-LABEL: test_ldnp_v17f32: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q1, q2, [x0, #32] -; CHECK-NEXT: ldp q3, q4, [x0] -; CHECK-NEXT: ldr s0, [x0, #64] -; CHECK-NEXT: stp q3, q4, [x8] -; CHECK-NEXT: stp q1, q2, [x8, #32] -; CHECK-NEXT: str s0, [x8, #64] +; CHECK-NEXT: ldnp q0, q1, [x0, #32] +; CHECK-NEXT: ldnp q2, q3, [x0] +; CHECK-NEXT: ldr s4, [x0, #64] +; CHECK-NEXT: stp q0, q1, [x8, #32] +; CHECK-NEXT: stp q2, q3, [x8] +; CHECK-NEXT: str s4, [x8, #64] ; CHECK-NEXT: ret ; ; CHECK-BE-LABEL: test_ldnp_v17f32: @@ -354,24 +354,24 @@ define <33 x double> @test_ldnp_v33f64(<33 x double>* %A) { ; CHECK-LABEL: test_ldnp_v33f64: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q0, q1, [x0] -; CHECK-NEXT: ldp q2, q3, [x0, #32] -; CHECK-NEXT: ldp q4, q5, [x0, #64] -; CHECK-NEXT: ldp q6, q7, [x0, #96] -; CHECK-NEXT: ldp q16, q17, [x0, #128] -; CHECK-NEXT: ldp q18, q19, [x0, #160] -; CHECK-NEXT: ldp q21, q22, [x0, #224] -; CHECK-NEXT: ldp q23, q24, [x0, #192] -; CHECK-NEXT: ldr d20, [x0, #256] +; CHECK-NEXT: ldnp q0, q1, [x0] +; CHECK-NEXT: ldnp q2, q3, [x0, #32] +; CHECK-NEXT: ldnp q4, q5, [x0, #64] +; CHECK-NEXT: ldnp q6, q7, [x0, #96] +; CHECK-NEXT: ldnp q16, q17, [x0, #128] +; CHECK-NEXT: ldnp q18, q19, [x0, #224] +; CHECK-NEXT: ldnp q20, q21, [x0, #192] +; CHECK-NEXT: ldnp q22, q23, [x0, #160] +; CHECK-NEXT: ldr d24, [x0, #256] ; CHECK-NEXT: stp q0, q1, [x8] ; CHECK-NEXT: stp q2, q3, [x8, #32] ; CHECK-NEXT: stp q4, q5, [x8, #64] -; CHECK-NEXT: str d20, [x8, #256] ; CHECK-NEXT: stp q6, q7, [x8, #96] ; CHECK-NEXT: stp q16, q17, [x8, #128] -; CHECK-NEXT: stp q18, q19, [x8, #160] -; CHECK-NEXT: stp q23, q24, [x8, #192] -; CHECK-NEXT: stp q21, q22, [x8, #224] +; CHECK-NEXT: stp q22, q23, [x8, #160] +; CHECK-NEXT: stp q20, q21, [x8, #192] +; CHECK-NEXT: stp q18, q19, [x8, #224] +; CHECK-NEXT: str d24, [x8, #256] ; CHECK-NEXT: ret ; ; CHECK-BE-LABEL: test_ldnp_v33f64: @@ -448,10 +448,11 @@ define <33 x i8> @test_ldnp_v33i8(<33 x i8>* %A) { ; CHECK-LABEL: test_ldnp_v33i8: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q1, q0, [x0] -; CHECK-NEXT: ldrb w9, [x0, #32] -; CHECK-NEXT: stp q1, q0, [x8] -; CHECK-NEXT: strb w9, [x8, #32] +; CHECK-NEXT: ldnp q0, q1, [x0] +; CHECK-NEXT: add x9, x8, #32 +; CHECK-NEXT: ldr b2, [x0, #32] +; CHECK-NEXT: stp q0, q1, [x8] +; CHECK-NEXT: st1.b { v2 }[0], [x9] ; CHECK-NEXT: ret ; ; CHECK-BE-LABEL: test_ldnp_v33i8: @@ -556,15 +557,14 @@ define <5 x double> @test_ldnp_v5f64(<5 x double>* %A) { ; CHECK-LABEL: test_ldnp_v5f64: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q0, q2, [x0] +; CHECK-NEXT: ldnp q0, q2, [x0] +; CHECK-NEXT: ldr d4, [x0, #32] ; CHECK-NEXT: ext.16b v1, v0, v0, #8 ; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1 ; CHECK-NEXT: ext.16b v3, v2, v2, #8 -; CHECK-NEXT: ldr d4, [x0, #32] ; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2 ; CHECK-NEXT: ; kill: def $d3 killed $d3 killed $q3 -; CHECK-NEXT: ; kill: def $d4 killed $d4 killed $q4 ; CHECK-NEXT: ret ; ; CHECK-BE-LABEL: test_ldnp_v5f64: