diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -503,7 +503,7 @@ /// State to track if we are in SPMD-mode, assumed or know, and why we decided /// we cannot be. If it is assumed, then RequiresFullRuntime should also be /// false. - BooleanStateWithPtrSetVector SPMDCompatibilityTracker; + BooleanStateWithPtrSetVector SPMDCompatibilityTracker; /// The __kmpc_target_init call in this kernel, if any. If we find more than /// one we abort as the kernel is malformed. @@ -2795,6 +2795,12 @@ AAKernelInfoFunction(const IRPosition &IRP, Attributor &A) : AAKernelInfo(IRP, A) {} + SmallPtrSet GuardedInstructions; + + SmallPtrSetImpl &getGuardedInstructions() { + return GuardedInstructions; + } + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { // This is a high-level transform that might change the constant arguments @@ -2991,6 +2997,188 @@ return false; } + auto CreateGuardedRegion = [&](Instruction *RegionStartI, + Instruction *RegionEndI) { + LoopInfo *LI = nullptr; + DominatorTree *DT = nullptr; + MemorySSAUpdater *MSU = nullptr; + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + BasicBlock *ParentBB = RegionStartI->getParent(); + Function *Fn = ParentBB->getParent(); + Module &M = *Fn->getParent(); + + // Create all the blocks and logic. + // ParentBB: + // goto RegionCheckTidBB + // RegionCheckTidBB: + // Tid = __kmpc_hardware_thread_id() + // if (Tid != 0) + // goto RegionBarrierBB + // RegionStartBB: + // + // goto RegionEndBB + // RegionEndBB: + // + // goto RegionBarrierBB + // RegionBarrierBB: + // __kmpc_simple_barrier_spmd() + // // second barrier is omitted if lacking escaping values. + // + // __kmpc_simple_barrier_spmd() + // goto RegionExitBB + // RegionExitBB: + // + + BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(), + DT, LI, MSU, "region.guarded.end"); + BasicBlock *RegionBarrierBB = + SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI, + MSU, "region.barrier"); + BasicBlock *RegionExitBB = + SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(), + DT, LI, MSU, "region.exit"); + BasicBlock *RegionStartBB = + SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded"); + + assert(ParentBB->getUniqueSuccessor() == RegionStartBB && + "Expected a different CFG"); + + BasicBlock *RegionCheckTidBB = SplitBlock( + ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid"); + + // Register basic blocks with the Attributor. + A.registerManifestAddedBasicBlock(*RegionEndBB); + A.registerManifestAddedBasicBlock(*RegionBarrierBB); + A.registerManifestAddedBasicBlock(*RegionExitBB); + A.registerManifestAddedBasicBlock(*RegionStartBB); + A.registerManifestAddedBasicBlock(*RegionCheckTidBB); + + bool HasBroadcastValues = false; + // Find escaping outputs from the guarded region to outside users and + // broadcast their values to them. + for (Instruction &I : *RegionStartBB) { + SmallPtrSet OutsideUsers; + for (User *Usr : I.users()) { + Instruction &UsrI = *cast(Usr); + if (UsrI.getParent() != RegionStartBB) + OutsideUsers.insert(&UsrI); + } + + if (OutsideUsers.empty()) + continue; + + HasBroadcastValues = true; + + // Emit a global variable in shared memory to store the broadcasted + // value. + auto *SharedMem = new GlobalVariable( + M, I.getType(), /* IsConstant */ false, + GlobalValue::InternalLinkage, UndefValue::get(I.getType()), + I.getName() + ".guarded.output.alloc", nullptr, + GlobalValue::NotThreadLocal, + static_cast(AddressSpace::Shared)); + + // Emit a store instruction to update the value. + new StoreInst(&I, SharedMem, RegionEndBB->getTerminator()); + + LoadInst *LoadI = new LoadInst(I.getType(), SharedMem, + I.getName() + ".guarded.output.load", + RegionBarrierBB->getTerminator()); + + // Emit a load instruction and replace uses of the output value. + for (Instruction *UsrI : OutsideUsers) { + assert(UsrI->getParent() == RegionExitBB && + "Expected escaping users in exit region"); + UsrI->replaceUsesOfWith(&I, LoadI); + } + } + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + + // Go to tid check BB in ParentBB. + const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); + ParentBB->getTerminator()->eraseFromParent(); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(ParentBB, ParentBB->end()), DL); + OMPInfoCache.OMPBuilder.updateToLocation(Loc); + auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc); + Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr); + BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL); + + // Add check for Tid in RegionCheckTidBB + RegionCheckTidBB->getTerminator()->eraseFromParent(); + OpenMPIRBuilder::LocationDescription LocRegionCheckTid( + InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL); + OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid); + FunctionCallee HardwareTidFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_hardware_thread_id_in_block); + Value *Tid = + OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {}); + Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid); + OMPInfoCache.OMPBuilder.Builder + .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB) + ->setDebugLoc(DL); + + // First barrier for synchronization, ensures main thread has updated + // values. + FunctionCallee BarrierFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_barrier_simple_spmd); + OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy( + RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt())); + OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid}) + ->setDebugLoc(DL); + + // Second barrier ensures workers have read broadcast values. + if (HasBroadcastValues) + CallInst::Create(BarrierFn, {Ident, Tid}, "", + RegionBarrierBB->getTerminator()) + ->setDebugLoc(DL); + }; + + SmallVector, 4> GuardedRegions; + + for (Instruction *GuardedI : SPMDCompatibilityTracker) { + BasicBlock *BB = GuardedI->getParent(); + auto *CalleeAA = A.lookupAAFor( + IRPosition::function(*GuardedI->getFunction()), nullptr, + DepClassTy::NONE); + assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo"); + auto &CalleeAAFunction = *cast(CalleeAA); + // Continue if instruction is already guarded. + if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI)) + continue; + + Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr; + for (Instruction &I : *BB) { + // If instruction I needs to be guarded update the guarded region + // bounds. + if (SPMDCompatibilityTracker.contains(&I)) { + CalleeAAFunction.getGuardedInstructions().insert(&I); + if (GuardedRegionStart) + GuardedRegionEnd = &I; + else + GuardedRegionStart = GuardedRegionEnd = &I; + + continue; + } + + // Instruction I does not need guarding, store + // any region found and reset bounds. + if (GuardedRegionStart) { + GuardedRegions.push_back( + std::make_pair(GuardedRegionStart, GuardedRegionEnd)); + GuardedRegionStart = nullptr; + GuardedRegionEnd = nullptr; + } + } + } + + for (auto &GR : GuardedRegions) + CreateGuardedRegion(GR.first, GR.second); + // Adjust the global exec mode flag that tells the runtime what mode this // kernel is executed in. Function *Kernel = getAnchorScope(); @@ -3322,8 +3510,21 @@ if (llvm::all_of(Objects, [](const Value *Obj) { return isa(Obj); })) return true; + // Check for AAHeapToStack moved objects which must not be guarded. + auto &HS = A.getAAFor( + *this, IRPosition::function(*I.getFunction()), + DepClassTy::REQUIRED); + if (llvm::all_of(Objects, [&HS](const Value *Obj) { + auto *CB = dyn_cast(Obj); + if (!CB) + return false; + return HS.isAssumedHeapToStack(*CB); + })) { + return true; + } } - // For now we give up on everything but stores. + + // Insert instruction that needs guarding. SPMDCompatibilityTracker.insert(&I); return true; }; @@ -3337,6 +3538,9 @@ if (!IsKernelEntry) { updateReachingKernelEntries(A); updateParallelLevels(A); + + if (!ParallelLevels.isValidState()) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); } // Callback to check a call instruction. @@ -3554,7 +3758,8 @@ // We do not look into tasks right now, just give up. SPMDCompatibilityTracker.insert(&CB); ReachedUnknownParallelRegions.insert(&CB); - break; + indicatePessimisticFixpoint(); + return; case OMPRTL___kmpc_alloc_shared: case OMPRTL___kmpc_free_shared: // Return without setting a fixpoint, to be resolved in updateImpl. @@ -3563,7 +3768,8 @@ // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, // generally. SPMDCompatibilityTracker.insert(&CB); - break; + indicatePessimisticFixpoint(); + return; } // All other OpenMP runtime calls will not reach parallel regions so they // can be safely ignored for now. Since it is a known OpenMP runtime call we