diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -13054,6 +13054,11 @@ unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); + auto isFusedOp = [HasFMA, HasFMAD](SDValue N) { + unsigned Opcode = N.getOpcode(); + return (HasFMA && Opcode == ISD::FMA) || (HasFMAD && Opcode == ISD::FMAD); + }; + // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { @@ -13085,12 +13090,12 @@ // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E) // This requires reassociation because it changes the order of operations. SDValue FMA, E; - if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode && + if (CanReassociate && isFusedOp(N0) && N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() && N0.getOperand(2).hasOneUse()) { FMA = N0; E = N1; - } else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode && + } else if (CanReassociate && isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() && N1.getOperand(2).hasOneUse()) { FMA = N1; @@ -13146,7 +13151,7 @@ DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); }; - if (N0.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N0)) { SDValue N02 = N0.getOperand(2); if (N02.getOpcode() == ISD::FP_EXTEND) { SDValue N020 = N02.getOperand(0); @@ -13176,7 +13181,7 @@ }; if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N00)) { SDValue N002 = N00.getOperand(2); if (isContractableFMUL(N002) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, @@ -13190,7 +13195,7 @@ // fold (fadd x, (fma y, z, (fpext (fmul u, v))) // -> (fma y, z, (fma (fpext u), (fpext v), x)) - if (N1.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N1)) { SDValue N12 = N1.getOperand(2); if (N12.getOpcode() == ISD::FP_EXTEND) { SDValue N120 = N12.getOperand(0); @@ -13211,7 +13216,7 @@ // interesting for all targets, especially GPUs. if (N1.getOpcode() == ISD::FP_EXTEND) { SDValue N10 = N1.getOperand(0); - if (N10.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N10)) { SDValue N102 = N10.getOperand(2); if (isContractableFMUL(N102) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, @@ -13407,12 +13412,17 @@ return isContractableFMUL(N) && isReassociable(N.getNode()); }; + auto isFusedOp = [HasFMA, HasFMAD](SDValue N) { + unsigned Opcode = N.getOpcode(); + return (HasFMA && Opcode == ISD::FMA) || (HasFMAD && Opcode == ISD::FMAD); + }; + // More folding opportunities when target permits. if (Aggressive && isReassociable(N)) { bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract(); // fold (fsub (fma x, y, (fmul u, v)), z) // -> (fma x, y (fma u, v, (fneg z))) - if (CanFuse && N0.getOpcode() == PreferredFusedOpcode && + if (CanFuse && isFusedOp(N0) && isContractableAndReassociableFMUL(N0.getOperand(2)) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), @@ -13425,7 +13435,7 @@ // fold (fsub x, (fma y, z, (fmul u, v))) // -> (fma (fneg y), z, (fma (fneg u), v, x)) - if (CanFuse && N1.getOpcode() == PreferredFusedOpcode && + if (CanFuse && isFusedOp(N1) && isContractableAndReassociableFMUL(N1.getOperand(2)) && N1->hasOneUse() && NoSignedZero) { SDValue N20 = N1.getOperand(2).getOperand(0); @@ -13439,8 +13449,7 @@ // fold (fsub (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) - if (N0.getOpcode() == PreferredFusedOpcode && - N0->hasOneUse()) { + if (isFusedOp(N0) && N0->hasOneUse()) { SDValue N02 = N0.getOperand(2); if (N02.getOpcode() == ISD::FP_EXTEND) { SDValue N020 = N02.getOperand(0); @@ -13466,7 +13475,7 @@ // interesting for all targets, especially GPUs. if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N00)) { SDValue N002 = N00.getOperand(2); if (isContractableAndReassociableFMUL(N002) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, @@ -13486,8 +13495,7 @@ // fold (fsub x, (fma y, z, (fpext (fmul u, v)))) // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x)) - if (N1.getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FP_EXTEND && + if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND && N1->hasOneUse()) { SDValue N120 = N1.getOperand(2).getOperand(0); if (isContractableAndReassociableFMUL(N120) && @@ -13511,8 +13519,7 @@ // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - if (N1.getOpcode() == ISD::FP_EXTEND && - N1.getOperand(0).getOpcode() == PreferredFusedOpcode) { + if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) { SDValue CvtSrc = N1.getOperand(0); SDValue N100 = CvtSrc.getOperand(0); SDValue N101 = CvtSrc.getOperand(1); diff --git a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll @@ -0,0 +1,148 @@ +; RUN: llc -march=amdgcn -mcpu=gfx1010 -verify-machineinstrs < %s | FileCheck -enable-var-scope -check-prefix=GCN %s + +; Check there are 13 fma/fmac instructions +; GCN-LABEL: {{^}}_amdgpu_ps_main: +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN: v_fma +; GCN-NOT: v_fma +define amdgpu_ps float @_amdgpu_ps_main() #0 { +.entry: + %0 = call <3 x float> @llvm.amdgcn.image.sample.2d.v3f32.f32(i32 7, float undef, float undef, <8 x i32> undef, <4 x i32> undef, i1 false, i32 0, i32 0) + %.i2243 = extractelement <3 x float> %0, i32 2 + %1 = call <3 x i32> @llvm.amdgcn.s.buffer.load.v3i32(<4 x i32> undef, i32 0, i32 0) + %2 = shufflevector <3 x i32> %1, <3 x i32> poison, <4 x i32> + %3 = bitcast <4 x i32> %2 to <4 x float> + %.i2248 = extractelement <4 x float> %3, i32 2 + %.i2249 = fmul reassoc nnan nsz arcp contract afn float %.i2243, %.i2248 + %4 = call reassoc nnan nsz arcp contract afn float @llvm.amdgcn.fmed3.f32(float undef, float 0.000000e+00, float 1.000000e+00) + %5 = call <3 x float> @llvm.amdgcn.image.sample.2d.v3f32.f32(i32 7, float undef, float undef, <8 x i32> undef, <4 x i32> undef, i1 false, i32 0, i32 0) + %.i2333 = extractelement <3 x float> %5, i32 2 + %6 = call reassoc nnan nsz arcp contract afn float @llvm.amdgcn.fmed3.f32(float undef, float 0.000000e+00, float 1.000000e+00) + %7 = call <2 x float> @llvm.amdgcn.image.sample.2d.v2f32.f32(i32 3, float undef, float undef, <8 x i32> undef, <4 x i32> undef, i1 false, i32 0, i32 0) + %.i1408 = extractelement <2 x float> %7, i32 1 + %.i0364 = extractelement <2 x float> %7, i32 0 + %8 = call float @llvm.amdgcn.image.sample.2d.f32.f32(i32 1, float undef, float undef, <8 x i32> undef, <4 x i32> undef, i1 false, i32 0, i32 0) + %9 = call <3 x i32> @llvm.amdgcn.s.buffer.load.v3i32(<4 x i32> undef, i32 112, i32 0) + %10 = shufflevector <3 x i32> %9, <3 x i32> poison, <4 x i32> + %11 = bitcast <4 x i32> %10 to <4 x float> + %.i2360 = extractelement <4 x float> %11, i32 2 + %.i2363 = fmul reassoc nnan nsz arcp contract afn float %.i2360, %8 + %12 = call <3 x i32> @llvm.amdgcn.s.buffer.load.v3i32(<4 x i32> undef, i32 96, i32 0) + %13 = shufflevector <3 x i32> %12, <3 x i32> poison, <4 x i32> + %14 = bitcast <4 x i32> %13 to <4 x float> + %.i2367 = extractelement <4 x float> %14, i32 2 + %.i2370 = fmul reassoc nnan nsz arcp contract afn float %.i0364, %.i2367 + %15 = call <3 x i32> @llvm.amdgcn.s.buffer.load.v3i32(<4 x i32> undef, i32 32, i32 0) + %16 = shufflevector <3 x i32> %15, <3 x i32> poison, <4 x i32> + %17 = bitcast <4 x i32> %16 to <4 x float> + %.i2373 = extractelement <4 x float> %17, i32 2 + %.i2376 = fsub reassoc nnan nsz arcp contract afn float %.i2373, %.i2370 + %.i2383 = fmul reassoc nnan nsz arcp contract afn float %.i2376, %6 + %.i2386 = fadd reassoc nnan nsz arcp contract afn float %.i2370, %.i2383 + %18 = call reassoc nnan nsz arcp contract afn float @llvm.amdgcn.fmed3.f32(float undef, float 0.000000e+00, float 1.000000e+00) + %19 = fmul reassoc nnan nsz arcp contract afn float %18, %.i2363 + %.i2394 = fsub reassoc nnan nsz arcp contract afn float %.i2386, %19 + %.i2397 = fmul reassoc nnan nsz arcp contract afn float %.i2363, %18 + %.i2404 = fmul reassoc nnan nsz arcp contract afn float %.i2394, %4 + %.i2407 = fadd reassoc nnan nsz arcp contract afn float %.i2397, %.i2404 + %20 = call i32 @llvm.amdgcn.s.buffer.load.i32(<4 x i32> undef, i32 92, i32 0) + %21 = bitcast i32 %20 to float + %22 = call i32 @llvm.amdgcn.s.buffer.load.i32(<4 x i32> undef, i32 124, i32 0) + %23 = bitcast i32 %22 to float + %24 = fsub reassoc nnan nsz arcp contract afn float %21, %23 + %25 = fmul reassoc nnan nsz arcp contract afn float %.i1408, %24 + %26 = fadd reassoc nnan nsz arcp contract afn float %25, %23 + %27 = call i32 @llvm.amdgcn.s.buffer.load.i32(<4 x i32> undef, i32 44, i32 0) + %28 = bitcast i32 %27 to float + %29 = fsub reassoc nnan nsz arcp contract afn float %28, %26 + %30 = fmul reassoc nnan nsz arcp contract afn float %6, %29 + %31 = fadd reassoc nnan nsz arcp contract afn float %26, %30 + %32 = call i32 @llvm.amdgcn.s.buffer.load.i32(<4 x i32> undef, i32 192, i32 0) + %33 = bitcast i32 %32 to float + %34 = fadd reassoc nnan nsz arcp contract afn float %33, -1.000000e+00 + %35 = fmul reassoc nnan nsz arcp contract afn float %18, %34 + %36 = fadd reassoc nnan nsz arcp contract afn float %35, 1.000000e+00 + %37 = fsub reassoc nnan nsz arcp contract afn float %31, %36 + %38 = fmul reassoc nnan nsz arcp contract afn float %37, %4 + %39 = fadd reassoc nnan nsz arcp contract afn float %36, %38 + %40 = fmul reassoc nnan nsz arcp contract afn float %39, 0x3F847AE140000000 + %41 = fadd reassoc nnan nsz arcp contract afn float %40, 0x3F947AE140000000 + %.i2415 = fmul reassoc nnan nsz arcp contract afn float %.i2407, %41 + %42 = call <3 x float> @llvm.amdgcn.image.load.mip.2d.v3f32.i32(i32 7, i32 undef, i32 undef, i32 0, <8 x i32> undef, i32 0, i32 0) + %.i2521 = extractelement <3 x float> %42, i32 2 + %43 = call reassoc nnan nsz arcp contract afn float @llvm.amdgcn.fmed3.f32(float undef, float 0.000000e+00, float 1.000000e+00) + %44 = call <3 x float> @llvm.amdgcn.image.sample.2d.v3f32.f32(i32 7, float undef, float undef, <8 x i32> undef, <4 x i32> undef, i1 false, i32 0, i32 0) + %.i2465 = extractelement <3 x float> %44, i32 2 + %.i2466 = fmul reassoc nnan nsz arcp contract afn float %.i2465, %43 + %.i2469 = fmul reassoc nnan nsz arcp contract afn float %.i2415, %.i2466 + %45 = call <3 x i32> @llvm.amdgcn.s.buffer.load.v3i32(<4 x i32> undef, i32 64, i32 0) + %46 = shufflevector <3 x i32> %45, <3 x i32> poison, <4 x i32> + %47 = bitcast <4 x i32> %46 to <4 x float> + %.i2476 = extractelement <4 x float> %47, i32 2 + %.i2479 = fmul reassoc nnan nsz arcp contract afn float %.i2476, %18 + %48 = call <3 x i32> @llvm.amdgcn.s.buffer.load.v3i32(<4 x i32> undef, i32 80, i32 0) + %49 = shufflevector <3 x i32> %48, <3 x i32> poison, <4 x i32> + %50 = bitcast <4 x i32> %49 to <4 x float> + %.i2482 = extractelement <4 x float> %50, i32 2 + %.i2485 = fsub reassoc nnan nsz arcp contract afn float %.i2482, %.i2479 + %.i2488 = fmul reassoc nnan nsz arcp contract afn float %.i2249, %18 + %.i2491 = fmul reassoc nnan nsz arcp contract afn float %.i2485, %4 + %.i2494 = fadd reassoc nnan nsz arcp contract afn float %.i2479, %.i2491 + %51 = call <3 x float> @llvm.amdgcn.image.sample.2d.v3f32.f32(i32 7, float undef, float undef, <8 x i32> undef, <4 x i32> undef, i1 false, i32 0, i32 0) + %.i2515 = extractelement <3 x float> %51, i32 2 + %.i2516 = fadd reassoc nnan nsz arcp contract afn float %.i2515, %.i2494 + %.i2522 = fadd reassoc nnan nsz arcp contract afn float %.i2521, %.i2516 + %.i2525 = fmul reassoc nnan nsz arcp contract afn float %.i2522, %43 + %52 = call <3 x i32> @llvm.amdgcn.s.buffer.load.v3i32(<4 x i32> undef, i32 16, i32 0) + %53 = shufflevector <3 x i32> %52, <3 x i32> poison, <4 x i32> + %54 = bitcast <4 x i32> %53 to <4 x float> + %.i2530 = extractelement <4 x float> %54, i32 2 + %.i2531 = fmul reassoc nnan nsz arcp contract afn float %.i2333, %.i2530 + %.i2536 = fsub reassoc nnan nsz arcp contract afn float %.i2531, %.i2488 + %.i2539 = fmul reassoc nnan nsz arcp contract afn float %.i2536, %4 + %.i2542 = fadd reassoc nnan nsz arcp contract afn float %.i2488, %.i2539 + %.i2545 = fmul reassoc nnan nsz arcp contract afn float %.i2525, %.i2542 + %.i2548 = fadd reassoc nnan nsz arcp contract afn float %.i2469, %.i2545 + %.i2551 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %.i2548, float 0.000000e+00) + ret float %.i2551 +} + +; Function Attrs: nofree nosync nounwind readnone speculatable willreturn +declare float @llvm.maxnum.f32(float, float) #1 + +; Function Attrs: nounwind readnone speculatable willreturn +declare float @llvm.amdgcn.fmed3.f32(float, float, float) #2 + +; Function Attrs: nounwind readonly willreturn +declare <2 x float> @llvm.amdgcn.image.sample.2d.v2f32.f32(i32 immarg, float, float, <8 x i32>, <4 x i32>, i1 immarg, i32 immarg, i32 immarg) #3 + +; Function Attrs: nounwind readonly willreturn +declare float @llvm.amdgcn.image.sample.2d.f32.f32(i32 immarg, float, float, <8 x i32>, <4 x i32>, i1 immarg, i32 immarg, i32 immarg) #3 + +; Function Attrs: nounwind readonly willreturn +declare <3 x float> @llvm.amdgcn.image.sample.2d.v3f32.f32(i32 immarg, float, float, <8 x i32>, <4 x i32>, i1 immarg, i32 immarg, i32 immarg) #3 + +; Function Attrs: nounwind readonly willreturn +declare <3 x float> @llvm.amdgcn.image.load.mip.2d.v3f32.i32(i32 immarg, i32, i32, i32, <8 x i32>, i32 immarg, i32 immarg) #3 + +; Function Attrs: nounwind readnone willreturn +declare i32 @llvm.amdgcn.s.buffer.load.i32(<4 x i32>, i32, i32 immarg) #3 + +; Function Attrs: nounwind readnone willreturn +declare <3 x i32> @llvm.amdgcn.s.buffer.load.v3i32(<4 x i32>, i32, i32 immarg) #3 + +attributes #0 = { "denormal-fp-math-f32"="preserve-sign" } +attributes #1 = { nofree nosync nounwind readnone speculatable willreturn } +attributes #2 = { nounwind readnone speculatable willreturn } +attributes #3 = { nounwind readonly willreturn }