Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3911,9 +3911,13 @@ Op.getOperand(2)); } case Intrinsic::aarch64_neon_sdot: - case Intrinsic::aarch64_neon_udot: { - unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT - : AArch64ISD::SDOT; + case Intrinsic::aarch64_neon_udot: + case Intrinsic::aarch64_sve_sdot: + case Intrinsic::aarch64_sve_udot: { + unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot || + IntNo == Intrinsic::aarch64_sve_udot) + ? AArch64ISD::UDOT + : AArch64ISD::SDOT; return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); } @@ -13304,6 +13308,24 @@ DAG.getConstant(0, DL, MVT::i64)); } +static bool isDupZero(const SDValue N) { + switch (N->getOpcode()) { + case AArch64ISD::DUP: + case ISD::SPLAT_VECTOR: { + auto Opnd0 = N->getOperand(0); + if (auto *CN = dyn_cast(Opnd0)) + if (CN->isNullValue()) + return true; + if (auto *CN = dyn_cast(Opnd0)) + if (CN->isZero()) + return true; + break; + } + } + + return false; +} + // ADD(UDOT(zero, x, y), A) --> UDOT(A, x, y) static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) { EVT VT = N->getValueType(0); @@ -13312,11 +13334,13 @@ SDValue Dot = N->getOperand(0); SDValue A = N->getOperand(1); + // Handle commutivity auto isZeroDot = [](SDValue Dot) { return (Dot.getOpcode() == AArch64ISD::UDOT || Dot.getOpcode() == AArch64ISD::SDOT) && - ISD::isBuildVectorAllZeros(Dot.getOperand(0).getNode()); + (isDupZero(Dot.getOperand(0)) || + ISD::isBuildVectorAllZeros(Dot.getOperand(0).getNode())); }; if (!isZeroDot(Dot)) std::swap(Dot, A); Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -353,8 +353,8 @@ defm SDIV_ZPZZ : sve_int_bin_pred_sd; defm UDIV_ZPZZ : sve_int_bin_pred_sd; - defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", int_aarch64_sve_sdot>; - defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", int_aarch64_sve_udot>; + defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>; + defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>; defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>; defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>; Index: llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll +++ llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll @@ -1,3 +1,4 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s 2>%t | FileCheck %s ; RUN: FileCheck --check-prefix=WARN --allow-empty %s <%t @@ -10,8 +11,9 @@ define @abs_i8( %a, %pg, %b) { ; CHECK-LABEL: abs_i8: -; CHECK: abs z0.b, p0/m, z1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: abs z0.b, p0/m, z1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.abs.nxv16i8( %a, %pg, %b) @@ -20,8 +22,9 @@ define @abs_i16( %a, %pg, %b) { ; CHECK-LABEL: abs_i16: -; CHECK: abs z0.h, p0/m, z1.h -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: abs z0.h, p0/m, z1.h +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.abs.nxv8i16( %a, %pg, %b) @@ -30,8 +33,9 @@ define @abs_i32( %a, %pg, %b) { ; CHECK-LABEL: abs_i32: -; CHECK: abs z0.s, p0/m, z1.s -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: abs z0.s, p0/m, z1.s +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.abs.nxv4i32( %a, %pg, %b) @@ -40,8 +44,9 @@ define @abs_i64( %a, %pg, %b) { ; CHECK-LABEL: abs_i64: -; CHECK: abs z0.d, p0/m, z1.d -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: abs z0.d, p0/m, z1.d +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.abs.nxv2i64( %a, %pg, %b) @@ -54,8 +59,9 @@ define @neg_i8( %a, %pg, %b) { ; CHECK-LABEL: neg_i8: -; CHECK: neg z0.b, p0/m, z1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: neg z0.b, p0/m, z1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.neg.nxv16i8( %a, %pg, %b) @@ -64,8 +70,9 @@ define @neg_i16( %a, %pg, %b) { ; CHECK-LABEL: neg_i16: -; CHECK: neg z0.h, p0/m, z1.h -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: neg z0.h, p0/m, z1.h +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.neg.nxv8i16( %a, %pg, %b) @@ -74,8 +81,9 @@ define @neg_i32( %a, %pg, %b) { ; CHECK-LABEL: neg_i32: -; CHECK: neg z0.s, p0/m, z1.s -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: neg z0.s, p0/m, z1.s +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.neg.nxv4i32( %a, %pg, %b) @@ -84,8 +92,9 @@ define @neg_i64( %a, %pg, %b) { ; CHECK-LABEL: neg_i64: -; CHECK: neg z0.d, p0/m, z1.d -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: neg z0.d, p0/m, z1.d +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.neg.nxv2i64( %a, %pg, %b) @@ -96,8 +105,9 @@ define @sdot_i32( %a, %b, %c) { ; CHECK-LABEL: sdot_i32: -; CHECK: sdot z0.s, z1.b, z2.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sdot.nxv4i32( %a, %b, %c) @@ -106,20 +116,44 @@ define @sdot_i64( %a, %b, %c) { ; CHECK-LABEL: sdot_i64: -; CHECK: sdot z0.d, z1.h, z2.h -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sdot z0.d, z1.h, z2.h +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sdot.nxv2i64( %a, %b, %c) ret %out } +define @test_sdot_i64_zero( %a, %b, %c) { +; CHECK-LABEL: test_sdot_i64_zero: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sdot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.sdot.nxv2i64( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + +define @test_sdot_i32_zero( %a, %b, %c) { +; CHECK-LABEL: test_sdot_i32_zero: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.sdot.nxv4i32( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + ; SDOT (Indexed) define @sdot_lane_i32( %a, %b, %c) { ; CHECK-LABEL: sdot_lane_i32: -; CHECK: sdot z0.s, z1.b, z2.b[2] -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sdot z0.s, z1.b, z2.b[2] +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sdot.lane.nxv4i32( %a, %b, %c, @@ -129,8 +163,9 @@ define @sdot_lane_i64( %a, %b, %c) { ; CHECK-LABEL: sdot_lane_i64: -; CHECK: sdot z0.d, z1.h, z2.h[1] -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sdot z0.d, z1.h, z2.h[1] +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sdot.lane.nxv2i64( %a, %b, %c, @@ -142,8 +177,9 @@ define @sqadd_i8( %a, %b) { ; CHECK-LABEL: sqadd_i8: -; CHECK: sqadd z0.b, z0.b, z1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sqadd z0.b, z0.b, z1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sqadd.x.nxv16i8( %a, %b) ret %out @@ -151,8 +187,9 @@ define @sqadd_i16( %a, %b) { ; CHECK-LABEL: sqadd_i16: -; CHECK: sqadd z0.h, z0.h, z1.h -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sqadd z0.h, z0.h, z1.h +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sqadd.x.nxv8i16( %a, %b) ret %out @@ -160,8 +197,9 @@ define @sqadd_i32( %a, %b) { ; CHECK-LABEL: sqadd_i32: -; CHECK: sqadd z0.s, z0.s, z1.s -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sqadd z0.s, z0.s, z1.s +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sqadd.x.nxv4i32( %a, %b) ret %out @@ -169,8 +207,9 @@ define @sqadd_i64( %a, %b) { ; CHECK-LABEL: sqadd_i64: -; CHECK: sqadd z0.d, z0.d, z1.d -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sqadd z0.d, z0.d, z1.d +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sqadd.x.nxv2i64( %a, %b) ret %out @@ -180,8 +219,9 @@ define @sqsub_i8( %a, %b) { ; CHECK-LABEL: sqsub_i8: -; CHECK: sqsub z0.b, z0.b, z1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sqsub z0.b, z0.b, z1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sqsub.x.nxv16i8( %a, %b) ret %out @@ -189,8 +229,9 @@ define @sqsub_i16( %a, %b) { ; CHECK-LABEL: sqsub_i16: -; CHECK: sqsub z0.h, z0.h, z1.h -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sqsub z0.h, z0.h, z1.h +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sqsub.x.nxv8i16( %a, %b) ret %out @@ -198,8 +239,9 @@ define @sqsub_i32( %a, %b) { ; CHECK-LABEL: sqsub_i32: -; CHECK: sqsub z0.s, z0.s, z1.s -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sqsub z0.s, z0.s, z1.s +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sqsub.x.nxv4i32( %a, %b) ret %out @@ -207,8 +249,9 @@ define @sqsub_i64( %a, %b) { ; CHECK-LABEL: sqsub_i64: -; CHECK: sqsub z0.d, z0.d, z1.d -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: sqsub z0.d, z0.d, z1.d +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.sqsub.x.nxv2i64( %a, %b) ret %out @@ -218,8 +261,9 @@ define @udot_i32( %a, %b, %c) { ; CHECK-LABEL: udot_i32: -; CHECK: udot z0.s, z1.b, z2.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: udot z0.s, z1.b, z2.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.udot.nxv4i32( %a, %b, %c) @@ -228,20 +272,44 @@ define @udot_i64( %a, %b, %c) { ; CHECK-LABEL: udot_i64: -; CHECK: udot z0.d, z1.h, z2.h -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: udot z0.d, z1.h, z2.h +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.udot.nxv2i64( %a, %b, %c) ret %out } +define @test_udot_i64_zero( %a, %b, %c) { +; CHECK-LABEL: test_udot_i64_zero: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: udot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.udot.nxv2i64( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + +define @test_udot_i32_zero( %a, %b, %c) { +; CHECK-LABEL: test_udot_i32_zero: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: udot z0.s, z1.b, z2.b +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.udot.nxv4i32( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + ; UDOT (Indexed) define @udot_lane_i32( %a, %b, %c) { ; CHECK-LABEL: udot_lane_i32: -; CHECK: udot z0.s, z1.b, z2.b[2] -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: udot z0.s, z1.b, z2.b[2] +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.udot.lane.nxv4i32( %a, %b, %c, @@ -253,8 +321,9 @@ define @uqadd_i8( %a, %b) { ; CHECK-LABEL: uqadd_i8: -; CHECK: uqadd z0.b, z0.b, z1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: uqadd z0.b, z0.b, z1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.uqadd.x.nxv16i8( %a, %b) ret %out @@ -262,8 +331,9 @@ define @uqadd_i16( %a, %b) { ; CHECK-LABEL: uqadd_i16: -; CHECK: uqadd z0.h, z0.h, z1.h -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: uqadd z0.h, z0.h, z1.h +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.uqadd.x.nxv8i16( %a, %b) ret %out @@ -271,8 +341,9 @@ define @uqadd_i32( %a, %b) { ; CHECK-LABEL: uqadd_i32: -; CHECK: uqadd z0.s, z0.s, z1.s -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: uqadd z0.s, z0.s, z1.s +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.uqadd.x.nxv4i32( %a, %b) ret %out @@ -280,8 +351,9 @@ define @uqadd_i64( %a, %b) { ; CHECK-LABEL: uqadd_i64: -; CHECK: uqadd z0.d, z0.d, z1.d -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: uqadd z0.d, z0.d, z1.d +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.uqadd.x.nxv2i64( %a, %b) ret %out @@ -291,8 +363,9 @@ define @uqsub_i8( %a, %b) { ; CHECK-LABEL: uqsub_i8: -; CHECK: uqsub z0.b, z0.b, z1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: uqsub z0.b, z0.b, z1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.uqsub.x.nxv16i8( %a, %b) ret %out @@ -300,8 +373,9 @@ define @uqsub_i16( %a, %b) { ; CHECK-LABEL: uqsub_i16: -; CHECK: uqsub z0.h, z0.h, z1.h -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: uqsub z0.h, z0.h, z1.h +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.uqsub.x.nxv8i16( %a, %b) ret %out @@ -309,8 +383,9 @@ define @uqsub_i32( %a, %b) { ; CHECK-LABEL: uqsub_i32: -; CHECK: uqsub z0.s, z0.s, z1.s -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: uqsub z0.s, z0.s, z1.s +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.uqsub.x.nxv4i32( %a, %b) ret %out @@ -318,8 +393,9 @@ define @uqsub_i64( %a, %b) { ; CHECK-LABEL: uqsub_i64: -; CHECK: uqsub z0.d, z0.d, z1.d -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: uqsub z0.d, z0.d, z1.d +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.uqsub.x.nxv2i64( %a, %b) ret %out @@ -328,30 +404,36 @@ ; ADD (tuples) define @add_i64_tuple2(* %out, %in1, %in2) { -; CHECK-LABEL: add_i64_tuple2 -; CHECK: add z0.d, z0.d, z0.d -; CHECK: add z1.d, z1.d, z1.d +; CHECK-LABEL: add_i64_tuple2: +; CHECK: // %bb.0: +; CHECK-NEXT: add z0.d, z0.d, z0.d +; CHECK-NEXT: add z1.d, z1.d, z1.d +; CHECK-NEXT: ret %tuple = tail call @llvm.aarch64.sve.tuple.create2.nxv4i64.nxv2i64( %in1, %in2) %res = add %tuple, %tuple ret %res } define @add_i64_tuple3(* %out, %in1, %in2, %in3) { -; CHECK-LABEL: add_i64_tuple3 -; CHECK: add z0.d, z0.d, z0.d -; CHECK: add z1.d, z1.d, z1.d -; CHECK: add z2.d, z2.d, z2.d +; CHECK-LABEL: add_i64_tuple3: +; CHECK: // %bb.0: +; CHECK-NEXT: add z0.d, z0.d, z0.d +; CHECK-NEXT: add z1.d, z1.d, z1.d +; CHECK-NEXT: add z2.d, z2.d, z2.d +; CHECK-NEXT: ret %tuple = tail call @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64( %in1, %in2, %in3) %res = add %tuple, %tuple ret %res } define @add_i64_tuple4(* %out, %in1, %in2, %in3, %in4) { -; CHECK-LABEL: add_i64_tuple4 -; CHECK: add z0.d, z0.d, z0.d -; CHECK: add z1.d, z1.d, z1.d -; CHECK: add z2.d, z2.d, z2.d -; CHECK: add z3.d, z3.d, z3.d +; CHECK-LABEL: add_i64_tuple4: +; CHECK: // %bb.0: +; CHECK-NEXT: add z0.d, z0.d, z0.d +; CHECK-NEXT: add z1.d, z1.d, z1.d +; CHECK-NEXT: add z2.d, z2.d, z2.d +; CHECK-NEXT: add z3.d, z3.d, z3.d +; CHECK-NEXT: ret %tuple = tail call @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64( %in1, %in2, %in3, %in4) %res = add %tuple, %tuple ret %res