diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h --- a/llvm/lib/Target/AArch64/AArch64.h +++ b/llvm/lib/Target/AArch64/AArch64.h @@ -55,6 +55,7 @@ FunctionPass *createAArch64CollectLOHPass(); ModulePass *createSVEIntrinsicOptsPass(); +ModulePass *createSVECoalescePTruesPass(); InstructionSelector * createAArch64InstructionSelector(const AArch64TargetMachine &, AArch64Subtarget &, AArch64RegisterBankInfo &); @@ -91,6 +92,7 @@ void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&); void initializeLDTLSCleanupPass(PassRegistry&); void initializeSVEIntrinsicOptsPass(PassRegistry&); +void initializeSVECoalescePTruesPass(PassRegistry &); void initializeAArch64StackTaggingPass(PassRegistry&); void initializeAArch64StackTaggingPreRAPass(PassRegistry&); } // end namespace llvm diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -153,6 +153,11 @@ cl::desc("Enable SVE intrinsic opts"), cl::init(true)); +static cl::opt + EnableSVECoalescePTrues("aarch64-enable-sve-coalesce-ptrues", cl::Hidden, + cl::desc("Enable the SVE coalesce ptrues pass"), + cl::init(true)); + static cl::opt EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix", cl::init(true), cl::Hidden); @@ -192,6 +197,7 @@ initializeFalkorHWPFFixPass(*PR); initializeFalkorMarkStridedAccessesLegacyPass(*PR); initializeLDTLSCleanupPass(*PR); + initializeSVECoalescePTruesPass(*PR); initializeSVEIntrinsicOptsPass(*PR); initializeAArch64SpeculationHardeningPass(*PR); initializeAArch64SLSHardeningPass(*PR); @@ -452,6 +458,10 @@ // ourselves. addPass(createAtomicExpandPass()); + // Coalesce ptrues. + if (EnableSVECoalescePTrues && TM->getOptLevel() == CodeGenOpt::Aggressive) + addPass(createSVECoalescePTruesPass()); + // Expand any SVE vector library calls that we can't code generate directly. if (EnableSVEIntrinsicOpts && TM->getOptLevel() == CodeGenOpt::Aggressive) addPass(createSVEIntrinsicOptsPass()); diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt --- a/llvm/lib/Target/AArch64/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/CMakeLists.txt @@ -75,6 +75,7 @@ AArch64TargetMachine.cpp AArch64TargetObjectFile.cpp AArch64TargetTransformInfo.cpp + SVECoalescePTrues.cpp SVEIntrinsicOpts.cpp AArch64SIMDInstrOpt.cpp diff --git a/llvm/lib/Target/AArch64/SVECoalescePTrues.cpp b/llvm/lib/Target/AArch64/SVECoalescePTrues.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/AArch64/SVECoalescePTrues.cpp @@ -0,0 +1,230 @@ +//===----- SVECoalescePTrues - Eliminate Redundant SVE PTrue Calls --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The goal of this pass is to remove redundant calls to the SVE ptrue +// intrinsic. +// +// Suppose that we have two SVE ptrue intrinsic calls P1 and P2. If P1 is at +// least as wide as P2, then P2 can be written as a reinterpret of P1 using the +// SVE reinterpret intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "Utils/AArch64BaseInfo.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Support/Debug.h" +#include + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "aarch64-sve-coalesce-ptrues" + +namespace llvm { +void initializeSVECoalescePTruesPass(PassRegistry &); +} + +namespace { +struct SVECoalescePTrues : public ModulePass { + static char ID; // Pass identification, replacement for typeid + SVECoalescePTrues() : ModulePass(ID) { + initializeSVECoalescePTruesPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + +private: + bool coalescePTrues(BasicBlock &BB, + SmallSetVector &PTrues); + bool optimizeFunctions(SmallSetVector &Functions); +}; +} // end anonymous namespace + +void SVECoalescePTrues::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.setPreservesCFG(); +} + +char SVECoalescePTrues::ID = 0; +static const char *name = "SVE coalesce ptrues"; +INITIALIZE_PASS_BEGIN(SVECoalescePTrues, DEBUG_TYPE, name, false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); +INITIALIZE_PASS_END(SVECoalescePTrues, DEBUG_TYPE, name, false, false) + +namespace llvm { +ModulePass *createSVECoalescePTruesPass() { return new SVECoalescePTrues(); } +} // namespace llvm + +/// Checks if a ptrue intrinsic call is promoted. A 'promoted' ptrue is +/// defined as a ptrue intrinsic call which is converted to a wider +/// type via a sequence of SVE reinterpret intrinsics. The act of widening a +/// ptrue will introduce zeroing. For example: +/// +/// %1 = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) +/// %2 = call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %1) +/// %3 = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %2) +/// +/// %1 is promoted, because it is converted: +/// +/// => => +/// +/// via a sequence of SVE reinterpret intrinsics. +bool isPTruePromoted(IntrinsicInst *PTrue) { + // Find all users of this intrinsic that are calls to convert-to-svbool + // reinterpret intrinsics. + SmallVector ConvertToUses; + for (User *User : PTrue->users()) { + auto *IntrUser = dyn_cast(User); + if (IntrUser && IntrUser->getIntrinsicID() == + Intrinsic::aarch64_sve_convert_to_svbool) { + ConvertToUses.push_back(IntrUser); + } + } + + // If no such calls were found, this is ptrue is not promoted. + if (ConvertToUses.empty()) + return false; + + // Otherwise, try to find users of the convert-to-svbool intrinsics that are + // calls to the convert-from-svbool intrinsic, and would result in some lanes + // being zeroed. + const auto *PTrueVTy = cast(PTrue->getType()); + for (IntrinsicInst *ConvertToUse : ConvertToUses) { + for (User *User : ConvertToUse->users()) { + auto *IntrUser = dyn_cast(User); + if (IntrUser && IntrUser->getIntrinsicID() == + Intrinsic::aarch64_sve_convert_from_svbool) { + const auto *IntrUserVTy = cast(IntrUser->getType()); + + // Would some lanes become zeroed by the conversion? + if (IntrUserVTy->getElementCount().getKnownMinValue() > + PTrueVTy->getElementCount().getKnownMinValue()) + // This is a promoted ptrue. + return true; + } + } + } + + // If no matching calls were found, this is not a promoted ptrue. + return false; +} + +/// Attempts to coalesce ptrues in a basic block. +bool SVECoalescePTrues::coalescePTrues( + BasicBlock &BB, SmallSetVector &PTrues) { + if (PTrues.size() <= 1) + return false; + + // Find the ptrue with the most lanes. + auto *WidestPTrue = *std::max_element( + PTrues.begin(), PTrues.end(), [](auto *PTrue1, auto *PTrue2) { + auto *PTrue1VTy = cast(PTrue1->getType()); + auto *PTrue2VTy = cast(PTrue2->getType()); + return PTrue1VTy->getElementCount().getKnownMinValue() < + PTrue2VTy->getElementCount().getKnownMinValue(); + }); + + // Remove the widest ptrue, as well as any promoted ptrues, leaving behind + // only the ptrues to be coalesced. + PTrues.remove(WidestPTrue); + PTrues.remove_if([&](auto *PTrue) { return isPTruePromoted(PTrue); }); + + // Hoist WidestPTrue to the start of the basic block. It is always safe to do + // this, since ptrue intrinsic calls are guaranteed to have no predecessors. + WidestPTrue->moveBefore(BB, BB.getFirstInsertionPt()); + + LLVMContext &Ctx = BB.getContext(); + IRBuilder<> Builder(Ctx); + Builder.SetInsertPoint(&BB, ++WidestPTrue->getIterator()); + + auto *WidestPTrueVTy = cast(WidestPTrue->getType()); + auto *ConvertToSVBool = + Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_to_svbool, + {WidestPTrueVTy}, {WidestPTrue}); + + for (auto *PTrue : PTrues) { + auto *PTrueVTy = cast(PTrue->getType()); + + Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator()); + auto *ConvertFromSVBool = + Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool, + {PTrueVTy}, {ConvertToSVBool}); + PTrue->replaceAllUsesWith(ConvertFromSVBool); + PTrue->eraseFromParent(); + } + + return true; +} + +bool SVECoalescePTrues::optimizeFunctions( + SmallSetVector &Functions) { + bool Changed = false; + for (auto *F : Functions) { + DominatorTree *DT = &getAnalysis(*F).getDomTree(); + + // Traverse the DT with an rpo walk so we see defs before uses, allowing + // simplification to be done incrementally. + BasicBlock *Root = DT->getRoot(); + ReversePostOrderTraversal RPOT(Root); + for (auto &BB : RPOT) { + SmallSetVector PTrues; + + // For each basic block, collect the used ptrues and try to coalesce them. + for (Instruction &I : *BB) { + if (I.use_empty()) + continue; + + constexpr unsigned SV_ALL = 31; + auto *IntrI = dyn_cast(&I); + if (IntrI && IntrI->getIntrinsicID() == Intrinsic::aarch64_sve_ptrue && + cast(IntrI->getOperand(0))->getZExtValue() == SV_ALL) { + PTrues.insert(IntrI); + } + } + + Changed |= coalescePTrues(*BB, PTrues); + } + } + + return Changed; +} + +bool SVECoalescePTrues::runOnModule(Module &M) { + bool Changed = false; + SmallSetVector Functions; + + // Check for SVE intrinsic declarations first, and store the function where + // they are used so that we only iterate over relevant functions. + for (auto &F : M.getFunctionList()) { + if (!F.isDeclaration() || + F.getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) + continue; + + for (auto I = F.user_begin(), E = F.user_end(); I != E;) { + auto *Inst = dyn_cast(*I++); + Functions.insert(Inst->getFunction()); + } + } + + if (!Functions.empty()) + Changed |= optimizeFunctions(Functions); + + return Changed; +} diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -18,6 +18,9 @@ ; CHECK-NEXT: Pre-ISel Intrinsic Lowering ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Expand Atomic instructions +; CHECK-NEXT: SVE coalesce ptrues +; CHECK-NEXT: FunctionPass Manager +; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: SVE intrinsics optimizations ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Dominator Tree Construction @@ -209,6 +212,9 @@ ; CHECK-NEXT: Pass Arguments: -domtree ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Dominator Tree Construction +; CHECK-NEXT: Pass Arguments: -domtree +; CHECK-NEXT: FunctionPass Manager +; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Pass Arguments: -assumption-cache-tracker -targetlibinfo -domtree -loops -scalar-evolution -stack-safety-local ; CHECK-NEXT: Assumption Cache Tracker ; CHECK-NEXT: Target Library Information diff --git a/llvm/test/CodeGen/AArch64/sve-coalesce-ptrues.ll b/llvm/test/CodeGen/AArch64/sve-coalesce-ptrues.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-coalesce-ptrues.ll @@ -0,0 +1,100 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -aarch64-sve-coalesce-ptrues -mtriple=aarch64-linux-gnu -mattr=+sve < %s 2>%t | FileCheck %s +; RUN: FileCheck --check-prefix=WARN --allow-empty %s <%t + +; If this check fails please read test/CodeGen/AArch64/README for instructions on how to resolve it. +; WARN-NOT: warning + +declare @llvm.aarch64.sve.ptrue.nxv16i1(i32 immarg) +declare @llvm.aarch64.sve.ptrue.nxv2i1(i32 immarg) +declare @llvm.aarch64.sve.ptrue.nxv4i1(i32 immarg) +declare @llvm.aarch64.sve.ptrue.nxv8i1(i32 immarg) + +declare @llvm.aarch64.sve.ld1.nxv16i32(, i32*) +declare @llvm.aarch64.sve.ld1.nxv2i32(, i32*) +declare @llvm.aarch64.sve.ld1.nxv4i32(, i32*) +declare @llvm.aarch64.sve.ld1.nxv8i16(, i16*) +declare @llvm.aarch64.sve.ld1.nxv8i32(, i32*) + +declare @llvm.aarch64.sve.convert.to.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv4i1() + +define @coalesce_test_basic(i32* %addr) { +; CHECK-LABEL: @coalesce_test_basic( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = call @llvm.aarch64.sve.ld1.nxv4i32( [[TMP3]], i32* [[ADDR:%.*]]) +; CHECK-NEXT: [[TMP5:%.*]] = call @llvm.aarch64.sve.ld1.nxv8i32( [[TMP1]], i32* [[ADDR]]) +; CHECK-NEXT: ret [[TMP5]] +; + %1 = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = call @llvm.aarch64.sve.ld1.nxv4i32( %1, i32* %addr) + %3 = call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %4 = call @llvm.aarch64.sve.ld1.nxv8i32( %3, i32* %addr) + ret %4 +} + +define @coalesce_test_multiple(i32* %addr) { +; CHECK-LABEL: @coalesce_test_multiple( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 31) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv16i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[TMP2]]) +; CHECK-NEXT: [[TMP5:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv2i1( [[TMP2]]) +; CHECK-NEXT: [[TMP6:%.*]] = call @llvm.aarch64.sve.ld1.nxv2i32( [[TMP5]], i32* [[ADDR:%.*]]) +; CHECK-NEXT: [[TMP7:%.*]] = call @llvm.aarch64.sve.ld1.nxv4i32( [[TMP4]], i32* [[ADDR]]) +; CHECK-NEXT: [[TMP8:%.*]] = call @llvm.aarch64.sve.ld1.nxv8i32( [[TMP3]], i32* [[ADDR]]) +; CHECK-NEXT: [[TMP9:%.*]] = call @llvm.aarch64.sve.ld1.nxv16i32( [[TMP1]], i32* [[ADDR]]) +; CHECK-NEXT: ret [[TMP9]] +; + %1 = tail call @llvm.aarch64.sve.ptrue.nxv2i1(i32 31) + %2 = call @llvm.aarch64.sve.ld1.nxv2i32( %1, i32* %addr) + %3 = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %4 = call @llvm.aarch64.sve.ld1.nxv4i32( %3, i32* %addr) + %5 = tail call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %6 = call @llvm.aarch64.sve.ld1.nxv8i32( %5, i32* %addr) + %7 = tail call @llvm.aarch64.sve.ptrue.nxv16i1(i32 31) + %8 = call @llvm.aarch64.sve.ld1.nxv16i32( %7, i32* %addr) + ret %8 +} + +define @coalesce_test_same_size(i32* %addr) { +; CHECK-LABEL: @coalesce_test_same_size( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = call @llvm.aarch64.sve.ld1.nxv4i32( [[TMP1]], i32* [[ADDR:%.*]]) +; CHECK-NEXT: [[TMP5:%.*]] = call @llvm.aarch64.sve.ld1.nxv4i32( [[TMP3]], i32* [[ADDR]]) +; CHECK-NEXT: ret [[TMP5]] +; + %1 = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = call @llvm.aarch64.sve.ld1.nxv4i32( %1, i32* %addr) + %3 = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %4 = call @llvm.aarch64.sve.ld1.nxv4i32( %3, i32* %addr) + ret %4 +} + +define @coalesce_test_promoted_ptrue(i32* %addr1, i16* %addr2) { +; CHECK-LABEL: @coalesce_test_promoted_ptrue( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) +; CHECK-NEXT: [[TMP4:%.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP4]]) +; CHECK-NEXT: [[TMP6:%.*]] = call @llvm.aarch64.sve.ld1.nxv4i32( [[TMP3]], i32* [[ADDR1:%.*]]) +; CHECK-NEXT: [[TMP7:%.*]] = call @llvm.aarch64.sve.ld1.nxv8i16( [[TMP5]], i16* [[ADDR2:%.*]]) +; CHECK-NEXT: [[TMP8:%.*]] = call @llvm.aarch64.sve.ld1.nxv8i16( [[TMP1]], i16* [[ADDR2]]) +; CHECK-NEXT: ret [[TMP8]] +; + %1 = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %1) + %3 = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %2) + + %4 = call @llvm.aarch64.sve.ld1.nxv4i32( %1, i32* %addr1) + %5 = call @llvm.aarch64.sve.ld1.nxv8i16( %3, i16* %addr2) + + %6 = call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %7 = call @llvm.aarch64.sve.ld1.nxv8i16( %6, i16* %addr2) + ret %7 +}