diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1159,9 +1159,9 @@ bool isVectorLoadExtDesirable(SDValue ExtVal) const override; bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; bool mayBeEmittedAsTailCall(const CallInst *CI) const override; - bool getIndexedAddressParts(SDNode *Op, SDValue &Base, SDValue &Offset, - ISD::MemIndexedMode &AM, bool &IsInc, - SelectionDAG &DAG) const; + bool getIndexedAddressParts(SDNode *N, SDNode *Op, SDValue &Base, + SDValue &Offset, ISD::MemIndexedMode &AM, + bool &IsInc, SelectionDAG &DAG) const; bool getPreIndexedAddressParts(SDNode *N, SDValue &Base, SDValue &Offset, ISD::MemIndexedMode &AM, SelectionDAG &DAG) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21530,14 +21530,38 @@ return CI->isTailCall(); } -bool AArch64TargetLowering::getIndexedAddressParts(SDNode *Op, SDValue &Base, - SDValue &Offset, - ISD::MemIndexedMode &AM, - bool &IsInc, - SelectionDAG &DAG) const { +bool AArch64TargetLowering::getIndexedAddressParts( + SDNode *N, SDNode *Op, SDValue &Base, SDValue &Offset, + ISD::MemIndexedMode &AM, bool &IsInc, SelectionDAG &DAG) const { if (Op->getOpcode() != ISD::ADD && Op->getOpcode() != ISD::SUB) return false; + // Non-null if there is exactly one user of the loaded value (ignoring chain). + SDNode *ValOnlyUser = nullptr; + for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end(); UI != UE; + ++UI) { + if (UI.getUse().getResNo() == 1) + continue; // Ignore chain. + if (ValOnlyUser == nullptr) + ValOnlyUser = *UI; + else { + ValOnlyUser = nullptr; // Multiple non-chain uses, bail out. + break; + } + } + + auto IsUndefOrZero = [](SDValue V) { + return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true); + }; + + // If the only user of the value is a scalable vector splat, it is + // preferable to do a replicating load (ld1r*). + if (ValOnlyUser && ValOnlyUser->getValueType(0).isScalableVector() && + (ValOnlyUser->getOpcode() == ISD::SPLAT_VECTOR || + (ValOnlyUser->getOpcode() == AArch64ISD::DUP_MERGE_PASSTHRU && + IsUndefOrZero(ValOnlyUser->getOperand(2))))) + return false; + Base = Op->getOperand(0); // All of the indexed addressing mode instructions take a signed // 9 bit immediate offset. @@ -21570,7 +21594,7 @@ return false; bool IsInc; - if (!getIndexedAddressParts(Ptr.getNode(), Base, Offset, AM, IsInc, DAG)) + if (!getIndexedAddressParts(N, Ptr.getNode(), Base, Offset, AM, IsInc, DAG)) return false; AM = IsInc ? ISD::PRE_INC : ISD::PRE_DEC; return true; @@ -21591,7 +21615,7 @@ return false; bool IsInc; - if (!getIndexedAddressParts(Op, Base, Offset, AM, IsInc, DAG)) + if (!getIndexedAddressParts(N, Op, Base, Offset, AM, IsInc, DAG)) return false; // Post-indexing updates the base, so it's not a valid transform // if that's not the same as the load's pointer. diff --git a/llvm/test/CodeGen/AArch64/sve-ld1r.ll b/llvm/test/CodeGen/AArch64/sve-ld1r.ll --- a/llvm/test/CodeGen/AArch64/sve-ld1r.ll +++ b/llvm/test/CodeGen/AArch64/sve-ld1r.ll @@ -1186,6 +1186,102 @@ ret %res } + +; Check that a load consumed by a scalable splat prefers a replicating load. +define i8* @avoid_preindex_load(i8* %src, * %out) { +; CHECK-LABEL: avoid_preindex_load: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: add x8, x0, #1 +; CHECK-NEXT: ld1rsb { z0.d }, p0/z, [x0, #1] +; CHECK-NEXT: mov x0, x8 +; CHECK-NEXT: st1d { z0.d }, p0, [x1] +; CHECK-NEXT: ret + %ptr = getelementptr inbounds i8, i8* %src, i64 1 + %tmp = load i8, i8* %ptr, align 4 + %ext = sext i8 %tmp to i64 + %ins = insertelement undef, i64 %ext, i32 0 + %dup = shufflevector %ins, undef, zeroinitializer + store %dup, * %out + ret i8* %ptr +} + +; Check that a load consumed by a scalable splat prefers a replicating +; load over a pre-indexed load. +define i8* @avoid_preindex_load_dup(i8* %src, %pg, * %out) { +; CHECK-LABEL: avoid_preindex_load_dup: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, #1 +; CHECK-NEXT: ld1rsb { z0.d }, p0/z, [x0, #1] +; CHECK-NEXT: mov x0, x8 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: st1d { z0.d }, p0, [x1] +; CHECK-NEXT: ret + %ptr = getelementptr inbounds i8, i8* %src, i64 1 + %tmp = load i8, i8* %ptr, align 4 + %ext = sext i8 %tmp to i64 + %dup = call @llvm.aarch64.sve.dup.nxv2i64( undef, %pg, i64 %ext) + store %dup, * %out + ret i8* %ptr +} + +; Same as avoid_preindex_load_dup, but with zero passthru. +define i8* @avoid_preindex_load_dup_passthru_zero(i8* %src, %pg, * %out) { +; CHECK-LABEL: avoid_preindex_load_dup_passthru_zero: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, #1 +; CHECK-NEXT: ld1rsb { z0.d }, p0/z, [x0, #1] +; CHECK-NEXT: mov x0, x8 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: st1d { z0.d }, p0, [x1] +; CHECK-NEXT: ret + %ptr = getelementptr inbounds i8, i8* %src, i64 1 + %tmp = load i8, i8* %ptr, align 4 + %ext = sext i8 %tmp to i64 + %dup = call @llvm.aarch64.sve.dup.nxv2i64( zeroinitializer, %pg, i64 %ext) + store %dup, * %out + ret i8* %ptr +} + +; If a dup has a non-undef passthru, stick with the pre-indexed load. +define i8* @preindex_load_dup_passthru( %passthru, i8* %src, %pg, * %out) { +; CHECK-LABEL: preindex_load_dup_passthru: +; CHECK: // %bb.0: +; CHECK-NEXT: ldrsb x8, [x0, #1]! +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: mov z0.d, p0/m, x8 +; CHECK-NEXT: st1d { z0.d }, p1, [x1] +; CHECK-NEXT: ret + %ptr = getelementptr inbounds i8, i8* %src, i64 1 + %tmp = load i8, i8* %ptr, align 4 + %ext = sext i8 %tmp to i64 + %dup = call @llvm.aarch64.sve.dup.nxv2i64( %passthru, %pg, i64 %ext) + store %dup, * %out + ret i8* %ptr +} + +; Show that a second user of the load prevents the replicating load +; check which would ordinarily inhibit indexed loads from firing. +define i8* @preidx8sext64_instead_of_ld1r(i8* %src, * %out, i64* %dst) { +; CHECK-LABEL: preidx8sext64_instead_of_ld1r: +; CHECK: // %bb.0: +; CHECK-NEXT: ldrsb x8, [x0, #1]! +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z0.d, x8 +; CHECK-NEXT: st1d { z0.d }, p0, [x1] +; CHECK-NEXT: str x8, [x2] +; CHECK-NEXT: ret + %ptr = getelementptr inbounds i8, i8* %src, i64 1 + %tmp = load i8, i8* %ptr, align 4 + %ext = sext i8 %tmp to i64 + %ins = insertelement undef, i64 %ext, i32 0 + %dup = shufflevector %ins, undef, zeroinitializer + store %dup, * %out + store i64 %ext, i64* %dst + ret i8* %ptr +} + + declare @llvm.aarch64.sve.dupq.lane.nxv16i8(, i64) declare @llvm.aarch64.sve.dupq.lane.nxv8i16(, i64) declare @llvm.aarch64.sve.dupq.lane.nxv4i32(, i64)