Page MenuHomePhabricator

[RISCV] Optimize 2x SELECT for floating-point types
ClosedPublic

Authored by liaolucy on Jun 15 2022, 9:10 AM.

Details

Summary

Including the following opcode:
Select_FPR16_Using_CC_GPR
Select_FPR32_Using_CC_GPR
Select_FPR64_Using_CC_GPR

Diff Detail

Event Timeline

liaolucy created this revision.Jun 15 2022, 9:10 AM
Herald added a project: Restricted Project. · View Herald TranscriptJun 15 2022, 9:10 AM
liaolucy requested review of this revision.Jun 15 2022, 9:10 AM

The asm of testcase before this patch:

        auipc   a0, %pcrel_hi(.LCPI0_0)
        addi    a0, a0, %pcrel_lo(.LBB0_5)
        flw     ft0, 0(a0)
        fmv.w.x ft1, zero
        flt.s   a1, fa0, ft1
        flt.s   a0, ft0, fa0
        beqz    a1, .LBB0_3
# %bb.1:                                # %entry
        beqz    a0, .LBB0_4
.LBB0_2:                                # %entry
        fmv.s   fa0, ft0
        ret
.LBB0_3:                                # %entry
        fmv.s   ft1, fa0
        bnez    a0, .LBB0_2
.LBB0_4:                                # %entry
        fmv.s   ft0, ft1
        fmv.s   fa0, ft0
        ret

after this patch :

# %bb.0:                                # %entry
        fmv.w.x ft0, zero
        flt.s   a0, fa0, ft0
        bnez    a0, .LBB0_3
# %bb.1:                                # %entry
.LBB0_4:                                # %entry
                                        # Label of block must be emitted
        auipc   a0, %pcrel_hi(.LCPI0_0)
        addi    a0, a0, %pcrel_lo(.LBB0_4)
        flw     ft0, 0(a0)
        flt.s   a0, ft0, fa0
        bnez    a0, .LBB0_3
# %bb.2:                                # %entry
        fmv.s   ft0, fa0
.LBB0_3:                                # %entry
        fmv.s   fa0, ft0
        ret
craig.topper added inline comments.Jun 18 2022, 12:02 PM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9698

Variable names should start with a capital letter

9705

Why do we need an explicit PseudoBR? Can't we let it fallthrough?

9775

Should this use next_nodbg to increment?

9777

Shouldn't we be checking NextMIIt != BB->end() before checking the opcode of NextMIIt?

llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
537

Drop dso_local and local_unnamed_addr and #0

liaolucy updated this revision to Diff 438263.Jun 19 2022, 11:59 PM

address craig.topper's comments and thanks

This revision is now accepted and ready to land.Jun 27 2022, 9:35 AM
This revision was landed with ongoing or failed builds.Jun 27 2022, 9:02 PM
This revision was automatically updated to reflect the committed changes.

Hi!

This commit is causing a correctness regression in one of the ML models in IREE. Please, find attach the LLVM IR after Codegen Prepare. Hopefully, you can easily identify where this change is triggering.
To reproduce: llc bug.ll -mcpu=generic-rv64 -mattr=+m,+a,+f,+d,+c -target-abi=lp64d. Please, note that no +v is provided. Could we please consider a revert while this is being investigated?

Thanks,
Diego

Hi!

This commit is causing a correctness regression in one of the ML models in IREE. Please, find attach the LLVM IR after Codegen Prepare. Hopefully, you can easily identify where this change is triggering.
To reproduce: llc bug.ll -mcpu=generic-rv64 -mattr=+m,+a,+f,+d,+c -target-abi=lp64d. Please, note that no +v is provided. Could we please consider a revert while this is being investigated?

Thanks,
Diego

@dcaballe Can you provide a file that can be compiled and run? Help me to debug.

I was able to reduce the test case to:

target datalayout = "e-m:e-p:64:64-i64:64-i128:128-n64-S128"                                                                                                                                
target triple = "riscv64-unknown-unknown-eabi-elf"                                                                                                                                          
                                                                                                                                                                                            
