Page MenuHomePhabricator

[NVPTX] Implement NVPTXTargetLowering::getSqrtEstimate.
ClosedPublic

Authored by jlebar on Jan 9 2017, 7:30 PM.

Details

Summary

This lets us lower to sqrt.approx and rsqrt.approx under more
circumstances.

  • Now we emit sqrt.approx and rsqrt.approx for calls to @llvm.sqrt.f32, when fast-math is enabled. Previously, we only would emit it for calls to @llvm.nvvm.sqrt.f. (With this patch we no longer emit sqrt.approx for calls to @llvm.nvvm.sqrt.f; we rely on intcombine to simplify llvm.nvvm.sqrt.f into llvm.sqrt.f32.)
  • Now we emit the ftz version of rsqrt.approx when ftz is enabled. Previously, we only emitted rsqrt.approx when ftz was disabled.

Event Timeline

jlebar updated this revision to Diff 83769.Jan 9 2017, 7:30 PM
jlebar retitled this revision from to [NVPTX] Lower to sqrt.approx and rsqrt.approx under more circumstances..
jlebar updated this object.
jlebar added a reviewer: majnemer.
jlebar added subscribers: tra, llvm-commits.

Can you comment on how this relates to other targets? On x86, AArch64, PPC, and for the AMD GPUs, we have implemented the callback functions getSqrtEstimate and getRecipEstimate to handle generating estimates. The callbacks also specify how many refinement iterations are used to provide answers of approximately the correct precision.

This is important because we allow for user control of the number of refinement steps, where approximations are used, etc. (see Clang's test/Driver/mrecip.c - we implement a generalized version of GCC's -mrecip option). Obviously there are some issues here with separating host vs. accelerator options, but we should try to reuse this infrastructure for NVPTX to the extent possible.

P.S. You might also look at implementing combineRepeatedFPDivisors in NVPTX.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
941

I don't understand this comment. I thought that f32 sqrt on NVIDIA GPUs was only approximate for early generations (sm_1x) and was correct for later ones (sm_2x+): https://devtalk.nvidia.com/default/topic/475101/sqrt-precision/

jlebar edited reviewers, added: tra; removed: majnemer.Jan 10 2017, 8:40 AM
jlebar added a subscriber: majnemer.

Can you comment on how this relates to other targets? On x86, AArch64, PPC, and for the AMD GPUs, we have implemented the callback functions getSqrtEstimate and getRecipEstimate to handle generating estimates. The callbacks also specify how many refinement iterations are used to provide answers of approximately the correct precision.

Is your thought that because we provide these, we should emit our own such functions rather than emitting the .approx versions?

This being SASS I am not totally sure what sqrt.approx.f32 is doing, but here's SASS that does

void f(float a, float* b) { *b = sqrt.approx.f32(a); }

https://gist.github.com/anonymous/79a2a90fd22b0fa37fd3e880641bb9b4. It looks like one iteration of Newton's method?

And here's the equivalent code for rsqrt.approx.f32: https://gist.github.com/0d8881a652e039ee3aff566176d9c98b -- it's the same as above just without a call to recip.

An unfortunate effect of the fact that we're using ptxas is that we may not be able to match the performance of {r}sqrt.approx with our own implementation in ptx. My thought is we should turn this on in any case, particularly because different GPUs have different machine instructions, and maybe one will either be slow with our PTX Newton's method implementation, or maybe it will have an explicit {r}sqrt.approx machine instruction. Then if someone wants to characterize the performance difference between this and LLVM's getSqrtEstimate on their particular hw, they can.

Does that sound like a reasonable plan?

At the moment this all moot because of an llvm bug where if you compile with the UnsafeFPMath TargetOption enabled, then call an intrinsic (or any other function) which doesn't have the unsafe-fp-math attr set, the inliner will explicitly set UnsafeFPMath=false on you. This has the effect of pretty much disabling UnsafeFPMath across the board in CUDA programs (and I expect lots of other programs too).

P.S. You might also look at implementing combineRepeatedFPDivisors in NVPTX.

Thanks, I will look into that. There's already div.approx enabled with fast-math, and I'll need to figure out whether this is preferable.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
941

This is inverse-sqrt; the prx rsqrt instruction is always approximate. Per the first para of the comment, we only enable this transformation when both approx sqrt and approx div is enabled. The comment here is saying, maybe we should also enable it when one of the two is enabled.

If this was confusing to you it's probably not well-written, but I am sort of at a loss as to how to improve it. Suggestions more than welcome.

