diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -898,8 +898,7 @@ ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR}); - if (Subtarget->supportsAddressTopByteIgnored()) - setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::LOAD); setTargetDAGCombine(ISD::MSTORE); @@ -17808,6 +17807,76 @@ return SDValue(); } +static SDValue performLOADCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + LoadSDNode *LD = cast(N); + if(!LD->isNonTemporal()) + return SDValue(N, 0); + EVT MemVT = LD->getMemoryVT(); + + if (MemVT.isScalableVector() || MemVT.getSizeInBits() <= 256 || MemVT.getSizeInBits() % 256 == 0 || + 256 % MemVT.getScalarSizeInBits() != 0 || !LD->isNonTemporal()) + return SDValue(N, 0); + + SDLoc DL(LD); + SDValue Chain = LD->getChain(); + SDValue Ptr = LD->getBasePtr(); + SDNodeFlags Flags = LD->getFlags(); + SmallVector LoadOps; + // Replace any non temporal load over 256-bit with a series of 256 bit loads + // and a scalar/vector load less than 256. This way we can utilise 256-bit loads and reduce the amount + // of load instructions generated. + MVT NewVT = + MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(), + 256 / MemVT.getVectorElementType().getSizeInBits()); + int Num256Loads = MemVT.getSizeInBits() / 256; + // Create all 256-bit loads starting from offset 0 and up to Num256Loads-1*32. + for (int I = 0; I < Num256Loads; I ++) { + unsigned PtrOffset = I * 32; + SDValue NewPtr = + DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOffset), DL, Flags); + Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset); + SDValue NewLoad = DAG.getLoad( + NewVT, DL, Chain, NewPtr, LD->getPointerInfo().getWithOffset(PtrOffset), + NewAlign, LD->getMemOperand()->getFlags(), LD->getAAInfo()); + LoadOps.push_back(NewLoad); + } + // Process remaining bits in the load operation + // This is done by creating a null value vector to match the size of the + // 256-bit loads and inserting the remaining load to it. We extract the + // original load type at the end using EXTRACT_SUBVECTOR instruction + unsigned BitsRemaining = MemVT.getSizeInBits() % 256; + unsigned PtrOffset = (MemVT.getSizeInBits() - BitsRemaining) / 8; + MVT RemainingVT = MVT::getVectorVT( + MemVT.getVectorElementType().getSimpleVT(), + BitsRemaining / MemVT.getVectorElementType().getSizeInBits()); + SDValue NewPtr = + DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOffset), DL, Flags); + Align NewAlign = commonAlignment(LD->getAlign(), MemVT.getSizeInBits()); + SDValue RemainingLoad = + DAG.getLoad(RemainingVT, DL, Chain, NewPtr, + LD->getPointerInfo().getWithOffset(PtrOffset), NewAlign, + LD->getMemOperand()->getFlags(), LD->getAAInfo()); + SmallVector MaskVector; + SDValue NullVector = DAG.getUNDEF(NewVT); + SDValue InsertIdx = DAG.getVectorIdxConstant(0, DL); + SDValue ExtendedReminingLoad = DAG.getNode( + ISD::INSERT_SUBVECTOR, DL, NewVT, {NullVector, RemainingLoad, InsertIdx}); + LoadOps.push_back(ExtendedReminingLoad); + + EVT ConcatVectorType = + EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), + LoadOps.size() * NewVT.getVectorNumElements()); + SDValue ConcatVectors = + DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVectorType, LoadOps); + // Extract the original vector type size. + SDValue ExtractSubVector = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT, + {ConcatVectors, DAG.getVectorIdxConstant(0, DL)}); + return DAG.getMergeValues({ExtractSubVector, Chain}, DL); +} static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -19857,9 +19926,9 @@ case ISD::SETCC: return performSETCCCombine(N, DCI, DAG); case ISD::LOAD: - if (performTBISimplification(N->getOperand(1), DCI, DAG)) + if (Subtarget->supportsAddressTopByteIgnored()&& performTBISimplification(N->getOperand(1), DCI, DAG)) return SDValue(N, 0); - break; + return performLOADCombine(N, DCI, DAG, Subtarget); case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); case ISD::MSTORE: diff --git a/llvm/test/CodeGen/AArch64/nontemporal-load.ll b/llvm/test/CodeGen/AArch64/nontemporal-load.ll --- a/llvm/test/CodeGen/AArch64/nontemporal-load.ll +++ b/llvm/test/CodeGen/AArch64/nontemporal-load.ll @@ -205,12 +205,12 @@ define <17 x float> @test_ldnp_v17f32(<17 x float>* %A) { ; CHECK-LABEL: test_ldnp_v17f32: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q1, q2, [x0, #32] -; CHECK-NEXT: ldp q3, q4, [x0] -; CHECK-NEXT: ldr s0, [x0, #64] -; CHECK-NEXT: stp q3, q4, [x8] -; CHECK-NEXT: stp q1, q2, [x8, #32] -; CHECK-NEXT: str s0, [x8, #64] +; CHECK-NEXT: ldnp q0, q1, [x0, #32] +; CHECK-NEXT: ldnp q2, q3, [x0] +; CHECK-NEXT: ldr s4, [x0, #64] +; CHECK-NEXT: stp q0, q1, [x8, #32] +; CHECK-NEXT: stp q2, q3, [x8] +; CHECK-NEXT: str s4, [x8, #64] ; CHECK-NEXT: ret %lv = load <17 x float>, <17 x float>* %A, align 8, !nontemporal !0 ret <17 x float> %lv @@ -219,24 +219,24 @@ define <33 x double> @test_ldnp_v33f64(<33 x double>* %A) { ; CHECK-LABEL: test_ldnp_v33f64: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q0, q1, [x0] -; CHECK-NEXT: ldp q2, q3, [x0, #32] -; CHECK-NEXT: ldp q4, q5, [x0, #64] -; CHECK-NEXT: ldp q6, q7, [x0, #96] -; CHECK-NEXT: ldp q16, q17, [x0, #128] -; CHECK-NEXT: ldp q18, q19, [x0, #160] -; CHECK-NEXT: ldp q21, q22, [x0, #224] -; CHECK-NEXT: ldp q23, q24, [x0, #192] -; CHECK-NEXT: ldr d20, [x0, #256] +; CHECK-NEXT: ldnp q0, q1, [x0] +; CHECK-NEXT: ldnp q2, q3, [x0, #32] +; CHECK-NEXT: ldnp q4, q5, [x0, #64] +; CHECK-NEXT: ldnp q6, q7, [x0, #96] +; CHECK-NEXT: ldnp q16, q17, [x0, #128] +; CHECK-NEXT: ldnp q18, q19, [x0, #224] +; CHECK-NEXT: ldnp q20, q21, [x0, #192] +; CHECK-NEXT: ldnp q22, q23, [x0, #160] +; CHECK-NEXT: ldr d24, [x0, #256] ; CHECK-NEXT: stp q0, q1, [x8] ; CHECK-NEXT: stp q2, q3, [x8, #32] ; CHECK-NEXT: stp q4, q5, [x8, #64] -; CHECK-NEXT: str d20, [x8, #256] ; CHECK-NEXT: stp q6, q7, [x8, #96] ; CHECK-NEXT: stp q16, q17, [x8, #128] -; CHECK-NEXT: stp q18, q19, [x8, #160] -; CHECK-NEXT: stp q23, q24, [x8, #192] -; CHECK-NEXT: stp q21, q22, [x8, #224] +; CHECK-NEXT: stp q22, q23, [x8, #160] +; CHECK-NEXT: stp q20, q21, [x8, #192] +; CHECK-NEXT: stp q18, q19, [x8, #224] +; CHECK-NEXT: str d24, [x8, #256] ; CHECK-NEXT: ret %lv = load <33 x double>, <33 x double>* %A, align 8, !nontemporal !0 ret <33 x double> %lv @@ -245,10 +245,11 @@ define <33 x i8> @test_ldnp_v33i8(<33 x i8>* %A) { ; CHECK-LABEL: test_ldnp_v33i8: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q1, q0, [x0] -; CHECK-NEXT: ldrb w9, [x0, #32] -; CHECK-NEXT: stp q1, q0, [x8] -; CHECK-NEXT: strb w9, [x8, #32] +; CHECK-NEXT: ldnp q0, q1, [x0] +; CHECK-NEXT: add x9, x8, #32 +; CHECK-NEXT: ldr b2, [x0, #32] +; CHECK-NEXT: stp q0, q1, [x8] +; CHECK-NEXT: st1.b { v2 }[0], [x9] ; CHECK-NEXT: ret %lv = load<33 x i8>, <33 x i8>* %A, align 8, !nontemporal !0 ret <33 x i8> %lv @@ -295,15 +296,14 @@ define <5 x double> @test_ldnp_v5f64(<5 x double>* %A) { ; CHECK-LABEL: test_ldnp_v5f64: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q0, q2, [x0] +; CHECK-NEXT: ldnp q0, q2, [x0] +; CHECK-NEXT: ldr d4, [x0, #32] ; CHECK-NEXT: ext.16b v1, v0, v0, #8 ; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1 ; CHECK-NEXT: ext.16b v3, v2, v2, #8 -; CHECK-NEXT: ldr d4, [x0, #32] ; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2 ; CHECK-NEXT: ; kill: def $d3 killed $d3 killed $q3 -; CHECK-NEXT: ; kill: def $d4 killed $d4 killed $q4 ; CHECK-NEXT: ret %lv = load<5 x double>, <5 x double>* %A, align 8, !nontemporal !0 ret <5 x double> %lv