; Function Attrs: nofree nosync nounwind                                                                                                                                                    
define internal i32 @main_dispatch_72(i8* %i5, <2 x float> %i44, <2 x float> %i61, <2 x i32> %i65) {                                                                                        
bbi:                                                                                                                                                                                        
  %i45 = fcmp uno <2 x float> %i44, zeroinitializer                                                                                                                                         
  %i62 = fcmp oeq <2 x float> %i44, <float 0xFFF0000000000000, float 0xFFF0000000000000>                                                                                                    
  %i63 = fcmp oeq <2 x float> %i44, <float 0x7FF0000000000000, float 0x7FF0000000000000>                                                                                                    
  %i64 = fcmp ogt <2 x float> %i44, zeroinitializer                                                                                                                                         
  %i66 = icmp ult <2 x i32> %i65, <i32 255, i32 255>                                                                                                                                        
  %i67 = select <2 x i1> %i64, <2 x float> <float 0x7FF0000000000000, float 0x7FF0000000000000>, <2 x float> <float 0x3810000000000000, float 0x3810000000000000>                           
  %i68 = select <2 x i1> %i66, <2 x float> %i61, <2 x float> %i67                                                                                                                           
  %i69 = select <2 x i1> %i63, <2 x float> <float 0x7FF0000000000000, float 0x7FF0000000000000>, <2 x float> %i68                                                                           
  %i70 = select <2 x i1> %i62, <2 x float> zeroinitializer, <2 x float> %i69                                                                                                                
  %i71 = select <2 x i1> %i45, <2 x float> %i44, <2 x float> %i70                                                                                                                           
  %i72 = bitcast i8* %i5 to <2 x float>*                                                                                                                                                    
  store <2 x float> %i71, <2 x float>* %i72, align 64                                                                                                                                       
  ret i32 0                                                                                                                                                                                 
}

Your code should trigger with llc bug.ll -mcpu=generic-rv64 -mattr=+m,+a,+f,+d,+c -target-abi=lp64d.

It looks like it's a case where your transformation is applied multiple times on the output produced by previous instances.

I'll revert this commit while this is being investigated.

Thanks,
Diego

I think the expanded branches are emitted in the wrong order. The condition for the last Select_FPR32_Using_CC_GPR needs to be checked first.

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9775

I think we also need to make sure the condition for the second select doesn't use the output from the first select.

craig.topper reopened this revision.Jul 7 2022, 6:07 PM
This revision is now accepted and ready to land.Jul 7 2022, 6:07 PM
craig.topper requested changes to this revision.Jul 7 2022, 6:07 PM
This revision now requires changes to proceed.Jul 7 2022, 6:07 PM

I think the expanded branches are emitted in the wrong order. The condition for the last Select_FPR32_Using_CC_GPR needs to be checked first.

Thanks, Craig. I dump all select, The middle three pairs of instructions can be optimized

  %25:fpr32 = Select_FPR32_Using_CC_GPR killed %21:gpr, %24:gpr, 1, %16:fpr32, %23:fpr32
  %26:fpr32 = Select_FPR32_Using_CC_GPR killed %20:gpr, %24:gpr, 1, %16:fpr32, %23:fpr32
  %28:fpr32 = Select_FPR32_Using_CC_GPR killed %8:gpr, %27:gpr, 4, %4:fpr32, killed %26:fpr32

optimize1 
 1.1 %29:fpr32 = Select_FPR32_Using_CC_GPR killed %7:gpr, %27:gpr, 4, %3:fpr32, killed %25:fpr3
 2.1 %30:fpr32 = Select_FPR32_Using_CC_GPR killed %18:gpr, %24:gpr, 1, %16:fpr32, killed %29:fpr32
optimize2
 1.2 %31:fpr32 = Select_FPR32_Using_CC_GPR killed %17:gpr, %24:gpr, 1, %16:fpr32, killed %28:fpr32
 2.2 %32:fpr32 = Select_FPR32_Using_CC_GPR killed %14:gpr, %24:gpr, 1, %19:fpr32, killed %31:fpr32