Can you comment on how this relates to other targets? On x86, AArch64, PPC, and for the AMD GPUs, we have implemented the callback functions getSqrtEstimate and getRecipEstimate to handle generating estimates. The callbacks also specify how many refinement iterations are used to provide answers of approximately the correct precision.

Is your thought that because we provide these, we should emit our own such functions rather than emitting the .approx versions?

This being SASS I am not totally sure what sqrt.approx.f32 is doing, but here's SASS that does

void f(float a, float* b) { *b = sqrt.approx.f32(a); }

https://gist.github.com/anonymous/79a2a90fd22b0fa37fd3e880641bb9b4. It looks like one iteration of Newton's method?

And here's the equivalent code for rsqrt.approx.f32: https://gist.github.com/0d8881a652e039ee3aff566176d9c98b -- it's the same as above just without a call to recip.

An unfortunate effect of the fact that we're using ptxas is that we may not be able to match the performance of {r}sqrt.approx with our own implementation in ptx. My thought is we should turn this on in any case, particularly because different GPUs have different machine instructions, and maybe one will either be slow with our PTX Newton's method implementation, or maybe it will have an explicit {r}sqrt.approx machine instruction. Then if someone wants to characterize the performance difference between this and LLVM's getSqrtEstimate on their particular hw, they can.

Does that sound like a reasonable plan?

When you say, "turn this on", what exactly is "this"? Do you mean the PTX builtin-approximation, or do you mean using our generic CodeGen with one newton iteration, or do you mean using the PTX builtin with our own CodeGen but subtracting the one newton iteration from the count, or something else?

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
938

The way we normally seem to handle this kind of situation, which I think it probably better in this case too, is to transform the target-specific intrinsic into generic IR in InstCombine. We might want to move all of that logic into the targets at some point via TTI, but that's another mater.

941

I missed that this was specifically talking about the inverse. The comment starts by talking about " @llvm.sqrt.f32 and @llvm.nvvm.sqrt.f" which look like regular sqrts. Saying explicitly that, unlike the sqrt, which is or may be exact, the rsqrt is always approximate would help.

When you say, "turn this on", what exactly is "this"? Do you mean the PTX builtin-approximation, or do you mean using our generic CodeGen with one newton iteration, or do you mean using the PTX builtin with our own CodeGen but subtracting the one newton iteration from the count, or something else?

I would like to enable the functionality that this patch is (trying) to enable, namely emitting ptx-builtin-approximate versions of these functions when compiling with fastmath.

The way we normally seem to handle this kind of situation, which I think it probably better in this case too, is to transform the target-specific intrinsic into generic IR in InstCombine.

Thanks for the pointer, I wasn't aware of that, but I agree it's nicer. I'm spinning a patch to use AutoUpgrade to get rid of some nvvm intrinsics entirely, and to use InstCombine to transform other nvvm intrinsics that we can't unconditionally remove into llvm intrinsics, where possible.

But specifically for sqrt (AFAIK) I need to rely on some behavior that's not quite kosher according to the langref. Specifically, I need to transform llvm.nvvm.sqrt.f into

s = llvm.sqrt(arg)
select(arg >= -0.0, s, NaN)

The langref says that llvm.sqrt is undefined behavior if you call it with a negative value, but @arsenm, @mehdi_amini, and I think it should say that it *returns undef* if passed a negative value. arsenm looked at the blame and the language has been unchanged since llvm.sqrt was added. His and @mehdi_amini's GPU backends already assume the "returns undef" semantics, and I have code in flight for XLA to do the same. One of them also did a quick audit and concluded none of our current optimizations use the "undefined behavior" behavior.

Mehdi pointed me at a larger patch to fix up this and other math intrinsics which I now can't find (sorry), but maybe we can make this smaller change in parallel?

WDYT?

Can you comment on how this relates to other targets? On x86, AArch64, PPC, and for the AMD GPUs, we have implemented the callback functions getSqrtEstimate and getRecipEstimate to handle generating estimates. The callbacks also specify how many refinement iterations are used to provide answers of approximately the correct precision.

Is your thought that because we provide these, we should emit our own such functions rather than emitting the .approx versions?

This being SASS I am not totally sure what sqrt.approx.f32 is doing, but here's SASS that does

void f(float a, float* b) { *b = sqrt.approx.f32(a); }

https://gist.github.com/anonymous/79a2a90fd22b0fa37fd3e880641bb9b4. It looks like one iteration of Newton's method?

