diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -47487,6 +47487,23 @@ // t23: v16f32 = nnan ninf nsz arcp contract afn reassoc // X86ISD::VFMADDC/VFCMADDC t7, t8, t22 // t24: v32f16 = bitcast t23 +// And +// t11: v16f32 = ConstantFP vector<0.00> +// t4: v32f16,ch = CopyFromReg t0, Register:v32f16 %1 +// t7: v16f32 = bitcast t4 +// t6: v32f16,ch = CopyFromReg t0, Register:v32f16 %2 +// t8: v16f32 = bitcast t6 +// t21: v16f32 = X86ISD::VFCMADDC/VFMADDC +// nnan ninf nsz arcp contract afn reassoc t11, t7, t8 +// t15: v32f16 = bitcast t21 +// t2: v32f16,ch = CopyFromReg t0, Register:v32f16 %0 +// t16: v32f16 = fadd t15, t2 +// into +// t6: v32f16,ch = CopyFromReg t0, Register:v32f16 %2 +// t8: v16f32 = bitcast t6 +// t24: v16f32 = X86ISD::VFCMADDC/VFMADDC +// nnan ninf nsz arcp contract afn reassoc t23, t7, t8 +// t25: v32f16 = bitcast t24 static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { auto AllowContract = [&DAG](SDNode *N) { @@ -47503,14 +47520,23 @@ SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); SDValue CFmul, FAddOp1; - auto GetCFmulFrom = [&CFmul, &AllowContract](SDValue N) -> bool { + auto GetCFmulFrom = [&CFmul, &AllowContract, &DAG](SDValue N) -> bool { if (!N.hasOneUse() || N.getOpcode() != ISD::BITCAST) - return false; + return false; SDValue Op0 = N.getOperand(0); unsigned Opcode = Op0.getOpcode(); - if (Op0.hasOneUse() && AllowContract(Op0.getNode()) && - (Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC)) - CFmul = Op0; + if (Op0.hasOneUse() && AllowContract(Op0.getNode())) { + if ((Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC)) + CFmul = Op0; + else if ((Opcode == X86ISD::VFMADDC || Opcode == X86ISD::VFCMADDC) && + ISD::isBuildVectorAllZeros(Op0->getOperand(0).getNode())) { + CFmul = DAG.getNode(Opcode == X86ISD::VFMADDC ? X86ISD::VFMULC + : X86ISD::VFCMULC, + SDLoc(Op0), Op0.getSimpleValueType(), + Op0->getOperand(1), Op0->getOperand(2)); + DAG.ReplaceAllUsesOfValueWith(Op0, CFmul); + } + } return !!CFmul; }; diff --git a/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll b/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll @@ -0,0 +1,93 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx512fp16 --fp-contract=fast --enable-unsafe-fp-math | FileCheck %s + +define dso_local <32 x half> @test1(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test1: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call fast <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> zeroinitializer, <16 x float> %0, <16 x float> %1, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd fast <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <32 x half> @test2(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test2: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call fast <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> zeroinitializer, <16 x float> %0, <16 x float> %1, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd fast <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <16 x half> @test3(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test3: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call fast <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> zeroinitializer, <8 x float> %0, <8 x float> %1, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd fast <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <16 x half> @test4(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test4: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call fast <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> zeroinitializer, <8 x float> %0, <8 x float> %1, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd fast <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <8 x half> @test5(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test5: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call fast <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> zeroinitializer, <4 x float> %0, <4 x float> %1, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd fast <8 x half> %3, %acc + ret <8 x half> %add.i +} + +define dso_local <8 x half> @test6(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test6: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call fast <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> zeroinitializer, <4 x float> %0, <4 x float> %1, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd fast <8 x half> %3, %acc + ret <8 x half> %add.i +} + +declare <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg) +declare <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg) +declare <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8) +declare <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8) +declare <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8) +declare <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8)