This is an archive of the discontinued LLVM Phabricator instance.

[AArch64][SVE]Add cost model for masked gather and scatter for scalable vector.
ClosedPublic

Authored by CarolineConcatto on Dec 10 2020, 5:45 AM.

Details

Summary

A new TTI interface has been added 'Optional <unsigned>getMaxVScale' that
returns the maximum vscale for a given target.
When known getMaxVScale is used to compute the cost of masked gather scatter
for scalable vector.

Depends on D92094

Diff Detail

Event Timeline

CarolineConcatto requested review of this revision.Dec 10 2020, 5:45 AM
Herald added a project: Restricted Project. · View Herald TranscriptDec 10 2020, 5:45 AM
ctetreau added inline comments.Dec 10 2020, 6:09 AM
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
372

Is a concrete 1 here really the best default? I would think that for scalable architectures, the maximum vscale is almost certainly larger than 1.

Regardless, I would think that it should be none by default. It's an unknowable runtime constant after all.

sdesmalen added inline comments.Dec 10 2020, 6:10 AM
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
372

Agreed, None should be the default.

sdesmalen added inline comments.Dec 10 2020, 6:13 AM
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
372

The idea is that not all targets may have an architectural maximum, and so the question 'getMaxVScale' may result in a valid answer, e.g. for SVE, the architectural maximum is 2048 bits, but for other targets that may not be the case.

sdesmalen added inline comments.Dec 10 2020, 8:39 AM
llvm/include/llvm/Analysis/TargetTransformInfo.h
931

nit: The maximum value of vscale if the target specifies an architectural maximum vector length, and None otherwise.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
780

If you rewrite this in the opposite way, you can avoid indentation:

if (!VF.isScalable())
  return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
                                       Alignment, CostKind, I);
782

You can remove this if condition, if you just move the assert above this line.

783

nit: It's probably still worth asking how many gather operations this requires, e.g. for TLI->getTypeLegalizationCost(DL, DataTy);, and then multiply:

auto LT = TLI->getTypeLegalizationCost(DL, DataTy);
auto Cost = /*NumGathers*/ LT.first *
       /*NumElementsPerGather*/ (MaxNumVScale * LT.second.getElementCount.getKnownMinValue()) *
       MemOpCost;

Because it is multiplied, the result is currently the same, but perhaps there would be an added cost per gather instruction that is worth modelling at some point.

llvm/test/Analysis/CostModel/AArch64/sve-getIntrinsicInstrCost-gather.ll
13

Can you also check for the instruction (e.g. llvm.masked.gather and ret)

  • update the cost model for gather and scatter using legalized type
CarolineConcatto marked 5 inline comments as done.Dec 14 2020, 2:55 AM

Thank you @ctetreau and @sdesmalen for your reviews.
I believe I've updated the patch according to your suggestions.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
783

Thank you Sander!
I've update the way we compute the cost as you've asked.
I agree with you that is probably better to use the legalized type.
I can see that something similar is done in getMemOpCost.

tschuett added a comment.EditedDec 14 2020, 10:07 AM

Assuming you have machine model for the FUJITSU A64FX. You should be able to get a lower value for getMaxVScale()?

512/128 -> 4

Even better, I invoked clang with -msve-vector-bits=128.

david-arm added inline comments.Dec 16 2020, 4:22 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
784

nit: I wonder if the assert here should say something like "Expected valid max vscale" here?

llvm/test/Analysis/CostModel/AArch64/sve-getIntrinsicInstrCost-gather.ll
11

Is it worth having at least one illegal gather test here too? For example, gathers or scatters of <vscale x 8 x i32> - this tests the case when LT.first > 1 in your cost calculation code above.

  • add test test for illegal gather
llvm/include/llvm/CodeGen/BasicTTIImpl.h
1344–1345

This need fix, my git history got missed up with the rebase. I'll am trying to fix this!

  • fix history with git

