Index: llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h +++ llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h @@ -11,8 +11,8 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" namespace llvm { @@ -25,7 +25,6 @@ LogicalExpr Expr; unsigned Weight; unsigned OneUseWeight; - // TODO: Add weight to measure cost for more than one use value void printValue(raw_ostream &OS, Value *Val) const; void printAndChain(raw_ostream &OS, uint64_t Mask) const; @@ -59,11 +58,13 @@ SmallDenseMap LogicalOpNodes; SmallPtrSet LeafSet; SmallVector LeafValues; + uint64_t PoisonLeafMask; void clear(); LogicalOpNode *visitLeafNode(Value *Val, unsigned Depth); LogicalOpNode *visitBinOp(BinaryOperator *BO, unsigned Depth); + LogicalOpNode *visitSelect(SelectInst *SI, unsigned Depth); LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); Value *logicalOpToValue(LogicalOpNode *Node); Index: llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp @@ -32,6 +32,8 @@ #include "ComplexLogicalOpsCombine.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/CommandLine.h" @@ -105,6 +107,7 @@ void LogicalOpsHelper::clear() { for (auto node : LogicalOpNodes) delete node.second; + PoisonLeafMask = 0; LogicalOpNodes.clear(); LeafSet.clear(); LeafValues.clear(); @@ -131,6 +134,7 @@ LogicalOpNode *Node = new LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0); LogicalOpNodes[Val] = Node; + PoisonLeafMask |= isGuaranteedNotToBeUndefOrPoison(Val) ? 0 : ExprVal; return Node; } @@ -169,6 +173,38 @@ return Node; } +LogicalOpNode *LogicalOpsHelper::visitSelect(SelectInst *SI, unsigned Depth) { + if (!SI->getType()->isIntOrIntVectorTy(1)) + return nullptr; + + LogicalOpNode *Cond = getLogicalOpNode(SI->getCondition(), Depth + 1); + if (Cond == nullptr) + return nullptr; + + LogicalOpNode *TrueVal = getLogicalOpNode(SI->getTrueValue(), Depth + 1); + if (TrueVal == nullptr) + return nullptr; + + LogicalOpNode *FalseVal = getLogicalOpNode(SI->getFalseValue(), Depth + 1); + if (TrueVal == nullptr) + return nullptr; + + LogicalExpr NewExpr = (Cond->getExpr() & TrueVal->getExpr()) ^ + ((~Cond->getExpr()) & FalseVal->getExpr()); + // TODO: We can reduce the weight if th node can be simplified even if + // it is not the root node. + unsigned Weight = + Cond->getWeight() + TrueVal->getWeight() + FalseVal->getWeight() + 1; + unsigned OneUseWeight = Weight; + if (SI->hasOneUse()) + OneUseWeight = Cond->getOneUseWeight() + TrueVal->getOneUseWeight() + + FalseVal->getOneUseWeight(); + LogicalOpNode *Node = + new LogicalOpNode(this, SI, NewExpr, Weight, OneUseWeight); + LogicalOpNodes[SI] = Node; + return Node; +} + LogicalOpNode *LogicalOpsHelper::getLogicalOpNode(Value *Val, unsigned Depth) { if (Depth == MaxDepthLogicOpsToScan) return nullptr; @@ -179,6 +215,8 @@ // TODO: add select instruction support if (auto *BO = dyn_cast(Val)) Node = visitBinOp(BO, Depth); + else if (auto *SI = dyn_cast(Val)) + Node = visitSelect(SI, Depth); else Node = visitLeafNode(Val, Depth); @@ -205,6 +243,11 @@ if (ExprMask == LogicalExpr::ExprAllOne) return Constant::getAllOnesValue(Node->getValue()->getType()); + // TODO: For now, we use very conserative conditon to do poison check, later + // we can give every node a poison mask to detect more cases + if (ExprMask & PoisonLeafMask) + return nullptr; + unsigned ElementCnt = llvm::popcount(ExprMask); if (ElementCnt == 1) return LeafValues[llvm::Log2_64(ExprMask)]; Index: llvm/lib/Transforms/AggressiveInstCombine/LogicalExpr.h =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/LogicalExpr.h +++ llvm/lib/Transforms/AggressiveInstCombine/LogicalExpr.h @@ -53,6 +53,8 @@ } unsigned size() const { return AddChain.size(); } + uint64_t getLeafMask() const { return LeafMask; } + ExprAddChain::iterator begin() { return AddChain.begin(); } ExprAddChain::iterator end() { return AddChain.end(); } ExprAddChain::const_iterator begin() const { return AddChain.begin(); } Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll =================================================================== --- llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll @@ -28,7 +28,15 @@ define void @test2(i1 %a, i1 %b, i1 %c, i1 %d) { ; CHECK-LABEL: @test2( -; CHECK-NEXT: br i1 [[B:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK-NEXT: [[BD:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[BD]], [[C:%.*]] +; CHECK-NEXT: [[NOT_BD:%.*]] = xor i1 [[XOR]], true +; CHECK-NEXT: [[XOR_AB:%.*]] = xor i1 [[A:%.*]], [[B]] +; CHECK-NEXT: [[OR1:%.*]] = or i1 [[XOR_AB]], [[C]] +; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[NOT_BD]] +; CHECK-NEXT: [[OR3:%.*]] = or i1 [[OR2]], [[A]] +; CHECK-NEXT: [[AND:%.*]] = and i1 [[OR3]], [[B]] +; CHECK-NEXT: br i1 [[AND]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] ; CHECK: if.then: ; CHECK-NEXT: call void @usev() ; CHECK-NEXT: br label [[IF_END]] @@ -55,7 +63,14 @@ define void @test3(i1 %a, i1 %b, i1 %c) { ; CHECK-LABEL: @test3( -; CHECK-NEXT: br i1 [[A:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK-NEXT: [[NOT_A:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[OR1:%.*]] = or i1 [[NOT_A]], [[B:%.*]] +; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[C:%.*]] +; CHECK-NEXT: [[OR3:%.*]] = or i1 [[C]], [[NOT_A]] +; CHECK-NEXT: [[OR4:%.*]] = or i1 [[OR3]], [[B]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[OR2]], [[OR4]] +; CHECK-NEXT: [[COND:%.*]] = xor i1 [[XOR]], [[A]] +; CHECK-NEXT: br i1 [[COND]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] ; CHECK: if.then: ; CHECK-NEXT: tail call void @usev() ; CHECK-NEXT: br label [[IF_END]] @@ -82,8 +97,16 @@ define void @test4(i1 %a, i1 %b, i1 %c, i1 %d) { ; CHECK-LABEL: @test4( -; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B:%.*]], [[D:%.*]] -; CHECK-NEXT: br i1 [[TMP1]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK-NEXT: [[BD:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[BD]], [[C:%.*]] +; CHECK-NEXT: [[NOT_BD:%.*]] = xor i1 [[XOR]], true +; CHECK-NEXT: [[XOR_AB:%.*]] = xor i1 [[A:%.*]], [[B]] +; CHECK-NEXT: [[OR1:%.*]] = or i1 [[XOR_AB]], [[C]] +; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[NOT_BD]] +; CHECK-NEXT: [[OR3:%.*]] = or i1 [[OR2]], [[A]] +; CHECK-NEXT: [[AND:%.*]] = and i1 [[OR3]], [[B]] +; CHECK-NEXT: [[AND2:%.*]] = and i1 [[AND]], [[D]] +; CHECK-NEXT: br i1 [[AND2]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] ; CHECK: if.then: ; CHECK-NEXT: call void @usev() ; CHECK-NEXT: br label [[IF_END]] @@ -117,9 +140,11 @@ ; CHECK-NEXT: [[XOR_AB:%.*]] = xor i1 [[A:%.*]], [[B]] ; CHECK-NEXT: [[OR1:%.*]] = or i1 [[XOR_AB]], [[C]] ; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[NOT_BD]] -; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B]], [[D]] +; CHECK-NEXT: [[OR3:%.*]] = or i1 [[OR2]], [[A]] +; CHECK-NEXT: [[AND:%.*]] = and i1 [[OR3]], [[B]] +; CHECK-NEXT: [[AND2:%.*]] = and i1 [[AND]], [[D]] ; CHECK-NEXT: call void @use1(i1 [[OR2]]) -; CHECK-NEXT: br i1 [[TMP1]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK-NEXT: br i1 [[AND2]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] ; CHECK: if.then: ; CHECK-NEXT: call void @usev() ; CHECK-NEXT: br label [[IF_END]] @@ -146,5 +171,100 @@ ret void } +define void @test6(i1 noundef %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test6( +; CHECK-NEXT: br i1 [[A:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = select i1 %d, i1 %b, i1 false + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %a + br i1 %and, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +define void @test7(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test7( +; CHECK-NEXT: [[BD:%.*]] = select i1 [[D:%.*]], i1 [[B:%.*]], i1 false +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[BD]], [[C:%.*]] +; CHECK-NEXT: [[NOT_BD:%.*]] = xor i1 [[XOR]], true +; CHECK-NEXT: [[XOR_AB:%.*]] = xor i1 [[A:%.*]], [[B]] +; CHECK-NEXT: [[OR1:%.*]] = or i1 [[XOR_AB]], [[C]] +; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[NOT_BD]] +; CHECK-NEXT: [[OR3:%.*]] = or i1 [[OR2]], [[A]] +; CHECK-NEXT: [[AND:%.*]] = and i1 [[OR3]], [[A]] +; CHECK-NEXT: br i1 [[AND]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = select i1 %d, i1 %b, i1 false + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %a + br i1 %and, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +define void @test8(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @test8( +; CHECK-NEXT: [[NOT_A:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[OR1:%.*]] = or i1 [[NOT_A]], [[B:%.*]] +; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[C:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[OR2]], [[A]] +; CHECK-NEXT: [[OR3:%.*]] = or i1 [[NOT_A]], [[C]] +; CHECK-NEXT: [[OR4:%.*]] = select i1 [[OR3]], i1 true, i1 [[B]] +; CHECK-NEXT: [[COND:%.*]] = xor i1 [[XOR]], [[OR4]] +; CHECK-NEXT: br i1 [[COND]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %not.a = xor i1 %a, true + %or1 = or i1 %not.a, %b + %or2 = or i1 %or1, %c + %xor = xor i1 %or2, %a + %or3 = or i1 %not.a, %c + %or4 = select i1 %or3, i1 true, i1 %b + %cond = xor i1 %xor, %or4 + br i1 %cond, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + declare void @usev() declare void @use1(i1)