diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -825,7 +825,7 @@ // splat of 0 or undef) once vector selects supported in SVE codegen. See // D68877 for more details. for (MVT VT : MVT::integer_scalable_vector_valuetypes()) { - if (isTypeLegal(VT) && VT.getVectorElementType() != MVT::i1) + if (isTypeLegal(VT)) setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); @@ -7135,26 +7135,31 @@ switch (ElemVT.getSimpleVT().SimpleTy) { case MVT::i8: case MVT::i16: + case MVT::i32: SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); - break; + return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); case MVT::i64: SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); - break; - case MVT::i32: - // Fine as is - break; - // TODO: we can support splats of i1s and float types, but haven't added - // patterns yet. - case MVT::i1: + return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); + case MVT::i1: { + // The general case of i1. There isn't any natural way to do this, + // so we use some trickery with whilelo. + // TODO: Add special cases for splat of constant true/false. + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); + SplatVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, MVT::i64, SplatVal, + DAG.getValueType(MVT::i1)); + SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, + MVT::i64); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, + DAG.getConstant(0, dl, MVT::i64), SplatVal); + } + // TODO: we can support float types, but haven't added patterns yet. case MVT::f16: case MVT::f32: case MVT::f64: default: - llvm_unreachable("Unsupported SPLAT_VECTOR input operand type"); - break; + report_fatal_error("Unsupported SPLAT_VECTOR input operand type"); } - - return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); } static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits, diff --git a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll --- a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll +++ b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll @@ -93,3 +93,43 @@ %splat = shufflevector %ins, undef, zeroinitializer ret %splat } + +define @sve_splat_2xi1(i1 %val) { +; CHECK-LABEL: @sve_splat_2xi1 +; CHECK: sbfx x8, x0, #0, #1 +; CHECK-NEXT: whilelo p0.d, xzr, x8 +; CHECK-NEXT: ret + %ins = insertelement undef, i1 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +} + +define @sve_splat_4xi1(i1 %val) { +; CHECK-LABEL: @sve_splat_4xi1 +; CHECK: sbfx x8, x0, #0, #1 +; CHECK-NEXT: whilelo p0.s, xzr, x8 +; CHECK-NEXT: ret + %ins = insertelement undef, i1 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +} + +define @sve_splat_8xi1(i1 %val) { +; CHECK-LABEL: @sve_splat_8xi1 +; CHECK: sbfx x8, x0, #0, #1 +; CHECK-NEXT: whilelo p0.h, xzr, x8 +; CHECK-NEXT: ret + %ins = insertelement undef, i1 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +} + +define @sve_splat_16xi1(i1 %val) { +; CHECK-LABEL: @sve_splat_16xi1 +; CHECK: sbfx x8, x0, #0, #1 +; CHECK-NEXT: whilelo p0.b, xzr, x8 +; CHECK-NEXT: ret + %ins = insertelement undef, i1 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +}