diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -7563,7 +7563,6 @@ } // Try to form VWMUL, VWMULU or VWMULSU. -// TODO: Support VWMULSU.vx with a sign extend Op and a splat of scalar Op. static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG, bool Commute) { assert(N->getOpcode() == RISCVISD::MUL_VL && "Unexpected opcode"); @@ -7623,7 +7622,12 @@ if (ScalarBits < EltBits) return SDValue(); - if (IsSignExt) { + if (IsSignExt && ISD::isZEXTLoad(Op1.getNode())) { + APInt Mask = APInt::getBitsSetFrom(ScalarBits, NarrowSize); + if (!DAG.MaskedValueIsZero(Op1, Mask)) + return SDValue(); + IsVWMULSU = true; + } else if (IsSignExt) { if (DAG.ComputeNumSignBits(Op1) <= (ScalarBits - NarrowSize)) return SDValue(); } else { diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll @@ -682,16 +682,13 @@ ret <16 x i64> %f } -; ToDo: add tests for vwmulsu_vx when one input is a scalar splat. define <8 x i16> @vwmulsu_vx_v8i16_i8(<8 x i8>* %x, i8* %y) { ; CHECK-LABEL: vwmulsu_vx_v8i16_i8: ; CHECK: # %bb.0: ; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu -; CHECK-NEXT: vle8.v v8, (a0) +; CHECK-NEXT: vle8.v v9, (a0) ; CHECK-NEXT: lbu a0, 0(a1) -; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, mu -; CHECK-NEXT: vsext.vf2 v9, v8 -; CHECK-NEXT: vmul.vx v8, v9, a0 +; CHECK-NEXT: vwmulsu.vx v8, v9, a0 ; CHECK-NEXT: ret %a = load <8 x i8>, <8 x i8>* %x %b = load i8, i8* %y @@ -729,7 +726,7 @@ ; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: vle16.v v9, (a0) ; CHECK-NEXT: lbu a0, 0(a1) -; CHECK-NEXT: vwmul.vx v8, v9, a0 +; CHECK-NEXT: vwmulsu.vx v8, v9, a0 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i8, i8* %y @@ -745,11 +742,9 @@ ; CHECK-LABEL: vwmulsu_vx_v4i32_i16: ; CHECK: # %bb.0: ; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu -; CHECK-NEXT: vle16.v v8, (a0) +; CHECK-NEXT: vle16.v v9, (a0) ; CHECK-NEXT: lhu a0, 0(a1) -; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, mu -; CHECK-NEXT: vsext.vf2 v9, v8 -; CHECK-NEXT: vmul.vx v8, v9, a0 +; CHECK-NEXT: vwmulsu.vx v8, v9, a0 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i16, i16* %y @@ -784,7 +779,7 @@ ; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: vle32.v v9, (a0) ; RV64-NEXT: lbu a0, 0(a1) -; RV64-NEXT: vwmul.vx v8, v9, a0 +; RV64-NEXT: vwmulsu.vx v8, v9, a0 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i8, i8* %y @@ -819,7 +814,7 @@ ; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: vle32.v v9, (a0) ; RV64-NEXT: lhu a0, 0(a1) -; RV64-NEXT: vwmul.vx v8, v9, a0 +; RV64-NEXT: vwmulsu.vx v8, v9, a0 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i16, i16* %y @@ -852,11 +847,9 @@ ; RV64-LABEL: vwmulsu_vx_v2i64_i32: ; RV64: # %bb.0: ; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu -; RV64-NEXT: vle32.v v8, (a0) +; RV64-NEXT: vle32.v v9, (a0) ; RV64-NEXT: lwu a0, 0(a1) -; RV64-NEXT: vsetvli zero, zero, e64, m1, ta, mu -; RV64-NEXT: vsext.vf2 v9, v8 -; RV64-NEXT: vmul.vx v8, v9, a0 +; RV64-NEXT: vwmulsu.vx v8, v9, a0 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i32, i32* %y