diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -131,7 +131,7 @@ bool selectVSplatUimm5(SDValue N, SDValue &SplatVal); bool selectVSplatSimm5Plus1(SDValue N, SDValue &SplatVal); bool selectVSplatSimm5Plus1NonZero(SDValue N, SDValue &SplatVal); - bool selectExtOneUseVSplat(SDValue N, SDValue &SplatVal); + bool selectLowBitsVSplat(SDValue N, SDValue &SplatVal); bool selectFPImm(SDValue N, SDValue &Imm); bool selectRVVSimm5(SDValue N, unsigned Width, SDValue &Imm); diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3016,13 +3016,25 @@ return true; } -bool RISCVDAGToDAGISel::selectExtOneUseVSplat(SDValue N, SDValue &SplatVal) { - if (N->getOpcode() == ISD::SIGN_EXTEND || - N->getOpcode() == ISD::ZERO_EXTEND) { +bool RISCVDAGToDAGISel::selectLowBitsVSplat(SDValue N, SDValue &SplatVal) { + auto IsTrunc = [this](SDValue N) { + if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) + return false; + SDValue VL; + selectVLOp(N->getOperand(2), VL); + return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL && + isa(VL) && + cast(VL)->getSExtValue() == RISCV::VLMaxSentinel; + }; + + // We can have multiple nested truncates, so unravel them all if needed. + while (N->getOpcode() == ISD::SIGN_EXTEND || + N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) { if (!N.hasOneUse()) return false; N = N->getOperand(0); } + return selectVSplat(N, SplatVal); } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -576,8 +576,10 @@ def SplatPat_simm5_plus1_nonzero : ComplexPattern; -def ext_oneuse_SplatPat - : ComplexPattern; +// Selects extends or truncates of splats where we only care about the lowest 8 +// bits of each element. +def LowBitsSplatPat + : ComplexPattern; def SelectFPImm : ComplexPattern; @@ -1452,7 +1454,7 @@ (vti.Vector (riscv_trunc_vector_vl (op (wti.Vector wti.RegClass:$rs2), - (wti.Vector (ext_oneuse_SplatPat (XLenVT GPR:$rs1)))), + (wti.Vector (LowBitsSplatPat (XLenVT GPR:$rs1)))), (vti.Mask true_mask), VLOpFrag)), (!cast(instruction_name#"_WX_"#vti.LMul.MX) diff --git a/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll --- a/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll @@ -652,13 +652,8 @@ ; ; RV64-LABEL: vnsrl_wx_i64_nxv1i16: ; RV64: # %bb.0: -; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma -; RV64-NEXT: vmv.v.x v9, a0 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, ma -; RV64-NEXT: vnsrl.wi v9, v9, 0 -; RV64-NEXT: vsrl.vv v8, v8, v9 -; RV64-NEXT: vsetvli zero, zero, e16, mf4, ta, ma -; RV64-NEXT: vnsrl.wi v8, v8, 0 +; RV64-NEXT: vsetvli a1, zero, e16, mf4, ta, ma +; RV64-NEXT: vnsrl.wx v8, v8, a0 ; RV64-NEXT: ret %head = insertelement poison, i64 %b, i32 0 %splat = shufflevector %head, poison, zeroinitializer @@ -689,15 +684,8 @@ ; ; RV64-LABEL: vnsrl_wx_i64_nxv1i8: ; RV64: # %bb.0: -; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma -; RV64-NEXT: vmv.v.x v9, a0 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, ma -; RV64-NEXT: vnsrl.wi v9, v9, 0 -; RV64-NEXT: vsetvli zero, zero, e16, mf4, ta, ma -; RV64-NEXT: vnsrl.wi v9, v9, 0 -; RV64-NEXT: vsrl.vv v8, v8, v9 -; RV64-NEXT: vsetvli zero, zero, e8, mf8, ta, ma -; RV64-NEXT: vnsrl.wi v8, v8, 0 +; RV64-NEXT: vsetvli a1, zero, e8, mf8, ta, ma +; RV64-NEXT: vnsrl.wx v8, v8, a0 ; RV64-NEXT: ret %head = insertelement poison, i64 %b, i32 0 %splat = shufflevector %head, poison, zeroinitializer