diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -898,8 +898,7 @@ ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR}); - if (Subtarget->supportsAddressTopByteIgnored()) - setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::LOAD); setTargetDAGCombine(ISD::MSTORE); @@ -17808,6 +17807,79 @@ return SDValue(); } +// Break up nontermporal larger than 256-bits loads so LDNPQ 256-bit load +// instruction can be selected +static SDValue performLOADCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + LoadSDNode *LD = cast(N); + EVT MemVT = LD->getMemoryVT(); + if (LD->isVolatile() || LD->isNonTemporal()) + return SDValue(N, 0); + + if (MemVT.isScalableVector() || MemVT.getSizeInBits() <= 256 || + MemVT.getSizeInBits() % 256 == 0 || + 256 % MemVT.getScalarSizeInBits() != 0) + return SDValue(N, 0); + + SDLoc DL(LD); + SDValue Chain = LD->getChain(); + SDValue BasePtr = LD->getBasePtr(); + SDNodeFlags Flags = LD->getFlags(); + SmallVector LoadOps; + // Replace any non temporal load over 256-bit with a series of 256 bit loads + // and a scalar/vector load less than 256. This way we can utilise 256-bit + // loads and reduce the amount of load instructions generated. + MVT NewVT = + MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(), + 256 / MemVT.getVectorElementType().getSizeInBits()); + unsigned Num256Loads = MemVT.getSizeInBits() / 256; + // Create all 256-bit loads starting from offset 0 and up to Num256Loads-1*32. + for (unsigned I = 0; I < Num256Loads; I++) { + unsigned PtrOffset = I * 32; + SDValue NewPtr = DAG.getMemBasePlusOffset( + BasePtr, TypeSize::Fixed(PtrOffset), DL, Flags); + Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset); + SDValue NewLoad = DAG.getLoad( + NewVT, DL, Chain, NewPtr, LD->getPointerInfo().getWithOffset(PtrOffset), + NewAlign, LD->getMemOperand()->getFlags(), LD->getAAInfo()); + LoadOps.push_back(NewLoad); + } + // Process remaining bits of the load operation. + // This is done by creating a null value vector to match the size of the + // 256-bit loads and inserting the remaining load to it. We extract the + // original load type at the end using EXTRACT_SUBVECTOR instruction. + unsigned BitsRemaining = MemVT.getSizeInBits() % 256; + unsigned PtrOffset = (MemVT.getSizeInBits() - BitsRemaining) / 8; + MVT RemainingVT = MVT::getVectorVT( + MemVT.getVectorElementType().getSimpleVT(), + BitsRemaining / MemVT.getVectorElementType().getSizeInBits()); + SDValue NewPtr = + DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(PtrOffset), DL, Flags); + Align NewAlign = commonAlignment(LD->getAlign(), MemVT.getSizeInBits()); + SDValue RemainingLoad = + DAG.getLoad(RemainingVT, DL, Chain, NewPtr, + LD->getPointerInfo().getWithOffset(PtrOffset), NewAlign, + LD->getMemOperand()->getFlags(), LD->getAAInfo()); + SDValue UndefVector = DAG.getUNDEF(NewVT); + SDValue InsertIdx = DAG.getVectorIdxConstant(0, DL); + SDValue ExtendedReminingLoad = + DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, + {UndefVector, RemainingLoad, InsertIdx}); + LoadOps.push_back(ExtendedReminingLoad); + + EVT ConcatVectorType = + EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), + LoadOps.size() * NewVT.getVectorNumElements()); + SDValue ConcatVectors = + DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVectorType, LoadOps); + // Extract the original vector type size. + SDValue ExtractSubVector = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT, + {ConcatVectors, DAG.getVectorIdxConstant(0, DL)}); + return DAG.getMergeValues({ExtractSubVector, Chain}, DL); +} static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -19857,9 +19929,10 @@ case ISD::SETCC: return performSETCCCombine(N, DCI, DAG); case ISD::LOAD: - if (performTBISimplification(N->getOperand(1), DCI, DAG)) + if (Subtarget->supportsAddressTopByteIgnored() && + performTBISimplification(N->getOperand(1), DCI, DAG)) return SDValue(N, 0); - break; + return performLOADCombine(N, DCI, DAG, Subtarget); case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); case ISD::MSTORE: diff --git a/llvm/test/CodeGen/AArch64/nontemporal-load.ll b/llvm/test/CodeGen/AArch64/nontemporal-load.ll --- a/llvm/test/CodeGen/AArch64/nontemporal-load.ll +++ b/llvm/test/CodeGen/AArch64/nontemporal-load.ll @@ -216,6 +216,23 @@ ret <17 x float> %lv } +define <17 x float> @test_ldnp_v17f32_volatile(<17 x float>* %A) { +; CHECK-LABEL: test_ldnp_v17f32_volatile: +; CHECK: ; %bb.0: +; CHECK-NEXT: ldr q0, [x0] +; CHECK-NEXT: ldr q1, [x0, #16] +; CHECK-NEXT: ldr q2, [x0, #32] +; CHECK-NEXT: ldr q3, [x0, #48] +; CHECK-NEXT: ldr s4, [x0, #64] +; CHECK-NEXT: str q0, [x8] +; CHECK-NEXT: stp q1, q2, [x8, #16] +; CHECK-NEXT: str q3, [x8, #48] +; CHECK-NEXT: str s4, [x8, #64] +; CHECK-NEXT: ret + %lv = load volatile <17 x float>, <17 x float>* %A, align 8, !nontemporal !0 + ret <17 x float> %lv +} + define <33 x double> @test_ldnp_v33f64(<33 x double>* %A) { ; CHECK-LABEL: test_ldnp_v33f64: ; CHECK: ; %bb.0: