diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -22,23 +22,32 @@ #include "llvm/ADT/EnumeratedArray.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/Assumptions.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/InitializePasses.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/CodeExtractor.h" using namespace llvm; @@ -92,6 +101,12 @@ cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false)); +static cl::opt UseStateMachineSPMDizationGuards( + "openmp-opt-use-state-machine-spmdization-guards", cl::ZeroOrMore, + cl::desc("Use a state-machine guarding scheme instead of fine-grained " + "barriers."), + cl::Hidden, cl::init(false)); + STATISTIC(NumOpenMPRuntimeCallsDeduplicated, "Number of OpenMP runtime calls deduplicated"); STATISTIC(NumOpenMPParallelRegionsDeleted, @@ -2998,6 +3013,326 @@ return ChangeStatus::CHANGED; } + bool createStateMachineGuards(Attributor &A) { + auto &OMPInfoCache = static_cast(A.getInfoCache()); + + // First we create two maps that record all non-guardable calls and all + // required to be guarded instructions per function. In the process we also + // record what functions are involved in the guarding at all. + DenseMap> + NonGuardedCallsPerFunctionMap; + DenseMap> + GuardedInstructionsPerFunctionMap; + for (Instruction *GuardedI : SPMDCompatibilityTracker) + GuardedInstructionsPerFunctionMap[GuardedI->getFunction()].push_back( + GuardedI); + + Function *Kernel = getAnchorScope(); + SmallPtrSet Visited; + SmallVector Worklist; + Worklist.push_back(Kernel); + while (!Worklist.empty()) { + Function *F = Worklist.pop_back_val(); + if (!Visited.insert(F).second) + continue; + auto &NonGuardedCalls = NonGuardedCallsPerFunctionMap[F]; + for (Instruction &I : instructions(F)) { + CallBase *CB = dyn_cast(&I); + if (!CB || !CB->mayHaveSideEffects()) + continue; + Function *Callee = CB->getCalledFunction(); + if (!Callee) { + NonGuardedCalls.push_back(CB); + continue; + } + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + auto *CBAA = + A.lookupAAFor(IRPosition::callsite_function(*CB), + nullptr, DepClassTy::OPTIONAL); + if (!CBAA || !CBAA->ReachedKnownParallelRegions.isValidState() || + !CBAA->ReachedUnknownParallelRegions.isValidState() || + !CBAA->ReachedKnownParallelRegions.empty() || + !CBAA->ReachedUnknownParallelRegions.empty()) { + NonGuardedCalls.push_back(CB); + Worklist.push_back(Callee); + } + continue; + } + RuntimeFunction RF = It->getSecond(); + if (RF == OMPRTL___kmpc_parallel_51) + NonGuardedCalls.push_back(CB); + } + } + + // The visited set contains all functions which contain calls that cannot be + // guarded. We visit them now one by one and emit guarding code in them. All + // of them are entered by all threads. If no guarding is necessary in them + // we execute the entire function with all threads. If guarding is necessary + // we create a state machine that should fold together in all trivial cases. + // The basic structure looks as follows: + // + // entryBB: + // if (!IsMainThread) { + // workerBB: + // barrier(); + // switch(Continuation) { + // ... + // case N: goto continuationBB_N; + // ... + // } + // } + // + // /* function code */ + // + // Continuation = N; + // store_continuation_values_in_globals(); + // barrier(); + // goto continuationBB_N; + // + // continuationBB_N: + // load_continuation_values_from_globals(); + // non_guardable_call(continuation_values...); + // if (!IsMainThread) goto workerBB: + // + // /* function code */ + + Module &M = *Kernel->getParent(); + LLVMContext &Ctx = M.getContext(); + LoopInfo *LI = nullptr; + DominatorTree *DT = nullptr; + MemorySSAUpdater *MSU = nullptr; + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + auto CreateEntryBranch = [&](BasicBlock &EntryBB) -> BasicBlock * { + auto *IP = &*EntryBB.getFirstInsertionPt(); + while (isa(IP)) + IP = IP->getNextNode(); + if (llvm::any_of(llvm::make_range(IP->getIterator(), EntryBB.end()), + [](Instruction &I) { return isa(I); })) + return nullptr; +#if 0 + if (llvm::all_of(EntryBB, [](Instruction &I) { + StoreInst *SI = dyn_cast(&I); + if (!SI || !isa( + SI->getPointerOperand()->stripInBoundsOffsets())) + return false; + return true; + })) + return nullptr; +#endif + BasicBlock *InitBB = SplitBlock(&EntryBB, IP, DT, LI, MSU, "init.bb"); + A.registerManifestAddedBasicBlock(*InitBB); + return InitBB; + }; + + auto CreateExitBB = [&](BasicBlock &EntryBB) { + BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", EntryBB.getParent(), + EntryBB.getNextNode()); + ReturnInst::Create(Ctx, ExitBB); + A.registerManifestAddedBasicBlock(*ExitBB); + return ExitBB; + }; + + auto CreateTIDCheckAndBranch = [&](BasicBlock &BB, BasicBlock &TargetBB) { + auto *EntryBBTI = BB.getTerminator(); + assert(EntryBBTI->getNumSuccessors() == 1 && "Ill-formed entry block!"); + BasicBlock *EntryBBSucc = EntryBBTI->getSuccessor(0); + const DebugLoc DL = EntryBBTI->getDebugLoc(); + EntryBBTI->eraseFromParent(); + OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(&BB, BB.end()), + DL); + OMPInfoCache.OMPBuilder.updateToLocation(Loc); + FunctionCallee HardwareTidFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_hardware_thread_id_in_block); + Value *Tid = + OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {}); + Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid); + OMPInfoCache.OMPBuilder.Builder + .CreateCondBr(TidCheck, EntryBBSucc, &TargetBB) + ->setDebugLoc(DL); + }; + + for (Function *F : Visited) { + const auto &GuardedInstructions = + GuardedInstructionsPerFunctionMap.lookup(F); + if (GuardedInstructions.empty()) + continue; + const auto &NonGuardedCalls = NonGuardedCallsPerFunctionMap.lookup(F); + + BasicBlock &EntryBB = F->getEntryBlock(); + BasicBlock *InitBB = CreateEntryBranch(EntryBB); + if (!InitBB) { + LLVM_DEBUG(dbgs() << TAG + << "State machine guard creation failed due allocas " + "in init block\n"); + return false; + } + BasicBlock *ExitBB = CreateExitBB(EntryBB); + assert(ExitBB && "Expected exit BB creation to succeed!"); + + // Handle trivial kernels first, we basically exit with all threads but + // the first. + if (NonGuardedCalls.empty()) { + assert(F == Kernel && "No non-kernel function should have guarded " + "instructions but no non-guarded calls!"); + assert(Visited.size() == 1 && + "Did not expect other functions involved in guarding while the " + "kernel is trivial."); + CreateTIDCheckAndBranch(EntryBB, *ExitBB); + // Trivial kernel, we are done. + return true; + } + + auto IsOKEntryAlloca = [&](Value &V) { + if (auto *AI = dyn_cast(&V)) + if (AI->getParent() == &EntryBB) + return !AI->isArrayAllocation(); + return false; + }; + + auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; + SmallPtrSet NonGuardedInsts; + + auto IsOKUser = [&](User *U) -> bool { + auto *I = dyn_cast(U); + return (!I || NonGuardedInsts.count(I)); + }; + auto IsOKOp = [&](Value &V) -> bool { + auto *I = dyn_cast(&V); + // Verify the value itself is good, check the users if it is an + // instruction. Note that __kmpc_alloc_shared calls won't be + // instructions at the end of this but globals. + if (!I || NonGuardedInsts.count(I) || IsOKEntryAlloca(V)) + return !I || llvm::all_of(V.users(), IsOKUser); + return OpenMPOpt::getCallIfRegularCall(V, &AllocSharedRFI); + }; + + auto CreateBarrier = [&](FunctionCallee &BarrierFn, const DebugLoc &DL, + OpenMPIRBuilder::LocationDescription &Loc) { + OMPInfoCache.OMPBuilder.updateToLocation(Loc); + auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc); + Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr); + Value *DummyTID = ConstantInt::getNullValue( + BarrierFn.getFunctionType()->getParamType(1)); + OMPInfoCache.OMPBuilder.Builder + .CreateCall(BarrierFn, {Ident, DummyTID}) + ->setDebugLoc(DL); + }; + + struct NonGuardedRegion { + Instruction *First; + Instruction *Last; + }; + SmallVector NonGuardedRegions; + + for (CallBase *NonGuardedCall : NonGuardedCalls) { + NonGuardedInsts.insert(NonGuardedCall); + BasicBlock *NonGuardedCallBB = NonGuardedCall->getParent(); + Instruction *NonGuardedBegin = NonGuardedCall; + while (NonGuardedBegin != &NonGuardedCallBB->front()) { + Instruction *PrevI = NonGuardedBegin->getPrevNode(); + if (PrevI->mayHaveSideEffects()) { + auto *SI = dyn_cast(PrevI); + if (!SI || !IsOKEntryAlloca( + *SI->getPointerOperand()->stripInBoundsOffsets())) + break; + } + NonGuardedInsts.insert(PrevI); + NonGuardedBegin = PrevI; + } + + for (Instruction *NonGuardedInst : NonGuardedInsts) { + for (Value *Op : NonGuardedInst->operands()) { + if (IsOKOp(*Op)) + continue; + LLVM_DEBUG(dbgs() + << TAG << "State machine guard creation failed due to " + << *Op << " in " << *NonGuardedInst << "\n"); + return false; + } + } + + NonGuardedRegions.push_back( + NonGuardedRegion{NonGuardedBegin, NonGuardedCall}); + } + + // We know all non-guarded regions and we know all of them are valid, now + // we create the state machine. + BasicBlock *StateMachineEntryBB = + BasicBlock::Create(Ctx, "sm.entry", F, EntryBB.getNextNode()); + A.registerManifestAddedBasicBlock(*StateMachineEntryBB); + CreateTIDCheckAndBranch(EntryBB, *StateMachineEntryBB); + + IntegerType *Int32Ty = Type::getInt32Ty(Ctx); + auto *ContinuationGlobal = new GlobalVariable( + M, Int32Ty, /* IsConstant */ false, GlobalValue::InternalLinkage, + UndefValue::get(Int32Ty), "continuation.glb", nullptr, + GlobalValue::NotThreadLocal, + static_cast(AddressSpace::Shared)); + + const DebugLoc DL = EntryBB.getTerminator()->getDebugLoc(); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(StateMachineEntryBB, StateMachineEntryBB->end()), DL); + FunctionCallee BarrierFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_barrier_simple_spmd); + CreateBarrier(BarrierFn, DL, Loc); + + auto *ContinuationVal = new LoadInst( + Int32Ty, ContinuationGlobal, "continuation.val", StateMachineEntryBB); + auto *Switch = + SwitchInst::Create(ContinuationVal, ExitBB, NonGuardedRegions.size(), + StateMachineEntryBB); + + unsigned NGRNo = 0; + ConstantInt *FinishVal = ConstantInt::get(Int32Ty, -1); + for (NonGuardedRegion &NGR : NonGuardedRegions) { + + // Prior to the non-guarded region set the continuation value and wake + // the other threads. + ConstantInt *CaseVal = ConstantInt::get(Int32Ty, NGRNo++); + new StoreInst(CaseVal, ContinuationGlobal, NGR.First); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(NGR.First->getParent(), NGR.First->getIterator()), + DL); + CreateBarrier(BarrierFn, DL, Loc); + + Instruction *End = NGR.Last->getNextNode(); + OpenMPIRBuilder::LocationDescription Loc2( + InsertPointTy(NGR.Last->getParent(), End->getIterator()), DL); + CreateBarrier(BarrierFn, DL, Loc2); + + BasicBlock *NonGuardedRegionBB = + SplitBlock(NGR.First->getParent(), NGR.First, DT, LI, MSU, + "non.guarded.region"); + BasicBlock *GuardedBeginBB = + SplitBlock(NonGuardedRegionBB, End, DT, LI, MSU, "guarded.begin"); + A.registerManifestAddedBasicBlock(*NonGuardedRegionBB); + A.registerManifestAddedBasicBlock(*GuardedBeginBB); + + ValueToValueMapTy VMap; + BasicBlock *NonGuardedRegionBBClone = CloneBasicBlock( + NonGuardedRegionBB, VMap, "non.guarded.region.worker", F); + SmallVector Clones{NonGuardedRegionBBClone}; + remapInstructionsInBlocks(Clones, VMap); + + NonGuardedRegionBBClone->getTerminator()->setSuccessor( + 0, StateMachineEntryBB); + + Switch->addCase(CaseVal, NonGuardedRegionBBClone); + + // After the non-guarded region set the finish value to ensure threads + // will eventually exit. + new StoreInst(FinishVal, ContinuationGlobal, + &*GuardedBeginBB->getFirstInsertionPt()); + } + } + + return true; + } + void createLocalGuards(Attributor &A) { auto CreateGuardedRegion = [&](Instruction *RegionStartI, Instruction *RegionEndI) { @@ -3214,7 +3549,10 @@ return false; } - createLocalGuards(A); + // Now create guarding as necessary, local guards always work, state-machine + // guards are opt-in and might fail. + if (!UseStateMachineSPMDizationGuards || !createStateMachineGuards(A)) + createLocalGuards(A); // Adjust the global exec mode flag that tells the runtime what mode this // kernel is executed in. @@ -4230,7 +4568,6 @@ DepClassTy::NONE, /* ForceUpdate */ false, /* UpdateAfterInit */ false); - registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id); registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode); registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);