I don't see any Newton-Raphson there, but I'm not used to read SASS. It seems to me that:

  1. Check if the input is a denormal to anticipate underflow: FSETP.LT.AND P0, PT, |R0|, 1.175494350822287508e-38, PT;
  2. Range-up: FMUL R0, R0, 16777216;
  3. If it was a denormal, use the the upscaled value, otherwise use the original: SEL R0, R0, c[0x0][0x140], P0;

(I have no idea what the load to R3 is doing here...)

  1. Reciprocal square-root (may go to infinity or NaN with denormal depending on the implementation, hence the scaling above?). MUFU.RSQ R0, R0;
  2. Reciprocal, to get the sqrt result : MUFU.RCP R0, R0;
  3. Scale-down (only if the input was a denormal): @P0 FMUL R0, R0, 0.000244140625;
  4. store the result.

So without any Newton iteration, I don't think this can't be IEEE compliant, which is expected considering that it is the approx version. This is fine with fast-math though (I'm fairly sure some graphic shader compiler wouldn't even care about handling the denormals...).

What is SASS looking for the non approx version? (I believe llvm.sqrt should do the same as the non-approx without fast-math flag).

An unfortunate effect of the fact that we're using ptxas is that we may not be able to match the performance of {r}sqrt.approx with our own implementation in ptx.

Can you clarify what you mean?

Can you comment on how this relates to other targets? On x86, AArch64, PPC, and for the AMD GPUs, we have implemented the callback functions getSqrtEstimate and getRecipEstimate to handle generating estimates. The callbacks also specify how many refinement iterations are used to provide answers of approximately the correct precision.

Is your thought that because we provide these, we should emit our own such functions rather than emitting the .approx versions?

This being SASS I am not totally sure what sqrt.approx.f32 is doing, but here's SASS that does

Ah, thanks for figuring it out! Of course MUFU.RSQ is a reciprocal sqrt instruction. Maybe we're loading something into R3 as preparation for the MUFU.RSQ, otherwise that just seems like dumb codegen.

So without any Newton iteration, I don't think this can't be IEEE compliant, which is expected considering that it is the approx version. This is fine with fast-math though (I'm fairly sure some graphic shader compiler wouldn't even care about handling the denormals...).

Right.

What is SASS looking for the non approx version? (I believe llvm.sqrt should do the same as the non-approx without fast-math flag).

https://gist.github.com/b3fa71a72a02785cc47be606556d6d4a

An unfortunate effect of the fact that we're using ptxas is that we may not be able to match the performance of {r}sqrt.approx with our own implementation in ptx.

Can you clarify what you mean?

I meant that I wasn't sure whether we could generate code which matched the performance+accuracy of PTX sqrt.approx without using that instruction (e.g. by LLVM emitting a Newton's method hunk). In particular, now that you parsed the asm -- we see here that we're calling a special HW instruction for the rsqrt, and I have no way to cause this instruction to be emitted except by writing the PTX sqrt.approx.

Unless the suggestion is to take the approx sqrt generated by PTX sqrt.approx and then refine it using Newton's method? That's an interesting idea but out of scope for this patch, I think. I'd rather wait to do that until someone wants it.

What is SASS looking for the non approx version? (I believe llvm.sqrt should do the same as the non-approx without fast-math flag).

https://gist.github.com/b3fa71a72a02785cc47be606556d6d4a

Wow...

An unfortunate effect of the fact that we're using ptxas is that we may not be able to match the performance of {r}sqrt.approx with our own implementation in ptx.

Can you clarify what you mean?

I meant that I wasn't sure whether we could generate code which matched the performance+accuracy of PTX sqrt.approx without using that instruction (e.g. by LLVM emitting a Newton's method hunk). In particular, now that you parsed the asm -- we see here that we're calling a special HW instruction for the rsqrt, and I have no way to cause this instruction to be emitted except by writing the PTX sqrt.approx.

I bet you can get the rsqrt with the PTX rsqrt.approx.f32, but that's not really the point here.

Unless the suggestion is to take the approx sqrt generated by PTX sqrt.approx and then refine it using Newton's method? That's an interesting idea but out of scope for this patch, I think. I'd rather wait to do that until someone wants it.

Technically I don't think it is correct for your patch to lower llvm.sqrt (with the FMF) to PTX sqrt.approx, because "The maximum absolute error for sqrt.f32 is TBD."
llvm.sqrt should get roughly the same result as the last link you gave.
I don't how much it matter for users of the PTX backend though.

