Index: llvm/include/llvm/IR/Instructions.h =================================================================== --- llvm/include/llvm/IR/Instructions.h +++ llvm/include/llvm/IR/Instructions.h @@ -3439,15 +3439,24 @@ /// their prof branch_weights metadata. class SwitchInstProfUpdateWrapper { SwitchInst &SI; - Optional > Weights; - bool Changed = false; + Optional > Weights = None; + + // Sticky invalid state is needed to safely ignore operations with prof data + // in cases where SwitchInstProfUpdateWrapper is created from SwitchInst + // with inconsistent prof data. TODO: once we fix all prof data + // inconsistencies we can turn invalid state to assertions. + enum { + Invalid, + Initialized, + Changed + } State = Invalid; protected: static MDNode *getProfBranchWeightsMD(const SwitchInst &SI); MDNode *buildProfBranchWeightsMD(); - Optional > getProfBranchWeights(); + void init(); public: using CaseWeightOpt = Optional; @@ -3455,11 +3464,10 @@ SwitchInst &operator*() { return SI; } operator SwitchInst *() { return &SI; } - SwitchInstProfUpdateWrapper(SwitchInst &SI) - : SI(SI), Weights(getProfBranchWeights()) {} + SwitchInstProfUpdateWrapper(SwitchInst &SI) : SI(SI) { init(); } ~SwitchInstProfUpdateWrapper() { - if (Changed) + if (State == Changed) SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD()); } Index: llvm/lib/IR/Instructions.cpp =================================================================== --- llvm/lib/IR/Instructions.cpp +++ llvm/lib/IR/Instructions.cpp @@ -3880,7 +3880,7 @@ } MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { - assert(Changed && "called only if metadata has changed"); + assert(State == Changed && "called only if metadata has changed"); if (!Weights) return nullptr; @@ -3897,11 +3897,17 @@ return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights); } -Optional > -SwitchInstProfUpdateWrapper::getProfBranchWeights() { +void SwitchInstProfUpdateWrapper::init() { MDNode *ProfileData = getProfBranchWeightsMD(SI); - if (!ProfileData) - return None; + if (!ProfileData) { + State = Initialized; + return; + } + + if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) { + State = Invalid; + return; + } SmallVector Weights; for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) { @@ -3909,7 +3915,8 @@ uint32_t CW = C->getValue().getZExtValue(); Weights.push_back(CW); } - return Weights; + State = Initialized; + this->Weights = std::move(Weights); } SwitchInst::CaseIt @@ -3917,7 +3924,7 @@ if (Weights) { assert(SI.getNumSuccessors() == Weights->size() && "num of prof branch_weights must accord with num of successors"); - Changed = true; + State = Changed; // Copy the last case to the place of the removed one and shrink. // This is tightly coupled with the way SwitchInst::removeCase() removes // the cases in SwitchInst::removeCase(CaseIt). @@ -3932,12 +3939,15 @@ SwitchInstProfUpdateWrapper::CaseWeightOpt W) { SI.addCase(OnVal, Dest); + if (State == Invalid) + return; + if (!Weights && W && *W) { - Changed = true; + State = Changed; Weights = SmallVector(SI.getNumSuccessors(), 0); Weights.getValue()[SI.getNumSuccessors() - 1] = *W; } else if (Weights) { - Changed = true; + State = Changed; Weights.getValue().push_back(W ? *W : 0); } if (Weights) @@ -3948,10 +3958,11 @@ SymbolTableList::iterator SwitchInstProfUpdateWrapper::eraseFromParent() { // Instruction is erased. Mark as unchanged to not touch it in the destructor. - Changed = false; - - if (Weights) - Weights->resize(0); + if (State != Invalid) { + State = Initialized; + if (Weights) + Weights->resize(0); + } return SI.eraseFromParent(); } @@ -3964,7 +3975,7 @@ void SwitchInstProfUpdateWrapper::setSuccessorWeight( unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) { - if (!W) + if (!W || State == Invalid) return; if (!Weights && *W) @@ -3973,7 +3984,7 @@ if (Weights) { auto &OldW = Weights.getValue()[idx]; if (*W != OldW) { - Changed = true; + State = Changed; OldW = *W; } } @@ -3983,9 +3994,10 @@ SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI, unsigned idx) { if (MDNode *ProfileData = getProfBranchWeightsMD(SI)) - return mdconst::extract(ProfileData->getOperand(idx + 1)) - ->getValue() - .getZExtValue(); + if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1) + return mdconst::extract(ProfileData->getOperand(idx + 1)) + ->getValue() + .getZExtValue(); return None; } Index: llvm/unittests/IR/InstructionsTest.cpp =================================================================== --- llvm/unittests/IR/InstructionsTest.cpp +++ llvm/unittests/IR/InstructionsTest.cpp @@ -753,6 +753,85 @@ EXPECT_EQ(BB1.get(), Handle.getCaseSuccessor()); } +TEST(InstructionsTest, SwitchInstProfUpdateWrapper) { + LLVMContext C; + + std::unique_ptr BB1, BB2, BB3; + BB1.reset(BasicBlock::Create(C)); + BB2.reset(BasicBlock::Create(C)); + BB3.reset(BasicBlock::Create(C)); + + // We create block 0 after the others so that it gets destroyed first and + // clears the uses of the other basic blocks. + std::unique_ptr BB0(BasicBlock::Create(C)); + + auto *Int32Ty = Type::getInt32Ty(C); + + SwitchInst *SI = + SwitchInst::Create(UndefValue::get(Int32Ty), BB0.get(), 4, BB0.get()); + SI->addCase(ConstantInt::get(Int32Ty, 1), BB1.get()); + SI->addCase(ConstantInt::get(Int32Ty, 2), BB2.get()); + SI->setMetadata(LLVMContext::MD_prof, + MDBuilder(C).createBranchWeights({ 9, 1, 22 })); + + { + SwitchInstProfUpdateWrapper SIW(*SI); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 9u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 1u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + SIW.setSuccessorWeight(0, 99u); + SIW.setSuccessorWeight(1, 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + } + + { // Create another wrapper and check that the data persist. + SwitchInstProfUpdateWrapper SIW(*SI); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + } + + // Make prof data invalid by adding one extra weight. + SI->setMetadata(LLVMContext::MD_prof, MDBuilder(C).createBranchWeights( + { 99, 11, 22, 33 })); // extra + { // Invalid prof data makes wrapper act as if there were no prof data. + SwitchInstProfUpdateWrapper SIW(*SI); + ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue()); + ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue()); + ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue()); + SIW.addCase(ConstantInt::get(Int32Ty, 3), BB3.get(), 39); + ASSERT_FALSE(SIW.getSuccessorWeight(3).hasValue()); // did not add weight 39 + } + + { // With added 3rd case the prof data become consistent with num of cases. + SwitchInstProfUpdateWrapper SIW(*SI); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + EXPECT_EQ(*SIW.getSuccessorWeight(3), 33u); + } + + // Make prof data invalid by removing one extra weight. + SI->setMetadata(LLVMContext::MD_prof, + MDBuilder(C).createBranchWeights({ 99, 11, 22 })); // shorter + { // Invalid prof data makes wrapper act as if there were no prof data. + SwitchInstProfUpdateWrapper SIW(*SI); + ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue()); + ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue()); + ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue()); + SIW.removeCase(SwitchInst::CaseIt(SI, 2)); + } + + { // With removed 3rd case the prof data become consistent with num of cases. + SwitchInstProfUpdateWrapper SIW(*SI); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + } +} + TEST(InstructionsTest, CommuteShuffleMask) { SmallVector Indices({-1, 0, 7}); ShuffleVectorInst::commuteShuffleMask(Indices, 4);