diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13528,6 +13528,85 @@ Zero); } +/// Returns true if the given SDNode represents a call to the ptrue intrinsic +/// with the SV_ALL pattern, false otherwise. +static bool isPTrueAllIntrinsic(SDNode *N) { + return getIntrinsicID(N) == Intrinsic::aarch64_sve_ptrue && + N->getConstantOperandVal(1) == AArch64SVEPredPattern::all; +} + +static SDValue combineSVEIntrinsicBinOp(SDNode *N, unsigned Opc, + SelectionDAG &DAG) { + SDLoc DL(N); + + SDValue Pg = N->getOperand(1); + SDValue VecA = N->getOperand(2); + SDValue VecB = N->getOperand(3); + SDNode *PgNode = Pg.getNode(); + SDNode *VecANode = VecA.getNode(); + SDNode *VecBNode = VecB.getNode(); + + switch (Opc) { + case AArch64ISD::MUL_PRED: { + if (!isPTrueAllIntrinsic(PgNode)) + return SDValue(); + + const auto IsUnitDup = [&](SDNode *N) { + if (getIntrinsicID(N) == Intrinsic::aarch64_sve_dup_x) { + auto *DupOperand = cast(N->getOperand(1)); + return DupOperand->isOne(); + } + return false; + }; + + // mul (ptrue sv_all) (dup 1) V => V + if (IsUnitDup(VecANode)) + return VecB; + // mul (ptrue sv_all) V (dup 1) => V + if (IsUnitDup(VecBNode)) + return VecA; + } + } + + return SDValue(); +} + +static SDValue combineSVEIntrinsicFPBinOp(SDNode *N, unsigned Opc, + SelectionDAG &DAG) { + SDLoc DL(N); + + SDValue Pg = N->getOperand(1); + SDValue VecA = N->getOperand(2); + SDValue VecB = N->getOperand(3); + SDNode *PgNode = Pg.getNode(); + SDNode *VecANode = VecA.getNode(); + SDNode *VecBNode = VecB.getNode(); + + switch (Opc) { + case AArch64ISD::FMUL_PRED: { + if (!isPTrueAllIntrinsic(PgNode)) + return SDValue(); + + const auto IsUnitDup = [&](SDNode *N) { + if (getIntrinsicID(N) == Intrinsic::aarch64_sve_dup_x) { + auto *DupOperand = cast(N->getOperand(1)); + return DupOperand->isExactlyValue(1.0); + } + return false; + }; + + // fmul (ptrue sv_all) (dup 1.0) V => V + if (IsUnitDup(VecANode)) + return VecB; + // fmul (ptrue sv_all) V (dup 1.0) => V + if (IsUnitDup(VecBNode)) + return VecA; + } + } + + return SDValue(); +} + static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc, SelectionDAG &DAG) { SDLoc DL(N); @@ -13717,6 +13796,10 @@ return combineSVEReductionFP(N, AArch64ISD::FMINNMV_PRED, DAG); case Intrinsic::aarch64_sve_fminv: return combineSVEReductionFP(N, AArch64ISD::FMINV_PRED, DAG); + case Intrinsic::aarch64_sve_mul: + return combineSVEIntrinsicBinOp(N, AArch64ISD::MUL_PRED, DAG); + case Intrinsic::aarch64_sve_fmul: + return combineSVEIntrinsicFPBinOp(N, AArch64ISD::FMUL_PRED, DAG); case Intrinsic::aarch64_sve_sel: return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2), N->getOperand(3)); diff --git a/llvm/test/CodeGen/AArch64/sve-mul-fmul-idempotency.ll b/llvm/test/CodeGen/AArch64/sve-mul-fmul-idempotency.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-mul-fmul-idempotency.ll @@ -0,0 +1,164 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s 2>%t | FileCheck %s +; RUN: FileCheck --check-prefix=WARN --allow-empty %s <%t + +; If this check fails please read test/CodeGen/AArch64/README for instructions on how to resolve it. +; WARN-NOT: warning + +; Idempotent muls -- should compile to just a ret. +define @idempotent_mul_0( %a) { +; CHECK-LABEL: idempotent_mul_0: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv8i16(i16 1) + %3 = call @llvm.aarch64.sve.mul.nxv8i16( %1, %a, %2) + ret %3 +} + +define @idempotent_mul_1( %a) { +; CHECK-LABEL: idempotent_mul_1: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv4i32(i32 1) + %3 = call @llvm.aarch64.sve.mul.nxv4i32( %1, %a, %2) + ret %3 +} + +define @idempotent_mul_2( %a) { +; CHECK-LABEL: idempotent_mul_2: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv2i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv2i64(i64 1) + %3 = call @llvm.aarch64.sve.mul.nxv2i64( %1, %a, %2) + ret %3 +} + +define @idempotent_mul_3( %a) { +; CHECK-LABEL: idempotent_mul_3: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv2i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv2i64(i64 1) + ; Different argument order to the above tests. + %3 = call @llvm.aarch64.sve.mul.nxv2i64( %1, %2, %a) + ret %3 +} + +; Idempotent fmuls -- should compile to just a ret. +define @idempotent_fmul_0( %a) { +; CHECK-LABEL: idempotent_fmul_0: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv8f16(half 1.0) + %3 = call @llvm.aarch64.sve.fmul.nxv8f16( %1, %a, %2) + ret %3 +} + +define @idempotent_fmul_1( %a) { +; CHECK-LABEL: idempotent_fmul_1: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv4f32(float 1.0) + %3 = call @llvm.aarch64.sve.fmul.nxv4f32( %1, %a, %2) + ret %3 +} + +define @idempotent_fmul_2( %a) { +; CHECK-LABEL: idempotent_fmul_2: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv2i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv2f64(double 1.0) + %3 = call @llvm.aarch64.sve.fmul.nxv2f64( %1, %a, %2) + ret %3 +} + +define @idempotent_fmul_3( %a) { +; CHECK-LABEL: idempotent_fmul_3: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv2i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv2f64(double 1.0) + ; Different argument order to the above tests. + %3 = call @llvm.aarch64.sve.fmul.nxv2f64( %1, %2, %a) + ret %3 +} + +; Non-idempotent muls -- we don't expect these to be optimised out. +define @non_idempotent_mul_0( %a) { +; CHECK-LABEL: non_idempotent_mul_0: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z1.h, #2 // =0x2 +; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv8i16(i16 2) + %3 = call @llvm.aarch64.sve.mul.nxv8i16( %1, %a, %2) + ret %3 +} + +define @non_idempotent_mul_1( %a) { +; CHECK-LABEL: non_idempotent_mul_1: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: mov z1.s, #2 // =0x2 +; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv4i32(i32 2) + %3 = call @llvm.aarch64.sve.mul.nxv4i32( %1, %a, %2) + ret %3 +} + +define @non_idempotent_mul_2( %a) { +; CHECK-LABEL: non_idempotent_mul_2: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z1.d, #2 // =0x2 +; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv2i1(i32 31) + %2 = call @llvm.aarch64.sve.dup.x.nxv2i64(i64 2) + %3 = call @llvm.aarch64.sve.mul.nxv2i64( %1, %a, %2) + ret %3 +} + +define @non_idempotent_mul_3( %a) { + ; Uses a predicate that is not all true, so it shouldn't be optimized out. +; CHECK-LABEL: non_idempotent_mul_3: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h, pow2 +; CHECK-NEXT: mov z1.h, #1 // =0x1 +; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret + %1 = call @llvm.aarch64.sve.ptrue.nxv8i1(i32 0) + %2 = call @llvm.aarch64.sve.dup.x.nxv8i16(i16 1) + %3 = call @llvm.aarch64.sve.mul.nxv8i16( %1, %a, %2) + ret %3 +} + +declare @llvm.aarch64.sve.convert.to.svbool.nxv2i1() +declare @llvm.aarch64.sve.ptrue.nxv2i1(i32 immarg) +declare @llvm.aarch64.sve.ptrue.nxv4i1(i32 immarg) +declare @llvm.aarch64.sve.ptrue.nxv8i1(i32 immarg) + +declare @llvm.aarch64.sve.dup.x.nxv8i16(i16) +declare @llvm.aarch64.sve.dup.x.nxv4i32(i32) +declare @llvm.aarch64.sve.dup.x.nxv2i64(i64) +declare @llvm.aarch64.sve.dup.x.nxv8f16(half) +declare @llvm.aarch64.sve.dup.x.nxv4f32(float) +declare @llvm.aarch64.sve.dup.x.nxv2f64(double) + +declare @llvm.aarch64.sve.mul.nxv8i16(, , ) +declare @llvm.aarch64.sve.mul.nxv4i32(, , ) +declare @llvm.aarch64.sve.mul.nxv2i64(, , ) + +declare @llvm.aarch64.sve.fmul.nxv8f16(, , ) +declare @llvm.aarch64.sve.fmul.nxv4f32(, , ) +declare @llvm.aarch64.sve.fmul.nxv2f64(, , )