Technically I don't think it is correct for your patch to lower llvm.sqrt (with the FMF) to PTX sqrt.approx, because "The maximum absolute error for sqrt.f32 is TBD."

The patch only does this transformation with fastmath enabled (or if you pass a special flag to llvm that specifically asks for this transformation):

defm FSQRT_f32_approx_ftz :
  FSQRT_f32<"approx.ftz.", [doF32FTZ, do_SQRTF32_APPROX]>;
 defm FSQRT_f32_approx : FSQRT_f32<"approx.", [do_SQRTF32_APPROX]>;
 defm FSQRT_f32_ftz : FSQRT_f32<"rn.ftz.", [doF32FTZ]>;
 defm FSQRT_f32_noftz : FSQRT_f32<"rn.", []>;

Surely fastmath implies we should be lowering to the approx instruction, no?

When fastmath is disabled, we lower to PTX sqrt.rn.f32, which is spec'ed to be exact.

I agree the commit message should be clearer. :) I think I was trying to say, now we will *under some circumstances* emit sqrt.approx.f32 for llvm.sqrt.

Technically I don't think it is correct for your patch to lower llvm.sqrt (with the FMF) to PTX sqrt.approx, because "The maximum absolute error for sqrt.f32 is TBD."

The patch only does this transformation with fastmath enabled (or if you pass a special flag to llvm that specifically asks for this transformation):

Ah, good :)

I didn't read the patch, just the discussion, and misunderstood.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
932

I have to say that I find amazing that someone if finally documenting the backend table gen, this has always driven me crazy :)

The patch only does this transformation with fastmath enabled (or if you pass a special flag to llvm that specifically asks for this transformation):

Hal, maybe this was your point of confusion as well, and if so, I'm sorry about that. Really bad commit message on my part.

jlebar updated this object.Jan 12 2017, 8:56 PM
tra accepted this revision.Jan 13 2017, 1:16 PM
tra edited edge metadata.

One small nit, looks good otherwise.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
767–775

Nit. I'd use period as prefix everywhere which is closer to the way types and instruction options are used (almost) everywhere else.
I.e. :

!strconcat("sqrt", Options, ".f32 \t$dst, $a;"),
...
defm FSQRT_f32_ftz : FSQRT_f32<".rn.ftz", [doF32FTZ]>;
This revision is now accepted and ready to land.Jan 13 2017, 1:16 PM

I'm spinning a patch to use AutoUpgrade to get rid of some nvvm intrinsics entirely, and to use InstCombine to transform other nvvm intrinsics that we can't unconditionally remove into llvm intrinsics, where possible.

This is coming along, but I think it makes sense to leave the llvm.nvvm.sqrt.f special case in this patch.

The "upgrade patch" will upgrade llvm.nvvm.sqrt.f to llvm.sqrt.f32. Therefore it needs this patch, which adds additional patterns around llvm.sqrt.f32. But we can only get rid of the pattern matching 1/llvm.nvvm.sqrt.f after the upgrade patch lands.

jlebar updated this revision to Diff 84370.Jan 13 2017, 1:56 PM
jlebar edited edge metadata.

Update per tra's comments, and try to make the heading on the rsqrt section clearer.

jlebar marked an inline comment as done.Jan 13 2017, 2:55 PM

...

But specifically for sqrt (AFAIK) I need to rely on some behavior that's not quite kosher according to the langref. Specifically, I need to transform llvm.nvvm.sqrt.f into

s = llvm.sqrt(arg)
select(arg >= -0.0, s, NaN)

The langref says that llvm.sqrt is undefined behavior if you call it with a negative value, but @arsenm, @mehdi_amini, and I think it should say that it *returns undef* if passed a negative value. arsenm looked at the blame and the language has been unchanged since llvm.sqrt was added. His and @mehdi_amini's GPU backends already assume the "returns undef" semantics, and I have code in flight for XLA to do the same. One of them also did a quick audit and concluded none of our current optimizations use the "undefined behavior" behavior.

Mehdi pointed me at a larger patch to fix up this and other math intrinsics which I now can't find (sorry), but maybe we can make this smaller change in parallel?

WDYT?

Sounds reasonable to me.

The patch only does this transformation with fastmath enabled (or if you pass a special flag to llvm that specifically asks for this transformation):

Hal, maybe this was your point of confusion as well, and if so, I'm sorry about that. Really bad commit message on my part.