optimize3
 1.3 %33:fpr32 = Select_FPR32_Using_CC_GPR killed %13:gpr, %24:gpr, 1, %19:fpr32, killed %30:fpr32
 2.3 %34:fpr32 = Select_FPR32_Using_CC_GPR killed %10:gpr, %24:gpr, 0, %1:fpr32, killed %33:fpr32

  %35:fpr32 = Select_FPR32_Using_CC_GPR killed %9:gpr, %24:gpr, 0, %2:fpr32, killed %32:fpr32

I think we should check select1-rs2 == select2-rs2, then 1.1 and 2.1 can not be optimized. I will update patch later

I think the expanded branches are emitted in the wrong order. The condition for the last Select_FPR32_Using_CC_GPR needs to be checked first.

Thanks, Craig. I dump all select, The middle three pairs of instructions can be optimized

  %25:fpr32 = Select_FPR32_Using_CC_GPR killed %21:gpr, %24:gpr, 1, %16:fpr32, %23:fpr32
  %26:fpr32 = Select_FPR32_Using_CC_GPR killed %20:gpr, %24:gpr, 1, %16:fpr32, %23:fpr32
  %28:fpr32 = Select_FPR32_Using_CC_GPR killed %8:gpr, %27:gpr, 4, %4:fpr32, killed %26:fpr32

optimize1 
 1.1 %29:fpr32 = Select_FPR32_Using_CC_GPR killed %7:gpr, %27:gpr, 4, %3:fpr32, killed %25:fpr3
 2.1 %30:fpr32 = Select_FPR32_Using_CC_GPR killed %18:gpr, %24:gpr, 1, %16:fpr32, killed %29:fpr32
optimize2
 1.2 %31:fpr32 = Select_FPR32_Using_CC_GPR killed %17:gpr, %24:gpr, 1, %16:fpr32, killed %28:fpr32
 2.2 %32:fpr32 = Select_FPR32_Using_CC_GPR killed %14:gpr, %24:gpr, 1, %19:fpr32, killed %31:fpr32
optimize3
 1.3 %33:fpr32 = Select_FPR32_Using_CC_GPR killed %13:gpr, %24:gpr, 1, %19:fpr32, killed %30:fpr32
 2.3 %34:fpr32 = Select_FPR32_Using_CC_GPR killed %10:gpr, %24:gpr, 0, %1:fpr32, killed %33:fpr32

  %35:fpr32 = Select_FPR32_Using_CC_GPR killed %9:gpr, %24:gpr, 0, %2:fpr32, killed %32:fpr32

I think we should check select1-rs2 == select2-rs2, then 1.1 and 2.1 can not be optimized. I will update patch later

I dont' think checking select1-rs2 == select2-rs2 solves the problem. When the condition of the second select is true, the true value of that select has priority. That means when the select2 condition is true, the select1 condition matter. The patch is trying to prioritize one of the compares to skip the other the other. That means you need to check the select2 condition first to maintain the priority.

liaolucy added inline comments.Jul 8 2022, 12:20 AM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9775

I'm still confused. You mean, should I need to add: Next->getOperand(4).getReg() != MI.getOperand(0).getReg() ?

But, the assembly of the bug.ll file has not changed.

craig.topper added inline comments.Jul 8 2022, 12:43 AM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9775

I was theorizing another possible bug. If either rs1 or rs2 of the second select is the result of the first select. Meaning one of the compare operands is also the false operand. It’s not safe to do the optimization. Is that already protected?

liaolucy added inline comments.Jul 8 2022, 12:53 AM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9775

If either rs1 or rs2 of the second select is the result of the first select.

This condition does not hold.

eg: fpr32 = Select_FPR32_Using_CC_GPR killed %13:gpr, %24:gpr, 1, %19:fpr32, killed %30:fpr32

rs1 and rs2 are gpr, but dst is fpr.

liaolucy added inline comments.Jul 8 2022, 12:55 AM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9775

If either rs1 or rs2 of the second select is the result of the first select.

