Index: include/llvm/Transforms/Scalar/GVN.h =================================================================== --- include/llvm/Transforms/Scalar/GVN.h +++ include/llvm/Transforms/Scalar/GVN.h @@ -176,7 +176,7 @@ // Block-local map of equivalent values to their leader, does not // propagate to any successors. Entries added mid-block are applied // to the remaining instructions in the block. - SmallMapVector ReplaceWithConstMap; + SmallMapVector ReplaceOperandsWithMap; SmallVector InstrsToErase; // Map the block to reversed postorder traversal number. It is used to @@ -281,7 +281,7 @@ void verifyRemoved(const Instruction *I) const; bool splitCriticalEdges(); BasicBlock *splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ); - bool replaceOperandsWithConsts(Instruction *I) const; + bool replaceOperandsForInBlockEquality(Instruction *I) const; bool propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, bool DominatesByEdge); bool processFoldableCondBr(BranchInst *BI); Index: lib/Transforms/Scalar/GVN.cpp =================================================================== --- lib/Transforms/Scalar/GVN.cpp +++ lib/Transforms/Scalar/GVN.cpp @@ -1387,6 +1387,14 @@ return PerformLoadPRE(LI, ValuesPerBlock, UnavailableBlocks); } +static bool hasUsersIn(Value *V, BasicBlock *BB) { + for (User *U : V->users()) + if (isa(U) && + cast(U)->getParent() == BB) + return true; + return false; +} + bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { assert(IntrinsicI->getIntrinsicID() == Intrinsic::assume && "This function can only be called with llvm.assume intrinsic"); @@ -1425,12 +1433,21 @@ // We can replace assume value with true, which covers cases like this: // call void @llvm.assume(i1 %cmp) // br i1 %cmp, label %bb1, label %bb2 ; will change %cmp to true - ReplaceWithConstMap[V] = True; + ReplaceOperandsWithMap[V] = True; - // If one of *cmp *eq operand is const, adding it to map will cover this: + // If we find an equality fact, canonicalize all dominated uses to one + // of the two values. We heuristically choice the "oldest" of the two. + // + // Key case to cover are: + // 1) // %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen // call void @llvm.assume(i1 %cmp) // ret float %0 ; will change it to ret float 3.000000e+00 + // 2) + // %load = load float, float* %addr + // %cmp = fcmp oeq float %load, %0 ; const on lhs could happen + // call void @llvm.assume(i1 %cmp) + // ret float %load ; will change it to ret float %v if (auto *CmpI = dyn_cast(V)) { if (CmpI->getPredicate() == CmpInst::Predicate::ICMP_EQ || CmpI->getPredicate() == CmpInst::Predicate::FCMP_OEQ || @@ -1438,13 +1455,43 @@ CmpI->getFastMathFlags().noNaNs())) { Value *CmpLHS = CmpI->getOperand(0); Value *CmpRHS = CmpI->getOperand(1); - if (isa(CmpLHS)) + // Hueristically pick the better replacement -- the choice of heuristic + // isn't terribly important here, but the fact we canonicalize on some + // replacement is for exposing other simplifications. + // TODO: pull this out as a helper function and reuse w/existing + // (slightly different) logic. + if (isa(CmpLHS) && !isa(CmpRHS)) + std::swap(CmpLHS, CmpRHS); + if (!isa(CmpLHS) && isa(CmpRHS)) std::swap(CmpLHS, CmpRHS); - auto *RHSConst = dyn_cast(CmpRHS); + if ((isa(CmpLHS) && isa(CmpRHS)) || + (isa(CmpLHS) && isa(CmpRHS))) { + // Move the 'oldest' value to the right-hand side, using the value + // number as a proxy for age. + uint32_t LVN = VN.lookupOrAdd(CmpLHS); + uint32_t RVN = VN.lookupOrAdd(CmpRHS); + if (LVN < RVN) + std::swap(CmpLHS, CmpRHS); + } - // If only one operand is constant. - if (RHSConst != nullptr && !isa(CmpLHS)) - ReplaceWithConstMap[CmpLHS] = RHSConst; + // Handle degenerate case where we either haven't pruned a dead path or a + // removed a trivial assume yet. + if (isa(CmpLHS) && isa(CmpRHS)) + return Changed; + + LLVM_DEBUG(dbgs() << "Replacing dominated uses of " + << *CmpLHS << " with " + << *CmpRHS << "\n"); + + + // Setup the replacement map - this handles uses within the same block + if (hasUsersIn(CmpLHS, IntrinsicI->getParent())) + ReplaceOperandsWithMap[CmpLHS] = CmpRHS; + + // NOTE: The non-block local cases are handled by the call to + // propagateEquality above; this block is just about handling the block + // local cases. TODO: There's a bunch of logic in propagateEqualiy which + // isn't duplicated for the block local case, can we share it somehow? } } return Changed; @@ -1659,16 +1706,15 @@ InvalidBlockRPONumbers = false; } -// Tries to replace instruction with const, using information from -// ReplaceWithConstMap. -bool GVN::replaceOperandsWithConsts(Instruction *Instr) const { +bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const { + // TODO: We can remove the separate ReplaceOperandsWithMap data structure in + // favor of putting equalitys into the leader table and using findLeader + // here. bool Changed = false; for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) { - Value *Operand = Instr->getOperand(OpNum); - auto it = ReplaceWithConstMap.find(Operand); - if (it != ReplaceWithConstMap.end()) { - assert(!isa(Operand) && - "Replacing constants with constants is invalid"); + Value *Operand = Instr->getOperand(OpNum); + auto it = ReplaceOperandsWithMap.find(Operand); + if (it != ReplaceOperandsWithMap.end()) { LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with " << *it->second << " in instruction " << *Instr << '\n'); Instr->setOperand(OpNum, it->second); @@ -2060,13 +2106,13 @@ return false; // Clearing map before every BB because it can be used only for single BB. - ReplaceWithConstMap.clear(); + ReplaceOperandsWithMap.clear(); bool ChangedFunction = false; for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { - if (!ReplaceWithConstMap.empty()) - ChangedFunction |= replaceOperandsWithConsts(&*BI); + if (!ReplaceOperandsWithMap.empty()) + ChangedFunction |= replaceOperandsForInBlockEquality(&*BI); ChangedFunction |= processInstruction(&*BI); if (InstrsToErase.empty()) { Index: test/Transforms/GVN/equality-assume.ll =================================================================== --- test/Transforms/GVN/equality-assume.ll +++ test/Transforms/GVN/equality-assume.ll @@ -6,7 +6,7 @@ ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[P:%.*]] ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[LOAD]], [[V:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[C]]) -; CHECK-NEXT: ret i32 [[LOAD]] +; CHECK-NEXT: ret i32 [[V]] ; %load = load i32, i32* %p %c = icmp eq i32 %load, %v @@ -32,7 +32,7 @@ ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[P:%.*]] ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[LOAD]], [[V:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[C]]) -; CHECK-NEXT: ret i32 [[LOAD]] +; CHECK-NEXT: ret i32 [[V]] ; %load = load i32, i32* %p %c = icmp eq i32 %load, %v