diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -56,6 +56,12 @@ return OS; } +VPValue::VPValue(const unsigned char SC, Value *UV, VPDef *Def) + : SubclassID(SC), UnderlyingVal(UV), Def(Def) { + if (Def) + Def->addDefinedValue(this); +} + void VPValue::print(raw_ostream &OS, VPSlotTracker &SlotTracker) const { if (const VPInstruction *Instr = dyn_cast(this)) Instr->print(OS, SlotTracker); diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -22,6 +22,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TinyPtrVector.h" #include "llvm/ADT/iterator_range.h" namespace llvm { @@ -32,6 +33,7 @@ class VPSlotTracker; class VPUser; class VPRecipeBase; +class VPDef; // This is the base class of the VPlan Def/Use graph, used for modeling the data // flow into, within and out of the VPlan. VPValues can stand for live-ins @@ -44,6 +46,7 @@ friend class VPInterleavedAccessInfo; friend class VPSlotTracker; friend class VPRecipeBase; + friend class VPDef; const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). @@ -53,8 +56,9 @@ // Hold the underlying Value, if any, attached to this VPValue. Value *UnderlyingVal; - VPValue(const unsigned char SC, Value *UV = nullptr) - : SubclassID(SC), UnderlyingVal(UV) {} + VPDef *Def = nullptr; + + VPValue(const unsigned char SC, Value *UV = nullptr, VPDef *Def = nullptr); // DESIGN PRINCIPLE: Access to the underlying IR must be strictly limited to // the front-end and back-end of VPlan so that the middle-end is as @@ -80,7 +84,8 @@ /// type identification. enum { VPValueSC, VPInstructionSC, VPMemoryInstructionSC }; - VPValue(Value *UV = nullptr) : VPValue(VPValueSC, UV) {} + VPValue(Value *UV = nullptr, VPDef *Def = nullptr) + : VPValue(VPValueSC, UV, Def) {} VPValue(const VPValue &) = delete; VPValue &operator=(const VPValue &) = delete; @@ -145,6 +150,8 @@ } void replaceAllUsesWith(VPValue *New); + + VPDef *getDef() { return Def; } }; typedef DenseMap Value2VPValueTy; @@ -212,6 +219,51 @@ /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPRecipeBase *Recipe); }; + +/// This class augment a recipe with a set of VPValues defined by the recipe. It +/// allows recipes to define zero, one or multiple VPValues. +class VPDef { + friend class VPValue; + + /// The VPValues defined by the VPDef. + TinyPtrVector DefinedValues; + + /// Add \p V as a defined value by the VPDef. + void addDefinedValue(VPValue *V) { + assert(V->getDef() == this); + DefinedValues.push_back(V); + } + +public: + VPDef() {} + + virtual ~VPDef() { + for (VPValue *D : DefinedValues) { + assert(D->Def == this && + "all defined VPValues should point to the containing VPDef"); + assert(D->getNumUsers() == 0 && + "all defined VPValues should have no more users"); + D->Def = nullptr; + } + } + + /// Returns the VPValue with index \p I defined by the VPDef. + VPValue *getVPValue(unsigned I = 0) { + assert(DefinedValues[I] && "defined value must be non-null"); + return DefinedValues[I]; + } + const VPValue *getVPValue(unsigned I = 0) const { + assert(DefinedValues[I] && "defined value must be non-null"); + return DefinedValues[I]; + } + + /// Returns an ArrayRef of the defined values by the VPDef. + ArrayRef defined_values() { return DefinedValues; } + + /// Returns the number of values defined by the VPDef. + unsigned getNumDefinedValues() const { return DefinedValues.size(); } +}; + class VPlan; class VPBasicBlock; class VPRegionBlock; diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp @@ -521,5 +521,45 @@ delete Load; } +struct VPMultiDef : public VPDef { + SmallVector Defs; + + VPMultiDef() : VPDef({}) { + new VPValue(nullptr, this); + new VPValue(nullptr, this); + } +}; + +TEST(VPMultiValueTest, traverseUseLists) { + // Check that the def-use chains of a multi-def can be traversed in both + // directions. + + // Create a multi-value def which defines 2 values. + VPMultiDef MultiDef; + + VPInstruction I1(1, {MultiDef.getVPValue(0), MultiDef.getVPValue(1)}); + VPInstruction I2(2, {MultiDef.getVPValue(0)}); + VPInstruction I3(3, {MultiDef.getVPValue(1)}); + + SmallVector MultiDefV0Users(MultiDef.getVPValue(0)->user_begin(), + MultiDef.getVPValue(0)->user_end()); + EXPECT_EQ(2u, MultiDefV0Users.size()); + EXPECT_EQ(&I1, MultiDefV0Users[0]); + EXPECT_EQ(&I2, MultiDefV0Users[1]); + + SmallVector MultiDefV1Users(MultiDef.getVPValue(1)->user_begin(), + MultiDef.getVPValue(1)->user_end()); + EXPECT_EQ(2u, MultiDefV1Users.size()); + EXPECT_EQ(&I1, MultiDefV1Users[0]); + EXPECT_EQ(&I3, MultiDefV1Users[1]); + + // Now check that we can get the right VPDef for each multi-value + // handed out. + EXPECT_EQ(&MultiDef, I1.getOperand(0)->getDef()); + EXPECT_EQ(&MultiDef, I1.getOperand(1)->getDef()); + EXPECT_EQ(&MultiDef, I2.getOperand(0)->getDef()); + EXPECT_EQ(&MultiDef, I3.getOperand(0)->getDef()); +} + } // namespace } // namespace llvm