diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -405,6 +405,10 @@ /// than the vector element type, and is implicitly truncated to it. SCALAR_TO_VECTOR, + /// SPLAT_VECTOR(VAL) - Returns a vector with the scalar value VAL + /// duplicated in all lanes. + SPLAT_VECTOR, + /// MULHU/MULHS - Multiply high - Multiply two integers of type iN, /// producing an unsigned/signed value of type i[2*N], then return the top /// part. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -161,6 +161,7 @@ SDValue EmitStackConvert(SDValue SrcOp, EVT SlotVT, EVT DestVT, const SDLoc &dl, SDValue ChainIn); SDValue ExpandBUILD_VECTOR(SDNode *Node); + SDValue ExpandSPLAT_VECTOR(SDNode *Node); SDValue ExpandSCALAR_TO_VECTOR(SDNode *Node); void ExpandDYNAMIC_STACKALLOC(SDNode *Node, SmallVectorImpl &Results); @@ -1995,6 +1996,14 @@ return ExpandVectorBuildThroughStack(Node); } +SDValue SelectionDAGLegalize::ExpandSPLAT_VECTOR(SDNode *Node) { + SDLoc DL(Node); + EVT VT = Node->getValueType(0); + SDValue SplatVal = Node->getOperand(0); + + return DAG.getSplatBuildVector(VT, DL, SplatVal); +} + // Expand a node into a call to a libcall. If the result value // does not fit into a register, return the lo part and set the hi part to the // by-reg argument. If it does fit into a single register, return the result @@ -3653,6 +3662,9 @@ case ISD::BUILD_VECTOR: Results.push_back(ExpandBUILD_VECTOR(Node)); break; + case ISD::SPLAT_VECTOR: + Results.push_back(ExpandSPLAT_VECTOR(Node)); + break; case ISD::SRA: case ISD::SRL: case ISD::SHL: { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -100,6 +100,8 @@ Res = PromoteIntRes_BUILD_VECTOR(N); break; case ISD::SCALAR_TO_VECTOR: Res = PromoteIntRes_SCALAR_TO_VECTOR(N); break; + case ISD::SPLAT_VECTOR: + Res = PromoteIntRes_SPLAT_VECTOR(N); break; case ISD::CONCAT_VECTORS: Res = PromoteIntRes_CONCAT_VECTORS(N); break; @@ -1130,6 +1132,8 @@ Res = PromoteIntOp_INSERT_VECTOR_ELT(N, OpNo);break; case ISD::SCALAR_TO_VECTOR: Res = PromoteIntOp_SCALAR_TO_VECTOR(N); break; + case ISD::SPLAT_VECTOR: + Res = PromoteIntOp_SPLAT_VECTOR(N); break; case ISD::VSELECT: case ISD::SELECT: Res = PromoteIntOp_SELECT(N, OpNo); break; case ISD::SELECT_CC: Res = PromoteIntOp_SELECT_CC(N, OpNo); break; @@ -1360,6 +1364,11 @@ GetPromotedInteger(N->getOperand(0))), 0); } +SDValue DAGTypeLegalizer::PromoteIntOp_SPLAT_VECTOR(SDNode *N) { + return SDValue(DAG.UpdateNodeOperands(N, + GetPromotedInteger(N->getOperand(0))), 0); +} + SDValue DAGTypeLegalizer::PromoteIntOp_SELECT(SDNode *N, unsigned OpNo) { assert(OpNo == 0 && "Only know how to promote the condition!"); SDValue Cond = N->getOperand(0); @@ -4101,6 +4110,23 @@ return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, NOutVT, Op); } +SDValue DAGTypeLegalizer::PromoteIntRes_SPLAT_VECTOR(SDNode *N) { + SDLoc dl(N); + + SDValue SplatVal = N->getOperand(0); + + assert(!SplatVal.getValueType().isVector() && "Input must be a scalar"); + + EVT OutVT = N->getValueType(0); + EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT); + assert(NOutVT.isVector() && "Type must be promoted to a vector type"); + EVT NOutElemVT = NOutVT.getVectorElementType(); + + SDValue Op = DAG.getNode(ISD::ANY_EXTEND, dl, NOutElemVT, SplatVal); + + return DAG.getNode(ISD::SPLAT_VECTOR, dl, NOutVT, Op); +} + SDValue DAGTypeLegalizer::PromoteIntRes_CONCAT_VECTORS(SDNode *N) { SDLoc dl(N); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -297,6 +297,7 @@ SDValue PromoteIntRes_VECTOR_SHUFFLE(SDNode *N); SDValue PromoteIntRes_BUILD_VECTOR(SDNode *N); SDValue PromoteIntRes_SCALAR_TO_VECTOR(SDNode *N); + SDValue PromoteIntRes_SPLAT_VECTOR(SDNode *N); SDValue PromoteIntRes_EXTEND_VECTOR_INREG(SDNode *N); SDValue PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N); SDValue PromoteIntRes_CONCAT_VECTORS(SDNode *N); @@ -354,6 +355,7 @@ SDValue PromoteIntOp_EXTRACT_SUBVECTOR(SDNode *N); SDValue PromoteIntOp_CONCAT_VECTORS(SDNode *N); SDValue PromoteIntOp_SCALAR_TO_VECTOR(SDNode *N); + SDValue PromoteIntOp_SPLAT_VECTOR(SDNode *N); SDValue PromoteIntOp_SELECT(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_SELECT_CC(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_SETCC(SDNode *N, unsigned OpNo); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3534,17 +3534,30 @@ void SelectionDAGBuilder::visitShuffleVector(const User &I) { SDValue Src1 = getValue(I.getOperand(0)); SDValue Src2 = getValue(I.getOperand(1)); + Constant *MaskV = cast(I.getOperand(2)); SDLoc DL = getCurSDLoc(); - - SmallVector Mask; - ShuffleVectorInst::getShuffleMask(cast(I.getOperand(2)), Mask); - unsigned MaskNumElts = Mask.size(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); EVT SrcVT = Src1.getValueType(); unsigned SrcNumElts = SrcVT.getVectorNumElements(); + if (MaskV->isNullValue() && VT.isScalableVector()) { + // Canonical splat form of first element of first input vector. + SDValue FirstElt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, + SrcVT.getScalarType(), Src1, + DAG.getConstant(0, DL, + TLI.getVectorIdxTy(DAG.getDataLayout()))); + setValue(&I, DAG.getNode(ISD::SPLAT_VECTOR, DL, VT, FirstElt)); + return; + } + + // For now, we only handle splats for scalable vectors. + assert(!VT.isScalableVector() && "Unsupported scalable vector shuffle"); + + SmallVector Mask; + ShuffleVectorInst::getShuffleMask(MaskV, Mask); + unsigned MaskNumElts = Mask.size(); + if (SrcNumElts == MaskNumElts) { setValue(&I, DAG.getVectorShuffle(VT, DL, Src1, Src2, Mask)); return; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -280,6 +280,7 @@ case ISD::EXTRACT_SUBVECTOR: return "extract_subvector"; case ISD::SCALAR_TO_VECTOR: return "scalar_to_vector"; case ISD::VECTOR_SHUFFLE: return "vector_shuffle"; + case ISD::SPLAT_VECTOR: return "splat_vector"; case ISD::CARRY_FALSE: return "carry_false"; case ISD::ADDC: return "addc"; case ISD::ADDE: return "adde"; diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -690,6 +690,7 @@ setOperationAction(ISD::ANY_EXTEND_VECTOR_INREG, VT, Expand); setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Expand); setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Expand); + setOperationAction(ISD::SPLAT_VECTOR, VT, Expand); } // Constrained floating-point operations default to expand. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -672,6 +672,7 @@ SDValue LowerSCALAR_TO_VECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerSPLAT_VECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -801,6 +801,16 @@ setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom); } + if (Subtarget->hasSVE()) { + for (MVT VT : MVT::integer_scalable_vector_valuetypes()) { + if (isTypeLegal(VT)) { + if (VT.getVectorElementType() != MVT::i1) { + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + } + } + } + } + PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive(); } @@ -3002,6 +3012,8 @@ return LowerBUILD_VECTOR(Op, DAG); case ISD::VECTOR_SHUFFLE: return LowerVECTOR_SHUFFLE(Op, DAG); + case ISD::SPLAT_VECTOR: + return LowerSPLAT_VECTOR(Op, DAG); case ISD::EXTRACT_SUBVECTOR: return LowerEXTRACT_SUBVECTOR(Op, DAG); case ISD::SRA: @@ -7023,6 +7035,41 @@ return GenerateTBL(Op, ShuffleMask, DAG); } +SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op, + SelectionDAG &DAG) const { + SDLoc dl(Op); + EVT VT = Op.getValueType(); + EVT ElemVT = VT.getScalarType(); + + SDValue SplatVal = Op.getOperand(0); + + // Extend input splat value where needed to fit into a GPR (32b or 64b only) + // FPRs don't have this restriction. + switch (ElemVT.getSimpleVT().SimpleTy) { + case MVT::i8: + case MVT::i16: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); + break; + case MVT::i64: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); + break; + case MVT::i32: + // Fine as is + break; + // TODO: we can support splats of i1s and float types, but haven't added + // patterns yet. + case MVT::i1: + case MVT::f16: + case MVT::f32: + case MVT::f64: + default: + llvm_unreachable("Unsupported SPLAT_VECTOR input operand type"); + break; + } + + return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); +} + static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits, APInt &UndefBits) { EVT VT = BVN->getValueType(0); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -187,7 +187,7 @@ defm FCPY_ZPmI : sve_int_dup_fpimm_pred<"fcpy">; // Splat scalar register (unpredicated, GPR or vector + element index) - defm DUP_ZR : sve_int_perm_dup_r<"dup">; + defm DUP_ZR : sve_int_perm_dup_r<"dup", AArch64dup>; defm DUP_ZZI : sve_int_perm_dup_i<"dup">; // Splat scalar register (predicated) diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -624,11 +624,12 @@ //===----------------------------------------------------------------------===// class sve_int_perm_dup_r sz8_64, string asm, ZPRRegOp zprty, - RegisterClass srcRegType> + ValueType vt, RegisterClass srcRegType, + SDPatternOperator op> : I<(outs zprty:$Zd), (ins srcRegType:$Rn), asm, "\t$Zd, $Rn", "", - []>, Sched<[]> { + [(set (vt zprty:$Zd), (op srcRegType:$Rn))]>, Sched<[]> { bits<5> Rn; bits<5> Zd; let Inst{31-24} = 0b00000101; @@ -638,11 +639,11 @@ let Inst{4-0} = Zd; } -multiclass sve_int_perm_dup_r { - def _B : sve_int_perm_dup_r<0b00, asm, ZPR8, GPR32sp>; - def _H : sve_int_perm_dup_r<0b01, asm, ZPR16, GPR32sp>; - def _S : sve_int_perm_dup_r<0b10, asm, ZPR32, GPR32sp>; - def _D : sve_int_perm_dup_r<0b11, asm, ZPR64, GPR64sp>; +multiclass sve_int_perm_dup_r { + def _B : sve_int_perm_dup_r<0b00, asm, ZPR8, nxv16i8, GPR32sp, op>; + def _H : sve_int_perm_dup_r<0b01, asm, ZPR16, nxv8i16, GPR32sp, op>; + def _S : sve_int_perm_dup_r<0b10, asm, ZPR32, nxv4i32, GPR32sp, op>; + def _D : sve_int_perm_dup_r<0b11, asm, ZPR64, nxv2i64, GPR64sp, op>; def : InstAlias<"mov $Zd, $Rn", (!cast(NAME # _B) ZPR8:$Zd, GPR32sp:$Rn), 1>; diff --git a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll @@ -0,0 +1,37 @@ +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s + +define @sve_splat_i8(i8 %val) { +; CHECK-LABEL: @sve_splat_i8 +; CHECK: mov z0.b, w0 +; CHECK-NEXT: ret + %ins = insertelement undef, i8 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +} + +define @sve_splat_i16(i16 %val) { +; CHECK-LABEL: @sve_splat_i16 +; CHECK: mov z0.h, w0 +; CHECK-NEXT: ret + %ins = insertelement undef, i16 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +} + +define @sve_splat_i32(i32 %val) { +; CHECK-LABEL: @sve_splat_i32 +; CHECK: mov z0.s, w0 +; CHECK-NEXT: ret + %ins = insertelement undef, i32 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +} + +define @sve_splat_i64(i64 %val) { +; CHECK-LABEL: @sve_splat_i64 +; CHECK: mov z0.d, x0 +; CHECK-NEXT: ret + %ins = insertelement undef, i64 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +}