This is an archive of the discontinued LLVM Phabricator instance.

[MLIR][GPU] Replace fdiv on fp16 with promoted (fp32) multiplication with reciprocal plus one (conditional) Newton iteration.
ClosedPublic

Authored by csigg on May 22 2022, 1:00 AM.

Details

Summary

This is correct for all values, i.e. the same as promoting the division to fp32 in the NVPTX backend. But it is faster (~10% in average, sometimes more) because:

  • it performs less Newton iterations
  • it avoids the slow path for e.g. denormals
  • it allows reuse of the reciprocal for multiple divisions by the same divisor

Test program:

#include <stdio.h>
#include "cuda_fp16.h"

// This is a variant of CUDA's own __hdiv which is fast than hdiv_promote below
// and doesn't suffer from the perf cliff of div.rn.fp32 with 'special' values.
__device__ half hdiv_newton(half a, half b) {
  float fa = __half2float(a);
  float fb = __half2float(b);

  float rcp;
  asm("{rcp.approx.ftz.f32 %0, %1;\n}" : "=f"(rcp) : "f"(fb));

  float result = fa * rcp;
  auto exponent = reinterpret_cast<const unsigned&>(result) & 0x7f800000;
  if (exponent != 0 && exponent != 0x7f800000) {
    float err = __fmaf_rn(-fb, result, fa);
    result = __fmaf_rn(rcp, err, result);
  }

  return __float2half(result);
}

// Surprisingly, this is faster than CUDA's own __hdiv.
__device__ half hdiv_promote(half a, half b) {
  return __float2half(__half2float(a) / __half2float(b));
}

// This is an approximation that is accurate up to 1 ulp.
__device__ half hdiv_approx(half a, half b) {
  float fa = __half2float(a);
  float fb = __half2float(b);

  float result;
  asm("{div.approx.ftz.f32 %0, %1, %2;\n}" : "=f"(result) : "f"(fa), "f"(fb));
  return __float2half(result);
}

__global__ void CheckCorrectness() {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  half x = reinterpret_cast<const half&>(i);
  for (int j = 0; j < 65536; ++j) {
    half y = reinterpret_cast<const half&>(j);
    half d1 = hdiv_newton(x, y);
    half d2 = hdiv_promote(x, y);
    auto s1 = reinterpret_cast<const short&>(d1);
    auto s2 = reinterpret_cast<const short&>(d2);
    if (s1 != s2) {
      printf("%f (%u) / %f (%u), got %f (%hu), expected: %f (%hu)\n",
             __half2float(x), i, __half2float(y), j, __half2float(d1), s1,
             __half2float(d2), s2);
      //__trap();
    }
  }
}

__device__ half dst;

__global__ void ProfileBuiltin(half x) {
  #pragma unroll 1
  for (int i = 0; i < 10000000; ++i) {
    x = x / x;
  }
  dst = x;
}

__global__ void ProfilePromote(half x) {
  #pragma unroll 1
  for (int i = 0; i < 10000000; ++i) {
    x = hdiv_promote(x, x);
  }
  dst = x;
}

__global__ void ProfileNewton(half x) {
  #pragma unroll 1
  for (int i = 0; i < 10000000; ++i) {
    x = hdiv_newton(x, x);
  }
  dst = x;
}

__global__ void ProfileApprox(half x) {
  #pragma unroll 1
  for (int i = 0; i < 10000000; ++i) {
    x = hdiv_approx(x, x);
  }
  dst = x;
}

int main() {
  CheckCorrectness<<<256, 256>>>();
  half one = __float2half(1.0f);
  ProfileBuiltin<<<1, 1>>>(one);  // 1.001s
  ProfilePromote<<<1, 1>>>(one);  // 0.560s
  ProfileNewton<<<1, 1>>>(one);   // 0.508s
  ProfileApprox<<<1, 1>>>(one);   // 0.304s
  auto status = cudaDeviceSynchronize();
  printf("%s\n", cudaGetErrorString(status));
}

Diff Detail

Event Timeline

csigg created this revision.May 22 2022, 1:00 AM
csigg requested review of this revision.May 22 2022, 1:00 AM
Herald added projects: Restricted Project, Restricted Project, Restricted Project. · View Herald Transcript
tra added a subscriber: tra.May 24 2022, 10:32 AM

I would suggest separating it into separate LLVM and MLIR patches.

LLVM changes look OK to me. No idea about MLIR. we would probably want to lower fp16 fdiv the same way in LLVM, too, but that would also have to be a separate patch.

tra added a reviewer: tra.May 24 2022, 10:33 AM
csigg added a comment.May 25 2022, 4:42 AM

I would suggest separating it into separate LLVM and MLIR patches.

Thanks Artem. I separated out the LLVM changes in https://reviews.llvm.org/D126369.

herhut added inline comments.May 25 2022, 8:26 AM
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
158

This pattern is a bit misplaced here, as LLVM::FDivOp is not really a GPU dialect operation. Instead, should this be a special lowering of the arith dialect to NVVM (which we do not have yet) or a rewrite at the LLVM dialect level?

When lowering to LLVM, we already typically configure a different lowering for math dialect, so configuring the lowering of arith dialect differently seems like an OK option. That would mean a specialized pattern for arith.divf with higher priority. That would also give users a choice.

304

I assume this is to differentiate this pattern somehow but there is no need for an extra patterns.add here.

csigg added inline comments.May 30 2022, 12:03 AM
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
158

Yes, I agree it's a bit misplaced. I considered it the best of all questionable options.

Adding it to ArithToLLVM doesn't really work, because we don't want it to depend on the NVVM dialect.

How about adding it as a separate pass to mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td?

csigg updated this revision to Diff 432871.May 30 2022, 2:28 AM

Rebase.

csigg updated this revision to Diff 432880.May 30 2022, 3:31 AM

Make fdiv rewrite an NVVM transform pass instead.

herhut accepted this revision.May 31 2022, 1:41 AM

Separate pass works for me.

mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
19 ↗(On Diff #432894)

Maybe llvm-optimize-for-nvvm? Or even llvm-optimize-for-nvvm-target?

This does not really optimize nvvm but rewrites llvm ir.

This revision is now accepted and ready to land.May 31 2022, 1:41 AM
csigg updated this revision to Diff 433772.Jun 2 2022, 9:26 AM

Rename pass.

The shared library build was broken, I had to revert: https://lab.llvm.org/buildbot/#/builders/61/builds/27377

csigg added a comment.Jun 4 2022, 3:25 AM

Thanks Mehdi for reverting.