Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12845,11 +12845,10 @@ return SDValue(N, 0); // Return N so it doesn't get rechecked! } -static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, - const TargetLowering &TLI, EVT VT, - SDNode *N, SDValue N0, - ISD::LoadExtType ExtLoadType, - ISD::NodeType ExtOpc) { +static SDValue +tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT, + bool LegalOperations, SDNode *N, SDValue N0, + ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) { if (!N0.hasOneUse()) return SDValue(); @@ -12857,7 +12856,8 @@ if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD) return SDValue(); - if (!TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0))) + if ((LegalOperations || !cast(N0)->isSimple()) && + !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0))) return SDValue(); if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0))) @@ -13130,8 +13130,8 @@ return foldedExt; if (SDValue foldedExt = - tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD, - ISD::SIGN_EXTEND)) + tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0, + ISD::SEXTLOAD, ISD::SIGN_EXTEND)) return foldedExt; // fold (sext (load x)) to multiple smaller sextloads. @@ -13409,8 +13409,8 @@ return foldedExt; if (SDValue foldedExt = - tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD, - ISD::ZERO_EXTEND)) + tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0, + ISD::ZEXTLOAD, ISD::ZERO_EXTEND)) return foldedExt; // fold (zext (load x)) to multiple smaller zextloads. Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -5342,8 +5342,25 @@ } bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { - return ExtVal.getValueType().isScalableVector() || - Subtarget->useSVEForFixedLengthVectors(); + // It may be worth creating extending masked loads if there are multiple + // masked loads using the same predicate. That way we'll end up creating + // extending masked loads that may then get split by the legaliser. This + // results in just one set of predicate unpacks at the start, instead of + // multiple sets of vector unpacks after each load. + EVT ExtVT = ExtVal.getValueType(); + + if (auto *Ld = dyn_cast(ExtVal->getOperand(0))) { + if (!isLoadExtLegalOrCustom(ISD::ZEXTLOAD, ExtVT, Ld->getValueType(0))) { + unsigned NumExtMaskedLoads = 0; + for (auto *U : Ld->getMask()->uses()) + NumExtMaskedLoads += isa(U) ? 1 : 0; + + if (NumExtMaskedLoads <= 1) + return false; + } + } + + return ExtVT.isScalableVector() || Subtarget->useSVEForFixedLengthVectors(); } unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) { Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll @@ -170,14 +170,14 @@ define @masked_sload_x2_4i8_4i64(ptr %a, ptr %b, %c) { ; CHECK-LABEL: masked_sload_x2_4i8_4i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x0] -; CHECK-NEXT: ld1sb { z1.s }, p0/z, [x1] -; CHECK-NEXT: sunpkhi z2.d, z0.s -; CHECK-NEXT: sunpklo z0.d, z0.s -; CHECK-NEXT: sunpkhi z3.d, z1.s -; CHECK-NEXT: sunpklo z1.d, z1.s -; CHECK-NEXT: add z0.d, z0.d, z1.d -; CHECK-NEXT: add z1.d, z2.d, z3.d +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ld1sb { z1.d }, p1/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1sb { z0.d }, p0/z, [x0] +; CHECK-NEXT: ld1sb { z2.d }, p1/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1sb { z3.d }, p0/z, [x1] +; CHECK-NEXT: add z0.d, z0.d, z3.d +; CHECK-NEXT: add z1.d, z1.d, z2.d ; CHECK-NEXT: ret %aval = call @llvm.masked.load.nxv4i8( *%a, i32 16, %c, zeroinitializer) %bval = call @llvm.masked.load.nxv4i8( *%b, i32 16, %c, zeroinitializer) @@ -190,14 +190,14 @@ define @masked_sload_x2_4i16_4i64(ptr %a, ptr %b, %c) { ; CHECK-LABEL: masked_sload_x2_4i16_4i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1sh { z0.s }, p0/z, [x0] -; CHECK-NEXT: ld1sh { z1.s }, p0/z, [x1] -; CHECK-NEXT: sunpkhi z2.d, z0.s -; CHECK-NEXT: sunpklo z0.d, z0.s -; CHECK-NEXT: sunpkhi z3.d, z1.s -; CHECK-NEXT: sunpklo z1.d, z1.s -; CHECK-NEXT: add z0.d, z0.d, z1.d -; CHECK-NEXT: add z1.d, z2.d, z3.d +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ld1sh { z1.d }, p1/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1sh { z0.d }, p0/z, [x0] +; CHECK-NEXT: ld1sh { z2.d }, p1/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1sh { z3.d }, p0/z, [x1] +; CHECK-NEXT: add z0.d, z0.d, z3.d +; CHECK-NEXT: add z1.d, z1.d, z2.d ; CHECK-NEXT: ret %aval = call @llvm.masked.load.nxv4i16( *%a, i32 16, %c, zeroinitializer) %bval = call @llvm.masked.load.nxv4i16( *%b, i32 16, %c, zeroinitializer) @@ -210,14 +210,14 @@ define @masked_sload_x2_8i8_8i32(ptr %a, ptr %b, %c) { ; CHECK-LABEL: masked_sload_x2_8i8_8i32: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1sb { z0.h }, p0/z, [x0] -; CHECK-NEXT: ld1sb { z1.h }, p0/z, [x1] -; CHECK-NEXT: sunpkhi z2.s, z0.h -; CHECK-NEXT: sunpklo z0.s, z0.h -; CHECK-NEXT: sunpkhi z3.s, z1.h -; CHECK-NEXT: sunpklo z1.s, z1.h -; CHECK-NEXT: add z0.s, z0.s, z1.s -; CHECK-NEXT: add z1.s, z2.s, z3.s +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ld1sb { z1.s }, p1/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x0] +; CHECK-NEXT: ld1sb { z2.s }, p1/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1sb { z3.s }, p0/z, [x1] +; CHECK-NEXT: add z0.s, z0.s, z3.s +; CHECK-NEXT: add z1.s, z1.s, z2.s ; CHECK-NEXT: ret %aval = call @llvm.masked.load.nxv8i8( *%a, i32 16, %c, zeroinitializer) %bval = call @llvm.masked.load.nxv8i8( *%b, i32 16, %c, zeroinitializer) @@ -230,24 +230,24 @@ define @masked_sload_x2_8i8_8i64(ptr %a, ptr %b, %c) { ; CHECK-LABEL: masked_sload_x2_8i8_8i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1sb { z0.h }, p0/z, [x0] -; CHECK-NEXT: ld1sb { z1.h }, p0/z, [x1] -; CHECK-NEXT: sunpkhi z2.s, z0.h -; CHECK-NEXT: sunpklo z0.s, z0.h -; CHECK-NEXT: sunpklo z3.s, z1.h -; CHECK-NEXT: sunpkhi z1.s, z1.h -; CHECK-NEXT: sunpkhi z4.d, z2.s -; CHECK-NEXT: sunpklo z2.d, z2.s -; CHECK-NEXT: sunpkhi z5.d, z0.s -; CHECK-NEXT: sunpklo z0.d, z0.s -; CHECK-NEXT: sunpklo z6.d, z3.s -; CHECK-NEXT: sunpkhi z7.d, z1.s -; CHECK-NEXT: sunpklo z24.d, z1.s -; CHECK-NEXT: sunpkhi z1.d, z3.s -; CHECK-NEXT: add z0.d, z0.d, z6.d -; CHECK-NEXT: add z3.d, z4.d, z7.d -; CHECK-NEXT: add z1.d, z5.d, z1.d -; CHECK-NEXT: add z2.d, z2.d, z24.d +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: punpkhi p2.h, p1.b +; CHECK-NEXT: punpklo p1.h, p1.b +; CHECK-NEXT: punpkhi p3.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ld1sb { z3.d }, p2/z, [x0, #3, mul vl] +; CHECK-NEXT: ld1sb { z2.d }, p1/z, [x0, #2, mul vl] +; CHECK-NEXT: ld1sb { z1.d }, p3/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1sb { z0.d }, p0/z, [x0] +; CHECK-NEXT: ld1sb { z4.d }, p2/z, [x1, #3, mul vl] +; CHECK-NEXT: ld1sb { z5.d }, p1/z, [x1, #2, mul vl] +; CHECK-NEXT: ld1sb { z6.d }, p3/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1sb { z7.d }, p0/z, [x1] +; CHECK-NEXT: add z2.d, z2.d, z5.d +; CHECK-NEXT: add z3.d, z3.d, z4.d +; CHECK-NEXT: add z0.d, z0.d, z7.d +; CHECK-NEXT: add z1.d, z1.d, z6.d ; CHECK-NEXT: ret %aval = call @llvm.masked.load.nxv8i8( *%a, i32 16, %c, zeroinitializer) %bval = call @llvm.masked.load.nxv8i8( *%b, i32 16, %c, zeroinitializer) Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -165,14 +165,14 @@ define @masked_zload_x2_4i8_4i64(ptr %a, ptr %b, %c) { ; CHECK-LABEL: masked_zload_x2_4i8_4i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1b { z0.s }, p0/z, [x0] -; CHECK-NEXT: ld1b { z1.s }, p0/z, [x1] -; CHECK-NEXT: uunpkhi z2.d, z0.s -; CHECK-NEXT: uunpklo z0.d, z0.s -; CHECK-NEXT: uunpkhi z3.d, z1.s -; CHECK-NEXT: uunpklo z1.d, z1.s -; CHECK-NEXT: add z0.d, z0.d, z1.d -; CHECK-NEXT: add z1.d, z2.d, z3.d +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ld1b { z1.d }, p1/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1b { z0.d }, p0/z, [x0] +; CHECK-NEXT: ld1b { z2.d }, p1/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1b { z3.d }, p0/z, [x1] +; CHECK-NEXT: add z0.d, z0.d, z3.d +; CHECK-NEXT: add z1.d, z1.d, z2.d ; CHECK-NEXT: ret %aval = call @llvm.masked.load.nxv4i8( *%a, i32 16, %c, zeroinitializer) %bval = call @llvm.masked.load.nxv4i8( *%b, i32 16, %c, zeroinitializer) @@ -185,14 +185,14 @@ define @masked_zload_x2_4i16_4i64(ptr %a, ptr %b, %c) { ; CHECK-LABEL: masked_zload_x2_4i16_4i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1h { z0.s }, p0/z, [x0] -; CHECK-NEXT: ld1h { z1.s }, p0/z, [x1] -; CHECK-NEXT: uunpkhi z2.d, z0.s -; CHECK-NEXT: uunpklo z0.d, z0.s -; CHECK-NEXT: uunpkhi z3.d, z1.s -; CHECK-NEXT: uunpklo z1.d, z1.s -; CHECK-NEXT: add z0.d, z0.d, z1.d -; CHECK-NEXT: add z1.d, z2.d, z3.d +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ld1h { z1.d }, p1/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1h { z0.d }, p0/z, [x0] +; CHECK-NEXT: ld1h { z2.d }, p1/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1h { z3.d }, p0/z, [x1] +; CHECK-NEXT: add z0.d, z0.d, z3.d +; CHECK-NEXT: add z1.d, z1.d, z2.d ; CHECK-NEXT: ret %aval = call @llvm.masked.load.nxv4i16( *%a, i32 16, %c, zeroinitializer) %bval = call @llvm.masked.load.nxv4i16( *%b, i32 16, %c, zeroinitializer) @@ -205,14 +205,14 @@ define @masked_zload_x2_8i8_8i32(ptr %a, ptr %b, %c) { ; CHECK-LABEL: masked_zload_x2_8i8_8i32: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1b { z0.h }, p0/z, [x0] -; CHECK-NEXT: ld1b { z1.h }, p0/z, [x1] -; CHECK-NEXT: uunpkhi z2.s, z0.h -; CHECK-NEXT: uunpklo z0.s, z0.h -; CHECK-NEXT: uunpkhi z3.s, z1.h -; CHECK-NEXT: uunpklo z1.s, z1.h -; CHECK-NEXT: add z0.s, z0.s, z1.s -; CHECK-NEXT: add z1.s, z2.s, z3.s +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ld1b { z1.s }, p1/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1b { z0.s }, p0/z, [x0] +; CHECK-NEXT: ld1b { z2.s }, p1/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1b { z3.s }, p0/z, [x1] +; CHECK-NEXT: add z0.s, z0.s, z3.s +; CHECK-NEXT: add z1.s, z1.s, z2.s ; CHECK-NEXT: ret %aval = call @llvm.masked.load.nxv8i8( *%a, i32 16, %c, zeroinitializer) %bval = call @llvm.masked.load.nxv8i8( *%b, i32 16, %c, zeroinitializer) @@ -225,24 +225,24 @@ define @masked_zload_x2_8i8_8i64(ptr %a, ptr %b, %c) { ; CHECK-LABEL: masked_zload_x2_8i8_8i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1b { z0.h }, p0/z, [x0] -; CHECK-NEXT: ld1b { z1.h }, p0/z, [x1] -; CHECK-NEXT: uunpkhi z2.s, z0.h -; CHECK-NEXT: uunpklo z0.s, z0.h -; CHECK-NEXT: uunpklo z3.s, z1.h -; CHECK-NEXT: uunpkhi z1.s, z1.h -; CHECK-NEXT: uunpkhi z4.d, z2.s -; CHECK-NEXT: uunpklo z2.d, z2.s -; CHECK-NEXT: uunpkhi z5.d, z0.s -; CHECK-NEXT: uunpklo z0.d, z0.s -; CHECK-NEXT: uunpklo z6.d, z3.s -; CHECK-NEXT: uunpkhi z7.d, z1.s -; CHECK-NEXT: uunpklo z24.d, z1.s -; CHECK-NEXT: uunpkhi z1.d, z3.s -; CHECK-NEXT: add z0.d, z0.d, z6.d -; CHECK-NEXT: add z3.d, z4.d, z7.d -; CHECK-NEXT: add z1.d, z5.d, z1.d -; CHECK-NEXT: add z2.d, z2.d, z24.d +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: punpkhi p2.h, p1.b +; CHECK-NEXT: punpklo p1.h, p1.b +; CHECK-NEXT: punpkhi p3.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ld1b { z3.d }, p2/z, [x0, #3, mul vl] +; CHECK-NEXT: ld1b { z2.d }, p1/z, [x0, #2, mul vl] +; CHECK-NEXT: ld1b { z1.d }, p3/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1b { z0.d }, p0/z, [x0] +; CHECK-NEXT: ld1b { z4.d }, p2/z, [x1, #3, mul vl] +; CHECK-NEXT: ld1b { z5.d }, p1/z, [x1, #2, mul vl] +; CHECK-NEXT: ld1b { z6.d }, p3/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1b { z7.d }, p0/z, [x1] +; CHECK-NEXT: add z2.d, z2.d, z5.d +; CHECK-NEXT: add z3.d, z3.d, z4.d +; CHECK-NEXT: add z0.d, z0.d, z7.d +; CHECK-NEXT: add z1.d, z1.d, z6.d ; CHECK-NEXT: ret %aval = call @llvm.masked.load.nxv8i8( *%a, i32 16, %c, zeroinitializer) %bval = call @llvm.masked.load.nxv8i8( *%b, i32 16, %c, zeroinitializer)