diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -105,8 +105,9 @@ SPLAT_VECTOR_I64, // Read VLENB CSR READ_VLENB, - // Truncates a RVV integer vector by one power-of-two. - TRUNCATE_VECTOR, + // Truncates a RVV integer vector by one power-of-two. Carries both an extra + // mask and VL operand. + TRUNCATE_VECTOR_VL, // Unit-stride fault-only-first load VLEFF, VLEFF_MASK, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -446,7 +446,7 @@ setOperationAction(ISD::FP_TO_SINT, VT, Custom); setOperationAction(ISD::FP_TO_UINT, VT, Custom); - // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR" + // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL" // nodes which truncate by one power of two at a time. setOperationAction(ISD::TRUNCATE, VT, Custom); @@ -524,6 +524,8 @@ // By default everything must be expanded. for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) setOperationAction(Op, VT, Expand); + for (MVT OtherVT : MVT::fixedlen_vector_valuetypes()) + setTruncStoreAction(VT, OtherVT, Expand); // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed. setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); @@ -568,6 +570,7 @@ setOperationAction(ISD::VSELECT, VT, Custom); + setOperationAction(ISD::TRUNCATE, VT, Custom); setOperationAction(ISD::ANY_EXTEND, VT, Custom); setOperationAction(ISD::SIGN_EXTEND, VT, Custom); setOperationAction(ISD::ZERO_EXTEND, VT, Custom); @@ -1153,7 +1156,7 @@ } case ISD::TRUNCATE: { SDLoc DL(Op); - EVT VT = Op.getValueType(); + MVT VT = Op.getSimpleValueType(); // Only custom-lower vector truncates if (!VT.isVector()) return Op; @@ -1163,28 +1166,42 @@ return lowerVectorMaskTrunc(Op, DAG); // RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary - // truncates as a series of "RISCVISD::TRUNCATE_VECTOR" nodes which + // truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which // truncate by one power of two at a time. - EVT DstEltVT = VT.getVectorElementType(); + MVT DstEltVT = VT.getVectorElementType(); SDValue Src = Op.getOperand(0); - EVT SrcVT = Src.getValueType(); - EVT SrcEltVT = SrcVT.getVectorElementType(); + MVT SrcVT = Src.getSimpleValueType(); + MVT SrcEltVT = SrcVT.getVectorElementType(); assert(DstEltVT.bitsLT(SrcEltVT) && isPowerOf2_64(DstEltVT.getSizeInBits()) && isPowerOf2_64(SrcEltVT.getSizeInBits()) && "Unexpected vector truncate lowering"); + MVT ContainerVT = SrcVT; + if (SrcVT.isFixedLengthVector()) { + ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( + DAG, SrcVT, Subtarget); + Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); + } + SDValue Result = Src; + SDValue Mask, VL; + std::tie(Mask, VL) = + getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget); LLVMContext &Context = *DAG.getContext(); - const ElementCount Count = SrcVT.getVectorElementCount(); + const ElementCount Count = ContainerVT.getVectorElementCount(); do { - SrcEltVT = EVT::getIntegerVT(Context, SrcEltVT.getSizeInBits() / 2); + SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2); EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count); - Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR, DL, ResultVT, Result); + Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result, + Mask, VL); } while (SrcEltVT != DstEltVT); + if (SrcVT.isFixedLengthVector()) + Result = convertFromScalableVector(VT, Result, DAG, Subtarget); + return Result; } case ISD::ANY_EXTEND: @@ -5300,7 +5317,7 @@ NODE_NAME_CASE(VMV_X_S) NODE_NAME_CASE(SPLAT_VECTOR_I64) NODE_NAME_CASE(READ_VLENB) - NODE_NAME_CASE(TRUNCATE_VECTOR) + NODE_NAME_CASE(TRUNCATE_VECTOR_VL) NODE_NAME_CASE(VLEFF) NODE_NAME_CASE(VLEFF_MASK) NODE_NAME_CASE(VSLIDEUP_VL) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -28,10 +28,6 @@ def rv32_splat_i64 : SDNode<"RISCVISD::SPLAT_VECTOR_I64", SDTSplatI64>; -def riscv_trunc_vector : SDNode<"RISCVISD::TRUNCATE_VECTOR", - SDTypeProfile<1, 1, - [SDTCisVec<0>, SDTCisVec<1>]>>; - // Give explicit Complexity to prefer simm5/uimm5. def SplatPat : ComplexPattern; def SplatPat_simm5 : ComplexPattern; @@ -433,15 +429,6 @@ defm "" : VPatBinarySDNode_VV_VX_VI; defm "" : VPatBinarySDNode_VV_VX_VI; -// 12.7. Vector Narrowing Integer Right Shift Instructions -foreach vtiTofti = AllFractionableVF2IntVectors in { - defvar vti = vtiTofti.Vti; - defvar fti = vtiTofti.Fti; - def : Pat<(fti.Vector (riscv_trunc_vector (vti.Vector vti.RegClass:$rs1))), - (!cast("PseudoVNSRL_WI_"#fti.LMul.MX) - vti.RegClass:$rs1, 0, fti.AVL, fti.SEW)>; -} - // 12.8. Vector Integer Comparison Instructions defm "" : VPatIntegerSetCCSDNode_VV_VX_VI; defm "" : VPatIntegerSetCCSDNode_VV_VX_VI; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -148,6 +148,13 @@ def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>; def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>; +def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL", + SDTypeProfile<1, 3, [SDTCisVec<0>, + SDTCisVec<1>, + SDTCisSameNumEltsAs<0, 2>, + SDTCVecEltisVT<2, i1>, + SDTCisVT<3, XLenVT>]>>; + // Ignore the vl operand. def SplatFPOp : PatFrag<(ops node:$op), (riscv_vfmv_v_f_vl node:$op, srcvalue)>; @@ -443,6 +450,17 @@ defm "" : VPatBinaryVL_VV_VX_VI; defm "" : VPatBinaryVL_VV_VX_VI; +// 12.7. Vector Narrowing Integer Right Shift Instructions +foreach vtiTofti = AllFractionableVF2IntVectors in { + defvar vti = vtiTofti.Vti; + defvar fti = vtiTofti.Fti; + def : Pat<(fti.Vector (riscv_trunc_vector_vl (vti.Vector vti.RegClass:$rs1), + (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVNSRL_WI_"#fti.LMul.MX) + vti.RegClass:$rs1, 0, GPR:$vl, fti.SEW)>; +} + // 12.8. Vector Integer Comparison Instructions foreach vti = AllIntegerVectors in { defm "" : VPatIntegerSetCCVL_VV; diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll @@ -192,3 +192,85 @@ store <32 x i32> %b, <32 x i32>* %z ret void } + +define void @trunc_v4i8_v4i32(<4 x i32>* %x, <4 x i8>* %z) { +; CHECK-LABEL: trunc_v4i8_v4i32: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a2, zero, 4 +; CHECK-NEXT: vsetvli a3, a2, e32,m1,ta,mu +; CHECK-NEXT: vle32.v v25, (a0) +; CHECK-NEXT: vsetvli a0, a2, e16,mf2,ta,mu +; CHECK-NEXT: vnsrl.wi v26, v25, 0 +; CHECK-NEXT: vsetvli a0, a2, e8,mf4,ta,mu +; CHECK-NEXT: vnsrl.wi v25, v26, 0 +; CHECK-NEXT: vsetvli a0, a2, e8,m1,ta,mu +; CHECK-NEXT: vse8.v v25, (a1) +; CHECK-NEXT: ret + %a = load <4 x i32>, <4 x i32>* %x + %b = trunc <4 x i32> %a to <4 x i8> + store <4 x i8> %b, <4 x i8>* %z + ret void +} + +define void @trunc_v8i8_v8i32(<8 x i32>* %x, <8 x i8>* %z) { +; LMULMAX8-LABEL: trunc_v8i8_v8i32: +; LMULMAX8: # %bb.0: +; LMULMAX8-NEXT: addi a2, zero, 8 +; LMULMAX8-NEXT: vsetvli a3, a2, e32,m2,ta,mu +; LMULMAX8-NEXT: vle32.v v26, (a0) +; LMULMAX8-NEXT: vsetvli a0, a2, e16,m1,ta,mu +; LMULMAX8-NEXT: vnsrl.wi v25, v26, 0 +; LMULMAX8-NEXT: vsetvli a0, a2, e8,mf2,ta,mu +; LMULMAX8-NEXT: vnsrl.wi v26, v25, 0 +; LMULMAX8-NEXT: vsetvli a0, a2, e8,m1,ta,mu +; LMULMAX8-NEXT: vse8.v v26, (a1) +; LMULMAX8-NEXT: ret +; +; LMULMAX2-LABEL: trunc_v8i8_v8i32: +; LMULMAX2: # %bb.0: +; LMULMAX2-NEXT: addi a2, zero, 8 +; LMULMAX2-NEXT: vsetvli a3, a2, e32,m2,ta,mu +; LMULMAX2-NEXT: vle32.v v26, (a0) +; LMULMAX2-NEXT: vsetvli a0, a2, e16,m1,ta,mu +; LMULMAX2-NEXT: vnsrl.wi v25, v26, 0 +; LMULMAX2-NEXT: vsetvli a0, a2, e8,mf2,ta,mu +; LMULMAX2-NEXT: vnsrl.wi v26, v25, 0 +; LMULMAX2-NEXT: vsetvli a0, a2, e8,m1,ta,mu +; LMULMAX2-NEXT: vse8.v v26, (a1) +; LMULMAX2-NEXT: ret +; +; LMULMAX1-LABEL: trunc_v8i8_v8i32: +; LMULMAX1: # %bb.0: +; LMULMAX1-NEXT: addi sp, sp, -16 +; LMULMAX1-NEXT: .cfi_def_cfa_offset 16 +; LMULMAX1-NEXT: addi a2, zero, 4 +; LMULMAX1-NEXT: vsetvli a3, a2, e32,m1,ta,mu +; LMULMAX1-NEXT: addi a3, a0, 16 +; LMULMAX1-NEXT: vle32.v v25, (a3) +; LMULMAX1-NEXT: vle32.v v26, (a0) +; LMULMAX1-NEXT: vsetvli a0, a2, e16,mf2,ta,mu +; LMULMAX1-NEXT: vnsrl.wi v27, v25, 0 +; LMULMAX1-NEXT: vsetvli a0, a2, e8,mf4,ta,mu +; LMULMAX1-NEXT: vnsrl.wi v25, v27, 0 +; LMULMAX1-NEXT: addi a0, sp, 12 +; LMULMAX1-NEXT: vsetvli a3, a2, e8,m1,ta,mu +; LMULMAX1-NEXT: vse8.v v25, (a0) +; LMULMAX1-NEXT: vsetvli a0, a2, e16,mf2,ta,mu +; LMULMAX1-NEXT: vnsrl.wi v25, v26, 0 +; LMULMAX1-NEXT: vsetvli a0, a2, e8,mf4,ta,mu +; LMULMAX1-NEXT: vnsrl.wi v26, v25, 0 +; LMULMAX1-NEXT: vsetvli a0, a2, e8,m1,ta,mu +; LMULMAX1-NEXT: addi a0, sp, 8 +; LMULMAX1-NEXT: vse8.v v26, (a0) +; LMULMAX1-NEXT: addi a0, zero, 8 +; LMULMAX1-NEXT: vsetvli a0, a0, e8,m1,ta,mu +; LMULMAX1-NEXT: addi a0, sp, 8 +; LMULMAX1-NEXT: vle8.v v25, (a0) +; LMULMAX1-NEXT: vse8.v v25, (a1) +; LMULMAX1-NEXT: addi sp, sp, 16 +; LMULMAX1-NEXT: ret + %a = load <8 x i32>, <8 x i32>* %x + %b = trunc <8 x i32> %a to <8 x i8> + store <8 x i8> %b, <8 x i8>* %z + ret void +}