This patch improves the lowering by changing target LLVM intrinsics from
reduce.fmax and reduce.fmin,
which have different semantic for handling NaN,
to reduce.fmaximum and reduce.fminimum ones.
Fixes #63969
Depends on D155869
Differential D155877
[mlir][vector] Improve lowering to LLVM for `minf`, `maxf` reductions unterumarmung on Jul 20 2023, 12:17 PM. Authored by
Details This patch improves the lowering by changing target LLVM intrinsics from Fixes #63969 Depends on D155869
Diff Detail
Event TimelineComment Actions Thanks for looking into this!
Comment Actions Thank you for providing such a clear explanation! Now I understand that the issue goes beyond just lowering to LLVM IR itself. I have a few questions about the points you mentioned:
Comment Actions I had quickly tested x86 and AArch64 but let's run a more extensive testing also including RISC-V: https://github.com/openxla/iree/pull/14472
MLIR allows breaking changes. In this case they would be well justified. We usually send a PSA to Discourse in advance, letting the community know... Comment Actions This change causes downstream breakages. On CUDA backends this is generating an instruction that is valid only on SM80 and above (and therefor causes a crash on anything below SM80). Posting this here as an FYI. Looking at the patch this is probably just exposing an issue in the NVPTX backend rather than anything here being the root cause. Comment Actions It could be that the new intrinsics are not supported/implemented in the NVPTX backend for some cases? I think we can revert it if that is the case and something is filed against the NVPTX backend. Otherwise, this will become a blocker for a much larger effort: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671 Comment Actions I believe that the issue lies within the NVPTX backend, and it should be addressed as a bug. If this change is causing significant issues and there isn't a straightforward solution to the bug, we can definitely consider reverting it. Additionally, I'd like to request @mravishankar to review the RFC mentioned, particularly from the NVPTX backend standpoint. Your insights would be greatly appreciated. Comment Actions I filed this issue on NVPTX backend https://github.com/llvm/llvm-project/issues/64606 . Ill take a look at the RFC, but my knowledge of NVPTX is pretty dated at this point. This comes down to whether these instructions are supported or not. It looks like it is only supported on SM 80 and above. If I dont hear back, then instead of reverting the change, maybe refactor it so that the use of the llvm.intr.minimum can be controlled from downstream users. Comment Actions @unterumarmung and @dcaballe based on the discussion here (https://github.com/llvm/llvm-project/issues/64606) this patch either needs to be reverted or needs to have a way to opt-in/opt-out to not hit the issue on CUDA backends... Maybe create a separate entry point which populates these patterns, or have a flag that says enforce NaN propagation semantics that can be true by default, but can be set to false on CUDA backends. This is breaking downstream tests, so a fix sooner rather than later would be appreciated. Comment Actions Actually it seems like we have to make this opt-in somehow. It is failing on CUDA on architectures lesser then sm_80 and that doesn't seem easy to fix. I was going to update this bug saying this. Either we need to revert or make this transformation opt-in. Comment Actions If a workaround is required at MLIR level, we should add expansion patterns for these ops to Arith/Transforms/ExpandOps.cpp, where we can a introduce compare + select for the NaN part and another compare + select for the +-0.0 part. However, these ops are first class instructions/intrinsics in LLVM so they should be supported in one way or the other by the backends. |
IIRC, I added this comment quite sometime ago.
Have you tried replacing these two ops with an llvm.maximum/llvm.minimum? Maybe they are more widely supported now.
Otherwise, we would have to propagate the NaNs ourselves here.