diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -47662,22 +47662,37 @@ return Res; } -// Try to combine the following nodes -// t21: v16f32 = X86ISD::VFMULC/VFCMULC t7, t8 -// t15: v32f16 = bitcast t21 -// t16: v32f16 = fadd nnan ninf nsz arcp contract afn reassoc t15, t2 -// into X86ISD::VFMADDC/VFCMADDC if possible: -// t22: v16f32 = bitcast t2 -// t23: v16f32 = nnan ninf nsz arcp contract afn reassoc -// X86ISD::VFMADDC/VFCMADDC t7, t8, t22 -// t24: v32f16 = bitcast t23 +// Try to combine the following nodes: +// FADD(A, FMA(B, C, 0)) and FADD(A, FMUL(B, C)) to FMA(B, C, A) static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { - auto AllowContract = [&DAG](SDNode *N) { + auto AllowContract = [&DAG](const SDNodeFlags &Flags) { return DAG.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast || - N->getFlags().hasAllowContract(); + Flags.hasAllowContract(); }; - if (N->getOpcode() != ISD::FADD || !Subtarget.hasFP16() || !AllowContract(N)) + + auto HasNoSignedZero = [&DAG](const SDNodeFlags &Flags) { + return DAG.getTarget().Options.NoSignedZerosFPMath || + Flags.hasNoSignedZeros(); + }; + auto IsVectorAllNegativeZero = [](const SDNode *N) { + if (N->getOpcode() != X86ISD::VBROADCAST_LOAD) + return false; + assert(N->getSimpleValueType(0).getScalarType() == MVT::f32 && + "Unexpected vector type!"); + if (ConstantPoolSDNode *CP = + dyn_cast(N->getOperand(1)->getOperand(0))) { + APInt AI = APInt(32, 0x80008000, true); + if (const auto *CI = dyn_cast(CP->getConstVal())) + return CI->getValue() == AI; + if (const auto *CF = dyn_cast(CP->getConstVal())) + return CF->getValue() == APFloat(APFloat::IEEEsingle(), AI); + } + return false; + }; + + if (N->getOpcode() != ISD::FADD || !Subtarget.hasFP16() || + !AllowContract(N->getFlags())) return SDValue(); EVT VT = N->getValueType(0); @@ -47686,16 +47701,33 @@ SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); - SDValue CFmul, FAddOp1; - auto GetCFmulFrom = [&CFmul, &AllowContract](SDValue N) -> bool { + bool IsConj; + SDValue FAddOp1, MulOp0, MulOp1; + auto GetCFmulFrom = [&MulOp0, &MulOp1, &IsConj, &AllowContract, + &IsVectorAllNegativeZero, + &HasNoSignedZero](SDValue N) -> bool { if (!N.hasOneUse() || N.getOpcode() != ISD::BITCAST) - return false; + return false; SDValue Op0 = N.getOperand(0); unsigned Opcode = Op0.getOpcode(); - if (Op0.hasOneUse() && AllowContract(Op0.getNode()) && - (Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC)) - CFmul = Op0; - return !!CFmul; + if (Op0.hasOneUse() && AllowContract(Op0->getFlags())) { + if ((Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC)) { + MulOp0 = Op0.getOperand(0); + MulOp1 = Op0.getOperand(1); + IsConj = Opcode == X86ISD::VFCMULC; + return true; + } + if ((Opcode == X86ISD::VFMADDC || Opcode == X86ISD::VFCMADDC) && + ((ISD::isBuildVectorAllZeros(Op0->getOperand(2).getNode()) && + HasNoSignedZero(Op0->getFlags())) || + IsVectorAllNegativeZero(Op0->getOperand(2).getNode()))) { + MulOp0 = Op0.getOperand(0); + MulOp1 = Op0.getOperand(1); + IsConj = Opcode == X86ISD::VFCMADDC; + return true; + } + } + return false; }; if (GetCFmulFrom(LHS)) @@ -47706,14 +47738,12 @@ return SDValue(); MVT CVT = MVT::getVectorVT(MVT::f32, VT.getVectorNumElements() / 2); - assert(CFmul->getValueType(0) == CVT && "Complex type mismatch"); FAddOp1 = DAG.getBitcast(CVT, FAddOp1); - unsigned newOp = CFmul.getOpcode() == X86ISD::VFMULC ? X86ISD::VFMADDC - : X86ISD::VFCMADDC; + unsigned NewOp = IsConj ? X86ISD::VFCMADDC : X86ISD::VFMADDC; // FIXME: How do we handle when fast math flags of FADD are different from // CFMUL's? - CFmul = DAG.getNode(newOp, SDLoc(N), CVT, FAddOp1, CFmul.getOperand(0), - CFmul.getOperand(1), N->getFlags()); + SDValue CFmul = + DAG.getNode(NewOp, SDLoc(N), CVT, FAddOp1, MulOp0, MulOp1, N->getFlags()); return DAG.getBitcast(VT, CFmul); } diff --git a/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll b/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll @@ -0,0 +1,234 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown --fp-contract=fast --enable-no-signed-zeros-fp-math -mattr=avx512fp16 | FileCheck %s --check-prefixes=CHECK,NO-SZ +; RUN: llc < %s -mtriple=x86_64-unknown-unknown --fp-contract=fast -mattr=avx512fp16 | FileCheck %s --check-prefixes=CHECK,HAS-SZ + +; FADD(acc, FMA(a, b, +0.0)) can be combined to FMA(a, b, acc) if the nsz flag set. +define dso_local <32 x half> @test1(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; NO-SZ-LABEL: test1: +; NO-SZ: # %bb.0: # %entry +; NO-SZ-NEXT: vfcmaddcph %zmm1, %zmm0, %zmm2 +; NO-SZ-NEXT: vmovaps %zmm2, %zmm0 +; NO-SZ-NEXT: retq +; +; HAS-SZ-LABEL: test1: +; HAS-SZ: # %bb.0: # %entry +; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; HAS-SZ-NEXT: vfcmaddcph %zmm2, %zmm1, %zmm3 +; HAS-SZ-NEXT: vaddph %zmm0, %zmm3, %zmm0 +; HAS-SZ-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> zeroinitializer, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <32 x half> @test2(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; NO-SZ-LABEL: test2: +; NO-SZ: # %bb.0: # %entry +; NO-SZ-NEXT: vfmaddcph %zmm1, %zmm0, %zmm2 +; NO-SZ-NEXT: vmovaps %zmm2, %zmm0 +; NO-SZ-NEXT: retq +; +; HAS-SZ-LABEL: test2: +; HAS-SZ: # %bb.0: # %entry +; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; HAS-SZ-NEXT: vfmaddcph %zmm2, %zmm1, %zmm3 +; HAS-SZ-NEXT: vaddph %zmm0, %zmm3, %zmm0 +; HAS-SZ-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> zeroinitializer, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <16 x half> @test3(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; NO-SZ-LABEL: test3: +; NO-SZ: # %bb.0: # %entry +; NO-SZ-NEXT: vfcmaddcph %ymm1, %ymm0, %ymm2 +; NO-SZ-NEXT: vmovaps %ymm2, %ymm0 +; NO-SZ-NEXT: retq +; +; HAS-SZ-LABEL: test3: +; HAS-SZ: # %bb.0: # %entry +; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; HAS-SZ-NEXT: vfcmaddcph %ymm2, %ymm1, %ymm3 +; HAS-SZ-NEXT: vaddph %ymm0, %ymm3, %ymm0 +; HAS-SZ-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> zeroinitializer, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <16 x half> @test4(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; NO-SZ-LABEL: test4: +; NO-SZ: # %bb.0: # %entry +; NO-SZ-NEXT: vfmaddcph %ymm1, %ymm0, %ymm2 +; NO-SZ-NEXT: vmovaps %ymm2, %ymm0 +; NO-SZ-NEXT: retq +; +; HAS-SZ-LABEL: test4: +; HAS-SZ: # %bb.0: # %entry +; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; HAS-SZ-NEXT: vfmaddcph %ymm2, %ymm1, %ymm3 +; HAS-SZ-NEXT: vaddph %ymm0, %ymm3, %ymm0 +; HAS-SZ-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> zeroinitializer, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <8 x half> @test5(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; NO-SZ-LABEL: test5: +; NO-SZ: # %bb.0: # %entry +; NO-SZ-NEXT: vfcmaddcph %xmm1, %xmm0, %xmm2 +; NO-SZ-NEXT: vmovaps %xmm2, %xmm0 +; NO-SZ-NEXT: retq +; +; HAS-SZ-LABEL: test5: +; HAS-SZ: # %bb.0: # %entry +; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; HAS-SZ-NEXT: vfcmaddcph %xmm2, %xmm1, %xmm3 +; HAS-SZ-NEXT: vaddph %xmm0, %xmm3, %xmm0 +; HAS-SZ-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> zeroinitializer, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd <8 x half> %3, %acc + ret <8 x half> %add.i +} + +define dso_local <8 x half> @test6(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; NO-SZ-LABEL: test6: +; NO-SZ: # %bb.0: # %entry +; NO-SZ-NEXT: vfmaddcph %xmm1, %xmm0, %xmm2 +; NO-SZ-NEXT: vmovaps %xmm2, %xmm0 +; NO-SZ-NEXT: retq +; +; HAS-SZ-LABEL: test6: +; HAS-SZ: # %bb.0: # %entry +; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; HAS-SZ-NEXT: vfmaddcph %xmm2, %xmm1, %xmm3 +; HAS-SZ-NEXT: vaddph %xmm0, %xmm3, %xmm0 +; HAS-SZ-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> zeroinitializer, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd <8 x half> %3, %acc + ret <8 x half> %add.i +} + +; FADD(acc, FMA(a, b, -0.0)) can be combined to FMA(a, b, acc) no matter if the nsz flag set. +define dso_local <32 x half> @test13(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test13: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %zmm1, %zmm0, %zmm2 +; CHECK-NEXT: vmovaps %zmm2, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> , i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <32 x half> @test14(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test14: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %zmm1, %zmm0, %zmm2 +; CHECK-NEXT: vmovaps %zmm2, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> , i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <16 x half> @test15(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test15: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %ymm1, %ymm0, %ymm2 +; CHECK-NEXT: vmovaps %ymm2, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> , i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <16 x half> @test16(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test16: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %ymm1, %ymm0, %ymm2 +; CHECK-NEXT: vmovaps %ymm2, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> , i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <8 x half> @test17(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test17: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %xmm1, %xmm0, %xmm2 +; CHECK-NEXT: vmovaps %xmm2, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> , i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd <8 x half> %3, %acc + ret <8 x half> %add.i +} + +define dso_local <8 x half> @test18(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test18: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %xmm1, %xmm0, %xmm2 +; CHECK-NEXT: vmovaps %xmm2, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> , i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd <8 x half> %3, %acc + ret <8 x half> %add.i +} + +declare <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg) +declare <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg) +declare <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8) +declare <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8) +declare <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8) +declare <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8)