diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -583,6 +583,15 @@ /// implicitly truncated to it. SPLAT_VECTOR, + /// SPLAT_VECTOR_PARTS(SCALAR1, SCALAR2, ...) - Returns a vector with the + /// scalar values joined together and then duplicated in all lanes. This + /// represents a SPLAT_VECTOR that has had its scalar operand expanded. This + /// allows representing a 64-bit splat on a target with 32-bit integers. The + /// total width of the scalars must cover the element width. SCALAR1 contains + /// the least significant bits of the value regardless of endianness and all + /// scalars should have the same type. + SPLAT_VECTOR_PARTS, + /// MULHU/MULHS - Multiply high - Multiply two integers of type iN, /// producing an unsigned/signed value of type i[2*N], then return the top /// part. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -4194,6 +4194,7 @@ case ISD::EXTRACT_ELEMENT: Res = ExpandOp_EXTRACT_ELEMENT(N); break; case ISD::INSERT_VECTOR_ELT: Res = ExpandOp_INSERT_VECTOR_ELT(N); break; case ISD::SCALAR_TO_VECTOR: Res = ExpandOp_SCALAR_TO_VECTOR(N); break; + case ISD::SPLAT_VECTOR: Res = ExpandIntOp_SPLAT_VECTOR(N); break; case ISD::SELECT_CC: Res = ExpandIntOp_SELECT_CC(N); break; case ISD::SETCC: Res = ExpandIntOp_SETCC(N); break; case ISD::SETCCCARRY: Res = ExpandIntOp_SETCCCARRY(N); break; @@ -4449,6 +4450,14 @@ LowCmp.getValue(1), Cond); } +SDValue DAGTypeLegalizer::ExpandIntOp_SPLAT_VECTOR(SDNode *N) { + // Split the operand and replace with SPLAT_VECTOR_PARTS. + SDValue Lo, Hi; + GetExpandedInteger(N->getOperand(0), Lo, Hi); + return DAG.getNode(ISD::SPLAT_VECTOR_PARTS, SDLoc(N), N->getValueType(0), Lo, + Hi); +} + SDValue DAGTypeLegalizer::ExpandIntOp_Shift(SDNode *N) { // The value being shifted is legal, but the shift amount is too big. // It follows that either the result of the shift is undefined, or the diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -481,6 +481,7 @@ SDValue ExpandIntOp_UINT_TO_FP(SDNode *N); SDValue ExpandIntOp_RETURNADDR(SDNode *N); SDValue ExpandIntOp_ATOMIC_STORE(SDNode *N); + SDValue ExpandIntOp_SPLAT_VECTOR(SDNode *N); void IntegerExpandSetCCOperands(SDValue &NewLHS, SDValue &NewRHS, ISD::CondCode &CCCode, const SDLoc &dl); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1383,6 +1383,22 @@ const APInt &NewVal = Elt->getValue(); EVT ViaEltVT = TLI->getTypeToTransformTo(*getContext(), EltVT); unsigned ViaEltSizeInBits = ViaEltVT.getSizeInBits(); + + // For scalable vectors, try to use a SPLAT_VECTOR_PARTS node. + if (VT.isScalableVector()) { + assert(EltVT.getSizeInBits() % ViaEltSizeInBits == 0 && + "Can only handle an even split!"); + unsigned Parts = EltVT.getSizeInBits() / ViaEltSizeInBits; + + SmallVector ScalarParts; + for (unsigned i = 0; i != Parts; ++i) + ScalarParts.push_back(getConstant( + NewVal.lshr(i * ViaEltSizeInBits).trunc(ViaEltSizeInBits), DL, + ViaEltVT, isT, isO)); + + return getNode(ISD::SPLAT_VECTOR_PARTS, DL, VT, ScalarParts); + } + unsigned ViaVecNumElts = VT.getSizeInBits() / ViaEltSizeInBits; EVT ViaVecVT = EVT::getVectorVT(*getContext(), ViaEltVT, ViaVecNumElts); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -290,6 +290,7 @@ case ISD::VECTOR_SHUFFLE: return "vector_shuffle"; case ISD::VECTOR_SPLICE: return "vector_splice"; case ISD::SPLAT_VECTOR: return "splat_vector"; + case ISD::SPLAT_VECTOR_PARTS: return "splat_vector_parts"; case ISD::VECTOR_REVERSE: return "vector_reverse"; case ISD::CARRY_FALSE: return "carry_false"; case ISD::ADDC: return "addc"; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -437,7 +437,7 @@ SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const; SDValue lowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const; SDValue lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, bool IsSRA) const; - SDValue lowerSPLATVECTOR(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerSPLAT_VECTOR_PARTS(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG, int64_t ExtTrueVal) const; SDValue lowerVectorMaskTrunc(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -399,7 +399,6 @@ if (!Subtarget.is64Bit()) { // We must custom-lower certain vXi64 operations on RV32 due to the vector // element type being illegal. - setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom); @@ -424,15 +423,13 @@ for (MVT VT : IntVecVTs) { setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); + setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom); setOperationAction(ISD::SMIN, VT, Legal); setOperationAction(ISD::SMAX, VT, Legal); setOperationAction(ISD::UMIN, VT, Legal); setOperationAction(ISD::UMAX, VT, Legal); - if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64) - setOperationAction(ISD::ABS, VT, Custom); - setOperationAction(ISD::ROTL, VT, Expand); setOperationAction(ISD::ROTR, VT, Expand); @@ -1313,8 +1310,8 @@ Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1) return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1); return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL); - case ISD::SPLAT_VECTOR: - return lowerSPLATVECTOR(Op, DAG); + case ISD::SPLAT_VECTOR_PARTS: + return lowerSPLAT_VECTOR_PARTS(Op, DAG); case ISD::INSERT_VECTOR_ELT: return lowerINSERT_VECTOR_ELT(Op, DAG); case ISD::EXTRACT_VECTOR_ELT: @@ -2035,30 +2032,28 @@ return DAG.getMergeValues(Parts, DL); } -// Custom-lower a SPLAT_VECTOR where XLEN(SplatVal)) { - if (isInt<32>(CVal->getSExtValue())) - return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, - DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32)); - } + assert(Op.getNumOperands() == 2 && "Unexpected number of operands!"); + SDValue Lo = Op.getOperand(0); + SDValue Hi = Op.getOperand(1); - if (SplatVal.getOpcode() == ISD::SIGN_EXTEND && - SplatVal.getOperand(0).getValueType() == MVT::i32) { - return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, - SplatVal.getOperand(0)); + if (isa(Lo) && isa(Hi)) { + int32_t LoC = cast(Lo)->getSExtValue(); + int32_t HiC = cast(Hi)->getSExtValue(); + // If Hi constant is all the same sign bit as Lo, lower this as a custom + // node in order to try and match RVV vector/scalar instructions. + if ((LoC >> 31) == HiC) + return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo); } // Else, on RV32 we lower an i64-element SPLAT_VECTOR thus, being careful not @@ -2069,11 +2064,7 @@ // vsll.vx vY, vY, /*32*/ // vsrl.vx vY, vY, /*32*/ // vor.vv vX, vX, vY - SDValue One = DAG.getConstant(1, DL, MVT::i32); - SDValue Zero = DAG.getConstant(0, DL, MVT::i32); SDValue ThirtyTwoV = DAG.getConstant(32, DL, VecVT); - SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, Zero); - SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, One); Lo = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo); Lo = DAG.getNode(ISD::SHL, DL, VecVT, Lo, ThirtyTwoV); @@ -3162,17 +3153,6 @@ MVT VT = Op.getSimpleValueType(); SDValue X = Op.getOperand(0); - // For scalable vectors we just need to deal with i64 on RV32 since the - // default expansion crashes in getConstant. - if (VT.isScalableVector()) { - assert(!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64 && - "Unexpected custom lowering!"); - SDValue SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VT, - DAG.getConstant(0, DL, MVT::i32)); - SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, SplatZero, X); - return DAG.getNode(ISD::SMAX, DL, VT, X, NegX); - } - assert(VT.isFixedLengthVector() && "Unexpected type"); MVT ContainerVT =