This condition does not hold.

eg: fpr32 = Select_FPR32_Using_CC_GPR killed %13:gpr, %24:gpr, 1, %19:fpr32, killed %30:fpr32

rs1 and rs2 are gpr, but dst is fpr,

craig.topper added inline comments.Jul 8 2022, 9:08 AM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9775

You're right. I forgot this was FP only.

Happy to run any tentative fix that you may have on the full model to know if it fixes the problem.

liaolucy updated this revision to Diff 443400.Jul 8 2022, 6:06 PM

Try to fix according to my guess. @dcaballe Could you help test it? If it can't be solved, I may need to spend more time to analyze.

It works! I must have messed up the test when trimming it. I'm attaching the whole function for your reference. Your code was invoked four times on this function with the previous version of the code. It's only invoked twice with the fixed one.

Thanks!

craig.topper added inline comments.Jul 9 2022, 1:26 PM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9777

I don't understand the new check. Why was it wrong to optimize 1.1 and 2.1 in the failed case?

liaolucy added inline comments.Jul 9 2022, 10:50 PM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9777

Yesterday, I saw that x86 has this check.

// CMOV ((CMOV F, T, cc1), T, cc2) is checked here and handled by a separate
// function - EmitLoweredCascadedSelect

// This checks for case 2, but only do this if we didn't already find
// case 1, as indicated by LastCMOV == MI.
if (LastCMOV == &MI && NextMIIt != ThisMBB->end() &&
    NextMIIt->getOpcode() == MI.getOpcode() &&
    NextMIIt->getOperand(2).getReg() == MI.getOperand(2).getReg() &&
    NextMIIt->getOperand(1).getReg() == MI.getOperand(0).getReg() &&
    NextMIIt->getOperand(1).isKill()) {
  return EmitLoweredCascadedSelect(MI, *NextMIIt, ThisMBB);
}

Try to understand:
1.1 %29:fpr32 = Select_FPR32_Using_CC_GPR killed %7:gpr, %27:gpr, 4, %3:fpr32, killed %25:fpr3
2.1 %30:fpr32 = Select_FPR32_Using_CC_GPR killed %18:gpr, %24:gpr, 1, %16:fpr32, killed %29:fpr32

%7 (a), %27 (b)
%18 (c), %24 (d)
Eg:

a=b , c=d,
 a=b , c!=d
 a!=b, c=d,
 a!=b, c!=d

b=d:

%7 (a), %27 (b), 
%18 (c), %24 (b) 
a=b,  c=b     a=b=c
a=b,  c!=b 
                a!= c (optimize)   
a!=b, c=b
a!=b, c!=b   a!=b!=c
craig.topper added inline comments.Jul 9 2022, 11:13 PM
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
9777

X86 has the CMP done has a separate instruction that writes EFLAGS. The getOperand(2) check for X86 is making sure that the EFLAGS come from the same CMP instruction. RISC-V has does the comparison has part of the branch so it's different.

For
1.1 %29:fpr32 = Select_FPR32_Using_CC_GPR killed %7:gpr, %27:gpr, 4, %3:fpr32, killed %25:fpr3
2.1 %30:fpr32 = Select_FPR32_Using_CC_GPR killed %18:gpr, %24:gpr, 1, %16:fpr32, killed %29:fpr32

The original code is
(%18 != %24) ? %16 : ((%7 < %27) ? %3 : %25)

The transform this patch is trying to do needs to be

bb1:
  BNE %18, %24, bb4
bb2:
  BLTU %7, %27, bb4
bb3:
  // fallthrough
bb4
  phi %16, bb1, %3, bb2, %25 bb3

The condition of the second select needs to be checked first.

liaolucy updated this revision to Diff 443489.Jul 10 2022, 3:18 AM

address craig.topper's comments and thanks

This revision is now accepted and ready to land.Jul 10 2022, 10:45 PM
This revision was landed with ongoing or failed builds.Jul 10 2022, 11:10 PM
This revision was automatically updated to reflect the committed changes.