Index: llvm/trunk/include/llvm/Transforms/Scalar/GVN.h =================================================================== --- llvm/trunk/include/llvm/Transforms/Scalar/GVN.h +++ llvm/trunk/include/llvm/Transforms/Scalar/GVN.h @@ -178,7 +178,7 @@ // Block-local map of equivalent values to their leader, does not // propagate to any successors. Entries added mid-block are applied // to the remaining instructions in the block. - SmallMapVector ReplaceWithConstMap; + SmallMapVector ReplaceOperandsWithMap; SmallVector InstrsToErase; // Map the block to reversed postorder traversal number. It is used to @@ -283,7 +283,7 @@ void verifyRemoved(const Instruction *I) const; bool splitCriticalEdges(); BasicBlock *splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ); - bool replaceOperandsWithConsts(Instruction *I) const; + bool replaceOperandsForInBlockEquality(Instruction *I) const; bool propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, bool DominatesByEdge); bool processFoldableCondBr(BranchInst *BI); Index: llvm/trunk/lib/Transforms/Scalar/GVN.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/GVN.cpp +++ llvm/trunk/lib/Transforms/Scalar/GVN.cpp @@ -1387,6 +1387,14 @@ return PerformLoadPRE(LI, ValuesPerBlock, UnavailableBlocks); } +static bool hasUsersIn(Value *V, BasicBlock *BB) { + for (User *U : V->users()) + if (isa(U) && + cast(U)->getParent() == BB) + return true; + return false; +} + bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { assert(IntrinsicI->getIntrinsicID() == Intrinsic::assume && "This function can only be called with llvm.assume intrinsic"); @@ -1425,12 +1433,23 @@ // We can replace assume value with true, which covers cases like this: // call void @llvm.assume(i1 %cmp) // br i1 %cmp, label %bb1, label %bb2 ; will change %cmp to true - ReplaceWithConstMap[V] = True; + ReplaceOperandsWithMap[V] = True; - // If one of *cmp *eq operand is const, adding it to map will cover this: + // If we find an equality fact, canonicalize all dominated uses in this block + // to one of the two values. We heuristically choice the "oldest" of the + // two where age is determined by value number. (Note that propagateEquality + // above handles the cross block case.) + // + // Key case to cover are: + // 1) // %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen // call void @llvm.assume(i1 %cmp) // ret float %0 ; will change it to ret float 3.000000e+00 + // 2) + // %load = load float, float* %addr + // %cmp = fcmp oeq float %load, %0 + // call void @llvm.assume(i1 %cmp) + // ret float %load ; will change it to ret float %0 if (auto *CmpI = dyn_cast(V)) { if (CmpI->getPredicate() == CmpInst::Predicate::ICMP_EQ || CmpI->getPredicate() == CmpInst::Predicate::FCMP_OEQ || @@ -1438,13 +1457,50 @@ CmpI->getFastMathFlags().noNaNs())) { Value *CmpLHS = CmpI->getOperand(0); Value *CmpRHS = CmpI->getOperand(1); - if (isa(CmpLHS)) + // Heuristically pick the better replacement -- the choice of heuristic + // isn't terribly important here, but the fact we canonicalize on some + // replacement is for exposing other simplifications. + // TODO: pull this out as a helper function and reuse w/existing + // (slightly different) logic. + if (isa(CmpLHS) && !isa(CmpRHS)) + std::swap(CmpLHS, CmpRHS); + if (!isa(CmpLHS) && isa(CmpRHS)) std::swap(CmpLHS, CmpRHS); - auto *RHSConst = dyn_cast(CmpRHS); + if ((isa(CmpLHS) && isa(CmpRHS)) || + (isa(CmpLHS) && isa(CmpRHS))) { + // Move the 'oldest' value to the right-hand side, using the value + // number as a proxy for age. + uint32_t LVN = VN.lookupOrAdd(CmpLHS); + uint32_t RVN = VN.lookupOrAdd(CmpRHS); + if (LVN < RVN) + std::swap(CmpLHS, CmpRHS); + } - // If only one operand is constant. - if (RHSConst != nullptr && !isa(CmpLHS)) - ReplaceWithConstMap[CmpLHS] = RHSConst; + // Handle degenerate case where we either haven't pruned a dead path or a + // removed a trivial assume yet. + if (isa(CmpLHS) && isa(CmpRHS)) + return Changed; + + // +0.0 and -0.0 compare equal, but do not imply equivalence. Unless we + // can prove equivalence, bail. + if (CmpRHS->getType()->isFloatTy() && + (!isa(CmpRHS) || cast(CmpRHS)->isZero())) + return Changed; + + LLVM_DEBUG(dbgs() << "Replacing dominated uses of " + << *CmpLHS << " with " + << *CmpRHS << " in block " + << IntrinsicI->getParent()->getName() << "\n"); + + + // Setup the replacement map - this handles uses within the same block + if (hasUsersIn(CmpLHS, IntrinsicI->getParent())) + ReplaceOperandsWithMap[CmpLHS] = CmpRHS; + + // NOTE: The non-block local cases are handled by the call to + // propagateEquality above; this block is just about handling the block + // local cases. TODO: There's a bunch of logic in propagateEqualiy which + // isn't duplicated for the block local case, can we share it somehow? } } return Changed; @@ -1697,16 +1753,15 @@ InvalidBlockRPONumbers = false; } -// Tries to replace instruction with const, using information from -// ReplaceWithConstMap. -bool GVN::replaceOperandsWithConsts(Instruction *Instr) const { +bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const { + // TODO: We can remove the separate ReplaceOperandsWithMap data structure in + // favor of putting equalitys into the leader table and using findLeader + // here. bool Changed = false; for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) { - Value *Operand = Instr->getOperand(OpNum); - auto it = ReplaceWithConstMap.find(Operand); - if (it != ReplaceWithConstMap.end()) { - assert(!isa(Operand) && - "Replacing constants with constants is invalid"); + Value *Operand = Instr->getOperand(OpNum); + auto it = ReplaceOperandsWithMap.find(Operand); + if (it != ReplaceOperandsWithMap.end()) { LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with " << *it->second << " in instruction " << *Instr << '\n'); Instr->setOperand(OpNum, it->second); @@ -2098,13 +2153,13 @@ return false; // Clearing map before every BB because it can be used only for single BB. - ReplaceWithConstMap.clear(); + ReplaceOperandsWithMap.clear(); bool ChangedFunction = false; for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { - if (!ReplaceWithConstMap.empty()) - ChangedFunction |= replaceOperandsWithConsts(&*BI); + if (!ReplaceOperandsWithMap.empty()) + ChangedFunction |= replaceOperandsForInBlockEquality(&*BI); ChangedFunction |= processInstruction(&*BI); if (InstrsToErase.empty()) { Index: llvm/trunk/test/Transforms/GVN/equality-assume.ll =================================================================== --- llvm/trunk/test/Transforms/GVN/equality-assume.ll +++ llvm/trunk/test/Transforms/GVN/equality-assume.ll @@ -6,7 +6,7 @@ ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[P:%.*]] ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[LOAD]], [[V:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[C]]) -; CHECK-NEXT: ret i32 [[LOAD]] +; CHECK-NEXT: ret i32 [[V]] ; %load = load i32, i32* %p %c = icmp eq i32 %load, %v @@ -27,8 +27,9 @@ ret i32 %v } -define float @float_oeq(float* %p, float %v) { -; CHECK-LABEL: @float_oeq( +; Lack of equivalance due to +0.0 vs -0.0 +define float @neg_float_oeq(float* %p, float %v) { +; CHECK-LABEL: @neg_float_oeq( ; CHECK-NEXT: [[LOAD:%.*]] = load float, float* [[P:%.*]] ; CHECK-NEXT: [[C:%.*]] = fcmp oeq float [[LOAD]], [[V:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[C]]) @@ -40,8 +41,9 @@ ret float %load } -define float @float_ueq(float* %p, float %v) { -; CHECK-LABEL: @float_ueq( +; Lack of equivalance due to +0.0 vs -0.0 +define float @neg_float_ueq(float* %p, float %v) { +; CHECK-LABEL: @neg_float_ueq( ; CHECK-NEXT: [[LOAD:%.*]] = load float, float* [[P:%.*]] ; CHECK-NEXT: [[C:%.*]] = fcmp ueq float [[LOAD]], [[V:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[C]]) @@ -66,8 +68,9 @@ ret float %load } -define float @float_ueq_constant(float* %p) { -; CHECK-LABEL: @float_ueq_constant( +; Lack of equivalance due to Nan +define float @neq_float_ueq_constant(float* %p) { +; CHECK-LABEL: @neq_float_ueq_constant( ; CHECK-NEXT: [[LOAD:%.*]] = load float, float* [[P:%.*]] ; CHECK-NEXT: [[C:%.*]] = fcmp ueq float [[LOAD]], 5.000000e+00 ; CHECK-NEXT: call void @llvm.assume(i1 [[C]]) @@ -79,12 +82,25 @@ ret float %load } +define float @float_ueq_constant_nnas(float* %p) { +; CHECK-LABEL: @float_ueq_constant_nnas( +; CHECK-NEXT: [[LOAD:%.*]] = load float, float* [[P:%.*]] +; CHECK-NEXT: [[C:%.*]] = fcmp nnan ueq float [[LOAD]], 5.000000e+00 +; CHECK-NEXT: call void @llvm.assume(i1 [[C]]) +; CHECK-NEXT: ret float 5.000000e+00 +; + %load = load float, float* %p + %c = fcmp nnan ueq float %load, 5.0 + call void @llvm.assume(i1 %c) + ret float %load +} + define i32 @test2(i32* %p, i32 %v) { ; CHECK-LABEL: @test2( ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[P:%.*]] ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[LOAD]], [[V:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[C]]) -; CHECK-NEXT: ret i32 [[LOAD]] +; CHECK-NEXT: ret i32 [[V]] ; %load = load i32, i32* %p %c = icmp eq i32 %load, %v @@ -93,8 +109,6 @@ ret i32 %load2 } - - define i32 @test3(i32* %p, i32 %v) { ; CHECK-LABEL: @test3( ; CHECK-NEXT: [[LOAD:%.*]] = load i32, i32* [[P:%.*]]