@tschuett I agree that we should refine the value if MaxVScale according to the Architecture by using flags.
But I believe this should be the scope of another patch.
Implementing the use of a flag to compute the value of MaxVScale will increase the complexity of the patch, and IMHO it is not related to the cost of gather and scatters.
The value of MaxVScale is/will be used to compute the cost of others intrinsic/instructions with vscale data type.

david-arm accepted this revision.Dec 18 2020, 3:36 AM

LGTM! Hi @tschuett I agree I think it's better to leave any possible refinement of getMaxVScale to a later patch as it may require new flags.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
784

nit: Sorry I didn't spot this before. Does the TODO still make sense here?

This revision is now accepted and ready to land.Dec 18 2020, 3:36 AM
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
784

Yes, because I cannot replace this to return InstructionCost without changing a big chain of functions to InstructionCost.
I believe we will tackle this problem later, in another patch responsible to replace these functions to return InstructionCost.

  • update TODO message on AArch64TTIImpl::getGatherScatterOpCost

LGTM with nits addressed and TODO removed.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
778

nit: can you give this a more intuitive name, like LegalVF?

784

For AArch64 there is always an architectural maximum, so getMaxVScale cannot return None. There isn't anything special we need to do such as returning an Invalid cost when we've already asserted that MaxVscale is not None (i.e. since the compiler will already have crashed if that wasn't the case, it doesn't really matter what happens after that point). So I agree that you can indeed remove the TODO as @david-arm suggested.

793

nit: s/NumElementsPerGather/MaxNumElementsPerGather/
(also, why the parentheses areound NumElementsPerGather?)

  • remove TODO from MaxVScale assert
  • replace LVF by LegalVF
  • rewrite return equation
CarolineConcatto marked 8 inline comments as done.Jan 4 2021, 5:54 AM

Anyone else seeing crashes with this patch? I'm seeing something like this:

UNREACHABLE executed at /usr/local/google/home/blaikie/dev/llvm/src/llvm/include/llvm/Support/MachineValueType.h:644!
...
 #0 0x00000000096fd56a llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Support/Unix/Signals.inc:563:11
 #1 0x00000000096fd73b PrintStackTraceSignalHandler(void*) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Support/Unix/Signals.inc:630:1
 #2 0x00000000096fbd5b llvm::sys::RunSignalHandlers() /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Support/Signals.cpp:70:5
 #3 0x00000000096fde6d SignalHandler(int) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Support/Unix/Signals.inc:405:1
 #4 0x00007f220be59140 __restore_rt (/lib/x86_64-linux-gnu/libpthread.so.0+0x14140)
 #5 0x00007f220b929c81 raise ./signal/../sysdeps/unix/sysv/linux/raise.c:51:1
 #6 0x00007f220b913537 abort ./stdlib/abort.c:81:7
 #7 0x00000000096234c4 /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Support/ErrorHandling.cpp:213:3
 #8 0x00000000063112f3 llvm::MVT::getVectorNumElements() const /usr/local/google/home/blaikie/dev/llvm/src/llvm/include/llvm/Support/MachineValueType.h:646:22
 #9 0x000000000631121c llvm::MVT::getVectorElementCount() const /usr/local/google/home/blaikie/dev/llvm/src/llvm/include/llvm/Support/MachineValueType.h:791:32
