diff --git a/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h b/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h --- a/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h @@ -114,7 +114,7 @@ bool doOneIteration(Function &F); // Reassociates I for better CSE. - Instruction *tryReassociate(Instruction *I); + Instruction *tryReassociate(Instruction *I, const SCEV *&OrigSCEV); // Reassociate GEP for better CSE. Instruction *tryReassociateGEP(GetElementPtrInst *GEP); diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -213,18 +213,6 @@ return Changed; } -// Explicitly list the instruction types NaryReassociate handles for now. -static bool isPotentiallyNaryReassociable(Instruction *I) { - switch (I->getOpcode()) { - case Instruction::Add: - case Instruction::GetElementPtr: - case Instruction::Mul: - return true; - default: - return false; - } -} - bool NaryReassociatePass::doOneIteration(Function &F) { bool Changed = false; SeenExprs.clear(); @@ -236,13 +224,8 @@ BasicBlock *BB = Node->getBlock(); for (auto I = BB->begin(); I != BB->end(); ++I) { Instruction *OrigI = &*I; - - if (!SE->isSCEVable(OrigI->getType()) || - !isPotentiallyNaryReassociable(OrigI)) - continue; - - const SCEV *OrigSCEV = SE->getSCEV(OrigI); - if (Instruction *NewI = tryReassociate(OrigI)) { + const SCEV *OrigSCEV = nullptr; + if (Instruction *NewI = tryReassociate(OrigI, OrigSCEV)) { Changed = true; OrigI->replaceAllUsesWith(NewI); @@ -274,9 +257,8 @@ // nary-gep.ll. if (NewSCEV != OrigSCEV) SeenExprs[OrigSCEV].push_back(WeakTrackingVH(NewI)); - } else + } else if (OrigSCEV) SeenExprs[OrigSCEV].push_back(WeakTrackingVH(OrigI)); - } } // Delete all dead instructions from 'DeadInsts'. // Please note ScalarEvolution is updated along the way. @@ -286,16 +268,26 @@ return Changed; } -Instruction *NaryReassociatePass::tryReassociate(Instruction *I) { +Instruction *NaryReassociatePass::tryReassociate(Instruction * I, + const SCEV * &OrigSCEV) { + + if (!SE->isSCEVable(I->getType())) + return nullptr; + switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: + OrigSCEV = SE->getSCEV(I); return tryReassociateBinaryOp(cast(I)); case Instruction::GetElementPtr: + OrigSCEV = SE->getSCEV(I); return tryReassociateGEP(cast(I)); default: - llvm_unreachable("should be filtered out by isPotentiallyNaryReassociable"); + return nullptr; } + + llvm_unreachable("should not be reached"); + return nullptr; } static bool isGEPFoldable(GetElementPtrInst *GEP,