Page MenuHomePhabricator

[mlir]Implement SoftwareBF16 to handle the bf16 type
Needs ReviewPublic

Authored by yiqian1 on May 25 2022, 9:00 PM.

Details

Summary
Some LLVM targets such as AMDGPU and X86 do not support bfloat types.
Add a SoftwareBF16 pass to support the bf16 type on such targets. This
pass replaces all bf16 by i16 and then replaces operations on bf16 by
f32 operations with extended operands and/or truncated results.

Diff Detail

Event Timeline

yiqian1 created this revision.May 25 2022, 9:00 PM
yiqian1 requested review of this revision.May 25 2022, 9:00 PM

A part of the motivation that is missing here (but was mentioned in the RFC) is why do this at this level than in LLVM? E.g., type conversions for backends not supporting specific types is already in LLVM with all kinds of helpers of how one can generally lower by combining overflow behavior of compare odd with add with carry op (as example). It feels like one would need to reimplement some of those expansions here if done at this level. Given it was mentioned, was this discussed and evaluated?

A part of the motivation that is missing here (but was mentioned in the RFC) is why do this at this level than in LLVM? E.g., type conversions for backends not supporting specific types is already in LLVM with all kinds of helpers of how one can generally lower by combining overflow behavior of compare odd with add with carry op (as example). It feels like one would need to reimplement some of those expansions here if done at this level. Given it was mentioned, was this discussed and evaluated?

+1. This pass also seems to be making some opinionated stances on how the conversions should take place, what types to change to, etc. It isn't clear to me why this shouldn't be a part of backend legalization/isel.

krzysz00 added a subscriber: rampitec.EditedJun 1 2022, 3:38 PM

I'll admit to not having looked into the possibility of doing this in LLVM in any particular detail, especially due to my rather limited knowledge of what infrastructure is available for backend legalization.

(I looked around and found some old discussions with @rampitec that I incorrectly remembered as implying doing things in the backend would be hard, which is probably why I didn't look into it)

Doing this as an LLVM backend thing may well be the right call, and it's surprising it hasn't already been done, given that the following LLVM IR

define void @test(bfloat* %0, bfloat* %1) {
  %3 = load bfloat, bfloat* %0, align 2
  store bfloat %3, bfloat* %1, align 2
  ret void
}

gives the following results when run through llc -march=x86-64

LLVM ERROR: Cannot select: t8: ch = store<(store (s16) into %ir.1)> t7:1, t7, t4, undef:i64
  t7: bf16,ch = load<(load (s16) from %ir.0)> t0, t2, undef:i64                             t2: i64,ch = CopyFromReg t0, Register:i64 %0                                              t1: i64 = Register %0
    t6: i64 = undef                                                                       t4: i64,ch = CopyFromReg t0, Register:i64 %1
    t3: i64 = Register %1                                                                 t6: i64 = undef
In function: test

I'll admit to not having looked into the possibility of doing this in LLVM in any particular detail, especially due to my rather limited knowledge of what infrastructure is available for backend legalization.

(I looked around and found some old discussions with @rampitec that I incorrectly remembered as implying doing things in the backend would be hard, which is probably why I didn't look into it)

Doing this as an LLVM backend thing may well be the right call, and it's surprising it hasn't already been done, given that the following LLVM IR

I.e. it is a full SW emulation. That is why it is hard to do (although it seems to be done via conversions). Not sure though there are users for that, the reason to use bfloat16 is to have a faster less precise float, and this is a slower less precise float.

I'll quickly note that we are users for this. Fundamentally, we want to generate either GPU kernels or CPU functions that can be used to validate generated code that uses the bfloat-baded MFMA instructions

Since hardware doesn't know about bfloat except for those instructions, other operations that can be meaningfully defined on bfloat need to be emulated. We can (and currently do) this at the MLIR level, but comments above have made the point that how to do such emulation is probably a decision hardware bsckends should be making.

And while we *could* avoid all this by pulling all this emulation into the matrix multiply lowering path far before LLVM gets involved ... that's ugly and confusing, IMO. I strongly prefer a world where higher-level code can say bfloat to mean bfloat and then have the details of how those ops are actually implemented live somewhere in (or, as in this patch, near) LLVM.

As to the speed question, would folks be unhappy with something like -femulate-bfloat as a backend option?

And at that point should the discussion be taken to the LLVM Discourse?

To re-raise things, what's the right venue for discussing bf16 emulation in LLVM?

To re-raise things, what's the right venue for discussing bf16 emulation in LLVM?

I landed fb34d531af953119593be74753b89baf99fbc194 today which does it for x86, it should
be portable to other target fairly easily. It works by promoting all bf16 arithmetic to f32.
Converting from bf16 to f32 is expanded inline, but I opted for a libcall for the other way
because proper handling of rounding, NaNs and denormals is quite tricky.

So all a user has to do now is provide a version of __truncsfbf2 in the runtime environment
and it should just work. Targets can also directly lower the BF16_TO_FP node if there's a
better way than a libcall.