diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -56,6 +56,12 @@ return OS; } +VPValue::VPValue(const unsigned char SC, Value *UV, VPDef *Def) + : SubclassID(SC), UnderlyingVal(UV), Def(Def) { + if (Def) + Def->addDefinedValue(this); +} + void VPValue::print(raw_ostream &OS, VPSlotTracker &SlotTracker) const { if (const VPInstruction *Instr = dyn_cast(this)) Instr->print(OS, SlotTracker); diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -22,6 +22,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TinyPtrVector.h" #include "llvm/ADT/iterator_range.h" namespace llvm { @@ -31,7 +32,9 @@ class Value; class VPSlotTracker; class VPUser; +class VPValue; class VPRecipeBase; +class VPDef; // This is the base class of the VPlan Def/Use graph, used for modeling the data // flow into, within and out of the VPlan. VPValues can stand for live-ins @@ -44,6 +47,7 @@ friend class VPInterleavedAccessInfo; friend class VPSlotTracker; friend class VPRecipeBase; + friend class VPDef; const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). @@ -53,8 +57,9 @@ // Hold the underlying Value, if any, attached to this VPValue. Value *UnderlyingVal; - VPValue(const unsigned char SC, Value *UV = nullptr) - : SubclassID(SC), UnderlyingVal(UV) {} + VPDef *Def = nullptr; + + VPValue(const unsigned char SC, Value *UV = nullptr, VPDef *Def = nullptr); // DESIGN PRINCIPLE: Access to the underlying IR must be strictly limited to // the front-end and back-end of VPlan so that the middle-end is as @@ -80,7 +85,8 @@ /// type identification. enum { VPValueSC, VPInstructionSC, VPMemoryInstructionSC }; - VPValue(Value *UV = nullptr) : VPValue(VPValueSC, UV) {} + VPValue(Value *UV = nullptr, VPDef *Def = nullptr) + : VPValue(VPValueSC, UV, Def) {} VPValue(const VPValue &) = delete; VPValue &operator=(const VPValue &) = delete; @@ -145,6 +151,8 @@ } void replaceAllUsesWith(VPValue *New); + + VPDef *getDef() { return Def; } }; typedef DenseMap Value2VPValueTy; @@ -212,6 +220,39 @@ /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPRecipeBase *Recipe); }; + +class VPDef : public VPUser { + TinyPtrVector DefinedValues; + +public: + VPDef(ArrayRef Operands) : VPUser(Operands) {} + + template + VPDef(iterator_range Operands) : VPUser(Operands) {} + + virtual ~VPDef() { + for (VPValue *D : DefinedValues) { + D->Def = nullptr; + delete D; + } + } + + void addDefinedValue(VPValue *V) { DefinedValues.push_back(V); } + + VPValue *getVPValue(unsigned i = 0) { + assert(DefinedValues[i]); + return DefinedValues[i]; + } + const VPValue *getVPValue(unsigned i = 0) const { + assert(DefinedValues[i]); + return DefinedValues[i]; + } + + ArrayRef defined_values() { return DefinedValues; } + + unsigned getNumDefinedValues() const { return DefinedValues.size(); } +}; + class VPlan; class VPBasicBlock; class VPRegionBlock; diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp @@ -521,5 +521,45 @@ delete Load; } +struct VPMultiDef : public VPDef { + SmallVector Defs; + + VPMultiDef() : VPDef({}) { + new VPValue(nullptr, this); + new VPValue(nullptr, this); + } +}; + +TEST(VPMultiValueTest, traverseUseLists) { + // Check that the def-use chains of a multi-def can be traversed in both + // directions. + + // Create a multi-value def which defines 2 values. + VPMultiDef MultiDef; + + VPInstruction I1(1, {MultiDef.getVPValue(0), MultiDef.getVPValue(1)}); + VPInstruction I2(2, {MultiDef.getVPValue(0)}); + VPInstruction I3(3, {MultiDef.getVPValue(1)}); + + SmallVector MultiDefV0Users(MultiDef.getVPValue(0)->user_begin(), + MultiDef.getVPValue(0)->user_end()); + EXPECT_EQ(2u, MultiDefV0Users.size()); + EXPECT_EQ(&I1, MultiDefV0Users[0]); + EXPECT_EQ(&I2, MultiDefV0Users[1]); + + SmallVector MultiDefV1Users(MultiDef.getVPValue(1)->user_begin(), + MultiDef.getVPValue(1)->user_end()); + EXPECT_EQ(2u, MultiDefV1Users.size()); + EXPECT_EQ(&I1, MultiDefV1Users[0]); + EXPECT_EQ(&I3, MultiDefV1Users[1]); + + // Now check that we can get the right VPDef for each multi-value + // handed out. + EXPECT_EQ(&MultiDef, I1.getOperand(0)->getDef()); + EXPECT_EQ(&MultiDef, I1.getOperand(1)->getDef()); + EXPECT_EQ(&MultiDef, I2.getOperand(0)->getDef()); + EXPECT_EQ(&MultiDef, I3.getOperand(0)->getDef()); +} + } // namespace } // namespace llvm