diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -145,6 +145,9 @@ if (auto *Op0 = dyn_cast(N->getOperand(0))) { SplatVal = Op0->getAPIntValue().truncOrSelf(EltSize); return true; + } else if (auto *Op0 = dyn_cast(N->getOperand(0))) { + SplatVal = Op0->getValueAPF().bitcastToAPInt().truncOrSelf(EltSize); + return true; } } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2153,6 +2153,24 @@ // Lowering Code //===----------------------------------------------------------------------===// +/// isZerosVector - Check whether SDNode N is a zero-filled vector. +static bool isZerosVector(const SDNode *N) { + // Look through a bit convert. + while (N->getOpcode() == ISD::BITCAST) + N = N->getOperand(0).getNode(); + + if (ISD::isConstantSplatVectorAllZeros(N)) + return true; + + if (N->getOpcode() != AArch64ISD::DUP) + return false; + + auto Opnd0 = N->getOperand(0); + auto *CINT = dyn_cast(Opnd0); + auto *CFP = dyn_cast(Opnd0); + return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero()); +} + /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64 /// CC static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) { @@ -3924,9 +3942,13 @@ Op.getOperand(2)); } case Intrinsic::aarch64_neon_sdot: - case Intrinsic::aarch64_neon_udot: { - unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT - : AArch64ISD::SDOT; + case Intrinsic::aarch64_neon_udot: + case Intrinsic::aarch64_sve_sdot: + case Intrinsic::aarch64_sve_udot: { + unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot || + IntNo == Intrinsic::aarch64_sve_udot) + ? AArch64ISD::UDOT + : AArch64ISD::SDOT; return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); } @@ -13340,7 +13362,7 @@ auto isZeroDot = [](SDValue Dot) { return (Dot.getOpcode() == AArch64ISD::UDOT || Dot.getOpcode() == AArch64ISD::SDOT) && - ISD::isBuildVectorAllZeros(Dot.getOperand(0).getNode()); + isZerosVector(Dot.getOperand(0).getNode()); }; if (!isZeroDot(Dot)) std::swap(Dot, A); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -353,8 +353,8 @@ defm SDIV_ZPZZ : sve_int_bin_pred_sd; defm UDIV_ZPZZ : sve_int_bin_pred_sd; - defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", int_aarch64_sve_sdot>; - defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", int_aarch64_sve_udot>; + defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>; + defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>; defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>; defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>; diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll @@ -114,6 +114,26 @@ ret %out } +define @test_sdot_i64_zero( %a, %b, %c) { +; CHECK-LABEL: test_sdot_i64_zero: +; CHECK: sdot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.sdot.nxv2i64( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + +define @test_sdot_i32_zero( %a, %b, %c) { +; CHECK-LABEL: test_sdot_i32_zero: +; CHECK: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.sdot.nxv4i32( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + ; SDOT (Indexed) define @sdot_lane_i32( %a, %b, %c) { @@ -236,6 +256,26 @@ ret %out } +define @test_udot_i64_zero( %a, %b, %c) { +; CHECK-LABEL: test_udot_i64_zero: +; CHECK: udot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.udot.nxv2i64( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + +define @test_udot_i32_zero( %a, %b, %c) { +; CHECK-LABEL: test_udot_i32_zero: +; CHECK: udot z0.s, z1.b, z2.b +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.udot.nxv4i32( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + ; UDOT (Indexed) define @udot_lane_i32( %a, %b, %c) {