diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -121,8 +121,8 @@ VPValue *createNaryOp(unsigned Opcode, ArrayRef Operands, Instruction *Inst = nullptr) { VPInstruction *NewVPInst = createInstruction(Opcode, Operands); - NewVPInst->setUnderlyingValue(Inst); - return NewVPInst; + NewVPInst->setUnderlyingInstr(Inst); + return NewVPInst->getVPValue(); } VPValue *createNaryOp(unsigned Opcode, std::initializer_list Operands, @@ -131,15 +131,17 @@ } VPValue *createNot(VPValue *Operand) { - return createInstruction(VPInstruction::Not, {Operand}); + return createInstruction(VPInstruction::Not, {Operand})->getVPValue(); } VPValue *createAnd(VPValue *LHS, VPValue *RHS) { - return createInstruction(Instruction::BinaryOps::And, {LHS, RHS}); + return createInstruction(Instruction::BinaryOps::And, {LHS, RHS}) + ->getVPValue(); } VPValue *createOr(VPValue *LHS, VPValue *RHS) { - return createInstruction(Instruction::BinaryOps::Or, {LHS, RHS}); + return createInstruction(Instruction::BinaryOps::Or, {LHS, RHS}) + ->getVPValue(); } //===--------------------------------------------------------------------===// diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7797,14 +7797,22 @@ if (auto Recipe = RecipeBuilder.tryToCreateWidenRecipe(Instr, Range, Plan)) { - // Check if the recipe can be converted to a VPValue. We need the extra - // down-casting step until VPRecipeBase inherits from VPValue. - VPValue *MaybeVPValue = Recipe->toVPValue(); - if (!Instr->getType()->isVoidTy() && MaybeVPValue) { - if (NeedDef.contains(Instr)) - Plan->addOrReplaceVPValue(Instr, MaybeVPValue); + auto RegisterVPValue = [&Plan, &NeedDef](Value *V, VPValue *Def) { + if (NeedDef.contains(V)) + Plan->addOrReplaceVPValue(V, Def); else - Plan->addVPValue(Instr, MaybeVPValue); + Plan->addVPValue(V, Def); + }; + + if (Recipe->getNumDefinedValues() > 1) { + for (auto *Def : Recipe->defined_values()) { + auto *UV = Def->getUnderlyingValue(); + RegisterVPValue(UV, Def); + } + } else if (Recipe->getNumDefinedValues() != 0) { + assert(Recipe->getNumDefinedValues() == 1 && + "all multi-defs should be handled earlier"); + RegisterVPValue(&I, Recipe->getVPValue()); } RecipeBuilder.setRecipe(Instr, Recipe); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -659,11 +659,6 @@ /// \returns an iterator pointing to the element after the erased one iplist::iterator eraseFromParent(); - /// Returns a pointer to a VPValue, if the recipe inherits from VPValue or - /// nullptr otherwise. - VPValue *toVPValue(); - const VPValue *toVPValue() const; - /// Returns the underlying instruction, if the recipe is a VPValue or nullptr /// otherwise. Instruction *getUnderlyingInstr() { @@ -695,7 +690,7 @@ /// While as any Recipe it may generate a sequence of IR instructions when /// executed, these instructions would always form a single-def expression as /// the VPInstruction is also a single def-use vertex. -class VPInstruction : public VPValue, public VPRecipeBase { +class VPInstruction : public VPRecipeBase { friend class VPlanSlp; public: @@ -717,16 +712,19 @@ void generateInstruction(VPTransformState &State, unsigned Part); protected: - void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } public: + void setUnderlyingInstr(Instruction *I) { + getVPValue()->setUnderlyingValue(I); + } VPInstruction(unsigned Opcode, ArrayRef Operands) - : VPValue(VPValue::VPInstructionSC), - VPRecipeBase(VPRecipeBase::VPInstructionSC, Operands), Opcode(Opcode) {} + : VPRecipeBase(VPRecipeBase::VPInstructionSC, Operands), Opcode(Opcode) { + new VPValue(nullptr, this); + } VPInstruction(unsigned Opcode, ArrayRef Operands) - : VPValue(VPValue::VPInstructionSC), - VPRecipeBase(VPRecipeBase::VPInstructionSC, {}), Opcode(Opcode) { + : VPRecipeBase(VPRecipeBase::VPInstructionSC, {}), Opcode(Opcode) { + new VPValue(nullptr, this); for (auto *I : Operands) addOperand(I->getVPValue()); } @@ -734,11 +732,6 @@ VPInstruction(unsigned Opcode, std::initializer_list Operands) : VPInstruction(Opcode, ArrayRef(Operands)) {} - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPInstructionSC; - } - VPInstruction *clone() const { SmallVector Operands(operands()); return new VPInstruction(Opcode, Operands); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -49,28 +49,27 @@ #define DEBUG_TYPE "vplan" raw_ostream &llvm::operator<<(raw_ostream &OS, const VPValue &V) { - const VPInstruction *Instr = dyn_cast(&V); + const VPInstruction *Instr = dyn_cast(V.getDef()); VPSlotTracker SlotTracker( (Instr && Instr->getParent()) ? Instr->getParent()->getPlan() : nullptr); V.print(OS, SlotTracker); return OS; } -VPValue::VPValue(const unsigned char SC, Value *UV, VPDef *Def) - : SubclassID(SC), UnderlyingVal(UV), Def(Def) { +VPValue::VPValue(Value *UV, VPDef *Def) : UnderlyingVal(UV), Def(Def) { if (Def) Def->addDefinedValue(this); } void VPValue::print(raw_ostream &OS, VPSlotTracker &SlotTracker) const { - if (const VPInstruction *Instr = dyn_cast(this)) + if (const VPInstruction *Instr = dyn_cast_or_null(getDef())) Instr->print(OS, SlotTracker); else printAsOperand(OS, SlotTracker); } void VPValue::dump() const { - const VPInstruction *Instr = dyn_cast(this); + const VPInstruction *Instr = dyn_cast_or_null(getDef()); VPSlotTracker SlotTracker( (Instr && Instr->getParent()) ? Instr->getParent()->getPlan() : nullptr); print(dbgs(), SlotTracker); @@ -83,35 +82,6 @@ dbgs() << "\n"; } -VPValue *VPRecipeBase::toVPValue() { - if (auto *V = dyn_cast(this)) - return V; - if (auto *V = dyn_cast(this)) - if (V->getNumDefinedValues() == 1) - return V->getVPValue(); - if (auto *V = dyn_cast(this)) - return V->getVPValue(); - if (auto *V = dyn_cast(this)) - return V->getVPValue(); - if (auto *V = dyn_cast(this)) - return V->getVPValue(); - return nullptr; -} - -const VPValue *VPRecipeBase::toVPValue() const { - if (auto *V = dyn_cast(this)) - return V; - if (auto *V = dyn_cast(this)) - if (V->getNumDefinedValues() == 1) - return V->getVPValue(); - if (auto *V = dyn_cast(this)) - return V->getVPValue(); - if (auto *V = dyn_cast(this)) - return V->getVPValue(); - if (auto *V = dyn_cast(this)) - return V->getVPValue(); - return nullptr; -} // Get the top-most entry block of \p Start. This is the entry block of the // containing VPlan. This function is templated to support both const and non-const blocks @@ -336,12 +306,8 @@ void VPBasicBlock::dropAllReferences(VPValue *NewValue) { for (VPRecipeBase &R : Recipes) { - if (VPValue *Def = R.toVPValue()) + for (auto *Def : R.defined_values()) Def->replaceAllUsesWith(NewValue); - else if (auto *IR = dyn_cast(&R)) { - for (auto *Def : IR->defined_values()) - Def->replaceAllUsesWith(NewValue); - } for (unsigned I = 0, E = R.getNumOperands(); I != E; I++) R.setOperand(I, NewValue); @@ -437,7 +403,7 @@ Value *A = State.get(getOperand(0), Part); Value *B = State.get(getOperand(1), Part); Value *V = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B); - State.set(this, V, Part); + State.set(getVPValue(), V, Part); return; } @@ -445,14 +411,14 @@ case VPInstruction::Not: { Value *A = State.get(getOperand(0), Part); Value *V = Builder.CreateNot(A); - State.set(this, V, Part); + State.set(getVPValue(), V, Part); break; } case VPInstruction::ICmpULE: { Value *IV = State.get(getOperand(0), Part); Value *TC = State.get(getOperand(1), Part); Value *V = Builder.CreateICmpULE(IV, TC); - State.set(this, V, Part); + State.set(getVPValue(), V, Part); break; } case Instruction::Select: { @@ -460,7 +426,7 @@ Value *Op1 = State.get(getOperand(1), Part); Value *Op2 = State.get(getOperand(2), Part); Value *V = Builder.CreateSelect(Cond, Op1, Op2); - State.set(this, V, Part); + State.set(getVPValue(), V, Part); break; } case VPInstruction::ActiveLaneMask: { @@ -474,7 +440,7 @@ Instruction *Call = Builder.CreateIntrinsic( Intrinsic::get_active_lane_mask, {PredTy, ScalarTC->getType()}, {VIVElem0, ScalarTC}, nullptr, "active.lane.mask"); - State.set(this, Call, Part); + State.set(getVPValue(), Call, Part); break; } default: @@ -501,7 +467,7 @@ void VPInstruction::print(raw_ostream &O, VPSlotTracker &SlotTracker) const { if (hasResult()) { - printAsOperand(O, SlotTracker); + getVPValue()->printAsOperand(O, SlotTracker); O << " = "; } @@ -750,8 +716,8 @@ const VPValue *Pred = BasicBlock->getPredicate(); if (Pred) { OS << " +\n" << Indent << " \"BlockPredicate: "; - if (const VPInstruction *PredI = dyn_cast(Pred)) { - PredI->printAsOperand(OS, SlotTracker); + if (const VPInstruction *PredI = Pred->getAs()) { + PredI->getVPValue()->printAsOperand(OS, SlotTracker); OS << " (" << DOT::EscapeString(PredI->getParent()->getName()) << ")\\l\""; } else @@ -768,8 +734,8 @@ const VPValue *CBV = BasicBlock->getCondBit(); if (CBV) { OS << " +\n" << Indent << " \"CondBit: "; - if (const VPInstruction *CBI = dyn_cast(CBV)) { - CBI->printAsOperand(OS, SlotTracker); + if (const VPInstruction *CBI = CBV->getAs()) { + CBI->getVPValue()->printAsOperand(OS, SlotTracker); OS << " (" << DOT::EscapeString(CBI->getParent()->getName()) << ")\\l\""; } else { CBV->printAsOperand(OS, SlotTracker); @@ -1008,7 +974,8 @@ for (VPRecipeBase &VPI : *VPBB) { assert(isa(&VPI) && "Can only handle VPInstructions"); auto *VPInst = cast(&VPI); - auto *Inst = cast(VPInst->getUnderlyingValue()); + auto *Inst = + cast(VPInst->getVPValue()->getUnderlyingValue()); auto *IG = IAI.getInterleaveGroup(Inst); if (!IG) continue; @@ -1044,7 +1011,7 @@ const Value *UV = V->getUnderlyingValue(); if (UV) return; - const auto *VPI = dyn_cast(V); + const auto *VPI = V->getAs(); if (VPI && !VPI->hasResult()) return; @@ -1067,7 +1034,7 @@ void VPSlotTracker::assignSlots(const VPBasicBlock *VPBB) { for (const VPRecipeBase &Recipe : *VPBB) { if (const auto *VPI = dyn_cast(&Recipe)) - assignSlot(VPI); + assignSlot(VPI->getVPValue()); else if (const auto *VPIV = dyn_cast(&Recipe)) assignSlot(VPIV->getVPValue()); } @@ -1079,7 +1046,7 @@ assignSlot(V); for (auto &E : Plan.Value2VPValue) - if (!isa(E.second)) + if (!E.second->isa()) assignSlot(E.second); for (const VPValue *V : Plan.VPCBVs) diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -94,8 +94,9 @@ for (auto *Phi : PhisToFix) { assert(IRDef2VPValue.count(Phi) && "Missing VPInstruction for PHINode."); VPValue *VPVal = IRDef2VPValue[Phi]; - assert(isa(VPVal) && "Expected VPInstruction for phi node."); - auto *VPPhi = cast(VPVal); + assert(isa(VPVal->getDef()) && + "Expected VPInstruction for phi node."); + auto *VPPhi = VPVal->getAs(); assert(VPPhi->getNumOperands() == 0 && "Expected VPInstruction with no operands."); @@ -215,8 +216,9 @@ // Phi node's operands may have not been visited at this point. We create // an empty VPInstruction that we will fix once the whole plain CFG has // been built. - NewVPInst = cast(VPIRBuilder.createNaryOp( - Inst->getOpcode(), {} /*No operands*/, Inst)); + NewVPInst = + VPIRBuilder.createNaryOp(Inst->getOpcode(), {} /*No operands*/, Inst) + ->getAs(); PhisToFix.push_back(Phi); } else { // Translate LLVM-IR operands into VPValue operands and set them in the @@ -227,11 +229,11 @@ // Build VPInstruction for any arbitraty Instruction without specific // representation in VPlan. - NewVPInst = cast( - VPIRBuilder.createNaryOp(Inst->getOpcode(), VPOperands, Inst)); + NewVPInst = VPIRBuilder.createNaryOp(Inst->getOpcode(), VPOperands, Inst) + ->getAs(); } - IRDef2VPValue[Inst] = NewVPInst; + IRDef2VPValue[Inst] = NewVPInst->getVPValue(); } } diff --git a/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp b/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp @@ -56,11 +56,11 @@ void VPlanSlp::addCombined(ArrayRef Operands, VPInstruction *New) { if (all_of(Operands, [](VPValue *V) { - return cast(V)->getUnderlyingInstr(); + return V->getAs()->getUnderlyingInstr(); })) { unsigned BundleSize = 0; for (VPValue *V : Operands) { - Type *T = cast(V)->getUnderlyingInstr()->getType(); + Type *T = V->getAs()->getUnderlyingInstr()->getType(); assert(!T->isVectorTy() && "Only scalar types supported for now"); BundleSize += T->getScalarSizeInBits(); } @@ -76,8 +76,8 @@ bool VPlanSlp::areVectorizable(ArrayRef Operands) const { // Currently we only support VPInstructions. if (!all_of(Operands, [](VPValue *Op) { - return Op && isa(Op) && - cast(Op)->getUnderlyingInstr(); + return Op && Op->isa() && + Op->getAs()->getUnderlyingInstr(); })) { LLVM_DEBUG(dbgs() << "VPSLP: not all operands are VPInstructions\n"); return false; @@ -88,11 +88,11 @@ // instructions. // FIXME: Deal with non-primitive types. const Instruction *OriginalInstr = - cast(Operands[0])->getUnderlyingInstr(); + Operands[0]->getAs()->getUnderlyingInstr(); unsigned Opcode = OriginalInstr->getOpcode(); unsigned Width = OriginalInstr->getType()->getPrimitiveSizeInBits(); if (!all_of(Operands, [Opcode, Width](VPValue *Op) { - const Instruction *I = cast(Op)->getUnderlyingInstr(); + const Instruction *I = Op->getAs()->getUnderlyingInstr(); return I->getOpcode() == Opcode && I->getType()->getPrimitiveSizeInBits() == Width; })) { @@ -102,7 +102,7 @@ // For now, all operands must be defined in the same BB. if (any_of(Operands, [this](VPValue *Op) { - return cast(Op)->getParent() != &this->BB; + return Op->getAs()->getParent() != &this->BB; })) { LLVM_DEBUG(dbgs() << "VPSLP: operands in different BBs\n"); return false; @@ -120,11 +120,11 @@ // interfere with any of the loads in the bundle if (Opcode == Instruction::Load) { unsigned LoadsSeen = 0; - VPBasicBlock *Parent = cast(Operands[0])->getParent(); + VPBasicBlock *Parent = Operands[0]->getAs()->getParent(); for (auto &I : *Parent) { auto *VPI = cast(&I); if (VPI->getOpcode() == Instruction::Load && - llvm::is_contained(Operands, VPI)) + llvm::is_contained(Operands, VPI->getVPValue())) LoadsSeen++; if (LoadsSeen == Operands.size()) @@ -137,7 +137,8 @@ } if (!all_of(Operands, [](VPValue *Op) { - return cast(cast(Op)->getUnderlyingInstr()) + return cast( + Op->getAs()->getUnderlyingInstr()) ->isSimple(); })) { LLVM_DEBUG(dbgs() << "VPSLP: only simple loads are supported.\n"); @@ -147,7 +148,8 @@ if (Opcode == Instruction::Store) if (!all_of(Operands, [](VPValue *Op) { - return cast(cast(Op)->getUnderlyingInstr()) + return cast( + Op->getAs()->getUnderlyingInstr()) ->isSimple(); })) { LLVM_DEBUG(dbgs() << "VPSLP: only simple stores are supported.\n"); @@ -162,7 +164,7 @@ SmallVector Operands; for (VPValue *V : Values) { // Currently we only support VPInstructions. - auto *U = cast(V); + auto *U = V->getAs(); Operands.push_back(U->getOperand(OperandIndex)); } return Operands; @@ -170,13 +172,13 @@ static bool areCommutative(ArrayRef Values) { return Instruction::isCommutative( - cast(Values[0])->getOpcode()); + Values[0]->getAs()->getOpcode()); } static SmallVector, 4> getOperands(ArrayRef Values) { SmallVector, 4> Result; - auto *VPI = cast(Values[0]); + auto *VPI = Values[0]->getAs(); switch (VPI->getOpcode()) { case Instruction::Load: @@ -195,9 +197,9 @@ /// Returns the opcode of Values or ~0 if they do not all agree. static Optional getOpcode(ArrayRef Values) { - unsigned Opcode = cast(Values[0])->getOpcode(); + unsigned Opcode = Values[0]->getAs()->getOpcode(); if (any_of(Values, [Opcode](VPValue *V) { - return cast(V)->getOpcode() != Opcode; + return V->getAs()->getOpcode() != Opcode; })) return None; return {Opcode}; @@ -223,8 +225,8 @@ /// Traverses and compares operands of V1 and V2 to MaxLevel. static unsigned getLAScore(VPValue *V1, VPValue *V2, unsigned MaxLevel, VPInterleavedAccessInfo &IAI) { - auto *I1 = dyn_cast(V1); - auto *I2 = dyn_cast(V2); + auto *I1 = V1->getAs(); + auto *I2 = V2->getAs(); // Currently we only support VPInstructions. if (!I1 || !I2) return 0; @@ -250,13 +252,13 @@ SmallVector BestCandidates; LLVM_DEBUG(dbgs() << " Candidates for " - << *cast(Last)->getUnderlyingInstr() << " "); + << *Last->getAs()->getUnderlyingInstr() + << " "); for (auto *Candidate : Candidates) { - auto *LastI = cast(Last); - auto *CandidateI = cast(Candidate); + auto *LastI = Last->getAs(); + auto *CandidateI = Candidate->getAs(); if (areConsecutiveOrMatch(LastI, CandidateI, IAI)) { - LLVM_DEBUG(dbgs() << *cast(Candidate)->getUnderlyingInstr() - << " "); + LLVM_DEBUG(dbgs() << *CandidateI->getUnderlyingInstr() << " "); BestCandidates.push_back(Candidate); } } @@ -292,7 +294,7 @@ break; } LLVM_DEBUG(dbgs() << "Found best " - << *cast(Best)->getUnderlyingInstr() + << *Best->getAs()->getUnderlyingInstr() << "\n"); Candidates.erase(Best); @@ -309,7 +311,7 @@ for (auto &Operands : MultiNodeOps) { FinalOrder.push_back({Operands.first, {Operands.second[0]}}); - if (cast(Operands.second[0])->getOpcode() == + if (Operands.second[0]->getAs()->getOpcode() == Instruction::Load) Mode.push_back(OpMode::Load); else @@ -322,8 +324,9 @@ LLVM_DEBUG(dbgs() << " Candidates "); for (auto Ops : MultiNodeOps) { LLVM_DEBUG( - dbgs() << *cast(Ops.second[Lane])->getUnderlyingInstr() - << " "); + dbgs() + << *Ops.second[Lane]->getAs()->getUnderlyingInstr() + << " "); Candidates.insert(Ops.second[Lane]); } LLVM_DEBUG(dbgs() << "\n"); @@ -338,9 +341,11 @@ getBest(Mode[Op], Last, Candidates, IAI); if (Res.second) FinalOrder[Op].second.push_back(Res.second); - else + else { // TODO: handle this case - FinalOrder[Op].second.push_back(markFailed()); + markFailed(); + FinalOrder[Op].second.push_back(nullptr); + } } } @@ -350,7 +355,7 @@ void VPlanSlp::dumpBundle(ArrayRef Values) { dbgs() << " Ops: "; for (auto Op : Values) { - if (auto *VPInstr = cast_or_null(Op)) + if (auto *VPInstr = Op->getAs()) if (auto *Instr = VPInstr->getUnderlyingInstr()) { dbgs() << *Instr << " | "; continue; @@ -407,13 +412,13 @@ auto OperandsOpcode = getOpcode(Operands); if (OperandsOpcode && OperandsOpcode == getOpcode(Values)) { LLVM_DEBUG(dbgs() << " Same opcode, continue building\n"); - CombinedOperands.push_back(buildGraph(Operands)); + CombinedOperands.push_back(buildGraph(Operands)->getVPValue()); } else { LLVM_DEBUG(dbgs() << " Adding multinode Ops\n"); // Create dummy VPInstruction, which will we replace later by the // re-ordered operand. VPInstruction *Op = new VPInstruction(0, {}); - CombinedOperands.push_back(Op); + CombinedOperands.push_back(Op->getVPValue()); MultiNodeOps.emplace_back(Op, Operands); } } @@ -427,10 +432,11 @@ MultiNodeOps.clear(); for (auto &Ops : FinalOrder) { VPInstruction *NewOp = buildGraph(Ops.second); - Ops.first->replaceAllUsesWith(NewOp); + VPValue *NewVal = NewOp ? NewOp->getVPValue() : nullptr; + Ops.first->getVPValue()->replaceAllUsesWith(NewVal); for (unsigned i = 0; i < CombinedOperands.size(); i++) - if (CombinedOperands[i] == Ops.first) - CombinedOperands[i] = NewOp; + if (CombinedOperands[i] == Ops.first->getVPValue()) + CombinedOperands[i] = NewVal; delete Ops.first; Ops.first = NewOp; } @@ -440,10 +446,13 @@ LLVM_DEBUG(dbgs() << " NonCommuntative\n"); if (ValuesOpcode == Instruction::Load) for (VPValue *V : Values) - CombinedOperands.push_back(cast(V)->getOperand(0)); + CombinedOperands.push_back( + cast(V->getDef())->getOperand(0)); else - for (auto &Operands : getOperands(Values)) - CombinedOperands.push_back(buildGraph(Operands)); + for (auto &Operands : getOperands(Values)) { + auto *Res = buildGraph(Operands); + CombinedOperands.push_back(Res ? Res->getVPValue() : nullptr); + } } unsigned Opcode; @@ -464,10 +473,11 @@ assert(CombinedOperands.size() > 0 && "Need more some operands"); auto *VPI = new VPInstruction(Opcode, CombinedOperands); - VPI->setUnderlyingInstr(cast(Values[0])->getUnderlyingInstr()); + VPI->setUnderlyingInstr( + Values[0]->getAs()->getUnderlyingInstr()); LLVM_DEBUG(dbgs() << "Create VPInstruction "; VPI->print(dbgs()); - cast(Values[0])->print(dbgs()); dbgs() << "\n"); + Values[0]->getAs()->print(dbgs()); dbgs() << "\n"); addCombined(Values, VPI); return VPI; } diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -47,9 +47,9 @@ VPRecipeBase *Ingredient = &*I++; // Can only handle VPInstructions. VPInstruction *VPInst = cast(Ingredient); - Instruction *Inst = cast(VPInst->getUnderlyingValue()); + Instruction *Inst = VPInst->getUnderlyingInstr(); if (DeadInstructions.count(Inst)) { - VPInst->replaceAllUsesWith(&DummyValue); + VPInst->getVPValue()->replaceAllUsesWith(&DummyValue); Ingredient->eraseFromParent(); continue; } @@ -79,7 +79,7 @@ new VPWidenRecipe(*Inst, Plan->mapToVPValues(Inst->operands())); NewRecipe->insertBefore(Ingredient); - VPInst->replaceAllUsesWith(&DummyValue); + VPInst->getVPValue()->replaceAllUsesWith(&DummyValue); Ingredient->eraseFromParent(); } } diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -51,8 +51,6 @@ friend class VPInstruction; friend class VPDef; - const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). - SmallVector Users; protected: @@ -61,8 +59,13 @@ VPDef *Def = nullptr; - VPValue(const unsigned char SC, Value *UV = nullptr, VPDef *Def = nullptr); + // Set \p Val as the underlying Value of this VPValue. + void setUnderlyingValue(Value *Val) { + assert(!UnderlyingVal && "Underlying Value is already set."); + UnderlyingVal = Val; + } +public: // DESIGN PRINCIPLE: Access to the underlying IR must be strictly limited to // the front-end and back-end of VPlan so that the middle-end is as // independent as possible of the underlying IR. We grant access to the @@ -74,40 +77,15 @@ Value *getUnderlyingValue() { return UnderlyingVal; } const Value *getUnderlyingValue() const { return UnderlyingVal; } - // Set \p Val as the underlying Value of this VPValue. - void setUnderlyingValue(Value *Val) { - assert(!UnderlyingVal && "Underlying Value is already set."); - UnderlyingVal = Val; - } - -public: - /// An enumeration for keeping track of the concrete subclass of VPValue that - /// are actually instantiated. Values of this enumeration are kept in the - /// SubclassID field of the VPValue objects. They are used for concrete - /// type identification. - enum { - VPValueSC, - VPInstructionSC, - VPMemoryInstructionSC, - VPVWidenCallSC, - VPVWidenSelectSC, - VPVWidenGEPSC - }; - - VPValue(Value *UV = nullptr, VPDef *Def = nullptr) - : VPValue(VPValueSC, UV, Def) {} + VPValue(Value *UV = nullptr, VPDef *Def = nullptr); VPValue(const VPValue &) = delete; VPValue &operator=(const VPValue &) = delete; virtual ~VPValue() { + assert(!Def && "trying to delete a VPVAlue with a remaining linked def"); assert(Users.empty() && "trying to delete a VPValue with remaining users"); } - /// \return an ID for the concrete type of this object. - /// This is used to implement the classof checks. This should not be used - /// for any other purpose, as the values may change as LLVM evolves. - unsigned getVPValueID() const { return SubclassID; } - void printAsOperand(raw_ostream &OS, VPSlotTracker &Tracker) const; void print(raw_ostream &OS, VPSlotTracker &Tracker) const; @@ -162,6 +140,18 @@ void replaceAllUsesWith(VPValue *New); VPDef *getDef() { return Def; } + const VPDef *getDef() const { return Def; } + + template DestTy *getAs() { + return dyn_cast_or_null(getDef()); + } + template const DestTy *getAs() const { + return dyn_cast_or_null(getDef()); + } + + template bool isa() const { + return Def && llvm::isa(Def); + } }; typedef DenseMap Value2VPValueTy; @@ -206,6 +196,8 @@ return Operands[N]; } + inline VPDef *getOperandAsDef(unsigned N) { return getOperand(N)->getDef(); } + void setOperand(unsigned I, VPValue *New) { Operands[I]->removeUser(*this); Operands[I] = New; diff --git a/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp @@ -60,34 +60,34 @@ VPInstruction *Idx = dyn_cast(&*Iter++); EXPECT_EQ(Instruction::GetElementPtr, Idx->getOpcode()); EXPECT_EQ(2u, Idx->getNumOperands()); - EXPECT_EQ(Phi, Idx->getOperand(1)); + EXPECT_EQ(Phi->getVPValue(), Idx->getOperand(1)); VPInstruction *Load = dyn_cast(&*Iter++); EXPECT_EQ(Instruction::Load, Load->getOpcode()); EXPECT_EQ(1u, Load->getNumOperands()); - EXPECT_EQ(Idx, Load->getOperand(0)); + EXPECT_EQ(Idx->getVPValue(), Load->getOperand(0)); VPInstruction *Add = dyn_cast(&*Iter++); EXPECT_EQ(Instruction::Add, Add->getOpcode()); EXPECT_EQ(2u, Add->getNumOperands()); - EXPECT_EQ(Load, Add->getOperand(0)); + EXPECT_EQ(Load->getVPValue(), Add->getOperand(0)); VPInstruction *Store = dyn_cast(&*Iter++); EXPECT_EQ(Instruction::Store, Store->getOpcode()); EXPECT_EQ(2u, Store->getNumOperands()); - EXPECT_EQ(Add, Store->getOperand(0)); - EXPECT_EQ(Idx, Store->getOperand(1)); + EXPECT_EQ(Add->getVPValue(), Store->getOperand(0)); + EXPECT_EQ(Idx->getVPValue(), Store->getOperand(1)); VPInstruction *IndvarAdd = dyn_cast(&*Iter++); EXPECT_EQ(Instruction::Add, IndvarAdd->getOpcode()); EXPECT_EQ(2u, IndvarAdd->getNumOperands()); - EXPECT_EQ(Phi, IndvarAdd->getOperand(0)); + EXPECT_EQ(Phi->getVPValue(), IndvarAdd->getOperand(0)); VPInstruction *ICmp = dyn_cast(&*Iter++); EXPECT_EQ(Instruction::ICmp, ICmp->getOpcode()); EXPECT_EQ(2u, ICmp->getNumOperands()); - EXPECT_EQ(IndvarAdd, ICmp->getOperand(0)); - EXPECT_EQ(VecBB->getCondBit(), ICmp); + EXPECT_EQ(IndvarAdd->getVPValue(), ICmp->getOperand(0)); + EXPECT_EQ(VecBB->getCondBit(), ICmp->getVPValue()); // Add an external value to check we do not print the list of external values, // as this is not required with the new printing. diff --git a/llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp @@ -96,7 +96,7 @@ EXPECT_EQ(InnerAnd->getOpcode(), Instruction::And); EXPECT_EQ(InnerAnd->getOperand(0), CBV1); EXPECT_EQ(InnerAnd->getOperand(1), CBV2); - EXPECT_EQ(InnerIfPred, InnerAnd); + EXPECT_EQ(InnerIfPred, InnerAnd->getVPValue()); // Test Linearization EXPECT_EQ(InnerLoopLinSucc, OuterIf); @@ -210,19 +210,19 @@ EXPECT_EQ(Not->getOperand(0), InnerCBV); EXPECT_EQ(NotAnd->getOperand(0), OuterCBV); - EXPECT_EQ(NotAnd->getOperand(1), Not); + EXPECT_EQ(NotAnd->getOperand(1), Not->getVPValue()); - EXPECT_EQ(InnerIfTSucc->getPredicate(), And); - EXPECT_EQ(InnerIfFSucc->getPredicate(), NotAnd); + EXPECT_EQ(InnerIfTSucc->getPredicate(), And->getVPValue()); + EXPECT_EQ(InnerIfFSucc->getPredicate(), NotAnd->getVPValue()); EXPECT_EQ(TSuccSucc, FSuccSucc); EXPECT_EQ(Or->getOpcode(), Instruction::Or); - EXPECT_EQ(TSuccSucc->getPredicate(), Or); + EXPECT_EQ(TSuccSucc->getPredicate(), Or->getVPValue()); // Test operands of the Or - account for differences in predecessor block // ordering. - VPInstruction *OrOp0Inst = cast(Or->getOperand(0)); - VPInstruction *OrOp1Inst = cast(Or->getOperand(1)); + VPInstruction *OrOp0Inst = cast(Or->getOperandAsDef(0)); + VPInstruction *OrOp1Inst = cast(Or->getOperandAsDef(1)); bool ValidOrOperands = false; if (((OrOp0Inst == And) && (OrOp1Inst == NotAnd)) || diff --git a/llvm/unittests/Transforms/Vectorize/VPlanSlpTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanSlpTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanSlpTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanSlpTest.cpp @@ -104,16 +104,20 @@ VPInstruction *Store2 = cast(&*std::next(Body->begin(), 14)); VPlanSlp Slp(VPIAI, *Body); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; VPInstruction *CombinedStore = Slp.buildGraph(StoreRoot); EXPECT_EQ(64u, Slp.getWidestBundleBits()); EXPECT_EQ(VPInstruction::SLPStore, CombinedStore->getOpcode()); - auto *CombinedAdd = cast(CombinedStore->getOperand(0)); + auto *CombinedAdd = + cast(CombinedStore->getOperand(0)->getDef()); EXPECT_EQ(Instruction::Add, CombinedAdd->getOpcode()); - auto *CombinedLoadA = cast(CombinedAdd->getOperand(0)); - auto *CombinedLoadB = cast(CombinedAdd->getOperand(1)); + auto *CombinedLoadA = + cast(CombinedAdd->getOperand(0)->getDef()); + auto *CombinedLoadB = + cast(CombinedAdd->getOperand(1)->getDef()); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadA->getOpcode()); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadB->getOpcode()); @@ -178,23 +182,27 @@ auto VPIAI = getInterleavedAccessInfo(*F, LI->getLoopFor(LoopHeader), *Plan); VPlanSlp Slp(VPIAI, *Body); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; VPInstruction *CombinedStore = Slp.buildGraph(StoreRoot); EXPECT_EQ(64u, Slp.getWidestBundleBits()); EXPECT_EQ(VPInstruction::SLPStore, CombinedStore->getOpcode()); - auto *CombinedAdd = cast(CombinedStore->getOperand(0)); + auto *CombinedAdd = + cast(CombinedStore->getOperand(0)->getDef()); EXPECT_EQ(Instruction::Add, CombinedAdd->getOpcode()); - auto *CombinedLoadA = cast(CombinedAdd->getOperand(0)); - auto *CombinedLoadB = cast(CombinedAdd->getOperand(1)); + auto *CombinedLoadA = + cast(CombinedAdd->getOperand(0)->getDef()); + auto *CombinedLoadB = + cast(CombinedAdd->getOperand(1)->getDef()); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadA->getOpcode()); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadB->getOpcode()); VPInstruction *GetA = cast(&*std::next(Body->begin(), 1)); VPInstruction *GetB = cast(&*std::next(Body->begin(), 3)); - EXPECT_EQ(GetA, CombinedLoadA->getOperand(0)); - EXPECT_EQ(GetB, CombinedLoadB->getOperand(0)); + EXPECT_EQ(GetA->getVPValue(), CombinedLoadA->getOperand(0)); + EXPECT_EQ(GetB->getVPValue(), CombinedLoadB->getOperand(0)); delete CombinedStore; delete CombinedAdd; @@ -248,16 +256,19 @@ VPInstruction *Store2 = cast(&*std::next(Body->begin(), 10)); VPlanSlp Slp(VPIAI, *Body); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; VPInstruction *CombinedStore = Slp.buildGraph(StoreRoot); EXPECT_EQ(64u, Slp.getWidestBundleBits()); EXPECT_EQ(VPInstruction::SLPStore, CombinedStore->getOpcode()); - auto *CombinedAdd = cast(CombinedStore->getOperand(0)); + auto *CombinedAdd = + cast(CombinedStore->getOperand(0)->getDef()); EXPECT_EQ(Instruction::Add, CombinedAdd->getOpcode()); - auto *CombinedLoadA = cast(CombinedAdd->getOperand(0)); - EXPECT_EQ(CombinedLoadA, CombinedAdd->getOperand(1)); + auto *CombinedLoadA = + cast(CombinedAdd->getOperand(0)->getDef()); + EXPECT_EQ(CombinedLoadA->getVPValue(), CombinedAdd->getOperand(1)); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadA->getOpcode()); delete CombinedStore; @@ -311,7 +322,8 @@ VPInstruction *Store2 = cast(&*std::next(Body->begin(), 10)); VPlanSlp Slp(VPIAI, *Body); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; Slp.buildGraph(StoreRoot); EXPECT_FALSE(Slp.isCompletelySLP()); } @@ -320,24 +332,25 @@ VPBasicBlock *Body, VPInterleavedAccessInfo &&IAI) { VPlanSlp Slp(IAI, *Body); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; VPInstruction *CombinedStore = Slp.buildGraph(StoreRoot); EXPECT_TRUE(Slp.isCompletelySLP()); EXPECT_EQ(CombinedStore->getOpcode(), VPInstruction::SLPStore); VPInstruction *CombinedAdd = - cast(CombinedStore->getOperand(0)); + cast(CombinedStore->getOperand(0)->getDef()); EXPECT_EQ(CombinedAdd->getOpcode(), Instruction::Add); VPInstruction *CombinedMulAB = - cast(CombinedAdd->getOperand(0)); + cast(CombinedAdd->getOperand(0)->getDef()); VPInstruction *CombinedMulCD = - cast(CombinedAdd->getOperand(1)); + cast(CombinedAdd->getOperand(1)->getDef()); EXPECT_EQ(CombinedMulAB->getOpcode(), Instruction::Mul); VPInstruction *CombinedLoadA = - cast(CombinedMulAB->getOperand(0)); + cast(CombinedMulAB->getOperand(0)->getDef()); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadA->getOpcode()); VPInstruction *LoadvA0 = cast(&*std::next(Body->begin(), 2)); VPInstruction *LoadvA1 = cast(&*std::next(Body->begin(), 12)); @@ -345,7 +358,7 @@ EXPECT_EQ(LoadvA1->getOperand(0), CombinedLoadA->getOperand(1)); VPInstruction *CombinedLoadB = - cast(CombinedMulAB->getOperand(1)); + cast(CombinedMulAB->getOperand(1)->getDef()); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadB->getOpcode()); VPInstruction *LoadvB0 = cast(&*std::next(Body->begin(), 4)); VPInstruction *LoadvB1 = cast(&*std::next(Body->begin(), 14)); @@ -355,7 +368,7 @@ EXPECT_EQ(CombinedMulCD->getOpcode(), Instruction::Mul); VPInstruction *CombinedLoadC = - cast(CombinedMulCD->getOperand(0)); + cast(CombinedMulCD->getOperand(0)->getDef()); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadC->getOpcode()); VPInstruction *LoadvC0 = cast(&*std::next(Body->begin(), 7)); VPInstruction *LoadvC1 = cast(&*std::next(Body->begin(), 17)); @@ -363,7 +376,7 @@ EXPECT_EQ(LoadvC1->getOperand(0), CombinedLoadC->getOperand(1)); VPInstruction *CombinedLoadD = - cast(CombinedMulCD->getOperand(1)); + cast(CombinedMulCD->getOperand(1)->getDef()); EXPECT_EQ(VPInstruction::SLPLoad, CombinedLoadD->getOpcode()); VPInstruction *LoadvD0 = cast(&*std::next(Body->begin(), 9)); VPInstruction *LoadvD1 = cast(&*std::next(Body->begin(), 19)); @@ -593,7 +606,8 @@ auto VPIAI = getInterleavedAccessInfo(*F, LI->getLoopFor(LoopHeader), *Plan); VPlanSlp Slp(VPIAI, *Body); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; EXPECT_EQ(nullptr, Slp.buildGraph(StoreRoot)); // FIXME Need to select better first value for lane0. @@ -730,7 +744,8 @@ VPInstruction *Store2 = cast(&*std::next(BB2->begin(), 5)); VPlanSlp Slp(VPIAI, *BB2); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; EXPECT_EQ(nullptr, Slp.buildGraph(StoreRoot)); EXPECT_EQ(0u, Slp.getWidestBundleBits()); } @@ -793,7 +808,8 @@ VPInstruction *Store2 = cast(&*std::next(BB2->begin(), 3)); VPlanSlp Slp(VPIAI, *BB2); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; EXPECT_EQ(nullptr, Slp.buildGraph(StoreRoot)); EXPECT_EQ(0u, Slp.getWidestBundleBits()); } @@ -852,7 +868,8 @@ VPInstruction *Store2 = cast(&*std::next(Body->begin(), 14)); VPlanSlp Slp(VPIAI, *Body); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; EXPECT_EQ(nullptr, Slp.buildGraph(StoreRoot)); EXPECT_FALSE(Slp.isCompletelySLP()); } @@ -911,7 +928,8 @@ VPInstruction *Store2 = cast(&*std::next(Body->begin(), 14)); VPlanSlp Slp(VPIAI, *Body); - SmallVector StoreRoot = {Store1, Store2}; + SmallVector StoreRoot = {Store1->getVPValue(), + Store2->getVPValue()}; Slp.buildGraph(StoreRoot); EXPECT_FALSE(Slp.isCompletelySLP()); } diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp @@ -351,7 +351,7 @@ { std::string I4Dump; raw_string_ostream OS(I4Dump); - OS << *I4; + OS << *I4->getVPValue(); OS.flush(); EXPECT_EQ("vp<%2> = mul vp<%1> vp<%0>", I4Dump); }