diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -15,6 +15,7 @@ #include "llvm/Transforms/IPO/OpenMPOpt.h" #include "llvm/ADT/EnumeratedArray.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" @@ -78,6 +79,36 @@ namespace { +struct AAExecutionDomain + : public StateWrapper { + using Base = StateWrapper; + AAExecutionDomain(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + + /// Create an abstract attribute view for the position \p IRP. + static AAExecutionDomain &createForPosition(const IRPosition &IRP, + Attributor &A); + + /// See AbstractAttribute::getName(). + const std::string getName() const override { return "AAExecutionDomain"; } + + /// See AbstractAttribute::getIdAddr(). + const char *getIdAddr() const override { return &ID; } + + /// Check if an instruction is executed by a single thread. + virtual bool isSingleThreadExecution(const Instruction &) const = 0; + + virtual bool isSingleThreadExecution(const BasicBlock &) const = 0; + + /// This function should return true if the type of the \p AA is + /// AAExecutionDomain. + static bool classof(const AbstractAttribute *AA) { + return (AA->getIdAddr() == &ID); + } + + /// Unique ID (due to the unique address) + static const char ID; +}; + struct AAICVTracker; /// OpenMP specific information. For now, stores RFIs and ICVs also needed for @@ -506,6 +537,8 @@ << OMPInfoCache.ModuleSlice.size() << " functions\n"); if (IsModulePass) { + Changed |= runAttributor(); + if (remarksEnabled()) analysisGlobalization(); } else { @@ -1606,6 +1639,11 @@ GetterRFI.foreachUse(SCC, CreateAA); } + + for (auto &F : M) { + if (!F.isDeclaration()) + A.getOrCreateAAFor(IRPosition::function(F)); + } } }; @@ -2244,9 +2282,136 @@ return Changed; } }; + +struct AAExecutionDomainFunction : public AAExecutionDomain { + AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) + : AAExecutionDomain(IRP, A) {} + + const std::string getAsStr() const override { + return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + + "/" + std::to_string(NumBBs) + " BBs thread 0 only."; + } + + /// See AbstractAttribute::trackStatistics(). + void trackStatistics() const override {} + + void initialize(Attributor &A) override { + Function *F = getAnchorScope(); + for (const auto &BB : *F) + SingleThreadedBBs.insert(&BB); + NumBBs = SingleThreadedBBs.size(); + } + + ChangeStatus manifest(Attributor &A) override { + LLVM_DEBUG({ + for (const BasicBlock *BB : SingleThreadedBBs) + dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " + << BB->getName() << " is executed by a single thread.\n"; + }); + return ChangeStatus::UNCHANGED; + } + + ChangeStatus updateImpl(Attributor &A) override; + + /// Check if an instruction is executed by a single thread. + bool isSingleThreadExecution(const Instruction &I) const override { + return isSingleThreadExecution(*I.getParent()); + } + + bool isSingleThreadExecution(const BasicBlock &BB) const override { + return SingleThreadedBBs.contains(&BB); + } + + /// Set of basic blocks that are executed by a single thread. + DenseSet SingleThreadedBBs; + + /// Total number of basic blocks in this function. + long unsigned NumBBs; +}; + +ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { + auto &OMPInfoCache = static_cast(A.getInfoCache()); + Function *F = getAnchorScope(); + ReversePostOrderTraversal RPOT(F); + auto NumSingleThreadedBBs = SingleThreadedBBs.size(); + + bool AllCallSitesKnown; + auto PredForCallSite = [&](AbstractCallSite ACS) { + const auto &ExecutionDomainAA = A.getAAFor( + *this, IRPosition::function(*ACS.getInstruction()->getFunction()), + DepClassTy::REQUIRED); + return ExecutionDomainAA.isSingleThreadExecution(*ACS.getInstruction()); + }; + + if (!A.checkForAllCallSites(PredForCallSite, *this, + /* RequiresAllCallSites */ true, + AllCallSitesKnown)) + SingleThreadedBBs.erase(&F->getEntryBlock()); + + // Check if the edge into the successor block compares a thread-id function to + // a constant zero. + // TODO: Use AAValueSimplify to simplify and propogate constants. + // TODO: Check more than a single use for thread ID's. + auto IsSingleThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { + if (!Edge || !Edge->isConditional()) + return false; + if (Edge->getSuccessor(0) != SuccessorBB) + return false; + + auto *Cmp = dyn_cast(Edge->getCondition()); + if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) + return false; + + ConstantInt *C = dyn_cast(Cmp->getOperand(1)); + if (!C || !C->isZero()) + return false; + + if (auto *CB = dyn_cast(Cmp->getOperand(0))) { + RuntimeFunction ThreadNumRuntimeIDs[] = {OMPRTL_omp_get_thread_num, + OMPRTL___kmpc_master, + OMPRTL___kmpc_global_thread_num}; + + for (const auto ThreadNumRuntimeID : ThreadNumRuntimeIDs) { + auto &RFI = OMPInfoCache.RFIs[ThreadNumRuntimeID]; + if (CB->getCalledFunction() == RFI.Declaration) + return true; + } + } + + return false; + }; + + // Merge all the predecessor states into the current basic block. A basic + // block is executed by a single thread if all of its predecessors are. + auto MergePredecessorStates = [&](BasicBlock *BB) { + if (pred_begin(BB) == pred_end(BB)) + return SingleThreadedBBs.contains(BB); + + bool IsSingleThreaded = true; + for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); + PredBB != PredEndBB; ++PredBB) { + if (!IsSingleThreadOnly(dyn_cast((*PredBB)->getTerminator()), + BB)) + IsSingleThreaded &= SingleThreadedBBs.contains(*PredBB); + } + + return IsSingleThreaded; + }; + + for (auto *BB : RPOT) { + if (!MergePredecessorStates(BB)) + SingleThreadedBBs.erase(BB); + } + + return (NumSingleThreadedBBs == SingleThreadedBBs.size()) + ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; +} + } // namespace const char AAICVTracker::ID = 0; +const char AAExecutionDomain::ID = 0; AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, Attributor &A) { @@ -2274,6 +2439,27 @@ return *AA; } +AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP, + Attributor &A) { + AAExecutionDomainFunction *AA = nullptr; + switch (IRP.getPositionKind()) { + case IRPosition::IRP_INVALID: + case IRPosition::IRP_FLOAT: + case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_CALL_SITE_ARGUMENT: + case IRPosition::IRP_RETURNED: + case IRPosition::IRP_CALL_SITE_RETURNED: + case IRPosition::IRP_CALL_SITE: + llvm_unreachable( + "AAExecutionDomain can only be created for function position!"); + case IRPosition::IRP_FUNCTION: + AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A); + break; + } + + return *AA; +} + PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { if (!containsOpenMP(M, OMPInModule)) return PreservedAnalyses::all(); diff --git a/llvm/test/Transforms/OpenMP/single_threaded_execution.ll b/llvm/test/Transforms/OpenMP/single_threaded_execution.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/OpenMP/single_threaded_execution.ll @@ -0,0 +1,43 @@ +; RUN: opt -passes=openmp-opt-cgscc -debug-only=openmp-opt -disable-output < %s 2>&1 | FileCheck %s +; ModuleID = 'single_threaded_exeuction.c' + +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +@.str = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 +@0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 +@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @0, i32 0, i32 0) }, align 8 + +; CHECK: [openmp-opt] Basic block @bar entry is executed by a single thread. +; Function Attrs: noinline nounwind uwtable +define internal void @bar() { +entry: + ret void +} + +; CHECK-NOT: [openmp-opt] Basic block @foo entry is executed by a single thread. +; CHECK: [openmp-opt] Basic block @foo if.then is executed by a single thread. +; CHECK-NOT: [openmp-opt] Basic block @foo if.end is executed by a single thread. +; Function Attrs: noinline nounwind uwtable +define dso_local void @foo() { +entry: + %call = call i32 @omp_get_thread_num() + %cmp = icmp eq i32 %call, 0 + br i1 %cmp, label %if.then, label %if.end + +if.then: + call void @bar() + br label %if.end + +if.end: + ret void +} + +declare dso_local i32 @omp_get_thread_num() + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 13.0.0"} +!2 = !{!3} +!3 = !{i64 2, i64 -1, i64 -1, i1 true}