diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -1166,11 +1166,16 @@ return ValueLatticeElement::getRange(NWR); } -static std::optional -getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest, - bool isRevisit, - SmallDenseMap &Visited, - SmallVectorImpl &Worklist) { +// Tracks a Value * condition and whether we're interested in it or its inverse +typedef PointerIntPair CondValue; + +static std::optional getValueFromConditionImpl( + Value *Val, CondValue CondVal, bool isRevisit, + SmallDenseMap &Visited, + SmallVectorImpl &Worklist) { + + Value *Cond = CondVal.getPointer(); + bool isTrueDest = CondVal.getInt(); if (!isRevisit) { if (ICmpInst *ICI = dyn_cast(Cond)) return getValueFromICmpCondition(Val, ICI, isTrueDest); @@ -1181,6 +1186,17 @@ return getValueFromOverflowCondition(Val, WO, isTrueDest); } + Value *N; + if (match(Cond, m_Not(m_Value(N)))) { + CondValue NKey(N, !isTrueDest); + auto NV = Visited.find(NKey); + if (NV == Visited.end()) { + Worklist.push_back(NKey); + return std::nullopt; + } + return NV->second; + } + Value *L, *R; bool IsAnd; if (match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))) @@ -1190,13 +1206,13 @@ else return ValueLatticeElement::getOverdefined(); - auto LV = Visited.find(L); - auto RV = Visited.find(R); + auto LV = Visited.find(CondValue(L, isTrueDest)); + auto RV = Visited.find(CondValue(R, isTrueDest)); // if (L && R) -> intersect L and R - // if (!(L || R)) -> intersect L and R + // if (!(L || R)) -> intersect !L and !R // if (L || R) -> union L and R - // if (!(L && R)) -> union L and R + // if (!(L && R)) -> union !L and !R if ((isTrueDest ^ IsAnd) && (LV != Visited.end())) { ValueLatticeElement V = LV->second; if (V.isOverdefined()) @@ -1210,9 +1226,9 @@ if (LV == Visited.end() || RV == Visited.end()) { assert(!isRevisit); if (LV == Visited.end()) - Worklist.push_back(L); + Worklist.push_back(CondValue(L, isTrueDest)); if (RV == Visited.end()) - Worklist.push_back(R); + Worklist.push_back(CondValue(R, isTrueDest)); return std::nullopt; } @@ -1222,12 +1238,13 @@ ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond, bool isTrueDest) { assert(Cond && "precondition"); - SmallDenseMap Visited; - SmallVector Worklist; + SmallDenseMap Visited; + SmallVector Worklist; - Worklist.push_back(Cond); + CondValue CondKey(Cond, isTrueDest); + Worklist.push_back(CondKey); do { - Value *CurrentCond = Worklist.back(); + CondValue CurrentCond = Worklist.back(); // Insert an Overdefined placeholder into the set to prevent // infinite recursion if there exists IRs that use not // dominated by its def as in this example: @@ -1237,14 +1254,14 @@ Visited.try_emplace(CurrentCond, ValueLatticeElement::getOverdefined()); bool isRevisit = !Iter.second; std::optional Result = getValueFromConditionImpl( - Val, CurrentCond, isTrueDest, isRevisit, Visited, Worklist); + Val, CurrentCond, isRevisit, Visited, Worklist); if (Result) { Visited[CurrentCond] = *Result; Worklist.pop_back(); } } while (!Worklist.empty()); - auto Result = Visited.find(Cond); + auto Result = Visited.find(CondKey); assert(Result != Visited.end()); return Result->second; } diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/basic.ll b/llvm/test/Transforms/CorrelatedValuePropagation/basic.ll --- a/llvm/test/Transforms/CorrelatedValuePropagation/basic.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/basic.ll @@ -1853,6 +1853,63 @@ ret void } +define i1 @xor_neg_cond(i32 %a) { +; CHECK-LABEL: @xor_neg_cond( +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[A:%.*]], 10 +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[CMP1]], true +; CHECK-NEXT: br i1 [[XOR]], label [[EXIT:%.*]], label [[GUARD:%.*]] +; CHECK: guard: +; CHECK-NEXT: ret i1 true +; CHECK: exit: +; CHECK-NEXT: ret i1 false +; + %cmp1 = icmp eq i32 %a, 10 + %xor = xor i1 %cmp1, true + br i1 %xor, label %exit, label %guard + +guard: + %cmp2 = icmp eq i32 %a, 10 + ret i1 %cmp2 + +exit: + ret i1 false +} + +define i1 @xor_approx(i32 %a) { +; CHECK-LABEL: @xor_approx( +; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i32 [[A:%.*]], 2 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i32 [[A]], 5 +; CHECK-NEXT: [[CMP3:%.*]] = icmp ugt i32 [[A]], 7 +; CHECK-NEXT: [[CMP4:%.*]] = icmp ult i32 [[A]], 9 +; CHECK-NEXT: [[AND1:%.*]] = and i1 [[CMP1]], [[CMP2]] +; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP3]], [[CMP4]] +; CHECK-NEXT: [[OR:%.*]] = or i1 [[AND1]], [[AND2]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[OR]], true +; CHECK-NEXT: br i1 [[XOR]], label [[EXIT:%.*]], label [[GUARD:%.*]] +; CHECK: guard: +; CHECK-NEXT: [[CMP5:%.*]] = icmp eq i32 [[A]], 6 +; CHECK-NEXT: ret i1 [[CMP5]] +; CHECK: exit: +; CHECK-NEXT: ret i1 false +; + %cmp1 = icmp ugt i32 %a, 2 + %cmp2 = icmp ult i32 %a, 5 + %cmp3 = icmp ugt i32 %a, 7 + %cmp4 = icmp ult i32 %a, 9 + %and1 = and i1 %cmp1, %cmp2 + %and2 = and i1 %cmp3, %cmp4 + %or = or i1 %and1, %and2 + %xor = xor i1 %or, true + br i1 %xor, label %exit, label %guard + +guard: + %cmp5 = icmp eq i32 %a, 6 + ret i1 %cmp5 + +exit: + ret i1 false +} + declare i32 @llvm.uadd.sat.i32(i32, i32) declare i32 @llvm.usub.sat.i32(i32, i32) declare i32 @llvm.sadd.sat.i32(i32, i32)