diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -50,6 +50,7 @@ #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/Attributes.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugLoc.h" @@ -69,6 +70,7 @@ #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" +#include "llvm/IR/VectorBuilder.h" #include "llvm/MC/MCRegisterInfo.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CodeGen.h" @@ -80,6 +82,7 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/TypeSize.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" @@ -898,8 +901,7 @@ ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR}); - if (Subtarget->supportsAddressTopByteIgnored()) - setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::LOAD); setTargetDAGCombine(ISD::MSTORE); @@ -17808,6 +17810,73 @@ return SDValue(); } +static SDValue performLOADCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + LoadSDNode *LD = cast(N); + SDLoc DL(LD); + EVT MemVT = LD->getMemoryVT(); + SDValue Chain = LD->getChain(); + SDValue Ptr = LD->getBasePtr(); + SDNodeFlags Flags = LD->getFlags(); + + SmallVector LoadOps; + if (MemVT.getSizeInBits() <= 256 || MemVT.getSizeInBits() % 256 == 0 || + 256 % MemVT.getScalarSizeInBits() != 0) + return SDValue(); + + unsigned PtrOffset = 256; + MVT NewVT = + MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(), + 256 / MemVT.getVectorElementType().getSizeInBits()); + for (; PtrOffset <= MemVT.getSizeInBits(); PtrOffset += 256) { + // Build 256-bit loads + SDValue NewPtr = + DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOffset), DL, Flags); + Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset); + SDValue NewLoad = DAG.getLoad( + NewVT, DL, Chain, NewPtr, LD->getPointerInfo().getWithOffset(PtrOffset), + NewAlign, LD->getMemOperand()->getFlags(), LD->getAAInfo()); + LoadOps.push_back(NewLoad); + } + // Process remaining bits in the load operation + // This is done by creating a null value vector to match the size of the + // 256-bit loads and inserting the remaining load to it We extract the + // original load type at the end + unsigned BitsRemaining = MemVT.getSizeInBits() % 256; + PtrOffset = MemVT.getScalarSizeInBits() - BitsRemaining; + MVT RemainingVT = MVT::getVectorVT( + MemVT.getVectorElementType().getSimpleVT(), + BitsRemaining / MemVT.getVectorElementType().getSizeInBits()); + SDValue NewPtr = + DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOffset), DL, Flags); + Align NewAlign = commonAlignment(LD->getAlign(), MemVT.getSizeInBits()); + SDValue RemainingLoad = + DAG.getLoad(RemainingVT, DL, Chain, NewPtr, + LD->getPointerInfo().getWithOffset(PtrOffset), NewAlign, + LD->getMemOperand()->getFlags(), LD->getAAInfo()); + SmallVector MaskVector; + for (unsigned i = 0; i < NewVT.getVectorNumElements(); i++) { + MaskVector.push_back(DAG.getUNDEF(NewVT.getScalarType())); + } + ArrayRef SV{MaskVector}; + SDValue NullVector = DAG.getBuildVector(NewVT, DL, SV); + SDValue InsertIdx = DAG.getVectorIdxConstant(0, DL); + SDValue ExtendedReminingLoad = DAG.getNode( + ISD::INSERT_SUBVECTOR, DL, NewVT, {NullVector, RemainingLoad, InsertIdx}); + LoadOps.push_back(ExtendedReminingLoad); + + EVT ConcatVectorType = + EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), + LoadOps.size() * NewVT.getVectorNumElements()); + SDValue ConcatVectors = + DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVectorType, LoadOps); + SDValue ExtractSubVector = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT, + {ConcatVectors, DAG.getVectorIdxConstant(0, DL)}); + return DAG.getMergeValues({ExtractSubVector, Chain}, DL); +} static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -19857,9 +19926,7 @@ case ISD::SETCC: return performSETCCCombine(N, DCI, DAG); case ISD::LOAD: - if (performTBISimplification(N->getOperand(1), DCI, DAG)) - return SDValue(N, 0); - break; + return performLOADCombine(N, DCI, DAG, Subtarget); case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); case ISD::MSTORE: diff --git a/llvm/test/CodeGen/AArch64/nontemporal-load.ll b/llvm/test/CodeGen/AArch64/nontemporal-load.ll --- a/llvm/test/CodeGen/AArch64/nontemporal-load.ll +++ b/llvm/test/CodeGen/AArch64/nontemporal-load.ll @@ -205,12 +205,12 @@ define <17 x float> @test_ldnp_v17f32(<17 x float>* %A) { ; CHECK-LABEL: test_ldnp_v17f32: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q1, q2, [x0, #32] -; CHECK-NEXT: ldp q3, q4, [x0] -; CHECK-NEXT: ldr s0, [x0, #64] -; CHECK-NEXT: stp q3, q4, [x8] -; CHECK-NEXT: stp q1, q2, [x8, #32] -; CHECK-NEXT: str s0, [x8, #64] +; CHECK-NEXT: ldnp q0, q1, [x0, #512] +; CHECK-NEXT: ldnp q2, q3, [x0, #256] +; CHECK-NEXT: ldr s4, [x0] +; CHECK-NEXT: stp q0, q1, [x8, #32] +; CHECK-NEXT: stp q2, q3, [x8] +; CHECK-NEXT: str s4, [x8, #64] ; CHECK-NEXT: ret %lv = load <17 x float>, <17 x float>* %A, align 8, !nontemporal !0 ret <17 x float> %lv @@ -219,24 +219,29 @@ define <33 x double> @test_ldnp_v33f64(<33 x double>* %A) { ; CHECK-LABEL: test_ldnp_v33f64: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q0, q1, [x0] -; CHECK-NEXT: ldp q2, q3, [x0, #32] -; CHECK-NEXT: ldp q4, q5, [x0, #64] -; CHECK-NEXT: ldp q6, q7, [x0, #96] -; CHECK-NEXT: ldp q16, q17, [x0, #128] -; CHECK-NEXT: ldp q18, q19, [x0, #160] -; CHECK-NEXT: ldp q21, q22, [x0, #224] -; CHECK-NEXT: ldp q23, q24, [x0, #192] -; CHECK-NEXT: ldr d20, [x0, #256] -; CHECK-NEXT: stp q0, q1, [x8] -; CHECK-NEXT: stp q2, q3, [x8, #32] -; CHECK-NEXT: stp q4, q5, [x8, #64] -; CHECK-NEXT: str d20, [x8, #256] -; CHECK-NEXT: stp q6, q7, [x8, #96] -; CHECK-NEXT: stp q16, q17, [x8, #128] -; CHECK-NEXT: stp q18, q19, [x8, #160] -; CHECK-NEXT: stp q23, q24, [x8, #192] -; CHECK-NEXT: stp q21, q22, [x8, #224] +; CHECK-NEXT: add x9, x0, #1024 +; CHECK-NEXT: add x10, x0, #1280 +; CHECK-NEXT: ldnp q4, q5, [x0, #256] +; CHECK-NEXT: ldnp q0, q1, [x9] +; CHECK-NEXT: add x9, x0, #1792 +; CHECK-NEXT: ldnp q2, q3, [x10] +; CHECK-NEXT: ldnp q6, q7, [x9] +; CHECK-NEXT: add x9, x0, #2048 +; CHECK-NEXT: ldnp q16, q17, [x0, #512] +; CHECK-NEXT: ldnp q18, q19, [x9] +; CHECK-NEXT: add x9, x0, #1536 +; CHECK-NEXT: ldnp q20, q21, [x0, #768] +; CHECK-NEXT: ldnp q22, q23, [x9] +; CHECK-NEXT: ldr d24, [x0] +; CHECK-NEXT: stp q4, q5, [x8] +; CHECK-NEXT: stp q16, q17, [x8, #32] +; CHECK-NEXT: stp q20, q21, [x8, #64] +; CHECK-NEXT: stp q0, q1, [x8, #96] +; CHECK-NEXT: stp q2, q3, [x8, #128] +; CHECK-NEXT: stp q22, q23, [x8, #160] +; CHECK-NEXT: stp q6, q7, [x8, #192] +; CHECK-NEXT: stp q18, q19, [x8, #224] +; CHECK-NEXT: str d24, [x8, #256] ; CHECK-NEXT: ret %lv = load <33 x double>, <33 x double>* %A, align 8, !nontemporal !0 ret <33 x double> %lv @@ -245,10 +250,11 @@ define <33 x i8> @test_ldnp_v33i8(<33 x i8>* %A) { ; CHECK-LABEL: test_ldnp_v33i8: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q1, q0, [x0] -; CHECK-NEXT: ldrb w9, [x0, #32] -; CHECK-NEXT: stp q1, q0, [x8] -; CHECK-NEXT: strb w9, [x8, #32] +; CHECK-NEXT: ldnp q0, q1, [x0, #256] +; CHECK-NEXT: add x9, x8, #32 +; CHECK-NEXT: ldr b2, [x0] +; CHECK-NEXT: stp q0, q1, [x8] +; CHECK-NEXT: st1.b { v2 }[0], [x9] ; CHECK-NEXT: ret %lv = load<33 x i8>, <33 x i8>* %A, align 8, !nontemporal !0 ret <33 x i8> %lv @@ -295,15 +301,14 @@ define <5 x double> @test_ldnp_v5f64(<5 x double>* %A) { ; CHECK-LABEL: test_ldnp_v5f64: ; CHECK: ; %bb.0: -; CHECK-NEXT: ldp q0, q2, [x0] +; CHECK-NEXT: ldnp q0, q2, [x0, #256] +; CHECK-NEXT: ldr d4, [x0] ; CHECK-NEXT: ext.16b v1, v0, v0, #8 ; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ; kill: def $d1 killed $d1 killed $q1 ; CHECK-NEXT: ext.16b v3, v2, v2, #8 -; CHECK-NEXT: ldr d4, [x0, #32] ; CHECK-NEXT: ; kill: def $d2 killed $d2 killed $q2 ; CHECK-NEXT: ; kill: def $d3 killed $d3 killed $q3 -; CHECK-NEXT: ; kill: def $d4 killed $d4 killed $q4 ; CHECK-NEXT: ret %lv = load<5 x double>, <5 x double>* %A, align 8, !nontemporal !0 ret <5 x double> %lv