diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -571,6 +571,13 @@ /// implicitly truncated to it. SPLAT_VECTOR, + /// SPLAT_VECTOR_PARTS(Hi, Lo) - Returns a vector with the scalar + /// values Lo and Hi joined together and then duplicated in all lanes. This + /// represents a SPLAT_VECTOR that has had its scalar operand expanded. The + /// total width of the scalars must cover the element width. This allows + /// representing a 64-bit splat on a target with 32-bit integers. + SPLAT_VECTOR_PARTS, + /// MULHU/MULHS - Multiply high - Multiply two integers of type iN, /// producing an unsigned/signed value of type i[2*N], then return the top /// part. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -4192,6 +4192,7 @@ case ISD::EXTRACT_ELEMENT: Res = ExpandOp_EXTRACT_ELEMENT(N); break; case ISD::INSERT_VECTOR_ELT: Res = ExpandOp_INSERT_VECTOR_ELT(N); break; case ISD::SCALAR_TO_VECTOR: Res = ExpandOp_SCALAR_TO_VECTOR(N); break; + case ISD::SPLAT_VECTOR: Res = ExpandIntOp_SPLAT_VECTOR(N); break; case ISD::SELECT_CC: Res = ExpandIntOp_SELECT_CC(N); break; case ISD::SETCC: Res = ExpandIntOp_SETCC(N); break; case ISD::SETCCCARRY: Res = ExpandIntOp_SETCCCARRY(N); break; @@ -4447,6 +4448,14 @@ LowCmp.getValue(1), Cond); } +SDValue DAGTypeLegalizer::ExpandIntOp_SPLAT_VECTOR(SDNode *N) { + // Split the operand and replace with SPLAT_VECTOR_PARTS. + SDValue Lo, Hi; + GetExpandedInteger(N->getOperand(0), Lo, Hi); + return DAG.getNode(ISD::SPLAT_VECTOR_PARTS, SDLoc(N), N->getValueType(0), Lo, + Hi); +} + SDValue DAGTypeLegalizer::ExpandIntOp_Shift(SDNode *N) { // The value being shifted is legal, but the shift amount is too big. // It follows that either the result of the shift is undefined, or the diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -480,6 +480,7 @@ SDValue ExpandIntOp_UINT_TO_FP(SDNode *N); SDValue ExpandIntOp_RETURNADDR(SDNode *N); SDValue ExpandIntOp_ATOMIC_STORE(SDNode *N); + SDValue ExpandIntOp_SPLAT_VECTOR(SDNode *N); void IntegerExpandSetCCOperands(SDValue &NewLHS, SDValue &NewRHS, ISD::CondCode &CCCode, const SDLoc &dl); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1372,6 +1372,22 @@ const APInt &NewVal = Elt->getValue(); EVT ViaEltVT = TLI->getTypeToTransformTo(*getContext(), EltVT); unsigned ViaEltSizeInBits = ViaEltVT.getSizeInBits(); + + // For scalable vectors, try to use a SPLAT_VECTOR_PARTS node. + if (VT.isScalableVector()) { + assert(EltVT.getSizeInBits() % ViaEltSizeInBits == 0 && + "Can only handle an even split!"); + unsigned Parts = EltVT.getSizeInBits() / ViaEltSizeInBits; + assert(Parts == 2 && "Can't handle more than 2 parts yet!"); + + SDValue Lo = + getConstant(NewVal.trunc(ViaEltSizeInBits), DL, ViaEltVT, isT, isO); + SDValue Hi = + getConstant(NewVal.lshr(ViaEltSizeInBits).trunc(ViaEltSizeInBits), DL, + ViaEltVT, isT, isO); + return getNode(ISD::SPLAT_VECTOR_PARTS, DL, VT, Lo, Hi); + } + unsigned ViaVecNumElts = VT.getSizeInBits() / ViaEltSizeInBits; EVT ViaVecVT = EVT::getVectorVT(*getContext(), ViaEltVT, ViaVecNumElts); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -289,6 +289,7 @@ case ISD::SCALAR_TO_VECTOR: return "scalar_to_vector"; case ISD::VECTOR_SHUFFLE: return "vector_shuffle"; case ISD::SPLAT_VECTOR: return "splat_vector"; + case ISD::SPLAT_VECTOR_PARTS: return "splat_vector_parts"; case ISD::VECTOR_REVERSE: return "vector_reverse"; case ISD::CARRY_FALSE: return "carry_false"; case ISD::ADDC: return "addc"; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -430,7 +430,7 @@ SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const; SDValue lowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const; SDValue lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, bool IsSRA) const; - SDValue lowerSPLATVECTOR(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerSPLAT_VECTOR_PARTS(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG, int64_t ExtTrueVal) const; SDValue lowerVectorMaskTrunc(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -400,7 +400,6 @@ } else { // We must custom-lower certain vXi64 operations on RV32 due to the vector // element type being illegal. - setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom); @@ -425,15 +424,13 @@ for (MVT VT : IntVecVTs) { setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); + setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom); setOperationAction(ISD::SMIN, VT, Legal); setOperationAction(ISD::SMAX, VT, Legal); setOperationAction(ISD::UMIN, VT, Legal); setOperationAction(ISD::UMAX, VT, Legal); - if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64) - setOperationAction(ISD::ABS, VT, Custom); - setOperationAction(ISD::ROTL, VT, Expand); setOperationAction(ISD::ROTR, VT, Expand); @@ -1288,8 +1285,8 @@ Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1) return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1); return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL); - case ISD::SPLAT_VECTOR: - return lowerSPLATVECTOR(Op, DAG); + case ISD::SPLAT_VECTOR_PARTS: + return lowerSPLAT_VECTOR_PARTS(Op, DAG); case ISD::INSERT_VECTOR_ELT: return lowerINSERT_VECTOR_ELT(Op, DAG); case ISD::EXTRACT_VECTOR_ELT: @@ -2013,30 +2010,27 @@ return DAG.getMergeValues(Parts, DL); } -// Custom-lower a SPLAT_VECTOR where XLEN(SplatVal)) { - if (isInt<32>(CVal->getSExtValue())) - return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, - DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32)); - } + SDValue Lo = Op.getOperand(0); + SDValue Hi = Op.getOperand(1); - if (SplatVal.getOpcode() == ISD::SIGN_EXTEND && - SplatVal.getOperand(0).getValueType() == MVT::i32) { - return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, - SplatVal.getOperand(0)); + if (isa(Lo) && isa(Hi)) { + int32_t LoC = cast(Lo)->getSExtValue(); + int32_t HiC = cast(Hi)->getSExtValue(); + // If Hi constant is all the same sign bit as Lo, lower this as a custom + // node in order to try and match RVV vector/scalar instructions. + if ((LoC >> 31) == HiC) + return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo); } // Else, on RV32 we lower an i64-element SPLAT_VECTOR thus, being careful not @@ -2047,11 +2041,7 @@ // vsll.vx vY, vY, /*32*/ // vsrl.vx vY, vY, /*32*/ // vor.vv vX, vX, vY - SDValue One = DAG.getConstant(1, DL, MVT::i32); - SDValue Zero = DAG.getConstant(0, DL, MVT::i32); SDValue ThirtyTwoV = DAG.getConstant(32, DL, VecVT); - SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, Zero); - SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, One); Lo = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo); Lo = DAG.getNode(ISD::SHL, DL, VecVT, Lo, ThirtyTwoV); @@ -2920,17 +2910,6 @@ MVT VT = Op.getSimpleValueType(); SDValue X = Op.getOperand(0); - // For scalable vectors we just need to deal with i64 on RV32 since the - // default expansion crashes in getConstant. - if (VT.isScalableVector()) { - assert(!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64 && - "Unexpected custom lowering!"); - SDValue SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VT, - DAG.getConstant(0, DL, MVT::i32)); - SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, SplatZero, X); - return DAG.getNode(ISD::SMAX, DL, VT, X, NegX); - } - assert(VT.isFixedLengthVector() && "Unexpected type"); MVT ContainerVT =