This is an archive of the discontinued LLVM Phabricator instance.

[X86][TwoAddressInstructionPass] Teach tryInstructionCommute to continue checking for commutable FMA operands in more cases.
ClosedPublic

Authored by craig.topper on Feb 23 2020, 12:33 AM.

Details

Summary

Previously we would only check for another commutable operand if the first commute was an aggressive commute.

But if we have two kill operands and neither is tied to the def at the start, we should consider both operands as the one to use as the new def.

This improves the loop in the fma-commute-loop.ll test. This test is derived from a post from discourse here https://llvm.discourse.group/t/unnecessary-vmovapd-instructions-generated-can-you-hint-in-favor-of-vfmadd231pd/582

It does degrade some of the fastmath tests, but that's probably just due to the known problems with our decision making with physical register constraints from above and below on small code with multiple two address instructions.

Diff Detail

Event Timeline

craig.topper created this revision.Feb 23 2020, 12:33 AM
Herald added a project: Restricted Project. · View Herald TranscriptFeb 23 2020, 12:33 AM
Herald added a subscriber: hiraditya. · View Herald Transcript
bondhugula added a subscriber: bondhugula.EditedFeb 24 2020, 5:02 AM

It's really nice to see this patch. I think the instruction sequence pattern that this has an impact on (the one in fma-commute-loop.ll and on the discourse post) is arguably the most important one for HPC! I tested this patch on the code generated from MLIR in the context of this experimentation: https://github.com/bondhugula/llvm-project/blob/hop/mlir/docs/HighPerfCodeGen.md#tweaking-m_c-k_c-m_r-n_r-to-maximize-reuse
and all the fma 213pd's shown in the llc assembly output snippet therein are replaced by 231pd's. It didn't have any impact on the performance though since the 213pds in that sequence didn't lead to any additional inter register shuffling, but FWIW, it's much more intuitive to see the same consistent set of registers due to usage of 231s there.

There seems to be a few extra register moves with this change

llvm/test/CodeGen/X86/recip-fastmath.ll
819

Any chance that we can avoid this regression (and the other ones above)?

craig.topper marked an inline comment as done.Feb 26 2020, 10:34 AM
craig.topper added inline comments.
llvm/test/CodeGen/X86/recip-fastmath.ll
819

Not without a new heuristic. We would need to see that one of the kills we're choosing for vfmsub is tied to xmm0 and that the user of the vfmsub def, the vfnmadd132ps, has its def tied to xmm0 and that it's commutable into a form that would allow the output from the vfmsub to be tied to that def.

OK, please can you add the additional tests now?

Rebase after uploading test. Hopefully with context this time

craig.topper edited the summary of this revision. (Show Details)Feb 27 2020, 8:56 PM
craig.topper edited the summary of this revision. (Show Details)
RKSimon accepted this revision.Feb 28 2020, 3:08 AM

LGTM, please can you raise a bug about fixing the extra move cases?

This revision is now accepted and ready to land.Feb 28 2020, 3:08 AM
This revision was automatically updated to reflect the committed changes.
chriselrod added a comment.EditedFeb 21 2021, 11:16 PM

I'm still seeing this for a complex dot product on LLVM 11.0.1. Godbolt example

LLVM IR:

L61:                                              ; preds = %L61, %L30.preheader
  %value_phi20178 = phi <8 x double> [ %res.i118, %L61 ], [ zeroinitializer, %L30.preheader ]
  %value_phi19177 = phi <8 x double> [ %res.i119, %L61 ], [ zeroinitializer, %L30.preheader ]
  %value_phi16176 = phi <8 x double> [ %res.i124, %L61 ], [ zeroinitializer, %L30.preheader ]
  %value_phi15175 = phi <8 x double> [ %res.i125, %L61 ], [ zeroinitializer, %L30.preheader ]
  %value_phi9174 = phi i64 [ %ptr.2.i114, %L61 ], [ %11, %L30.preheader ]
  %value_phi173 = phi i64 [ %ptr.2.i117, %L61 ], [ %10, %L30.preheader ]
  %ptr.1.i148 = inttoptr i64 %value_phi173 to <16 x double>*
  %res.i149 = load <16 x double>, <16 x double>* %ptr.1.i148, align 8
  %res.i147 = shufflevector <16 x double> %res.i149, <16 x double> undef, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
  %res.i146 = shufflevector <16 x double> %res.i149, <16 x double> undef, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
  %ptr.0.i143 = inttoptr i64 %value_phi173 to double*
  %ptr.1.i144 = getelementptr inbounds double, double* %ptr.0.i143, i64 16
  %ptr.1.i141 = bitcast double* %ptr.1.i144 to <16 x double>*
  %res.i142 = load <16 x double>, <16 x double>* %ptr.1.i141, align 8
  %res.i140 = shufflevector <16 x double> %res.i142, <16 x double> undef, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
  %res.i139 = shufflevector <16 x double> %res.i142, <16 x double> undef, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
  %ptr.1.i137 = inttoptr i64 %value_phi9174 to <16 x double>*
  %res.i138 = load <16 x double>, <16 x double>* %ptr.1.i137, align 8
  %res.i136 = shufflevector <16 x double> %res.i138, <16 x double> undef, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
  %res.i135 = shufflevector <16 x double> %res.i138, <16 x double> undef, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
  %ptr.0.i132 = inttoptr i64 %value_phi9174 to double*
  %ptr.1.i133 = getelementptr inbounds double, double* %ptr.0.i132, i64 16
  %ptr.1.i130 = bitcast double* %ptr.1.i133 to <16 x double>*
  %res.i131 = load <16 x double>, <16 x double>* %ptr.1.i130, align 8
  %res.i129 = shufflevector <16 x double> %res.i131, <16 x double> undef, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
  %res.i128 = shufflevector <16 x double> %res.i131, <16 x double> undef, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
  %res.i127 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i146, <8 x double> %res.i135, <8 x double> %value_phi15175)
  %res.i126 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i139, <8 x double> %res.i128, <8 x double> %value_phi16176)
  %res.i125 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i147, <8 x double> %res.i136, <8 x double> %res.i127)
  %res.i124 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i140, <8 x double> %res.i129, <8 x double> %res.i126)
  %res.i123 = fneg nsz contract <8 x double> %res.i146
  %res.i122 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i123, <8 x double> %res.i136, <8 x double> %value_phi19177)
  %res.i121 = fneg nsz contract <8 x double> %res.i139
  %res.i120 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i121, <8 x double> %res.i129, <8 x double> %value_phi20178)
  %res.i119 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i147, <8 x double> %res.i135, <8 x double> %res.i122)
  %res.i118 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i140, <8 x double> %res.i128, <8 x double> %res.i120)
  %ptr.1.i116 = getelementptr inbounds double, double* %ptr.0.i143, i64 32
  %ptr.2.i117 = ptrtoint double* %ptr.1.i116 to i64
  %ptr.1.i113 = getelementptr inbounds double, double* %ptr.0.i132, i64 32
  %ptr.2.i114 = ptrtoint double* %ptr.1.i113 to i64
  %.not = icmp ugt double* %ptr.1.i116, %ptr.1.i157
  br i1 %.not, label %L88, label %L61

ASM:

.LBB0_5:                                # %L61
        vmovupd zmm8, zmmword ptr [rsi]
        vmovupd zmm6, zmmword ptr [rsi + 64]
        vmovupd zmm9, zmmword ptr [rsi + 128]
        vmovupd zmm7, zmmword ptr [rsi + 192]
        vmovapd zmm10, zmm8
        vpermt2pd       zmm10, zmm2, zmm6
        vpermt2pd       zmm8, zmm3, zmm6
        vmovapd zmm11, zmm9
        vpermt2pd       zmm11, zmm2, zmm7
        vmovupd zmm12, zmmword ptr [rdx]
        vmovupd zmm13, zmmword ptr [rdx + 64]
        vmovupd zmm14, zmmword ptr [rdx + 128]
        vpermt2pd       zmm9, zmm3, zmm7
        vmovupd zmm15, zmmword ptr [rdx + 192]
        vmovapd zmm6, zmm12
        vpermt2pd       zmm6, zmm2, zmm13
        vpermt2pd       zmm12, zmm3, zmm13
        vmovapd zmm7, zmm14
        vpermt2pd       zmm7, zmm2, zmm15
        vpermt2pd       zmm14, zmm3, zmm15
        vfmadd231pd     zmm0, zmm8, zmm12       # zmm0 = (zmm8 * zmm12) + zmm0
        vfmadd231pd     zmm1, zmm9, zmm14       # zmm1 = (zmm9 * zmm14) + zmm1
        vfmadd231pd     zmm0, zmm10, zmm6       # zmm0 = (zmm10 * zmm6) + zmm0
        vfmadd231pd     zmm1, zmm11, zmm7       # zmm1 = (zmm11 * zmm7) + zmm1
        vfmsub213pd     zmm6, zmm8, zmm5        # zmm6 = (zmm8 * zmm6) - zmm5
        vfmsub213pd     zmm7, zmm9, zmm4        # zmm7 = (zmm9 * zmm7) - zmm4
        vfmsub231pd     zmm6, zmm10, zmm12      # zmm6 = (zmm10 * zmm12) - zmm6
        vfmsub231pd     zmm7, zmm11, zmm14      # zmm7 = (zmm11 * zmm14) - zmm7
        add     rsi, 256
        add     rdx, 256
        vmovapd zmm4, zmm7
        vmovapd zmm5, zmm6
        cmp     rsi, r11
        jbe     .LBB0_5