No, I understood that -- I only really care about sqrt and friends when fast-math is enabled. Otherwise the performance is horrible essentially everywhere. ;)

No, I understood that -- I only really care about sqrt and friends when fast-math is enabled. Otherwise the performance is horrible essentially everywhere. ;)

Heh, well, I expect this to be a strict improvement with fast-math enabled. :)

I'll check this in when I can. Thanks for your reviews and attention, everyone.

Alright, let me try to summarize what we want to do here...

  1. Update the LangRef to say that llvm.sqrt returns undef when provided with an input < -0 (instead of having undefined behavior).
  2. In InstCombine, transform llvm.nvvm.sqrt.f into

    s = llvm.sqrt(arg); select(arg >= -0.0, s, NaN);

    unless we have 'fast' set on the intrinsic, in which case we don't need the select (and we set 'fast' on the llvm.sqrt).
  3. When we get to the backend: a. We pattern match the select + llvm.sqrt (no 'fast') into the regular sqrt call b. We otherwise transform llvm.sqrt (no 'fast') into the regular sqrt call c. We transform select + llvm.sqrt (with 'fast') or just the llvm.sqrt (with 'fast') into some approximation
    1. For the approximation, we can use: a. The approximation with PTX provides (sqrt.approx.f32) b. The approximation with PTX provides, potentially with some extra newton iterations c. Form our own approximation using rsqrt.approx.f32 (the generic code in DAGCombine should do this automatically using x*rsqrt(x)).

We might not want the cost of the extra denormal handling in the PTX-provided approximation, but if we lower the nvvm intrinsic in the same way as the regular sqrt intrinsic, then are we forced to do this (it has fixed semantics)? If not, and given that you want to transform 1/nvvm.sqrt -> rsqrt, I assume you're comfortable with an answer of no, then we should consider other options. We should provide the option of generating newton iteration fixups, as other targets do, as a user-configurable option (using our target-independent infrastructure for this purpose). The natural way of doing that is just to let the code in DAGCombine form the sqrt approximation from x*rsqrt(x), which is probably faster than the PTX "builtin" approximation anyway (because it lacks the denormal fixups). I think it is definitely worth an experiment (i.e. is x*rsqrt faster than sqrt?).

Regardless, we should find a way to hook this up to our target-independent infrastructure which allows the user to select where to generate approximations and how many newton iterations to use. To do this, we should implement the associated TLI callbacks instead of directly matching the patterns in the TableGen files.

We might not want the cost of the extra denormal handling in the PTX-provided approximation, but if we lower the nvvm intrinsic in the same way as the regular sqrt intrinsic, then are we forced to do this (it has fixed semantics)?

The SASS for sqrt.approx.f32 and rsqrt.approx.f32 are identical except for the presence of an additional reciprocal call in sqrt.approx.f32. They both have special cases for denormals.

https://gist.github.com/anonymous/79a2a90fd22b0fa37fd3e880641bb9b4
https://gist.github.com/anonymous/0d8881a652e039ee3aff566176d9c98b

If you want a version without the denormal handling, you want sqrt.approx.f32.ftz and rsqrt.approx.f32.ftz. These look like:

https://gist.github.com/031ea494458f44e2d1ef4e16eec51699
https://gist.github.com/b352a20d18792d05395f76ab3aad742f

(Unlike the ones above, these are not identical except for the presence/absence of a reciprocal call. But they're nonetheless very close.)

With this patch we lower the "fast" version of llvm.sqrt.f32 to sqrt.approx.f32 or sqrt.approx.f32.ftz as appropriate based on the module's configuration, and same for 1.0/llvm.sqrt.f32 going to rsqrt.approx.f32{.ftz}.

This patch lets us treat llvm.nvvm.sqrt.f the same as llvm.sqrt.f32, but I'm happy to get rid of this once we auto-upgrade the nvvm intrinsic to the generic intrinsic (which we can't do until this patch lands because that would regress performance).

Right now ftz is specified in one of three ways (see NVVMReflect.cpp), but I have a patch out to reduce it to just one way. That will then make it easier for us to switch it to one canonical way, instead of the janky "__CUDA_FTZ" metadata thing we have now.

[...] The natural way of doing that is just to let the code in DAGCombine form the sqrt approximation from x*rsqrt(x), which is probably faster than the PTX "builtin" approximation anyway (because it lacks the denormal fixups). I think it is definitely worth an experiment (i.e. is x*rsqrt faster than sqrt?).

