diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -1294,6 +1294,7 @@ if (!all_of(PN.operands(), [](Value *V) { return isa(V); })) return nullptr; + unsigned PNSize = PN.getType()->getPrimitiveSizeInBits(); BasicBlock *BB = PN.getParent(); // Do not bother with unreachable instructions. if (!DT.isReachableFromEntry(BB)) @@ -1303,10 +1304,23 @@ LLVMContext &Context = PN.getContext(); auto *IDom = DT.getNode(BB)->getIDom()->getBlock(); Value *Cond; - SmallDenseMap SuccForValue; + SmallDenseMap SuccValue, SuccValueTrunc, + SuccValueZExt, SuccValueSExt; SmallDenseMap SuccCount; - auto AddSucc = [&](ConstantInt *C, BasicBlock *Succ) { - SuccForValue[C] = Succ; + auto AddSucc = [&](Value *Cond, ConstantInt *C, BasicBlock *Succ) { + unsigned CondSize = Cond->getType()->getPrimitiveSizeInBits(); + if (CondSize < PNSize) { + // The cond needs to be extended + SuccValueSExt[cast(ConstantExpr::getSExt(C, PN.getType()))] = + Succ; + SuccValueZExt[cast(ConstantExpr::getZExt(C, PN.getType()))] = + Succ; + } else if (CondSize == PNSize) { + SuccValue[C] = Succ; + } else { + SuccValueTrunc[cast( + ConstantExpr::getTrunc(C, PN.getType()))] = Succ; + } ++SuccCount[Succ]; }; if (auto *BI = dyn_cast(IDom->getTerminator())) { @@ -1314,62 +1328,84 @@ return nullptr; Cond = BI->getCondition(); - AddSucc(ConstantInt::getTrue(Context), BI->getSuccessor(0)); - AddSucc(ConstantInt::getFalse(Context), BI->getSuccessor(1)); + AddSucc(Cond, ConstantInt::getTrue(Context), BI->getSuccessor(0)); + AddSucc(Cond, ConstantInt::getFalse(Context), BI->getSuccessor(1)); } else if (auto *SI = dyn_cast(IDom->getTerminator())) { Cond = SI->getCondition(); ++SuccCount[SI->getDefaultDest()]; for (auto Case : SI->cases()) - AddSucc(Case.getCaseValue(), Case.getCaseSuccessor()); + AddSucc(Cond, Case.getCaseValue(), Case.getCaseSuccessor()); } else { return nullptr; } - if (Cond->getType() != PN.getType()) - return nullptr; - // Check that edges outgoing from the idom's terminators dominate respective // inputs of the Phi. - std::optional Invert; - for (auto Pair : zip(PN.incoming_values(), PN.blocks())) { - auto *Input = cast(std::get<0>(Pair)); - BasicBlock *Pred = std::get<1>(Pair); - auto IsCorrectInput = [&](ConstantInt *Input) { - // The input needs to be dominated by the corresponding edge of the idom. - // This edge cannot be a multi-edge, as that would imply that multiple - // different condition values follow the same edge. - auto It = SuccForValue.find(Input); - return It != SuccForValue.end() && SuccCount[It->second] == 1 && - DT.dominates(BasicBlockEdge(IDom, It->second), - BasicBlockEdge(Pred, BB)); - }; - - // Depending on the constant, the condition may need to be inverted. - bool NeedsInvert; - if (IsCorrectInput(Input)) - NeedsInvert = false; - else if (IsCorrectInput(cast(ConstantExpr::getNot(Input)))) - NeedsInvert = true; - else - return nullptr; - - // Make sure the inversion requirement is always the same. - if (Invert && *Invert != NeedsInvert) - return nullptr; - - Invert = NeedsInvert; - } - - if (!*Invert) - return Cond; + auto CheckSuccValue = + [&](SmallDenseMap SuccForValue) + -> std::tuple> { + std::optional Invert; + for (auto Pair : zip(PN.incoming_values(), PN.blocks())) { + auto *Input = cast(std::get<0>(Pair)); + BasicBlock *Pred = std::get<1>(Pair); + auto IsCorrectInput = [&](ConstantInt *Input) { + // The input needs to be dominated by the corresponding edge of the + // idom. This edge cannot be a multi-edge, as that would imply that + // multiple different condition values follow the same edge. + auto It = SuccForValue.find(Input); + return It != SuccForValue.end() && SuccCount[It->second] == 1 && + DT.dominates(BasicBlockEdge(IDom, It->second), + BasicBlockEdge(Pred, BB)); + }; + + // Depending on the constant, the condition may need to be inverted. + bool NeedsInvert; + if (IsCorrectInput(Input)) + NeedsInvert = false; + else if (IsCorrectInput(cast(ConstantExpr::getNot(Input)))) + NeedsInvert = true; + else + return {false, std::nullopt}; + + // Make sure the inversion requirement is always the same. + if (Invert && *Invert != NeedsInvert) + return {false, std::nullopt}; + + Invert = NeedsInvert; + } + return {true, Invert}; + }; + unsigned CondSize = Cond->getType()->getPrimitiveSizeInBits(); + // TODO(lyc): change the comment here // This Phi is actually opposite to branching condition of IDom. We invert // the condition that will potentially open up some opportunities for // sinking. auto InsertPt = BB->getFirstInsertionPt(); if (InsertPt != BB->end()) { Self.Builder.SetInsertPoint(&*InsertPt); - return Self.Builder.CreateNot(Cond); + if (CondSize < PNSize) { + auto [ZExtOk, ZExtInvert] = CheckSuccValue(SuccValueZExt); + if (ZExtOk) { + Cond = Self.Builder.CreateZExt(Cond, PN.getType()); + return *ZExtInvert ? Self.Builder.CreateNot(Cond) : Cond; + } + auto [SExtOk, SExtInvert] = CheckSuccValue(SuccValueSExt); + if (SExtOk) { + Cond = Self.Builder.CreateSExt(Cond, PN.getType()); + return *SExtInvert ? Self.Builder.CreateNot(Cond) : Cond; + } + } else if (CondSize == PNSize) { + auto [Ok, Invert] = CheckSuccValue(SuccValue); + if (Ok) + return *Invert ? Self.Builder.CreateNot(Cond) : Cond; + } else { + auto [Ok, Invert] = CheckSuccValue(SuccValueTrunc); + if (Ok) { + Cond = Self.Builder.CreateTrunc(Cond, PN.getType()); + return *Invert ? Self.Builder.CreateNot(Cond) : Cond; + } + } } return nullptr; diff --git a/llvm/test/Transforms/InstCombine/simple_phi_condition.ll b/llvm/test/Transforms/InstCombine/simple_phi_condition.ll --- a/llvm/test/Transforms/InstCombine/simple_phi_condition.ll +++ b/llvm/test/Transforms/InstCombine/simple_phi_condition.ll @@ -668,7 +668,7 @@ ; CHECK: bb1: ; CHECK-NEXT: br label [[BB5]] ; CHECK: bb5: -; CHECK-NEXT: [[DOT0:%.*]] = phi i32 [ 1, [[BB1]] ], [ 0, [[BB4]] ], [ -1, [[BB3]] ] +; CHECK-NEXT: [[DOT0:%.*]] = sext i8 [[TMP0]] to i32 ; CHECK-NEXT: ret i32 [[DOT0]] ; switch i8 %0, label %bb2 [ @@ -711,7 +711,7 @@ ; CHECK: bb1: ; CHECK-NEXT: br label [[BB5]] ; CHECK: bb5: -; CHECK-NEXT: [[DOT0:%.*]] = phi i32 [ 1, [[BB1]] ], [ 0, [[BB4]] ], [ 255, [[BB3]] ] +; CHECK-NEXT: [[DOT0:%.*]] = zext i8 [[TMP0]] to i32 ; CHECK-NEXT: ret i32 [[DOT0]] ; switch i8 %0, label %bb2 [ @@ -753,7 +753,7 @@ ; CHECK: bb1: ; CHECK-NEXT: br label [[BB5]] ; CHECK: bb5: -; CHECK-NEXT: [[DOT0:%.*]] = phi i8 [ 1, [[BB1]] ], [ 0, [[BB4]] ], [ -1, [[BB3]] ] +; CHECK-NEXT: [[DOT0:%.*]] = trunc i32 [[TMP0]] to i8 ; CHECK-NEXT: ret i8 [[DOT0]] ; switch i32 %0, label %bb2 [ @@ -796,7 +796,8 @@ ; CHECK: bb1: ; CHECK-NEXT: br label [[BB5]] ; CHECK: bb5: -; CHECK-NEXT: [[DOT0:%.*]] = phi i32 [ -2, [[BB1]] ], [ -1, [[BB4]] ], [ 0, [[BB3]] ] +; CHECK-NEXT: [[TMP2:%.*]] = xor i8 [[TMP0]], -1 +; CHECK-NEXT: [[DOT0:%.*]] = sext i8 [[TMP2]] to i32 ; CHECK-NEXT: ret i32 [[DOT0]] ; switch i8 %0, label %bb2 [ @@ -839,7 +840,8 @@ ; CHECK: bb1: ; CHECK-NEXT: br label [[BB5]] ; CHECK: bb5: -; CHECK-NEXT: [[DOT0:%.*]] = phi i32 [ -2, [[BB1]] ], [ -1, [[BB4]] ], [ -256, [[BB3]] ] +; CHECK-NEXT: [[TMP2:%.*]] = zext i8 [[TMP0]] to i32 +; CHECK-NEXT: [[DOT0:%.*]] = xor i32 [[TMP2]], -1 ; CHECK-NEXT: ret i32 [[DOT0]] ; switch i8 %0, label %bb2 [ @@ -881,7 +883,8 @@ ; CHECK: bb1: ; CHECK-NEXT: br label [[BB5]] ; CHECK: bb5: -; CHECK-NEXT: [[DOT0:%.*]] = phi i8 [ -2, [[BB1]] ], [ -1, [[BB4]] ], [ 0, [[BB3]] ] +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[TMP0]] to i8 +; CHECK-NEXT: [[DOT0:%.*]] = xor i8 [[TMP2]], -1 ; CHECK-NEXT: ret i8 [[DOT0]] ; switch i32 %0, label %bb2 [