diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8570,6 +8570,25 @@ return SDValue(); } +/// Check whether \p N supports doing sign/zero extension of its arguments. +bool canFoldExtensionInOpcode(const SDNode *N) { + switch (N->getOpcode()) { + case RISCVISD::ADD_VL: + case RISCVISD::SUB_VL: + case RISCVISD::MUL_VL: + return true; + default: + return false; + } +} + +/// Check if all the users of \p Val support sign/zero extending their +/// arguments. +bool canFoldExtensionInAllUsers(SDValue Val) { + return std::all_of(Val->use_begin(), Val->use_end(), + canFoldExtensionInOpcode); +} + // Try to form vwadd(u).wv/wx or vwsub(u).wv/wx. It might later be optimized to // vwadd(u).vv/vx or vwsub(u).vv/vx. static SDValue combineADDSUB_VLToVWADDSUB_VL(SDNode *N, SelectionDAG &DAG, @@ -8599,7 +8618,8 @@ // If the RHS is a sext or zext, we can form a widening op. if ((Op1.getOpcode() == RISCVISD::VZEXT_VL || Op1.getOpcode() == RISCVISD::VSEXT_VL) && - Op1.hasOneUse() && Op1.getOperand(1) == Mask && Op1.getOperand(2) == VL) { + canFoldExtensionInAllUsers(Op1) && Op1.getOperand(1) == Mask && + Op1.getOperand(2) == VL) { unsigned ExtOpc = Op1.getOpcode(); Op1 = Op1.getOperand(0); // Re-introduce narrower extends if needed. @@ -8709,7 +8729,7 @@ bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL; bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL; bool IsVWMULSU = IsSignExt && Op1.getOpcode() == RISCVISD::VZEXT_VL; - if ((!IsSignExt && !IsZeroExt) || !Op0.hasOneUse()) + if ((!IsSignExt && !IsZeroExt) || !canFoldExtensionInAllUsers(Op0)) return SDValue(); SDValue Merge = N->getOperand(2); @@ -8731,7 +8751,7 @@ // See if the other operand is the same opcode. if (IsVWMULSU || Op0.getOpcode() == Op1.getOpcode()) { - if (!Op1.hasOneUse()) + if (!canFoldExtensionInAllUsers(Op1)) return SDValue(); // Make sure the mask and VL match. diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll @@ -18,6 +18,30 @@ ret <2 x i16> %e } +define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) { +; CHECK-LABEL: vwmul_v2i16_multiple_users: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, mu +; CHECK-NEXT: vle8.v v8, (a0) +; CHECK-NEXT: vle8.v v9, (a1) +; CHECK-NEXT: vle8.v v10, (a2) +; CHECK-NEXT: vwmul.vv v11, v8, v9 +; CHECK-NEXT: vwmul.vv v9, v8, v10 +; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, mu +; CHECK-NEXT: vor.vv v8, v11, v9 +; CHECK-NEXT: ret + %a = load <2 x i8>, <2 x i8>* %x + %b = load <2 x i8>, <2 x i8>* %y + %b2 = load <2 x i8>, <2 x i8>* %z + %c = sext <2 x i8> %a to <2 x i16> + %d = sext <2 x i8> %b to <2 x i16> + %d2 = sext <2 x i8> %b2 to <2 x i16> + %e = mul <2 x i16> %c, %d + %f = mul <2 x i16> %c, %d2 + %g = or <2 x i16> %e, %f + ret <2 x i16> %g +} + define <4 x i16> @vwmul_v4i16(<4 x i8>* %x, <4 x i8>* %y) { ; CHECK-LABEL: vwmul_v4i16: ; CHECK: # %bb.0: