diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -534,12 +534,24 @@ MemorySSA *MSSA; std::unique_ptr MSSAUpdater; + /// Pair current values with a generation number, in order to prevent from + /// resuing values across coroutine suspension points. Value reusing typically + /// assumes that the execution within the same function happens + /// in the same thread, which is no longer true after coroutine suspension. + /// For example, pthread_self() from glibc is defined as readnone, but cannot + /// be CSE-ed across coroutine suspension points. + /// Furthermore, Value reuse generates local variables that needs to stay + /// alive acorss suspension points, leanding to coroutine frame size increase. + struct GenerationalValue { + Value *V = nullptr; + unsigned Generation = 0; + }; + using AllocatorTy = RecyclingAllocator>; - using ScopedHTType = - ScopedHashTable, - AllocatorTy>; + ScopedHashTableVal>; + using ScopedHTType = ScopedHashTable, AllocatorTy>; /// A scoped hash table of the current values of all of our simple /// scalar expressions. @@ -604,8 +616,13 @@ ScopedHashTable>; CallHTType AvailableCalls; - /// This is the current generation of the memory value. - unsigned CurrentGeneration = 0; + struct GenerationPair { + unsigned MemoryGeneration = 0; + unsigned ValueGeneration = 0; + }; + + /// This is the current generation of memory and expression values. + GenerationPair CurrentGeneration; /// Set up the EarlyCSE runner for a particular function. EarlyCSE(const DataLayout &DL, const TargetLibraryInfo &TLI, @@ -645,20 +662,21 @@ public: StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls, - unsigned cg, DomTreeNode *n, DomTreeNode::const_iterator child, + GenerationPair gp, DomTreeNode *n, + DomTreeNode::const_iterator child, DomTreeNode::const_iterator end) - : CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child), - EndIter(end), - Scopes(AvailableValues, AvailableLoads, AvailableInvariants, - AvailableCalls) - {} + : CurrentGeneration(gp), ChildGeneration(gp), Node(n), ChildIter(child), + EndIter(end), Scopes(AvailableValues, AvailableLoads, + AvailableInvariants, AvailableCalls) {} StackNode(const StackNode &) = delete; StackNode &operator=(const StackNode &) = delete; // Accessors. - unsigned currentGeneration() { return CurrentGeneration; } - unsigned childGeneration() { return ChildGeneration; } - void childGeneration(unsigned generation) { ChildGeneration = generation; } + GenerationPair currentGeneration() { return CurrentGeneration; } + GenerationPair childGeneration() { return ChildGeneration; } + void childGeneration(GenerationPair generation) { + ChildGeneration = generation; + } DomTreeNode *node() { return Node; } DomTreeNode::const_iterator childIter() { return ChildIter; } @@ -673,8 +691,8 @@ void process() { Processed = true; } private: - unsigned CurrentGeneration; - unsigned ChildGeneration; + GenerationPair CurrentGeneration; + GenerationPair ChildGeneration; DomTreeNode *Node; DomTreeNode::const_iterator ChildIter; DomTreeNode::const_iterator EndIter; @@ -830,7 +848,7 @@ const BasicBlock *BB, const BasicBlock *Pred); Value *getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, - unsigned CurrentGeneration); + unsigned CurrentMemoryGeneration); bool overridingStores(const ParseMemoryInst &Earlier, const ParseMemoryInst &Later); @@ -847,8 +865,8 @@ return TTI.getOrCreateResultFromMemIntrinsic(II, ExpectedType); } - Value *getOrCreateResultNonTargetMemIntrinsic(IntrinsicInst *II, - Type *ExpectedType) const { + static Value *getOrCreateResultNonTargetMemIntrinsic(IntrinsicInst *II, + Type *ExpectedType) { switch (II->getIntrinsicID()) { case Intrinsic::masked_load: return II; @@ -1080,7 +1098,7 @@ while (!WorkList.empty()) { Instruction *Curr = WorkList.pop_back_val(); - AvailableValues.insert(Curr, TorF); + AvailableValues.insert(Curr, {TorF, CurrentGeneration.ValueGeneration}); LLVM_DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" << Curr->getName() << "' as " << *TorF << " in " << BB->getName() << "\n"); @@ -1106,7 +1124,7 @@ } Value *EarlyCSE::getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, - unsigned CurrentGeneration) { + unsigned CurrentMemoryGeneration) { if (InVal.DefInst == nullptr) return nullptr; if (InVal.MatchingId != MemInst.getMatchingId()) @@ -1145,8 +1163,8 @@ } if (!isOperatingOnInvariantMemAt(MemInst.get(), InVal.Generation) && - !isSameMemGeneration(InVal.Generation, CurrentGeneration, InVal.DefInst, - MemInst.get())) + !isSameMemGeneration(InVal.Generation, CurrentMemoryGeneration, + InVal.DefInst, MemInst.get())) return nullptr; if (!Result) @@ -1196,7 +1214,7 @@ // just be conservative and invalidate memory if this block has multiple // predecessors. if (!BB->getSinglePredecessor()) - ++CurrentGeneration; + ++CurrentGeneration.MemoryGeneration; // If this node has a single predecessor which ends in a conditional branch, // we can infer the value of the branch condition given that we took this @@ -1249,7 +1267,8 @@ if (CondI && SimpleValue::canHandle(CondI)) { LLVM_DEBUG(dbgs() << "EarlyCSE considering assumption: " << Inst << '\n'); - AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); + AvailableValues.insert(CondI, {ConstantInt::getTrue(BB->getContext()), + CurrentGeneration.ValueGeneration}); } else LLVM_DEBUG(dbgs() << "EarlyCSE skipping assumption: " << Inst << '\n'); continue; @@ -1261,6 +1280,10 @@ continue; } + if (match(&Inst, m_Intrinsic())) { + ++CurrentGeneration.ValueGeneration; + } + // We can skip all invariant.start intrinsics since they only read memory, // and we can forward values across it. For invariant starts without // invariant ends, we can use the fact that the invariantness never ends to @@ -1282,7 +1305,7 @@ MemoryLocation::getForArgument(&cast(Inst), 1, TLI); // Don't start a scope if we already have a better one pushed if (!AvailableInvariants.count(MemLoc)) - AvailableInvariants.insert(MemLoc, CurrentGeneration); + AvailableInvariants.insert(MemLoc, CurrentGeneration.MemoryGeneration); continue; } @@ -1291,9 +1314,11 @@ dyn_cast(cast(Inst).getArgOperand(0))) { if (SimpleValue::canHandle(CondI)) { // Do we already know the actual value of this condition? - if (auto *KnownCond = AvailableValues.lookup(CondI)) { + auto P = AvailableValues.lookup(CondI); + Value *KnownCond = P.V; + if (KnownCond && P.Generation == CurrentGeneration.ValueGeneration) { // Is the condition known to be true? - if (isa(KnownCond) && + if (isa(P.V) && cast(KnownCond)->isOne()) { LLVM_DEBUG(dbgs() << "EarlyCSE removing guard: " << Inst << '\n'); @@ -1308,7 +1333,8 @@ } // The condition we're on guarding here is true for all dominated // locations. - AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); + AvailableValues.insert(CondI, {ConstantInt::getTrue(BB->getContext()), + CurrentGeneration.ValueGeneration}); } } @@ -1349,7 +1375,9 @@ // If this is a simple instruction that we can value number, process it. if (SimpleValue::canHandle(&Inst)) { // See if the instruction has an available value. If so, use it. - if (Value *V = AvailableValues.lookup(&Inst)) { + auto P = AvailableValues.lookup(&Inst); + Value *V = P.V; + if (V && P.Generation == CurrentGeneration.ValueGeneration) { LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << Inst << " to: " << *V << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { @@ -1368,7 +1396,7 @@ } // Otherwise, just remember that this value is available. - AvailableValues.insert(&Inst, &Inst); + AvailableValues.insert(&Inst, {&Inst, CurrentGeneration.ValueGeneration}); continue; } @@ -1379,7 +1407,7 @@ // operation, but we can add this load to our set of available values if (MemInst.isVolatile() || !MemInst.isUnordered()) { LastStore = nullptr; - ++CurrentGeneration; + ++CurrentGeneration.MemoryGeneration; } if (MemInst.isInvariantLoad()) { @@ -1390,7 +1418,8 @@ // restart it since we want to preserve the earliest point seen. auto MemLoc = MemoryLocation::get(&Inst); if (!AvailableInvariants.count(MemLoc)) - AvailableInvariants.insert(MemLoc, CurrentGeneration); + AvailableInvariants.insert(MemLoc, + CurrentGeneration.MemoryGeneration); } // If we have an available version of this load, and if it is the right @@ -1401,7 +1430,8 @@ // we can assume the current load loads the same value as the dominating // load. LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); - if (Value *Op = getMatchingValue(InVal, MemInst, CurrentGeneration)) { + if (Value *Op = getMatchingValue(InVal, MemInst, + CurrentGeneration.MemoryGeneration)) { LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << Inst << " to: " << *InVal.DefInst << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { @@ -1420,7 +1450,7 @@ // Otherwise, remember that we have this instruction. AvailableLoads.insert(MemInst.getPointerOperand(), - LoadValue(&Inst, CurrentGeneration, + LoadValue(&Inst, CurrentGeneration.MemoryGeneration, MemInst.getMatchingId(), MemInst.isAtomic())); LastStore = nullptr; @@ -1443,8 +1473,8 @@ // generation, replace this instruction. std::pair InVal = AvailableCalls.lookup(&Inst); if (InVal.first != nullptr && - isSameMemGeneration(InVal.second, CurrentGeneration, InVal.first, - &Inst)) { + isSameMemGeneration(InVal.second, CurrentGeneration.MemoryGeneration, + InVal.first, &Inst)) { LLVM_DEBUG(dbgs() << "EarlyCSE CSE CALL: " << Inst << " to: " << *InVal.first << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { @@ -1462,7 +1492,8 @@ } // Otherwise, remember that we have this instruction. - AvailableCalls.insert(&Inst, std::make_pair(&Inst, CurrentGeneration)); + AvailableCalls.insert( + &Inst, std::make_pair(&Inst, CurrentGeneration.MemoryGeneration)); continue; } @@ -1485,7 +1516,9 @@ if (MemInst.isValid() && MemInst.isStore()) { LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); if (InVal.DefInst && - InVal.DefInst == getMatchingValue(InVal, MemInst, CurrentGeneration)) { + InVal.DefInst == + getMatchingValue(InVal, MemInst, + CurrentGeneration.MemoryGeneration)) { // It is okay to have a LastStore to a different pointer here if MemorySSA // tells us that the load and store are from the same memory generation. // In that case, LastStore should keep its present value since we're @@ -1515,7 +1548,7 @@ // something that could modify memory. If so, our available memory values // cannot be used so bump the generation count. if (Inst.mayWriteToMemory()) { - ++CurrentGeneration; + ++CurrentGeneration.MemoryGeneration; if (MemInst.isValid() && MemInst.isStore()) { // We do a trivial form of DSE if there are two stores to the same @@ -1543,10 +1576,10 @@ // version of the pointer. It is safe to forward from volatile stores // to non-volatile loads, so we don't have to check for volatility of // the store. - AvailableLoads.insert(MemInst.getPointerOperand(), - LoadValue(&Inst, CurrentGeneration, - MemInst.getMatchingId(), - MemInst.isAtomic())); + AvailableLoads.insert( + MemInst.getPointerOperand(), + LoadValue(&Inst, CurrentGeneration.MemoryGeneration, + MemInst.getMatchingId(), MemInst.isAtomic())); // Remember that this was the last unordered store we saw for DSE. We // don't yet handle DSE on ordered or volatile stores since we don't @@ -1582,7 +1615,8 @@ CurrentGeneration, DT.getRootNode(), DT.getRootNode()->begin(), DT.getRootNode()->end())); - assert(!CurrentGeneration && "Create a new EarlyCSE instance to rerun it."); + assert(!CurrentGeneration.MemoryGeneration && + "Create a new EarlyCSE instance to rerun it."); // Process the stack. while (!nodesToProcess.empty()) { diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -2135,9 +2135,17 @@ } } - if (IntrinsicInst *IntrinsicI = dyn_cast(I)) + if (IntrinsicInst *IntrinsicI = dyn_cast(I)) { if (IntrinsicI->getIntrinsicID() == Intrinsic::assume) return processAssumeIntrinsic(IntrinsicI); + if (IntrinsicI->getIntrinsicID() == Intrinsic::coro_suspend) { + // Prevent value reusing across coroutine suspensions. Values + // may change since they could run on different threads. + VN.clear(); + LeaderTable.clear(); + return false; + } + } if (LoadInst *LI = dyn_cast(I)) { if (processLoad(LI)) diff --git a/llvm/test/Transforms/Coroutines/coro-no-value-reuse-across-suspend.ll b/llvm/test/Transforms/Coroutines/coro-no-value-reuse-across-suspend.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-no-value-reuse-across-suspend.ll @@ -0,0 +1,62 @@ +; check that no optimization pass would ever reuse values across +; suspension points in coroutines. +; RUN: opt < %s -O3 -S | FileCheck %s + +define i8* @foo(i64 %arg) { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @myAlloc(i32 %size) + %hdl = call i8* @llvm.coro.begin(token %id, i8* %alloc) + %x1 = call i64 @pthread_self() + call void @print(i64 %x1) + %y1 = mul i64 %arg, 2 + call void @print(i64 %y1) + + %0 = call i8 @llvm.coro.suspend(token none, i1 false) + switch i8 %0, label %suspend [i8 0, label %resume + i8 1, label %cleanup] +resume: + %x2 = call i64 @pthread_self() + call void @print(i64 %x2) + %y2 = mul i64 %arg, 2 + call void @print(i64 %y2) + + br label %cleanup + +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend + +suspend: + call i1 @llvm.coro.end(i8* %hdl, i1 0) + ret i8* %hdl +} + +; CHECK-LABLE: define i8* @foo() +; CHECK: entry: +; CHECK: %x1 = call i64 @pthread_self() +; CHECK: %y1 = mul i64 %arg, 2 +; CHECK: resume: +; CHECK: %x2 = call i64 @pthread_self() +; CHECK: %y2 = mul i64 %arg, 2 + +declare i8* @llvm.coro.free(token, i8*) +declare i32 @llvm.coro.size.i32() +declare i8 @llvm.coro.suspend(token, i1) +declare void @llvm.coro.resume(i8*) +declare void @llvm.coro.destroy(i8*) + +declare token @llvm.coro.id(i32, i8*, i8*, i8*) +declare i1 @llvm.coro.alloc(token) +declare i8* @llvm.coro.begin(token, i8*) +declare i1 @llvm.coro.end(i8*, i1) + +declare noalias i8* @myAlloc(i32) +declare void @free(i8*) +declare void @print(i64) + +declare dso_local i64 @pthread_self() #1 + +attributes #1 = { nounwind readnone }