The denormal behaviors of the reciprocal and non-reciprocal versions of sqrt.approx.f32{.ftz} are all the same as far as I can tell.

I'm happy to entertain the addition of more exotic (r)sqrt approximations controlled by flags or whatever, but this is likely not something my customers care deeply about.

We might not want the cost of the extra denormal handling in the PTX-provided approximation, but if we lower the nvvm intrinsic in the same way as the regular sqrt intrinsic, then are we forced to do this (it has fixed semantics)?

The SASS for sqrt.approx.f32 and rsqrt.approx.f32 are identical except for the presence of an additional reciprocal call in sqrt.approx.f32. They both have special cases for denormals.

https://gist.github.com/anonymous/79a2a90fd22b0fa37fd3e880641bb9b4
https://gist.github.com/anonymous/0d8881a652e039ee3aff566176d9c98b

If you want a version without the denormal handling, you want sqrt.approx.f32.ftz and rsqrt.approx.f32.ftz. These look like:

https://gist.github.com/031ea494458f44e2d1ef4e16eec51699
https://gist.github.com/b352a20d18792d05395f76ab3aad742f

(Unlike the ones above, these are not identical except for the presence/absence of a reciprocal call. But they're nonetheless very close.)

With this patch we lower the "fast" version of llvm.sqrt.f32 to sqrt.approx.f32 or sqrt.approx.f32.ftz as appropriate based on the module's configuration, and same for 1.0/llvm.sqrt.f32 going to rsqrt.approx.f32{.ftz}.

This patch lets us treat llvm.nvvm.sqrt.f the same as llvm.sqrt.f32, but I'm happy to get rid of this once we auto-upgrade the nvvm intrinsic to the generic intrinsic (which we can't do until this patch lands because that would regress performance).

Right now ftz is specified in one of three ways (see NVVMReflect.cpp), but I have a patch out to reduce it to just one way. That will then make it easier for us to switch it to one canonical way, instead of the janky "__CUDA_FTZ" metadata thing we have now.

[...] The natural way of doing that is just to let the code in DAGCombine form the sqrt approximation from x*rsqrt(x), which is probably faster than the PTX "builtin" approximation anyway (because it lacks the denormal fixups). I think it is definitely worth an experiment (i.e. is x*rsqrt faster than sqrt?).

The denormal behaviors of the reciprocal and non-reciprocal versions of sqrt.approx.f32{.ftz} are all the same as far as I can tell.

I'm happy to entertain the addition of more exotic (r)sqrt approximations controlled by flags or whatever, but this is likely not something my customers care deeply about.

Okay, thanks for checking. If there's no performance benefit to generating x*rsqrt.approx(x) over sqrt.approx(x), then there's no need to consider it.

However, many of us have customers who do tend to care about being able to adjust the accuracy of the approximations used for sqrt (and reciprocals for division). I certainly do. If my Googling has been sufficiently informative, these approximations generally have a maximum numerical error ~2-3 ULP. This will be good enough to many purposes, but prohibitively large for some others. Because this is a generic property of hardware-provided approximations, our target-independent codegen has an infrastructure for generating refinement code of user-configurable expense. Using this should be considered the best practice for backend construction. Specifically, unless the provided approximations have a maximum error <= 1 ULP, then refinement might be desirable, and the backend should never directly match the sqrt intrinsic (etc.) in TableGen. Instead, the backend should always implement the associated TLI callbacks so that the user can independently control which approximations are used, and how aggressively (if at all) these are refined. For NVPTX, I'd expect the default number of refinement iterations will be set to zero.

I realize that I'm asking for a little extra work here, but this is the right way to implement this support.

That having been said, you still need part of this patch, AFAIKT, because all of this infrastructure to which I'm referring is in DAGCombine, and we can't depend on DAGCombine to lower otherwise-legal instructions. So you do need to match sqrt, regardless of anything else. You should not, however, need to add the 1/sqrt -> rsqrt patterns, but once the TLI callbacks are implemented, DAGCombine should do that for you. Thanks again!

jlebar updated this revision to Diff 84632.Jan 16 2017, 11:08 PM

Update to use getSqrtEstimate, per Hal's suggestions.

jlebar retitled this revision from [NVPTX] Lower to sqrt.approx and rsqrt.approx under more circumstances. to [NVPTX] Implement NVPTXTargetLowering::getSqrtEstimate..Jan 16 2017, 11:09 PM
jlebar edited the summary of this revision. (Show Details)
jlebar edited reviewers, added: hfinkel; removed: tra.
This revision now requires review to proceed.Jan 16 2017, 11:09 PM

