diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -47,6 +47,22 @@ I != E;) { SDNode *N = &*I++; // Preincrement iterator to avoid invalidation issues. + // Convert integer SPLAT_VECTOR to VFMV_V_F_VL to reduce isel burden. + if (N->getOpcode() == ISD::SPLAT_VECTOR && + N->getSimpleValueType(0).isInteger()) { + MVT VT = N->getSimpleValueType(0); + SDLoc DL(N); + SDValue VL = CurDAG->getTargetConstant(RISCV::VLMaxSentinel, DL, + Subtarget->getXLenVT()); + SDValue Result = CurDAG->getNode(RISCVISD::VMV_V_X_VL, DL, VT, N->getOperand(0), VL); + + --I; + CurDAG->ReplaceAllUsesOfValueWith(SDValue(N, 0), Result); + ++I; + CurDAG->DeleteNode(N); + continue; + } + // Lower SPLAT_VECTOR_SPLIT_I64 to two scalar stores and a stride 0 vector // load. Done after lowering and combining so that we have a chance to // optimize this to VMV_V_X_VL when the upper bits aren't needed. @@ -1879,8 +1895,7 @@ } bool RISCVDAGToDAGISel::selectVSplat(SDValue N, SDValue &SplatVal) { - if (N.getOpcode() != ISD::SPLAT_VECTOR && - N.getOpcode() != RISCVISD::VMV_V_X_VL) + if (N.getOpcode() != RISCVISD::VMV_V_X_VL) return false; SplatVal = N.getOperand(0); return true; @@ -1892,14 +1907,13 @@ SelectionDAG &DAG, const RISCVSubtarget &Subtarget, ValidateFn ValidateImm) { - if ((N.getOpcode() != ISD::SPLAT_VECTOR && - N.getOpcode() != RISCVISD::VMV_V_X_VL) || + if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !isa(N.getOperand(0))) return false; int64_t SplatImm = cast(N.getOperand(0))->getSExtValue(); - // ISD::SPLAT_VECTOR, RISCVISD::VMV_V_X_VL share semantics when the operand + // The semantics of RISCVISD::VMV_V_X_VL is that when the operand // type is wider than the resulting vector element type: an implicit // truncation first takes place. Therefore, perform a manual // truncation/sign-extension in order to ignore any truncated bits and catch @@ -1940,8 +1954,7 @@ } bool RISCVDAGToDAGISel::selectVSplatUimm5(SDValue N, SDValue &SplatVal) { - if ((N.getOpcode() != ISD::SPLAT_VECTOR && - N.getOpcode() != RISCVISD::VMV_V_X_VL) || + if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !isa(N.getOperand(0))) return false; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -399,7 +399,7 @@ vti.Vti.RegClass:$rs2, vti.Vti.RegClass:$rs1, vti.Vti.AVL, vti.Vti.Log2SEW)>; def : Pat<(op (vti.Wti.Vector (fpext_oneuse (vti.Vti.Vector vti.Vti.RegClass:$rs2))), - (vti.Wti.Vector (fpext_oneuse (vti.Vti.Vector (SplatPat vti.Vti.ScalarRegClass:$rs1))))), + (vti.Wti.Vector (fpext_oneuse (vti.Vti.Vector (splat_vector vti.Vti.ScalarRegClass:$rs1))))), (!cast(instruction_name#"_V"#vti.Vti.ScalarSuffix#"_"#vti.Vti.LMul.MX) vti.Vti.RegClass:$rs2, vti.Vti.ScalarRegClass:$rs1, vti.Vti.AVL, vti.Vti.Log2SEW)>; @@ -414,7 +414,7 @@ vti.Wti.RegClass:$rs2, vti.Vti.RegClass:$rs1, vti.Vti.AVL, vti.Vti.Log2SEW)>; def : Pat<(op (vti.Wti.Vector vti.Wti.RegClass:$rs2), - (vti.Wti.Vector (fpext_oneuse (vti.Vti.Vector (SplatPat vti.Vti.ScalarRegClass:$rs1))))), + (vti.Wti.Vector (fpext_oneuse (vti.Vti.Vector (splat_vector vti.Vti.ScalarRegClass:$rs1))))), (!cast(instruction_name#"_W"#vti.Vti.ScalarSuffix#"_"#vti.Vti.LMul.MX) vti.Wti.RegClass:$rs2, vti.Vti.ScalarRegClass:$rs1, vti.Vti.AVL, vti.Vti.Log2SEW)>; @@ -497,12 +497,6 @@ foreach vti = AllIntegerVectors in { // Emit shift by 1 as an add since it might be faster. - def : Pat<(shl (vti.Vector vti.RegClass:$rs1), - (vti.Vector (splat_vector (XLenVT 1)))), - (!cast("PseudoVADD_VV_"# vti.LMul.MX) - vti.RegClass:$rs1, vti.RegClass:$rs1, vti.AVL, vti.Log2SEW)>; -} -foreach vti = [VI64M1, VI64M2, VI64M4, VI64M8] in { def : Pat<(shl (vti.Vector vti.RegClass:$rs1), (vti.Vector (riscv_vmv_v_x_vl 1, (XLenVT srcvalue)))), (!cast("PseudoVADD_VV_"# vti.LMul.MX) @@ -851,17 +845,6 @@ // Vector Splats //===----------------------------------------------------------------------===// -let Predicates = [HasVInstructions] in { -foreach vti = AllIntegerVectors in { - def : Pat<(vti.Vector (SplatPat GPR:$rs1)), - (!cast("PseudoVMV_V_X_" # vti.LMul.MX) - GPR:$rs1, vti.AVL, vti.Log2SEW)>; - def : Pat<(vti.Vector (SplatPat_simm5 simm5:$rs1)), - (!cast("PseudoVMV_V_I_" # vti.LMul.MX) - simm5:$rs1, vti.AVL, vti.Log2SEW)>; -} -} // Predicates = [HasVInstructions] - let Predicates = [HasVInstructionsAnyF] in { foreach fvti = AllFloatVectors in { def : Pat<(fvti.Vector (splat_vector fvti.ScalarRegClass:$rs1)), diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -279,15 +279,13 @@ def rvv_vecreduce_#kind#_vl : SDNode<"RISCVISD::VECREDUCE_"#kind#"_VL", SDTRVVVecReduce>; // Give explicit Complexity to prefer simm5/uimm5. -def SplatPat : ComplexPattern; -def SplatPat_simm5 : ComplexPattern; -def SplatPat_uimm5 : ComplexPattern; +def SplatPat : ComplexPattern; +def SplatPat_simm5 : ComplexPattern; +def SplatPat_uimm5 : ComplexPattern; def SplatPat_simm5_plus1 - : ComplexPattern; + : ComplexPattern; def SplatPat_simm5_plus1_nonzero - : ComplexPattern; + : ComplexPattern; // Ignore the vl operand. def SplatFPOp : PatFrag<(ops node:$op),