Index: lib/Transforms/Scalar/NewGVN.cpp =================================================================== --- lib/Transforms/Scalar/NewGVN.cpp +++ lib/Transforms/Scalar/NewGVN.cpp @@ -81,13 +81,13 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/MemorySSA.h" +#include "llvm/Transforms/Utils/PredicateInfo.h" #include #include #include using namespace llvm; using namespace PatternMatch; using namespace llvm::GVNExpression; - #define DEBUG_TYPE "newgvn" STATISTIC(NumGVNInstrDeleted, "Number of instructions deleted"); @@ -208,9 +208,13 @@ AliasAnalysis *AA; MemorySSA *MSSA; MemorySSAWalker *MSSAWalker; + PredicateInfo *PredInfo; BumpPtrAllocator ExpressionAllocator; ArrayRecycler ArgRecycler; + // Number of function arguments, used by ranking + unsigned int NumFuncArgs; + // Congruence class info. CongruenceClass *InitialClass; std::vector CongruenceClasses; @@ -288,7 +292,6 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); - AU.addPreserved(); AU.addPreserved(); } @@ -337,6 +340,7 @@ const Expression *performSymbolicPHIEvaluation(Instruction *); const Expression *performSymbolicAggrValueEvaluation(Instruction *); const Expression *performSymbolicCmpEvaluation(Instruction *); + const Expression *performSymbolicPredicateInfoEvaluation(Instruction *); // Congruence finding. Value *lookupOperandLeader(Value *) const; @@ -347,6 +351,9 @@ MemoryAccess *lookupMemoryAccessEquiv(MemoryAccess *) const; bool isMemoryAccessTop(const MemoryAccess *) const; + // Ranking + unsigned int getRank(Value *V) const; + // Reachability handling. void updateReachableEdge(BasicBlock *, BasicBlock *); void processOutgoingEdges(TerminatorInst *, BasicBlock *); @@ -573,7 +580,7 @@ // Sort the operand value numbers so xx get the same value // number. CmpInst::Predicate Predicate = CI->getPredicate(); - if (E->getOperand(0) > E->getOperand(1)) { + if (getRank(E->getOperand(0)) > getRank(E->getOperand(1))) { E->swapOperands(0, 1); Predicate = CmpInst::getSwappedPredicate(Predicate); } @@ -828,12 +835,79 @@ return E; } +const Expression * +NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) { + if (auto *PI = PredInfo->getPredicateInfoFor(I)) { + DEBUG(dbgs() << "Found predicate info from instruction !\n"); + auto *CopyOf = I->getOperand(0); + auto *Cmp = PI->Comparison; + // If this is an assume predicate and a copy of a comparison, it must be + // true. + if (isa(PI) && CopyOf == Cmp) + return createConstantExpression(ConstantInt::getTrue(Cmp->getType())); + + Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); + Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1)); + // Sort the ops + CmpInst::Predicate Predicate = Cmp->getPredicate(); + // FIXME: We should really be ranking them here + if (getRank(FirstOp) > getRank(SecondOp)) { + std::swap(FirstOp, SecondOp); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + + if (isa(PI)) { + // If the comparison is true when the operands are equal, then we know the + // operands are equal, because assumes must always be true. + if (CmpInst::isTrueWhenEqual(Predicate)) + if (auto *C = dyn_cast(FirstOp)) + return createConstantExpression(C); + return createVariableExpression(FirstOp); + } else if (const auto *PBranch = dyn_cast(PI)) { + // If this is a copy of the comparison, it's value is determined by the + // edge. + if (CopyOf == Cmp) { + if (PBranch->TrueEdge) + return createConstantExpression(ConstantInt::getTrue(Cmp->getType())); + return createConstantExpression(ConstantInt::getFalse(Cmp->getType())); + } else if ((PBranch->TrueEdge && CmpInst::isTrueWhenEqual(Predicate)) || + (!PBranch->TrueEdge && CmpInst::isFalseWhenEqual(Predicate))) { + // If we are *not* a copy of the comparison, we may equal to the other + // operand when + // the predicate implies something about equality of operations. + if (auto *C = dyn_cast(FirstOp)) + return createConstantExpression(C); + return createVariableExpression(FirstOp); + } else if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) || + (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) && + isa(FirstOp) && + !cast(FirstOp)->isZero()) { + return createConstantExpression(cast(FirstOp)); + } + } + } + return nullptr; +} + // Evaluate read only and pure calls, and create an expression result. const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) { auto *CI = cast(I); - if (AA->doesNotAccessMemory(CI)) + if (auto *II = dyn_cast(I)) { + // Things with the returned attribute are copies of arguments + if (auto *ReturnedValue = II->getReturnedArgOperand()) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) { + const Expression *Result = performSymbolicPredicateInfoEvaluation(I); + if (Result) + return Result; + } + if (auto *C = dyn_cast(ReturnedValue)) + return createConstantExpression(C); + return createVariableExpression(ReturnedValue); + } + } + if (AA->doesNotAccessMemory(CI)) { return createCallExpression(CI, nullptr); - if (AA->onlyReadsMemory(CI)) { + } else if (AA->onlyReadsMemory(CI)) { MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(CI); return createCallExpression(CI, lookupMemoryAccessEquiv(DefiningAccess)); } @@ -964,9 +1038,8 @@ // expression. assert(II->getNumArgOperands() == 2 && "Expect two args for recognised intrinsics."); - return createBinaryExpression(Opcode, EI->getType(), - II->getArgOperand(0), - II->getArgOperand(1)); + return createBinaryExpression( + Opcode, EI->getType(), II->getArgOperand(0), II->getArgOperand(1)); } } } @@ -974,16 +1047,69 @@ return createAggregateValueExpression(I); } const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) { - CmpInst *CI = dyn_cast(I); - // See if our operands are equal and that implies something. + auto *CI = dyn_cast(I); + // See if our operands are equal to those of a previous predicate, and if so, + // if it implies true or false. auto Op0 = lookupOperandLeader(CI->getOperand(0)); auto Op1 = lookupOperandLeader(CI->getOperand(1)); + // Avoid processing the same info twice + const PredicateBase *LastPredInfo = nullptr; + + // See if we know something about the comparison itself, like it is the target + // of an assume. + auto *CmpPI = PredInfo->getPredicateInfoFor(I); + if (dyn_cast_or_null(CmpPI)) + return createConstantExpression(ConstantInt::getTrue(CI->getType())); + + // See if we know something just from the operands themselves if (Op0 == Op1) { if (CI->isTrueWhenEqual()) return createConstantExpression(ConstantInt::getTrue(CI->getType())); else if (CI->isFalseWhenEqual()) return createConstantExpression(ConstantInt::getFalse(CI->getType())); } + + // See if our operands have predicate info, so that we may be able to derive + // something from a previous comparison. + for (const auto &Op : CI->operands()) { + auto *PI = PredInfo->getPredicateInfoFor(Op); + if (const auto *PBranch = dyn_cast_or_null(PI)) { + if (PI == LastPredInfo) + continue; + LastPredInfo = PI; + // TODO: Along the false edge, we may know more things too, like icmp of + // same operands is false. + // + auto *BranchCond = PBranch->Comparison; + if (lookupOperandLeader(BranchCond->getOperand(0)) == Op0 && + lookupOperandLeader(BranchCond->getOperand(1)) == Op1) { + if (PBranch->TrueEdge) { + // If we know the previous predicate is true and we are in the true + // edge then we may be implied true or false. + if (CI->isImpliedTrueByMatchingCmp(BranchCond->getPredicate())) + return createConstantExpression( + ConstantInt::getTrue(CI->getType())); + if (CI->isImpliedFalseByMatchingCmp(BranchCond->getPredicate())) + return createConstantExpression( + ConstantInt::getFalse(CI->getType())); + } else { + // Just handle the ne and eq cases, where if we have the same + // operands, we may know something. + if (BranchCond->getPredicate() == CI->getPredicate()) { + // Same predicate, same ops,we know it was false, so this is false. + return createConstantExpression( + ConstantInt::getFalse(CI->getType())); + } else if (BranchCond->getPredicate() == CI->getInversePredicate()) { + // Inverse predicate, we know the other was false, so this is true. + // FIXME: Double check this + return createConstantExpression( + ConstantInt::getTrue(CI->getType())); + } + } + } + } + } + // Create expression will take care of simplifyCmpInst return createExpression(I); } @@ -1696,11 +1822,13 @@ TargetLibraryInfo *_TLI, AliasAnalysis *_AA, MemorySSA *_MSSA) { bool Changed = false; + NumFuncArgs = F.arg_size(); DT = _DT; AC = _AC; TLI = _TLI; AA = _AA; MSSA = _MSSA; + PredInfo = new PredicateInfo(F, *DT, *AC); DL = &F.getParent()->getDataLayout(); MSSAWalker = MSSA->getWalker(); @@ -1775,6 +1903,9 @@ while (TouchedInstructions.any()) { ++Iterations; // Walk through all the instructions in all the blocks in RPO. + // TODO: As we hit a new block, we should push and pop equalities into a + // table lookupOperandLeader can use, to catch things PredicateInfo + // might miss, like edge-only equivalences. for (int InstrNum = TouchedInstructions.find_first(); InstrNum != -1; InstrNum = TouchedInstructions.find_next(InstrNum)) { @@ -2327,8 +2458,13 @@ // If we replaced something in an instruction, handle the patching of // metadata. - if (auto *ReplacedInst = dyn_cast(MemberUse->get())) - patchReplacementInstruction(ReplacedInst, Result); + if (auto *ReplacedInst = dyn_cast(MemberUse->get())) { + // Skip this if we are replacing predicateinfo with it's original + // operand, as we already know we can just drop it. + auto *PI = PredInfo->getPredicateInfoFor(ReplacedInst); + if (!PI || Result != PI->OriginalOp) + patchReplacementInstruction(ReplacedInst, Result); + } assert(isa(MemberUse->getUser())); MemberUse->set(Result); @@ -2393,3 +2529,18 @@ return AnythingReplaced; } + +unsigned int NewGVN::getRank(Value *V) const { + if (isa(V)) + return 0; + else if (Argument *A = dyn_cast(V)) + return 1 + A->getArgNo(); + + // Need to shift the instruction DFS by number of arguments + 1 to account for + // the constant and argument ranking above. + unsigned Result = InstrDFS.lookup(V); + if (Result > 0) + return 2 + NumFuncArgs + Result; + // Unreachable or something else, just return a really large number. + return ~0; +}