diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -139,7 +139,7 @@ std::vector getStructLatticeValueFor(Value *V) const; - void removeLatticeValueFor(Value *V); + void replaceAndRemoveLatticeValueFor(Value *From, Value *To); /// Invalidate the Lattice Value of \p Call and its users after specializing /// the call. Then recompute it. @@ -193,7 +193,6 @@ void visitCall(CallInst &I); bool simplifyInstsInBlock(BasicBlock &BB, - SmallPtrSetImpl &InsertedValues, Statistic &InstRemovedStat, Statistic &InstReplacedStat); diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp --- a/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/llvm/lib/Transforms/IPO/SCCP.cpp @@ -207,7 +207,6 @@ MadeChanges |= ReplacedPointerArg; } - SmallPtrSet InsertedValues; for (BasicBlock &BB : F) { if (!Solver.isBlockExecutable(&BB)) { LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB); @@ -221,7 +220,7 @@ } MadeChanges |= Solver.simplifyInstsInBlock( - BB, InsertedValues, NumInstRemoved, NumInstReplaced); + BB, NumInstRemoved, NumInstReplaced); } DomTreeUpdater DTU = IsFuncSpecEnabled && Specializer.isClonedFunction(&F) diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -92,7 +92,6 @@ // delete their contents now. Note that we cannot actually delete the blocks, // as we cannot modify the CFG of the function. - SmallPtrSet InsertedValues; SmallVector BlocksToErase; for (BasicBlock &BB : F) { if (!Solver.isBlockExecutable(&BB)) { @@ -103,8 +102,8 @@ continue; } - MadeChanges |= Solver.simplifyInstsInBlock(BB, InsertedValues, - NumInstRemoved, NumInstReplaced); + MadeChanges |= + Solver.simplifyInstsInBlock(BB, NumInstRemoved, NumInstReplaced); } // Remove unreachable blocks and non-feasible edges. diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -104,15 +104,14 @@ /// Try to use \p Inst's value range from \p Solver to infer the NUW flag. static bool refineInstruction(SCCPSolver &Solver, - const SmallPtrSetImpl &InsertedValues, Instruction &Inst) { if (!isa(Inst)) return false; - auto GetRange = [&Solver, &InsertedValues](Value *Op) { + auto GetRange = [&Solver](Value *Op) { if (auto *Const = dyn_cast(Op)) return ConstantRange(Const->getValue()); - if (isa(Op) || InsertedValues.contains(Op)) { + if (isa(Op)) { unsigned Bitwidth = Op->getType()->getScalarSizeInBits(); return ConstantRange::getFull(Bitwidth); } @@ -145,7 +144,6 @@ /// Try to replace signed instructions with their unsigned equivalent. static bool replaceSignedInst(SCCPSolver &Solver, - SmallPtrSetImpl &InsertedValues, Instruction &Inst) { // Determine if a signed value is known to be >= 0. auto isNonNegative = [&Solver](Value *V) { @@ -167,7 +165,7 @@ case Instruction::SExt: { // If the source value is not negative, this is a zext. Value *Op0 = Inst.getOperand(0); - if (InsertedValues.count(Op0) || !isNonNegative(Op0)) + if (!isNonNegative(Op0)) return false; NewInst = new ZExtInst(Op0, Inst.getType(), "", &Inst); break; @@ -175,7 +173,7 @@ case Instruction::AShr: { // If the shifted value is not negative, this is a logical shift right. Value *Op0 = Inst.getOperand(0); - if (InsertedValues.count(Op0) || !isNonNegative(Op0)) + if (!isNonNegative(Op0)) return false; NewInst = BinaryOperator::CreateLShr(Op0, Inst.getOperand(1), "", &Inst); break; @@ -184,8 +182,7 @@ case Instruction::SRem: { // If both operands are not negative, this is the same as udiv/urem. Value *Op0 = Inst.getOperand(0), *Op1 = Inst.getOperand(1); - if (InsertedValues.count(Op0) || InsertedValues.count(Op1) || - !isNonNegative(Op0) || !isNonNegative(Op1)) + if (!isNonNegative(Op0) || !isNonNegative(Op1)) return false; auto NewOpcode = Inst.getOpcode() == Instruction::SDiv ? Instruction::UDiv : Instruction::URem; @@ -199,15 +196,13 @@ // Wire up the new instruction and update state. assert(NewInst && "Expected replacement instruction"); NewInst->takeName(&Inst); - InsertedValues.insert(NewInst); Inst.replaceAllUsesWith(NewInst); - Solver.removeLatticeValueFor(&Inst); + Solver.replaceAndRemoveLatticeValueFor(&Inst, NewInst); Inst.eraseFromParent(); return true; } bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB, - SmallPtrSetImpl &InsertedValues, Statistic &InstRemovedStat, Statistic &InstReplacedStat) { bool MadeChanges = false; @@ -220,10 +215,10 @@ MadeChanges = true; ++InstRemovedStat; - } else if (replaceSignedInst(*this, InsertedValues, Inst)) { + } else if (replaceSignedInst(*this, Inst)) { MadeChanges = true; ++InstReplacedStat; - } else if (refineInstruction(*this, InsertedValues, Inst)) { + } else if (refineInstruction(*this, Inst)) { MadeChanges = true; } } @@ -734,7 +729,12 @@ return StructValues; } - void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + void replaceAndRemoveLatticeValueFor(Value *From, Value *To) { + assert(ValueState.count(From) && "From is not existed in ValueState"); + assert(!ValueState.count(To) && "To is alreadly existed in ValueState"); + ValueState.insert(std::make_pair(To, ValueState[From])); + ValueState.erase(From); + } /// Invalidate the Lattice Value of \p Call and its users after specializing /// the call. Then recompute it. @@ -2013,8 +2013,8 @@ return Visitor->getStructLatticeValueFor(V); } -void SCCPSolver::removeLatticeValueFor(Value *V) { - return Visitor->removeLatticeValueFor(V); +void SCCPSolver::replaceAndRemoveLatticeValueFor(Value *From, Value *To) { + return Visitor->replaceAndRemoveLatticeValueFor(From, To); } void SCCPSolver::resetLatticeValueFor(CallBase *Call) { diff --git a/llvm/test/Transforms/SCCP/add-nuw-nsw-flags.ll b/llvm/test/Transforms/SCCP/add-nuw-nsw-flags.ll --- a/llvm/test/Transforms/SCCP/add-nuw-nsw-flags.ll +++ b/llvm/test/Transforms/SCCP/add-nuw-nsw-flags.ll @@ -125,9 +125,9 @@ ; CHECK-NEXT: br i1 [[CMP]], label [[THEN:%.*]], label [[ELSE:%.*]] ; CHECK: then: ; CHECK-NEXT: [[SEXT:%.*]] = zext i8 [[A]] to i16 -; CHECK-NEXT: [[ADD_1:%.*]] = add i16 [[SEXT]], 1 -; CHECK-NEXT: [[ADD_2:%.*]] = add i16 [[SEXT]], -128 -; CHECK-NEXT: [[ADD_3:%.*]] = add i16 [[SEXT]], -127 +; CHECK-NEXT: [[ADD_1:%.*]] = add nuw nsw i16 [[SEXT]], 1 +; CHECK-NEXT: [[ADD_2:%.*]] = add nuw nsw i16 [[SEXT]], -128 +; CHECK-NEXT: [[ADD_3:%.*]] = add nsw i16 [[SEXT]], -127 ; CHECK-NEXT: [[RES_1:%.*]] = xor i16 [[ADD_1]], [[ADD_2]] ; CHECK-NEXT: [[RES_2:%.*]] = xor i16 [[RES_1]], [[ADD_3]] ; CHECK-NEXT: ret i16 [[RES_2]] @@ -222,7 +222,7 @@ ; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[COND4]] to i16 ; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[ELSE:%.*]] ; CHECK: then: -; CHECK-NEXT: [[ADD:%.*]] = add i16 1, [[CONV]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i16 1, [[CONV]] ; CHECK-NEXT: ret i16 [[ADD]] ; CHECK: else: ; CHECK-NEXT: ret i16 0 diff --git a/llvm/test/Transforms/SCCP/ip-ranges-sext.ll b/llvm/test/Transforms/SCCP/ip-ranges-sext.ll --- a/llvm/test/Transforms/SCCP/ip-ranges-sext.ll +++ b/llvm/test/Transforms/SCCP/ip-ranges-sext.ll @@ -127,7 +127,7 @@ ; CHECK-LABEL: @test7( ; CHECK-NEXT: [[P:%.*]] = and i16 [[X:%.*]], 15 ; CHECK-NEXT: [[EXT_1:%.*]] = zext i16 [[P]] to i32 -; CHECK-NEXT: [[EXT_2:%.*]] = sext i32 [[EXT_1]] to i64 +; CHECK-NEXT: [[EXT_2:%.*]] = zext i32 [[EXT_1]] to i64 ; CHECK-NEXT: ret i64 [[EXT_2]] ; %p = and i16 %x, 15