Index: llvm/include/llvm/CodeGen/TargetLowering.h =================================================================== --- llvm/include/llvm/CodeGen/TargetLowering.h +++ llvm/include/llvm/CodeGen/TargetLowering.h @@ -802,7 +802,7 @@ } // Return true if the target wants to transform Op(Splat(X)) -> Splat(Op(X)) - virtual bool preferScalarizeSplat(unsigned Opc) const { return true; } + virtual bool preferScalarizeSplat(SDNode *N) const { return true; } /// Return true if the target wants to use the optimization that /// turns ext(promotableInst1(...(promotableInstN(load)))) into Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -25619,7 +25619,7 @@ (N0.getOpcode() == ISD::SPLAT_VECTOR || TLI.isExtractVecEltCheap(VT, Index0)) && TLI.isOperationLegalOrCustom(Opcode, EltVT) && - TLI.preferScalarizeSplat(Opcode)) { + TLI.preferScalarizeSplat(N)) { SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL); SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC); Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1222,6 +1222,8 @@ bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1, LLT Ty2) const override; + + bool preferScalarizeSplat(SDNode *N) const override; }; namespace AArch64 { Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -24417,3 +24417,14 @@ return nullptr; } + +bool AArch64TargetLowering::preferScalarizeSplat(SDNode *N) const { + unsigned Opc = N->getOpcode(); + if (Opc == ISD::ZERO_EXTEND || Opc == ISD::SIGN_EXTEND || + Opc == ISD::ANY_EXTEND) { + if (any_of(N->uses(), + [&](SDNode *Use) { return Use->getOpcode() == ISD::MUL; })) + return false; + } + return true; +} Index: llvm/lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.h +++ llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -401,7 +401,7 @@ bool isIntDivCheap(EVT VT, AttributeList Attr) const override; - bool preferScalarizeSplat(unsigned Opc) const override; + bool preferScalarizeSplat(SDNode *N) const override; bool softPromoteHalfType() const override { return true; } Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -14925,9 +14925,10 @@ return OptSize && !VT.isVector(); } -bool RISCVTargetLowering::preferScalarizeSplat(unsigned Opc) const { +bool RISCVTargetLowering::preferScalarizeSplat(SDNode *N) const { // Scalarize zero_ext and sign_ext might stop match to widening instruction in // some situation. + unsigned Opc = N->getOpcode(); if (Opc == ISD::ZERO_EXTEND || Opc == ISD::SIGN_EXTEND) return false; return true; Index: llvm/lib/Target/X86/X86ISelLowering.h =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.h +++ llvm/lib/Target/X86/X86ISelLowering.h @@ -1126,7 +1126,7 @@ unsigned OldShiftOpcode, unsigned NewShiftOpcode, SelectionDAG &DAG) const override; - bool preferScalarizeSplat(unsigned Opc) const override; + bool preferScalarizeSplat(SDNode *N) const override; bool shouldFoldConstantShiftPairToMask(const SDNode *N, CombineLevel Level) const override; Index: llvm/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.cpp +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -6118,8 +6118,8 @@ return NewShiftOpcode == ISD::SHL; } -bool X86TargetLowering::preferScalarizeSplat(unsigned Opc) const { - return Opc != ISD::FP_EXTEND; +bool X86TargetLowering::preferScalarizeSplat(SDNode *N) const { + return N->getOpcode() != ISD::FP_EXTEND; } bool X86TargetLowering::shouldFoldConstantShiftPairToMask( Index: llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-int-extends.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-int-extends.ll +++ llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-int-extends.ll @@ -891,4 +891,44 @@ ret void } +define void @extend_and_mul(i32 %0, <8 x i64> %1, ptr %2) #0 { +; CHECK-LABEL: extend_and_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z4.s, w0 +; CHECK-NEXT: // kill: def $q3 killed $q3 def $z3 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: uunpklo z4.d, z4.s +; CHECK-NEXT: mul z3.d, p0/m, z3.d, z4.d +; CHECK-NEXT: mul z2.d, p0/m, z2.d, z4.d +; CHECK-NEXT: mul z1.d, p0/m, z1.d, z4.d +; CHECK-NEXT: mul z0.d, p0/m, z0.d, z4.d +; CHECK-NEXT: stp q0, q1, [x1] +; CHECK-NEXT: stp q2, q3, [x1, #32] +; CHECK-NEXT: ret + %broadcast.splatinsert2 = insertelement <8 x i32> poison, i32 %0, i64 0 + %broadcast.splat3 = shufflevector <8 x i32> %broadcast.splatinsert2, <8 x i32> zeroinitializer, <8 x i32> zeroinitializer + %4 = zext <8 x i32> %broadcast.splat3 to <8 x i64> + %5 = mul <8 x i64> %4, %1 + store <8 x i64> %5, ptr %2, align 2 + ret void +} + +define void @zero_extend_no_mul(ptr %s, i16 %0) #0 { +; CHECK-LABEL: zero_extend_no_mul: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: and w8, w1, #0xffff +; CHECK-NEXT: mov z0.s, w8 +; CHECK-NEXT: stp q0, q0, [x0] +; CHECK-NEXT: ret +entry: + %broadcast.splatinsert2 = insertelement <8 x i16> poison, i16 %0, i64 0 + %broadcast.splat3 = shufflevector <8 x i16> %broadcast.splatinsert2, <8 x i16> zeroinitializer, <8 x i32> zeroinitializer + %1 = zext <8 x i16> %broadcast.splat3 to <8 x i32> + store <8 x i32> %1, ptr %s, align 2 + ret void +} + attributes #0 = { nounwind "target-features"="+sve" }