Note the

        vfmsub213pd     zmm6, zmm8, zmm5        # zmm6 = (zmm8 * zmm6) - zmm5
        vfmsub213pd     zmm7, zmm9, zmm4        # zmm7 = (zmm9 * zmm7) - zmm4
        vfmsub231pd     zmm6, zmm10, zmm12      # zmm6 = (zmm10 * zmm12) - zmm6
        vfmsub231pd     zmm7, zmm11, zmm14      # zmm7 = (zmm11 * zmm14) - zmm7
# ...
        vmovapd zmm4, zmm7
        vmovapd zmm5, zmm6

EDIT:
This seems to be a separate (unrelated?) problem.

First, an aside pointing to what looks like a missed optimization, where there's an extra separaet loop preheader:

  %3 = load i64, i64* %2, align 8
  %res.i159 = shl nuw i64 %3, 1
  %indname.i = add i64 %res.i159, -32
  %9 = icmp ult i64 %3, 16
  br i1 %9, label %L93, label %L30.preheader

L30.preheader:                                    ; preds = %top
  %.not172 = icmp slt i64 %indname.i, 0
  br i1 %.not172, label %L88, label %L61

I'm not sure if the optimization of eliminating it/the extra branch is valid or not.
It'd be invalid if %3 were less than 0. but there is an assume on sgt %3, 0, but there could be other corner cases I'm missing. Running through the code witha few possible example valeues of %3 (and a few assumed impossible ones):

  %3 = [ -1, 0, 15, 16, 17 ]
  %4 = icmp sgt i64 %3, 0
  call void @llvm.assume(i1 %4) # the `-1` and `0` values aren't possible
  %res.i159 = shl nuw i64 %3, 1 = [ -2, 0, 30, 32, 34 ]
  %indname.i = add i64 %res.i159, -32 = [ -34, -32, -2, 0, 2 ]
  %9 = icmp ult i64 %3, 16 [ 0xffff...<16=false, 0<16=true, 15<16=true, 16<16=false, 17<16=false > 
  br i1 %9, label %L93, label %L30.preheader 

L30.preheader:                                    ; preds = %top
  %.not172 = icmp slt i64 %indname.i, 0 # <-34<0=true,couldn't reach, couldn't reach, 0<0=false, 2<0 false >
  br i1 %.not172, label %L88, label %L61

Anyway, the point of this long aside on the maybe missed optimization is that when I did ellimate the redundancy, I got t[[ https://godbolt.org/z/448vco | his Godbolt ]].

The two vmovapds associated with the vfmsubs are now gone (the remaining ones are all necessary because of the vperms).

So my theory (disclaimer: I am totally ignorant of LLVM's internals) is that the cause of those vmovapds is difficulty allocating registers consistently across all those blocks/in the presence of phi nodes.
Code with the extra vmovapds:

        vfmsub213pd     zmm6, zmm8, zmm5        # zmm6 = (zmm8 * zmm6) - zmm5
        vfmsub213pd     zmm7, zmm9, zmm4        # zmm7 = (zmm9 * zmm7) - zmm4
        vfmsub231pd     zmm6, zmm10, zmm12      # zmm6 = (zmm10 * zmm12) - zmm6
        vfmsub231pd     zmm7, zmm11, zmm14      # zmm7 = (zmm11 * zmm14) - zmm7
        add     rsi, 256
        add     rdx, 256
        vmovapd zmm4, zmm7
        vmovapd zmm5, zmm6
        cmp     rsi, r11
        jbe     .LBB0_5
        jmp     .LBB0_6
.LBB0_3:
        vxorpd  xmm0, xmm0, xmm0
        mov     rsi, r8
        vxorpd  xmm1, xmm1, xmm1
        vxorpd  xmm6, xmm6, xmm6
        vxorpd  xmm7, xmm7, xmm7
.LBB0_6:                                # %L88
        vaddpd  zmm2, zmm6, zmm7
        vaddpd  zmm0, zmm0, zmm1 # calculated earlier in the loop

corresponding IR:

  %res.i123 = fneg nsz contract <8 x double> %res.i146
  %res.i122 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i123, <8 x double> %res.i136, <8 x double> %value_phi19177)
  %res.i121 = fneg nsz contract <8 x double> %res.i139
  %res.i120 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i121, <8 x double> %res.i129, <8 x double> %value_phi20178)
  %res.i119 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i147, <8 x double> %res.i135, <8 x double> %res.i122)
  %res.i118 = call nsz contract <8 x double> @llvm.fma.v8f64(<8 x double> %res.i140, <8 x double> %res.i128, <8 x double> %res.i120)
  %ptr.1.i116 = getelementptr inbounds double, double* %ptr.0.i143, i64 32
  %ptr.2.i117 = ptrtoint double* %ptr.1.i116 to i64
  %ptr.1.i113 = getelementptr inbounds double, double* %ptr.0.i132, i64 32
  %ptr.2.i114 = ptrtoint double* %ptr.1.i113 to i64
  %.not = icmp ugt double* %ptr.1.i116, %ptr.1.i157
  br i1 %.not, label %L88, label %L61

L88:                                              ; preds = %L61, %L30.preheader
  %value_phi.lcssa = phi i64 [ %10, %L30.preheader ], [ %ptr.2.i117, %L61 ]
  %value_phi9.lcssa = phi i64 [ %11, %L30.preheader ], [ %ptr.2.i114, %L61 ]
  %value_phi15.lcssa = phi <8 x double> [ zeroinitializer, %L30.preheader ], [ %res.i125, %L61 ]
  %value_phi16.lcssa = phi <8 x double> [ zeroinitializer, %L30.preheader ], [ %res.i124, %L61 ]
  %value_phi19.lcssa = phi <8 x double> [ zeroinitializer, %L30.preheader ], [ %res.i119, %L61 ]
  %value_phi20.lcssa = phi <8 x double> [ zeroinitializer, %L30.preheader ], [ %res.i118, %L61 ]
  %res.i111 = fadd nsz contract <8 x double> %value_phi15.lcssa, %value_phi16.lcssa
  %res.i110 = fadd nsz contract <8 x double> %value_phi19.lcssa, %value_phi20.lcssa
  br label %L93

The start of the loop iteration takes zmm4 and zmm5 as inputs, and then updates them into zmm7 and zmm6.
Thus, the next loop iteration requires zmm4 and zmm5 to be updated -- hence, either vmovapds are necessary, or changing which register gets assigned in the second set of vmovapds. However, .LBB0_6: expects registers 6 and 7 as inputs. This would have to be changed as well.

Obviously it's possible, but it's bigger/more involved and thus understandably harder to do than the case when the extra blocks get elliminated, and the ASM instead looks like:

vfmsub231pd     zmm2, zmm6, zmm14       # zmm2 = (zmm6 * zmm14) - zmm2
vfmsub231pd     zmm4, zmm8, zmm12       # zmm4 = (zmm8 * zmm12) - zmm4
vfmsub231pd     zmm2, zmm10, zmm11      # zmm2 = (zmm10 * zmm11) - zmm2
vfmsub231pd     zmm4, zmm7, zmm13       # zmm4 = (zmm7 * zmm13) - zmm4
add     rsi, 256
add     rdx, 256
cmp     rsi, r11
jbe     .LBB0_3
vaddpd  zmm1, zmm0, zmm1  # calculated earlier in the loop
vaddpd  zmm0, zmm2, zmm4

corresponding IR:

  %res.i138 = fneg reassoc nsz arcp contract afn <8 x double> %res.i161
  %res.i137 = call reassoc nsz arcp contract afn <8 x double> @llvm.fma.v8f64(<8 x double> %res.i138, <8 x double> %res.i151, <8 x double> %value_phi18)
  %res.i136 = fneg reassoc nsz arcp contract afn <8 x double> %res.i154
  %res.i135 = call reassoc nsz arcp contract afn <8 x double> @llvm.fma.v8f64(<8 x double> %res.i136, <8 x double> %res.i144, <8 x double> %value_phi19)
  %res.i134 = call reassoc nsz arcp contract afn <8 x double> @llvm.fma.v8f64(<8 x double> %res.i162, <8 x double> %res.i150, <8 x double> %res.i137)
  %res.i133 = call reassoc nsz arcp contract afn <8 x double> @llvm.fma.v8f64(<8 x double> %res.i155, <8 x double> %res.i143, <8 x double> %res.i135)
  %ptr.1.i131 = getelementptr inbounds double, double* %ptr.0.i158, i64 32
  %ptr.2.i132 = ptrtoint double* %ptr.1.i131 to i64
  %ptr.1.i128 = getelementptr inbounds double, double* %ptr.0.i147, i64 32
  %ptr.2.i129 = ptrtoint double* %ptr.1.i128 to i64
  %.not = icmp ugt double* %ptr.1.i131, %ptr.1.i172
  br i1 %.not, label %L110, label %L51

L110:                                             ; preds = %L51
  %res.i126 = fadd nsz contract <8 x double> %res.i140, %res.i139
  %res.i125 = fadd nsz contract <8 x double> %res.i134, %res.i133
  br label %L129

No phi nodes here in the second block, and no unnecessary vmovapds in the ASM.