Page MenuHomePhabricator

[AArch64][SVE] Combine predicated FMUL/FADD into FMA
ClosedPublic

Authored by MattDevereau on Oct 12 2021, 6:25 AM.

Details

Summary

[AArch64][SVE] Combine predicated FMUL/FADD into FMA

Combine FADD and FMUL intrinsics into FMA when the result of the FMUL is an FADD operand with one only use and both use the same predicate.

Diff Detail

Event Timeline

There are a very large number of changes, so older changes are hidden. Show Older Changes
MattDevereau requested review of this revision.Oct 12 2021, 6:25 AM
Herald added a project: Restricted Project. · View Herald TranscriptOct 12 2021, 6:25 AM

ran clang-format

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
718

I'd expect this all to look simpler, please have a go at simplifying. I think you can drop the matching logic above and instead add a condition that checks the AddOp1's predicate matches the Add's predicate.

One issue I have is that both the swap and m_SVEFAdd as it currently is serve the purpose of testing both operand orders. I think it's important for clarity to only do this once.

724

It seems to me that a check is needed if the fast math flags contain 'contract', and if not, bailout.

759

Additionally to the diff-suggestion above, I think you could just call this FMLA. (I'd prefer to see FMLA capitalized in this context since it is an abbreviation so it matches, e.g. "SVE" in style.

761

I think I would prefer to see this in the big-switch-of-intrinsics or alternatively in a function called instCombineSVEVectorFAdd. It feels wrong for it to appear inside VectorBinOp since it doesn't apply to any BinOp other than FAdd, so the "hierarchy" is wrong in my view.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
700

Please add a comment around your combine of the form // fold (fadd a (fmul b c)) -> (fma a b c)

718

If you want to use the match syntax above I think you can also bind the fmul with m_And(m_Value(FMul), m_SVEFMul(m_Deferred(p), m_Value(a), m_Value(b)) and then you could grab FMul and c simultaneously with the existing logic.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
724

Please also check that the flags are equal instead of taking their intersection. The fadd might plausibly have a flag which allows more aggressive optimization and contracting it in this way may prevent those optimizations from taking place.

Matt added a subscriber: Matt.Oct 13 2021, 2:49 PM
MattDevereau marked 5 inline comments as done.Oct 20 2021, 5:29 AM
MattDevereau added inline comments.
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
718

I've opted to check AddOp1's predicate matches the Add's predicate

724

I'm not sure what you mean here. Since flags is its own fresh variable isnt the intersection the same as comparing if the flags are equal? The original flags on the fadd intrinsic will be preserved

MattDevereau added inline comments.Oct 20 2021, 5:51 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
718

Was this what you meant?

if (!match(&II, m_SVEFAdd(m_Value(p),
                          m_And(m_Value(FMul), m_SVEFMul(m_Deferred(p), m_Value(a), m_Value(b))),
                          m_Value(c)))){
  return None;
}

FMul isn't matching in this expression

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
718

Sorry, I think I meant match(X, m_CombineAnd(A,B)) which means to match both X on A and X on B. I erroneously wrote m_And(A,B) which means to match X on the expression A && B.

See these two:

https://github.com/llvm/llvm-project/blob/3efd2a0bec0298d804f274fcc10ea14431b61de1/llvm/include/llvm/IR/PatternMatch.h#L1122-L1126

https://github.com/llvm/llvm-project/blob/3efd2a0bec0298d804f274fcc10ea14431b61de1/llvm/include/llvm/IR/PatternMatch.h#L194-L218

724

When I wrote 'contracting' I think I was thinking 'intersecting'.

My suggestion is to do if (flags1 == flags2) return None;, which cannot be the same as doing an intersection -- in my proposed case the optimization would not take place. Intersection implies constructing a new set of flags which is different from those flags sets on the input. Also not sure what you mean by 'the original flags on the fadd intrinsic will be preserved' -- we're expecting that this optimization will replace the fadd with the newly constructed fmul, so the fadd is going to be erased.

Restructured logic path by adding instCombineSVEVectorFAdd
Added check for no contract flag when comparing fast flags
Extended matching logic in instCombineSVEVectorFmla to capture value* FMul

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
724

Sorry, I meant if (flags1 != flags2) return None;

MattDevereau added inline comments.Oct 20 2021, 8:59 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
724

There's no != operator on llvm::FastMathFlags which is why I used the intersection before. From reading https://llvm.org/docs/LangRef.html#fast-math-flags it seems the contract flag needs some special attention?

Added != operator to FastMathFlags
Compared FastMathFlags for equality instead of taking their intersection for FMLA combines

MattDevereau marked 5 inline comments as done.Oct 21 2021, 3:39 AM

Wrong argument order needs fixing, and some nits. Please also run clang-format.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
700

The fold as written here looks correct to me, but the match in m_SVEFAdd isn't a 1:1 with this (it is (fadd (fmul a b) c), which is different). The effect is that the resulting fma from the combine has the wrong argument order.

714

Please put a line break after this.

715

Is the dyn_cast needed? You should only need it if you need some methods which are on IntrinsicInst but not Value.

Also, nit, dyn_cast has logic which checks the type of the thing being cast -- this is unnecessary because the type is already constrained by the match() logic above, so if you did need a cast you can write cast<Ty>(Val) instead of dyn_cast.

716

It seems to me that p != II.getOperand(0) should already be being checked by m_Deferred?

720

Nit. I think this condition would read a bit more clearly if it were split into two, and negate the contraction allow.

if (!FAddFlags.allowContract() || !FMulFlags.allowContract())
  return false;

And the second condition is a bit hidden on the right compared to just having it as a separate if (FAddFlags != FMulFlags) return None.

768

Nit. FMLA and Fmla (in the function name). I think it should be consistent on FMLA.

Updated some formatting and used more appropriate casting

MattDevereau marked 6 inline comments as done.Oct 21 2021, 5:33 AM
MattDevereau added inline comments.
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
715

replaced the dyn_cast with cast

716

removed it

MattDevereau marked 2 inline comments as done.Oct 21 2021, 8:15 AM
bsmith added inline comments.Oct 22 2021, 6:38 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
719–723

None of this seems to take into account the global fast-math options, i.e. the "unsafe-fp-math"="true" attribute, hence I don't think this optimization can ever be triggered from C, only directly written IR with the fast flags.

720

Do we not care about Reassociation here also?

722–723

I feel like it might just be sufficient check to check both operands for reassociation and contraction, rather than a full flag check.

MattDevereau added inline comments.Oct 22 2021, 7:05 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
719–723

Compiling foo.c

svfloat16_t fmla_example(svbool_t p, svfloat16_t a, svfloat16_t b, svfloat16_t c) {
  return svadd_f16_m(p, a, svmul_f16_m(p, b, c));
}

with

clang foo.c -S -march=armv8-a+sve -emit-llvm -o - -Ofast

emits

; Function Attrs: mustprogress nofree nosync nounwind readnone uwtable willreturn vscale_range(0,16)
define dso_local <vscale x 8 x half> @fmla_example(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
entry:
  %0 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
  %1 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> %0, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
  ret <vscale x 8 x half> %1
}

for me after implementing this patch

bsmith added inline comments.Oct 22 2021, 7:11 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
719–723

You're right, I screwed up my testcase, I wasn't expecting the intrinsics to gain the fast flag but they do. It still won't trigger when using the global flag and not the instruction level flag however, but we perhaps don't care much about that given the intruction level flags do get used.

Thanks for your review @bsmith, a couple of thoughts on your queries.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
720

My read of the LangRef suggests that contract allows an FMA contraction (But not two of them). So to my eyes this looks sufficient. Please can you clarify your concern?

722–723

I feel like it might just be sufficient check to check both operands for reassociation and contraction, rather than a full flag check.

The full check was chosen because:

  • We're taking two instructions and turning them into one.
  • Therefore we have two sets of flags.
  • We need to put some flags on the resulting op.
  • We could intersect them, but that would result in losing flags.
  • It's conceivable that the lost flags could allow for the whole op to be removed by dead code elimination.
    • (I don't have a concrete example of this in practice, but maybe via nnan/ninf for example?)
  • Therefore, the conservative thing to do is only to apply the optimization if the flags are equal (and to preserve the flags).
  • This hasn't been a data-driven analysis but it seems a conservative one to work with for the moment. If there is evidence to consider an alternative approach, we'd consider it.

The DAG level FMA contractions don't bother with this, but they are running late in the pipeline.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
720

Oh, on another reading I get it; 'contract' is like a weaker form of 'reassociation', and I also see what you're talking about with respect to the global flags.

Added Global Fusion check before instCombineSVEVectorFMLA
Added check for Reassociation flag on FADD and FMUL for instCombineSVEVectorFMLA

MattDevereau marked 4 inline comments as done.Oct 25 2021, 3:51 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
771

This logic needs tweaking; AllowFusionGlobally is independent of the fast math checks, so if either are present the optimization should take place.

Also would be good to see some tests covering these cases.

773

Looks as though clang format is needed.

Moved Global fast-math checks down a level

Remove global flag condition checking for this pass

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
719

It looks like you've dropped FAddFlags != FMulFlags, and it looks like we need a test for that case.

767

stray 's'

Readded nequal flags condition
Reformatted tests to catch the nequal flags condition

MattDevereau marked 2 inline comments as done.Oct 26 2021, 3:22 AM
peterwaller-arm accepted this revision.Oct 26 2021, 6:17 AM
peterwaller-arm added inline comments.
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
700

nit. I just realised we are missing 'p' here, which would serve as a small hint that this exists (and is different to other combines) because of the predicate.

This revision is now accepted and ready to land.Oct 26 2021, 6:17 AM
bsmith accepted this revision.Oct 26 2021, 8:04 AM
This revision was automatically updated to reflect the committed changes.

Sorry @MattDevereau for the too late review but this is the first time I've had chance to look at the patch and I think there's an issue that needs fixing.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
702–713

Is this correct? It looks like you're allowing the fmul to be either operand of the fadd, effectively saying fadd(p, a, b) == fadd(p, b, a). Although true for the active lanes, the inactive lanes take the value of the first data operand and so these two forms are not identical.

724

Given the above I don't think we need AllowReassoc given we shouldn't be changing the order of operations.

MattDevereau added inline comments.Nov 2 2021, 6:39 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
702–713

That sounds right but it seems like something that would probably be caught by a test somewhere. I've had a look at FMLA here and it says "Inactive elements in the destination vector register remain unmodified.". FADD on the otherhand places the value of the first data operand in the destination vector, so FMLA and FADD seem to differ here. I'm probably wrong but from what I can see, the addition part of the FMLA can assume fadd(p, a, b) == fadd(p, b, a) but a normal FADD instruction cannot? Either way it seems like a more extensive set of rules should be considered for this

724

This seems linked to the other comment, which i think needs a bit more consideration

peterwaller-arm added inline comments.Nov 2 2021, 6:54 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
702–713

That sounds right but it seems like something that would probably be caught by a test somewhere.

Not sure what you mean there, can you be more specific? The <it> is ambiguous: what would be caught? And by which testing?

I've had a look at FMLA [...]

FADD on the otherhand places the value of the first data operand in the destination vector, so FMLA and FADD seem to differ here. I'm probably wrong but from what I can see, the addition part of the FMLA can assume fadd(p, a, b) == fadd(p, b, a) but a normal FADD instruction cannot? Either way it seems like a more extensive set of rules should be considered for this

I agree with what Paul's saying, consider: X = A + (B * C) and Y = (B * C) + A.

In the case of inactive lanes before the combine, X ends up with values from A and Y ends up with values from (B * C).

But the combine as is rewrites both cases FMLA(A, B, C). FMLA(A, B, C) takes lanes from A in case of an inactive predicate.

Therefore the output does not have lanes from (B * C) for the expression Y, as it should.

So this needs updating, and we can ignore the Reassoc flag.