Thank you for the explanation, Hal. I'd misunderstood what you were asking for -- your request was, in the end, quite reasonable, and indeed was better than what I had.

Please let me know what you think here.

This relies on D28794 to upgrade llvm.nvvm.sqrt.f to llvm.sqrt.f32, which then allows this "emit approx sqrt" function to fire.

This looks good, although please add a test case for sqrtf() where we've specific a non-zero number of extra steps. I ask because, unfortunately, I think it won't work correctly. The code in DAGCombiner::buildSqrtEstimateImpl, when some number of extra iterations are requested, calls buildSqrtNROneConst (or buildSqrtNRTwoConst), and these seem to assume that the incoming base estimate is always a reciprocal estimate, and so they always need to multiply by x so that they compute sqrt(x) = x*rsqrt(x).

To refine an estimate for sqrt, instead of using the iteration X_{i+1} = X_i (1.5 - A X_i^2 / 2), you'd use X_{i+1} = 0.5(X_i + A/X_i). The problem here is that you need to divide to make this work, and that defeats the purpose of having a fast approximation. As a result, the apparently "hacky" solution here is actually the best: When extra iterations are requested, the function should return the rsqrt estimate instead of the sqrt estimate. I think we just need to clean up the interface a bit to make all of this really make sense. This seems to indicate that, if I'm reading the code correctly, we have a bug here for other backends for sqrt() when no extra iterations are requested (it will directly return the rsqrt estimate instead of the proper sqrt estimate). The fact that we might return a sqrt estimate directly, when no iterations are requested, seems like a special case. Maybe the best thing is to always return rsqrt here, but when no iterations are requested for sqrt, just return SDValue() and match the sqrt later in TableGen?

This seems to indicate that, if I'm reading the code correctly, we have a bug here for other backends for sqrt() when no extra iterations are requested (it will directly return the rsqrt estimate instead of the proper sqrt estimate)

Wow, what an amazingly broken api.

How's this for a plan: Let me implement this function correctly for NVPTX and check in a test. Then I will at the very least improve the comment on the virtual function and try to contact the backend owners. I would normally volunteer to refactor the API, but it seems that nobody is testing it, otherwise they would have realized that it is broken, so I'm pretty hesitant to go after it myself.

jlebar updated this revision to Diff 86346.Jan 30 2017, 2:23 PM

Implement getSqrtEstimate API correctly.

Also lower approx sqrt(x) as x * rsqrt(x) when x is a double, since that was easy.

All right, I think this should work. I verified by hand that the approximation routines do something sane --namely, that when the error is zero, we get the intended exact result. I don't have the numerical chops to say whether it's actually a refinement or not... :)

This seems to indicate that, if I'm reading the code correctly, we have a bug here for other backends for sqrt() when no extra iterations are requested (it will directly return the rsqrt estimate instead of the proper sqrt estimate)

Wow, what an amazingly broken api.

How's this for a plan: Let me implement this function correctly for NVPTX and check in a test. Then I will at the very least improve the comment on the virtual function and try to contact the backend owners. I would normally volunteer to refactor the API, but it seems that nobody is testing it, otherwise they would have realized that it is broken, so I'm pretty hesitant to go after it myself.

The most natural way I can think of to fix this API, is the make it always expect a rsqrt estimate (including in the ExtraIterations == 0 case). In that case, if a target has a direct sqrt estimate to use in the ExtraIterations == 0 case, it would return SDValue() and then match the sqrt node in TableGen (or in C++). That fix is incompatible with the code here, so I'd like your opinion on this.

jlebar added a comment.EditedJan 30 2017, 6:59 PM

The most natural way I can think of to fix this API, is the make it always expect a rsqrt estimate (including in the ExtraIterations == 0 case). In that case, if a target has a direct sqrt estimate to use in the ExtraIterations == 0 case, it would return SDValue() and then match the sqrt node in TableGen (or in C++). That fix is incompatible with the code here, so I'd like your opinion on this.

I wonder if part of the problem is that the API is just too flexible. In particular, what if we didn't let you modify NumIters within this function?

We would have three or four TLI functions:

The problem right now is that because the function may modify NumIters, everything has to be stuffed into one big, confusing function.

We don't strictly need "get me an approx sqrt", but defining it this way lets us have a reasonable default, which seems nice.

