diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -62,6 +62,7 @@ #include "llvm/ADT/Hashing.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -668,8 +669,32 @@ bool runGVN(); private: + /// Helper struct return a Expression with an optional extra dependency. + struct ExprResult { + const Expression *Expr; + Value *ExtraDep; + + ~ExprResult() { assert(!ExtraDep && "unhandled ExtraDep"); } + + operator bool() const { return Expr; } + + static ExprResult none() { return {nullptr, nullptr}; } + static ExprResult some(const Expression *Expr, Value *ExtraDep = nullptr) { + return {Expr, ExtraDep}; + } + + /// Add \p User as an additional user of ExtraDep, if it is set, using \p + /// AddFn. + void addExtraUserTo(function_ref AddFn, + Instruction *User) { + if (ExtraDep && ExtraDep != User) + AddFn(ExtraDep, User); + ExtraDep = nullptr; + } + }; + // Expression handling. - const Expression *createExpression(Instruction *) const; + ExprResult createExpression(Instruction *) const; const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *, Instruction *) const; @@ -742,10 +767,9 @@ void valueNumberInstruction(Instruction *); // Symbolic evaluation. - const Expression *checkSimplificationResults(Expression *, Instruction *, - Value *) const; - const Expression *performSymbolicEvaluation(Value *, - SmallPtrSetImpl &) const; + ExprResult checkExprResults(Expression *, Instruction *, Value *) const; + ExprResult performSymbolicEvaluation(Value *, + SmallPtrSetImpl &) const; const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *, Instruction *, MemoryAccess *) const; @@ -757,7 +781,7 @@ Instruction *I, BasicBlock *PHIBlock) const; const Expression *performSymbolicAggrValueEvaluation(Instruction *) const; - const Expression *performSymbolicCmpEvaluation(Instruction *) const; + ExprResult performSymbolicCmpEvaluation(Instruction *) const; const Expression *performSymbolicPredicateInfoEvaluation(Instruction *) const; // Congruence finding. @@ -1052,19 +1076,22 @@ E->op_push_back(lookupOperandLeader(Arg2)); Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) { + Simplified.addExtraUserTo( + [this](Value *V, Value *U) { addAdditionalUsers(V, U); }, I); + return Simplified.Expr; + } return E; } // Take a Value returned by simplification of Expression E/Instruction // I, and see if it resulted in a simpler expression. If so, return // that expression. -const Expression *NewGVN::checkSimplificationResults(Expression *E, - Instruction *I, - Value *V) const { +NewGVN::ExprResult NewGVN::checkExprResults(Expression *E, Instruction *I, + Value *V) const { if (!V) - return nullptr; + return ExprResult::none(); + if (auto *C = dyn_cast(V)) { if (I) LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " @@ -1073,52 +1100,37 @@ assert(isa(E) && "We should always have had a basic expression here"); deleteExpression(E); - return createConstantExpression(C); + return ExprResult::some(createConstantExpression(C)); } else if (isa(V) || isa(V)) { if (I) LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " << " variable " << *V << "\n"); deleteExpression(E); - return createVariableExpression(V); + return ExprResult::some(createVariableExpression(V)); } CongruenceClass *CC = ValueToClass.lookup(V); if (CC) { if (CC->getLeader() && CC->getLeader() != I) { - // If we simplified to something else, we need to communicate - // that we're users of the value we simplified to. - if (I != V) { - // Don't add temporary instructions to the user lists. - if (!AllTempInstructions.count(I)) - addAdditionalUsers(V, I); - } - return createVariableOrConstant(CC->getLeader()); + return ExprResult::some(createVariableOrConstant(CC->getLeader()), V); } if (CC->getDefiningExpr()) { - // If we simplified to something else, we need to communicate - // that we're users of the value we simplified to. - if (I != V) { - // Don't add temporary instructions to the user lists. - if (!AllTempInstructions.count(I)) - addAdditionalUsers(V, I); - } - if (I) LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " << " expression " << *CC->getDefiningExpr() << "\n"); NumGVNOpsSimplified++; deleteExpression(E); - return CC->getDefiningExpr(); + return ExprResult::some(CC->getDefiningExpr(), V); } } - return nullptr; + return ExprResult::none(); } // Create a value expression from the instruction I, replacing operands with // their leaders. -const Expression *NewGVN::createExpression(Instruction *I) const { +NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const { auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands()); bool AllConstant = setBasicExpressionInfo(I, E); @@ -1149,8 +1161,8 @@ E->getOperand(1)->getType() == I->getOperand(1)->getType())); Value *V = SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } else if (isa(I)) { if (isa(E->getOperand(0)) || E->getOperand(1) == E->getOperand(2)) { @@ -1158,24 +1170,24 @@ E->getOperand(2)->getType() == I->getOperand(2)->getType()); Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1), E->getOperand(2), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } } else if (I->isBinaryOp()) { Value *V = SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } else if (auto *CI = dyn_cast(I)) { Value *V = SimplifyCastInst(CI->getOpcode(), E->getOperand(0), CI->getType(), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } else if (isa(I)) { Value *V = SimplifyGEPInst( E->getType(), ArrayRef(E->op_begin(), E->op_end()), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } else if (AllConstant) { // We don't bother trying to simplify unless all of the operands // were constant. @@ -1189,10 +1201,10 @@ C.emplace_back(cast(Arg)); if (Value *V = ConstantFoldInstOperands(I, C, DL, TLI)) - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } - return E; + return ExprResult::some(E); } const AggregateValueExpression * @@ -1778,7 +1790,7 @@ return createAggregateValueExpression(I); } -const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { +NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { assert(isa(I) && "Expected a cmp instruction."); auto *CI = cast(I); @@ -1798,14 +1810,17 @@ // of an assume. auto *CmpPI = PredInfo->getPredicateInfoFor(I); if (dyn_cast_or_null(CmpPI)) - return createConstantExpression(ConstantInt::getTrue(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getTrue(CI->getType()))); if (Op0 == Op1) { // This condition does not depend on predicates, no need to add users if (CI->isTrueWhenEqual()) - return createConstantExpression(ConstantInt::getTrue(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getTrue(CI->getType()))); else if (CI->isFalseWhenEqual()) - return createConstantExpression(ConstantInt::getFalse(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getFalse(CI->getType()))); } // NOTE: Because we are comparing both operands here and below, and using @@ -1865,15 +1880,15 @@ if (CmpInst::isImpliedTrueByMatchingCmp(BranchPredicate, OurPredicate)) { addPredicateUsers(PI, I); - return createConstantExpression( - ConstantInt::getTrue(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getTrue(CI->getType()))); } if (CmpInst::isImpliedFalseByMatchingCmp(BranchPredicate, OurPredicate)) { addPredicateUsers(PI, I); - return createConstantExpression( - ConstantInt::getFalse(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getFalse(CI->getType()))); } } else { // Just handle the ne and eq cases, where if we have the same @@ -1881,14 +1896,14 @@ if (BranchPredicate == OurPredicate) { addPredicateUsers(PI, I); // Same predicate, same ops,we know it was false, so this is false. - return createConstantExpression( - ConstantInt::getFalse(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getFalse(CI->getType()))); } else if (BranchPredicate == CmpInst::getInversePredicate(OurPredicate)) { addPredicateUsers(PI, I); // Inverse predicate, we know the other was false, so this is true. - return createConstantExpression( - ConstantInt::getTrue(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getTrue(CI->getType()))); } } } @@ -1899,9 +1914,10 @@ } // Substitute and symbolize the value before value numbering. -const Expression * +NewGVN::ExprResult NewGVN::performSymbolicEvaluation(Value *V, SmallPtrSetImpl &Visited) const { + const Expression *E = nullptr; if (auto *C = dyn_cast(V)) E = createConstantExpression(C); @@ -1937,11 +1953,11 @@ break; case Instruction::BitCast: case Instruction::AddrSpaceCast: - E = createExpression(I); + return createExpression(I); break; case Instruction::ICmp: case Instruction::FCmp: - E = performSymbolicCmpEvaluation(I); + return performSymbolicCmpEvaluation(I); break; case Instruction::FNeg: case Instruction::Add: @@ -1977,16 +1993,16 @@ case Instruction::ExtractElement: case Instruction::InsertElement: case Instruction::GetElementPtr: - E = createExpression(I); + return createExpression(I); break; case Instruction::ShuffleVector: // FIXME: Add support for shufflevector to createExpression. - return nullptr; + return ExprResult::none(); default: - return nullptr; + return ExprResult::none(); } } - return E; + return ExprResult::some(E); } // Look up a container of values/instructions in a map, and touch all the @@ -2414,9 +2430,15 @@ Value *CondEvaluated = findConditionEquivalence(Cond); if (!CondEvaluated) { if (auto *I = dyn_cast(Cond)) { - const Expression *E = createExpression(I); - if (const auto *CE = dyn_cast(E)) { + auto Res = createExpression(I); + if (const auto *CE = dyn_cast(Res.Expr)) { CondEvaluated = CE->getConstantValue(); + Res.addExtraUserTo( + [this](Value *V, Value *U) { addAdditionalUsers(V, U); }, I); + } else { + // Did not use simplification result, no need to add the extra + // dependency. + Res.ExtraDep = nullptr; } } else if (isa(Cond)) { CondEvaluated = Cond; @@ -2600,7 +2622,10 @@ TempToBlock.insert({TransInst, PredBB}); InstrDFS.insert({TransInst, IDFSNum}); - const Expression *E = performSymbolicEvaluation(TransInst, Visited); + auto Res = performSymbolicEvaluation(TransInst, Visited); + const Expression *E = Res.Expr; + Res.addExtraUserTo([this](Value *V, Value *U) { addAdditionalUsers(V, U); }, + OrigInst); InstrDFS.erase(TransInst); AllTempInstructions.erase(TransInst); TempToBlock.erase(TransInst); @@ -3027,7 +3052,11 @@ const Expression *Symbolized = nullptr; SmallPtrSet Visited; if (DebugCounter::shouldExecute(VNCounter)) { - Symbolized = performSymbolicEvaluation(I, Visited); + auto Res = performSymbolicEvaluation(I, Visited); + Symbolized = Res.Expr; + Res.addExtraUserTo( + [this](Value *V, Value *U) { addAdditionalUsers(V, U); }, I); + // Make a phi of ops if necessary if (Symbolized && !isa(Symbolized) && !isa(Symbolized) && PHINodeUses.count(I)) { diff --git a/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll b/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll @@ -0,0 +1,118 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -newgvn -S %s | FileCheck %s + +declare void @use.i16(i16*) +declare void @use.i32(i32) + +; Test cases from PR35074, where the simplification dependencies need to be +; tracked for phi-of-ops root instructions. + +define void @test1() { +; CHECK-LABEL: @test1( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[FOR_COND:%.*]] +; CHECK: for.cond: +; CHECK-NEXT: [[PHIOFOPS:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[Y_0:%.*]], [[FOR_INC6:%.*]] ] +; CHECK-NEXT: [[Y_0]] = phi i32 [ 1, [[ENTRY]] ], [ [[INC7:%.*]], [[FOR_INC6]] ] +; CHECK-NEXT: br i1 undef, label [[FOR_INC6]], label [[FOR_BODY_LR_PH:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: br label [[FOR_BODY4:%.*]] +; CHECK: for.body4: +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[PHIOFOPS]], [[Y_0]] +; CHECK-NEXT: br i1 [[CMP]], label [[FOR_END:%.*]], label [[FOR_BODY4_1:%.*]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; CHECK: for.inc6: +; CHECK-NEXT: [[INC7]] = add nuw nsw i32 [[Y_0]], 1 +; CHECK-NEXT: br label [[FOR_COND]] +; CHECK: for.body4.1: +; CHECK-NEXT: [[INC_1:%.*]] = add nuw nsw i32 [[Y_0]], 1 +; CHECK-NEXT: tail call void @use.i32(i32 [[INC_1]]) +; CHECK-NEXT: br label [[FOR_END]] +; +entry: + br label %for.cond + +for.cond: ; preds = %for.inc6, %entry + %y.0 = phi i32 [ 1, %entry ], [ %inc7, %for.inc6 ] + br i1 undef, label %for.inc6, label %for.body.lr.ph + +for.body.lr.ph: ; preds = %for.cond + %sub = add nsw i32 %y.0, -1 + br label %for.body4 + +for.body4: ; preds = %for.body.lr.ph + %cmp = icmp ugt i32 %sub, %y.0 + br i1 %cmp, label %for.end, label %for.body4.1 + +for.end: ; preds = %for.body4.1, %for.body4 + ret void + +for.inc6: ; preds = %for.cond + %inc7 = add nuw nsw i32 %y.0, 1 + br label %for.cond + +for.body4.1: ; preds = %for.body4 + %inc.1 = add nuw nsw i32 %y.0, 1 + tail call void @use.i32(i32 %inc.1) + br label %for.end +} + +define void @test2(i1 %c, i16* %ptr, i64 %N) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[HEADER:%.*]] +; CHECK: header: +; CHECK-NEXT: [[PHIOFOPS:%.*]] = phi i64 [ -1, [[ENTRY:%.*]] ], [ [[IV:%.*]], [[LATCH:%.*]] ] +; CHECK-NEXT: [[IV]] = phi i64 [ [[IV_NEXT:%.*]], [[LATCH]] ], [ 0, [[ENTRY]] ] +; CHECK-NEXT: br i1 [[C:%.*]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]] +; CHECK: if.then: +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[IV]], 0 +; CHECK-NEXT: br i1 [[CMP1]], label [[LATCH]], label [[LOR_RHS:%.*]] +; CHECK: lor.rhs: +; CHECK-NEXT: [[IV_ADD_1:%.*]] = add i64 [[IV]], 1 +; CHECK-NEXT: [[IDX_1:%.*]] = getelementptr inbounds i16, i16* [[PTR:%.*]], i64 [[IV_ADD_1]] +; CHECK-NEXT: call void @use.i16(i16* [[IDX_1]]) +; CHECK-NEXT: ret void +; CHECK: if.else: +; CHECK-NEXT: [[IDX_2:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 [[PHIOFOPS]] +; CHECK-NEXT: call void @use.i16(i16* [[IDX_2]]) +; CHECK-NEXT: br label [[LATCH]] +; CHECK: latch: +; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1 +; CHECK-NEXT: [[EC:%.*]] = icmp ugt i64 [[IV_NEXT]], [[N:%.*]] +; CHECK-NEXT: br i1 [[EC]], label [[HEADER]], label [[EXIT:%.*]] +; CHECK: exit: +; CHECK-NEXT: ret void +; +entry: + br label %header + +header: ; preds = %for.inc, %entry + %iv = phi i64 [ %iv.next, %latch ], [ 0, %entry ] + br i1 %c, label %if.then, label %if.else + +if.then: + %cmp1 = icmp eq i64 %iv, 0 + br i1 %cmp1, label %latch, label %lor.rhs + +lor.rhs: ; preds = %if.then + %iv.add.1 = add i64 %iv, 1 + %idx.1 = getelementptr inbounds i16, i16* %ptr, i64 %iv.add.1 + call void @use.i16(i16* %idx.1) + ret void + +if.else: + %iv.sub.1 = add i64 %iv, -1 + %idx.2 = getelementptr inbounds i16, i16* %ptr, i64 %iv.sub.1 + call void @use.i16(i16* %idx.2) + br label %latch + +latch: + %iv.next = add i64 %iv, 1 + %ec = icmp ugt i64 %iv.next, %N + br i1 %ec, label %header, label %exit + +exit: + ret void +}