diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -134,7 +134,9 @@ } bool selectVSplatSimm5Plus1(SDValue N, SDValue &SplatVal); bool selectVSplatSimm5Plus1NonZero(SDValue N, SDValue &SplatVal); - bool selectExtOneUseVSplat(SDValue N, SDValue &SplatVal); + // Matches the splat of a value which can be extended or truncated, such that + // only the bottom 8 bits are preserved. + bool selectLow8BitsVSplat(SDValue N, SDValue &SplatVal); bool selectFPImm(SDValue N, SDValue &Imm); bool selectRVVSimm5(SDValue N, unsigned Width, SDValue &Imm); diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3017,13 +3017,29 @@ return true; } -bool RISCVDAGToDAGISel::selectExtOneUseVSplat(SDValue N, SDValue &SplatVal) { - if (N->getOpcode() == ISD::SIGN_EXTEND || - N->getOpcode() == ISD::ZERO_EXTEND) { - if (!N.hasOneUse()) +bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) { + // Truncates are custom lowered during legalization. + auto IsTrunc = [this](SDValue N) { + if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) + return false; + SDValue VL; + selectVLOp(N->getOperand(2), VL); + // Any vmset_vl is ok, since any bits past VL are undefined and we can + // assume they are set. + return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL && + isa(VL) && + cast(VL)->getSExtValue() == RISCV::VLMaxSentinel; + }; + + // We can have multiple nested truncates, so unravel them all if needed. + while (N->getOpcode() == ISD::SIGN_EXTEND || + N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) { + if (!N.hasOneUse() || + N.getValueType().getSizeInBits().getKnownMinValue() < 8) return false; N = N->getOperand(0); } + return selectVSplat(N, SplatVal); } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -577,8 +577,10 @@ def SplatPat_simm5_plus1_nonzero : ComplexPattern; -def ext_oneuse_SplatPat - : ComplexPattern; +// Selects extends or truncates of splats where we only care about the lowest 8 +// bits of each element. +def Low8BitsSplatPat + : ComplexPattern; def SelectFPImm : ComplexPattern; @@ -1453,7 +1455,7 @@ (vti.Vector (riscv_trunc_vector_vl (op (wti.Vector wti.RegClass:$rs2), - (wti.Vector (ext_oneuse_SplatPat (XLenVT GPR:$rs1)))), + (wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1)))), (vti.Mask true_mask), VLOpFrag)), (!cast(instruction_name#"_WX_"#vti.LMul.MX) diff --git a/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll --- a/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll @@ -652,13 +652,8 @@ ; ; RV64-LABEL: vnsrl_wx_i64_nxv1i16: ; RV64: # %bb.0: -; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma -; RV64-NEXT: vmv.v.x v9, a0 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, ma -; RV64-NEXT: vnsrl.wi v9, v9, 0 -; RV64-NEXT: vsrl.vv v8, v8, v9 -; RV64-NEXT: vsetvli zero, zero, e16, mf4, ta, ma -; RV64-NEXT: vnsrl.wi v8, v8, 0 +; RV64-NEXT: vsetvli a1, zero, e16, mf4, ta, ma +; RV64-NEXT: vnsrl.wx v8, v8, a0 ; RV64-NEXT: ret %head = insertelement poison, i64 %b, i32 0 %splat = shufflevector %head, poison, zeroinitializer @@ -689,15 +684,8 @@ ; ; RV64-LABEL: vnsrl_wx_i64_nxv1i8: ; RV64: # %bb.0: -; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma -; RV64-NEXT: vmv.v.x v9, a0 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, ma -; RV64-NEXT: vnsrl.wi v9, v9, 0 -; RV64-NEXT: vsetvli zero, zero, e16, mf4, ta, ma -; RV64-NEXT: vnsrl.wi v9, v9, 0 -; RV64-NEXT: vsrl.vv v8, v8, v9 -; RV64-NEXT: vsetvli zero, zero, e8, mf8, ta, ma -; RV64-NEXT: vnsrl.wi v8, v8, 0 +; RV64-NEXT: vsetvli a1, zero, e8, mf8, ta, ma +; RV64-NEXT: vnsrl.wx v8, v8, a0 ; RV64-NEXT: ret %head = insertelement poison, i64 %b, i32 0 %splat = shufflevector %head, poison, zeroinitializer