#10 0x000000000658f5b9 llvm::AArch64TTIImpl::getGatherScatterOpCost(unsigned int, llvm::Type*, llvm::Value const*, bool, llvm::Align, llvm::TargetTransformInfo::TargetCostKind, llvm::Instruction const*) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp:778:36
#11 0x00000000064b17ac llvm::TargetTransformInfo::Model<llvm::AArch64TTIImpl>::getGatherScatterOpCost(unsigned int, llvm::Type*, llvm::Value const*, bool, llvm::Align, llvm::TargetTransformInfo::TargetCostKind, llvm::Instruction const*) /usr/local/google/home/blaikie/dev/llvm/src/llvm/include/llvm/Analysis/TargetTransformInfo.h:2043:5
#12 0x0000000008216cae llvm::TargetTransformInfo::getGatherScatterOpCost(unsigned int, llvm::Type*, llvm::Value const*, bool, llvm::Align, llvm::TargetTransformInfo::TargetCostKind, llvm::Instruction const*) const /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Analysis/TargetTransformInfo.cpp:875:23
#13 0x000000000994cf92 llvm::slpvectorizer::BoUpSLP::getEntryCost(llvm::slpvectorizer::BoUpSLP::TreeEntry*) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:3741:21
#14 0x000000000994f6cc llvm::slpvectorizer::BoUpSLP::getTreeCost() /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:4055:25
#15 0x000000000995be1c llvm::SLPVectorizerPass::tryToVectorizeList(llvm::ArrayRef<llvm::Value*>, llvm::slpvectorizer::BoUpSLP&, bool, llvm::ArrayRef<llvm::Value*>) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:6264:32
#16 0x000000000995b684 llvm::SLPVectorizerPass::tryToVectorizePair(llvm::Value*, llvm::Value*, llvm::slpvectorizer::BoUpSLP&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:6155:3
#17 0x000000000995c883 llvm::SLPVectorizerPass::tryToVectorize(llvm::Instruction*, llvm::slpvectorizer::BoUpSLP&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:6367:9
#18 0x000000000996b8f8 llvm::SLPVectorizerPass::vectorizeRootInstruction(llvm::PHINode*, llvm::Value*, llvm::BasicBlock*, llvm::slpvectorizer::BoUpSLP&, llvm::TargetTransformInfo*)::$_20::operator()(llvm::Instruction*, llvm::slpvectorizer::BoUpSLP&) const /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:7484:5
#19 0x000000000996b8b8 bool llvm::function_ref<bool (llvm::Instruction*, llvm::slpvectorizer::BoUpSLP&)>::callback_fn<llvm::SLPVectorizerPass::vectorizeRootInstruction(llvm::PHINode*, llvm::Value*, llvm::BasicBlock*, llvm::slpvectorizer::BoUpSLP&, llvm::TargetTransformInfo*)::$_20>(long, llvm::Instruction*, llvm::slpvectorizer::BoUpSLP&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/include/llvm/ADT/STLExtras.h:185:5
#20 0x0000000009989ac4 llvm::function_ref<bool (llvm::Instruction*, llvm::slpvectorizer::BoUpSLP&)>::operator()(llvm::Instruction*, llvm::slpvectorizer::BoUpSLP&) const /usr/local/google/home/blaikie/dev/llvm/src/llvm/include/llvm/ADT/STLExtras.h:209:5
#21 0x000000000995ceb2 tryToVectorizeHorReductionOrInstOperands(llvm::PHINode*, llvm::Instruction*, llvm::BasicBlock*, llvm::slpvectorizer::BoUpSLP&, llvm::TargetTransformInfo*, llvm::function_ref<bool (llvm::Instruction*, llvm::slpvectorizer::BoUpSLP&)>) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:7455:9
#22 0x000000000995cb21 llvm::SLPVectorizerPass::vectorizeRootInstruction(llvm::PHINode*, llvm::Value*, llvm::BasicBlock*, llvm::slpvectorizer::BoUpSLP&, llvm::TargetTransformInfo*) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:7486:3
#23 0x0000000009959c0f llvm::SLPVectorizerPass::vectorizeChainsInBlock(llvm::BasicBlock*, llvm::slpvectorizer::BoUpSLP&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:7665:20
#24 0x0000000009958f0c llvm::SLPVectorizerPass::runImpl(llvm::Function&, llvm::ScalarEvolution*, llvm::TargetTransformInfo*, llvm::TargetLibraryInfo*, llvm::AAResults*, llvm::LoopInfo*, llvm::DominatorTree*, llvm::AssumptionCache*, llvm::DemandedBits*, llvm::OptimizationRemarkEmitter*) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:5954:16
#25 0x0000000009965558 (anonymous namespace)::SLPVectorizer::runOnFunction(llvm::Function&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp:5856:5
#26 0x0000000008ab24a2 llvm::FPPassManager::runOnFunction(llvm::Function&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/IR/LegacyPassManager.cpp:1440:23
#27 0x0000000008ab7645 llvm::FPPassManager::runOnModule(llvm::Module&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/IR/LegacyPassManager.cpp:1486:16
#28 0x0000000008ab2e46 (anonymous namespace)::MPPassManager::runOnModule(llvm::Module&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/IR/LegacyPassManager.cpp:1555:23
#29 0x0000000008ab2976 llvm::legacy::PassManagerImpl::run(llvm::Module&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/IR/LegacyPassManager.cpp:542:16
#30 0x0000000008ab7941 llvm::legacy::PassManager::run(llvm::Module&) /usr/local/google/home/blaikie/dev/llvm/src/llvm/lib/IR/LegacyPassManager.cpp:1682:3
#31 0x0000000009b109d1 (anonymous namespace)::EmitAssemblyHelper::EmitAssembly(clang::BackendAction, std::unique_ptr<llvm::raw_pwrite_stream, std::default_delete<llvm::raw_pwrite_stream> >) /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/CodeGen/BackendUtil.cpp:1013:3
#32 0x0000000009b0d273 clang::EmitBackendOutput(clang::DiagnosticsEngine&, clang::HeaderSearchOptions const&, clang::CodeGenOptions const&, clang::TargetOptions const&, clang::LangOptions const&, llvm::DataLayout const&, llvm::Module*, clang::BackendAction, std::unique_ptr<llvm::raw_pwrite_stream, std::default_delete<llvm::raw_pwrite_stream> >) /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/CodeGen/BackendUtil.cpp:1576:5
#33 0x000000000a783d31 clang::BackendConsumer::HandleTranslationUnit(clang::ASTContext&) /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/CodeGen/CodeGenAction.cpp:344:7
#34 0x000000000d1bcc58 clang::ParseAST(clang::Sema&, bool, bool) /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/Parse/ParseAST.cpp:178:12
#35 0x000000000a5acc0d clang::ASTFrontendAction::ExecuteAction() /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/Frontend/FrontendAction.cpp:1058:1
#36 0x000000000a77f718 clang::CodeGenAction::ExecuteAction() /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/CodeGen/CodeGenAction.cpp:1083:5
#37 0x000000000a5ac5d8 clang::FrontendAction::Execute() /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/Frontend/FrontendAction.cpp:953:7
#38 0x000000000a4c22d8 clang::CompilerInstance::ExecuteAction(clang::FrontendAction&) /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/Frontend/CompilerInstance.cpp:957:23
#39 0x000000000a76c714 clang::ExecuteCompilerInvocation(clang::CompilerInstance*) /usr/local/google/home/blaikie/dev/llvm/src/clang/lib/FrontendTool/ExecuteCompilerInvocation.cpp:278:8
#40 0x00000000062266fc cc1_main(llvm::ArrayRef<char const*>, char const*, void*) /usr/local/google/home/blaikie/dev/llvm/src/clang/tools/driver/cc1_main.cpp:240:13
#41 0x0000000006218faa ExecuteCC1Tool(llvm::SmallVectorImpl<char const*>&) /usr/local/google/home/blaikie/dev/llvm/src/clang/tools/driver/driver.cpp:330:5
#42 0x000000000621814d main /usr/local/google/home/blaikie/dev/llvm/src/clang/tools/driver/driver.cpp:407:5

Would be great to get this reverted while it's figured out - I can try to reduce a test case if no one else has something more portable/shareable - but probably worth reverting before that so other people aren't blocked by this as well?

Thank you @dblaikie for pointing the problem.
I was able to generate a reproducer and I've added it as a test.
I'll push the fix to main ASAP

I push the fix patch to main in the commit:
01c190e907ca

Patch submitted to main.

I push the fix patch to main in the commit:
01c190e907ca

Thanks!