diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -2195,6 +2195,9 @@ return I != SDEI.end() ? I->second.NoMerge : false; } + /// Copy extra info associated with one node to another. + void copyExtraInfo(SDNode *From, SDNode *To); + /// Return the current function's default denormal handling kind for the given /// floating point type. DenormalMode getDenormalMode(EVT VT) const { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1047,6 +1047,9 @@ // If any of the SDDbgValue nodes refer to this SDNode, invalidate // them and forget about that node. DbgInfo->erase(N); + + // Invalidate extra info. + SDEI.erase(N); } #ifndef NDEBUG @@ -10177,6 +10180,8 @@ // Preserve Debug Values transferDbgValues(FromN, To); + // Preserve extra info. + copyExtraInfo(From, To.getNode()); // Iterate over all the existing uses of From. New uses will be added // to the beginning of the use list, which we avoid visiting. @@ -10238,6 +10243,8 @@ assert((i < To->getNumValues()) && "Invalid To location"); transferDbgValues(SDValue(From, i), SDValue(To, i)); } + // Preserve extra info. + copyExtraInfo(From, To); // Iterate over just the existing users of From. See the comments in // the ReplaceAllUsesWith above. @@ -10280,9 +10287,12 @@ if (From->getNumValues() == 1) // Handle the simple case efficiently. return ReplaceAllUsesWith(SDValue(From, 0), To[0]); - // Preserve Debug Info. - for (unsigned i = 0, e = From->getNumValues(); i != e; ++i) + for (unsigned i = 0, e = From->getNumValues(); i != e; ++i) { + // Preserve Debug Info. transferDbgValues(SDValue(From, i), To[i]); + // Preserve extra info. + copyExtraInfo(From, To[i].getNode()); + } // Iterate over just the existing users of From. See the comments in // the ReplaceAllUsesWith above. @@ -10335,6 +10345,7 @@ // Preserve Debug Info. transferDbgValues(From, To); + copyExtraInfo(From.getNode(), To.getNode()); // Iterate over just the existing users of From. See the comments in // the ReplaceAllUsesWith above. @@ -10488,6 +10499,7 @@ return ReplaceAllUsesOfValueWith(*From, *To); transferDbgValues(*From, *To); + copyExtraInfo(From->getNode(), To->getNode()); // Read up all the uses and make records of them. This helps // processing new uses that are introduced during the @@ -11933,6 +11945,14 @@ } } +void SelectionDAG::copyExtraInfo(SDNode *From, SDNode *To) { + assert(From && To && "Invalid SDNode; empty source SDValue?"); + auto I = SDEI.find(From); + if (I == SDEI.end()) + return; + SDEI[To] = I->second; +} + #ifndef NDEBUG static void checkForCyclesHelper(const SDNode *N, SmallPtrSetImpl &Visited, diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp --- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp +++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp @@ -591,4 +591,33 @@ EXPECT_EQ(Op.getOpcode(), ISD::SPLAT_VECTOR); } +TEST_F(AArch64SelectionDAGTest, ReplaceAllUsesWith) { + SDLoc Loc; + EVT IntVT = EVT::getIntegerVT(Context, 8); + + SDValue N0 = DAG->getConstant(0x42, Loc, IntVT); + SDValue N1 = DAG->getRegister(0, IntVT); + // Construct node to fill arbitrary ExtraInfo. + SDValue N2 = DAG->getNode(ISD::SUB, Loc, IntVT, N0, N1); + EXPECT_FALSE(DAG->getHeapAllocSite(N2.getNode())); + EXPECT_FALSE(DAG->getNoMergeSiteInfo(N2.getNode())); + MDNode *MD = MDNode::get(Context, None); + DAG->addHeapAllocSite(N2.getNode(), MD); + DAG->addNoMergeSiteInfo(N2.getNode(), true); + EXPECT_EQ(DAG->getHeapAllocSite(N2.getNode()), MD); + EXPECT_TRUE(DAG->getNoMergeSiteInfo(N2.getNode())); + + SDValue Root = DAG->getNode(ISD::ADD, Loc, IntVT, N2, N2); + EXPECT_EQ(Root->getOperand(0)->getOpcode(), ISD::SUB); + // Create new node and check that ExtraInfo is propagated on RAUW. + SDValue New = DAG->getNode(ISD::ADD, Loc, IntVT, N1, N1); + EXPECT_FALSE(DAG->getHeapAllocSite(New.getNode())); + EXPECT_FALSE(DAG->getNoMergeSiteInfo(New.getNode())); + + DAG->ReplaceAllUsesWith(N2, New); + EXPECT_EQ(Root->getOperand(0), New); + EXPECT_EQ(DAG->getHeapAllocSite(New.getNode()), MD); + EXPECT_TRUE(DAG->getNoMergeSiteInfo(New.getNode())); +} + } // end namespace llvm