(I would still rather check in this patch and refactor separately, though.)

hfinkel accepted this revision.Jan 30 2017, 7:54 PM

The most natural way I can think of to fix this API, is the make it always expect a rsqrt estimate (including in the ExtraIterations == 0 case). In that case, if a target has a direct sqrt estimate to use in the ExtraIterations == 0 case, it would return SDValue() and then match the sqrt node in TableGen (or in C++). That fix is incompatible with the code here, so I'd like your opinion on this.

I wonder if part of the problem is that the API is just too flexible. In particular, what if we didn't let you modify NumIters within this function?

We would have three or four TLI functions:

  • "get me the default number of Newton's method iters for an (r)sqrt approximation" (could be one method with a bool param or two; I lean towards two because http://jlebar.com/2011/12/16/Boolean_parameters_to_API_functions_considered_harmful..html)
  • "get me an approximate rsqrt"
  • "get me an approximate sqrt", where the default returns x * approx_rsqrt(x).

    The problem right now is that because the function may modify NumIters, everything has to be stuffed into one big, confusing function.

    We don't strictly need "get me an approx sqrt", but defining it this way lets us have a reasonable default, which seems nice.

    (I would still rather check in this patch and refactor separately, though.)

I think this makes sense. Let's commit this and then refactor/fix things. Thanks!

This revision is now accepted and ready to land.Jan 30 2017, 7:54 PM
This revision was automatically updated to reflect the committed changes.

All right, thank you, Hal!

Looking through this code, we don't even parse the mrecip string properly. For example, all,!sqrt should enable reciprocal approximations for everything other than sqrt, but as I read TargetLoweringBase::getOpEnabled, it enables it for nothing. :(

escha added a subscriber: escha.Jan 30 2017, 11:46 PM

afaik, x * rsqrt(x) is wrong when x is zero (it gives NaN instead of 0). we use x * rsqrt(x) for our expansion, but we have to use an extra select_cc to handle the zero special case.

afaik, x * rsqrt(x) is wrong when x is zero (it gives NaN instead of 0). we use x * rsqrt(x) for our expansion, but we have to use an extra select_cc to handle the zero special case.

Gosh darnit, I even found this bug in something else a few weeks ago, and I completely forgot here. I think I was distracted by a signed-zeroes nonissue (because we're already fast-math).

Thank you, will fix in the morning.

escha added a comment.Jan 31 2017, 6:49 AM

Don't be too embarrassed; when we switched internally from an rcp(rsqrt(x)) expansion to x * rsqrt(x), we *also* completely missed this, and literally only found the bug when it showed up as internal test failures. The new expansion was even brought up and talked about at a team meeting and signed off by multiple people, and nobody thought to consider the number zero, myself included.

Don't be too embarrassed; when we switched internally from an rcp(rsqrt(x)) expansion to x * rsqrt(x), we *also* completely missed this.

:) I switched to rcp(rsqrt(x)) in r293713. I tested and it's significantly faster than x*rsqrt(x), to say nothing of adding an extra select around that.

Thanks again for catching this.

escha added a comment.Jan 31 2017, 3:31 PM

That really surprises me that it's faster! I would expect SFU functions like RCP/RSQRT to dwarf the cost of a multiply, especially for double.

Also, do be careful that rcp(rsqrt(x)) and x * rsqrt(x) have different precisions under some implementations (because fmul is 0.5 ULP, while rcp/rsqrt may be as low as 2.5 ULP each).

jlebar added a comment.EditedJan 31 2017, 3:43 PM

That really surprises me that it's faster! I would expect SFU functions like RCP/RSQRT to dwarf the cost of a multiply, especially for double.

Me too. :)

Also, do be careful that rcp(rsqrt(x)) and x * rsqrt(x) have different precisions under some implementations (because fmul is 0.5 ULP, while rcp/rsqrt may be as low as 2.5 ULP each).

Yeah, I'm banking on the "you asked for it" aspect of fast-math. In particular, the only approximate f64 rcp instruction is flush-denormals-to-zero, so we call that even if ftz is entirely disabled.

The performance difference is the same with and without ftz on the mul:

precise sqrt - 73us
x*rsqrt.approx(x) - 64us
recip.approx(rsqrt.approx(x)) - 48us
rsqrt.approx(x) - 48us

Maybe it's an unfair microbenchmark, because I do nothing other than the sqrt and a store. https://gist.github.com/0ac6f0b0f994339838f5452f96e77cff