Index: llvm/include/llvm/Analysis/LogicCombine.h =================================================================== --- llvm/include/llvm/Analysis/LogicCombine.h +++ llvm/include/llvm/Analysis/LogicCombine.h @@ -24,20 +24,22 @@ LogicalExpr Expr; unsigned Weight; unsigned OneUseWeight; + uint64_t PoisonMaskSI; // Record Poison from Select Inst True/False Value void printAndChain(raw_ostream &OS, uint64_t LeafBits) const; public: LogicalOpNode(LogicCombiner *Helper, Value *Val, const LogicalExpr &SrcExpr, - unsigned Weight, unsigned OneUseWeight) + unsigned Weight, unsigned OneUseWeight, uint64_t PoisonMaskSI) : Helper(Helper), Val(Val), Expr(SrcExpr), Weight(Weight), - OneUseWeight(OneUseWeight) {} + OneUseWeight(OneUseWeight), PoisonMaskSI(PoisonMaskSI) {} ~LogicalOpNode() {} Value *getValue() const { return Val; } const LogicalExpr &getExpr() const { return Expr; } unsigned getWeight() const { return Weight; } unsigned getOneUseWeight() const { return OneUseWeight; } + uint64_t getPoisonMaskSI() const { return PoisonMaskSI; } bool worthToCombine(unsigned InstCnt) const { return (OneUseWeight + InstCnt) < Weight; @@ -47,7 +49,7 @@ class LogicCombiner { public: - LogicCombiner() {} + LogicCombiner() : LogicalOpNodes(), LeafValues(), LeafsMayPoison() {} ~LogicCombiner() { clear(); } Value *simplify(Value *Root); @@ -58,11 +60,13 @@ SpecificBumpPtrAllocator Alloc; SmallDenseMap LogicalOpNodes; SmallSetVector LeafValues; + uint64_t LeafsMayPoison; void clear(); LogicalOpNode *visitLeafNode(Value *Val, unsigned Depth); LogicalOpNode *visitBinOp(BinaryOperator *BO, unsigned Depth); + LogicalOpNode *visitSelect(SelectInst *SI, unsigned Depth); LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); Value *logicalOpToValue(LogicalOpNode *Node); Value *buildAndChain(IRBuilder<> &Builder, Type *Ty, uint64_t LeafBits); Index: llvm/include/llvm/Analysis/LogicalExpr.h =================================================================== --- llvm/include/llvm/Analysis/LogicalExpr.h +++ llvm/include/llvm/Analysis/LogicalExpr.h @@ -75,6 +75,8 @@ } unsigned size() const { return AddChain.size(); } + uint64_t getLeafMask() const { return LeafMask; } + ExprAddChain::iterator begin() { return AddChain.begin(); } ExprAddChain::iterator end() { return AddChain.end(); } ExprAddChain::const_iterator begin() const { return AddChain.begin(); } Index: llvm/lib/Analysis/LogicCombine.cpp =================================================================== --- llvm/lib/Analysis/LogicCombine.cpp +++ llvm/lib/Analysis/LogicCombine.cpp @@ -32,6 +32,8 @@ #include "llvm/Analysis/LogicCombine.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -89,10 +91,14 @@ printAndChain(OS, *I); } - OS << "\nWeight: " << Weight << "; OneUseWeight: " << OneUseWeight << ";\n\n"; + OS << "\nWeight: " << Weight << "; OneUseWeight: " << OneUseWeight << ";\n"; + OS << "PoisonSrc: 0x" + << Twine::utohexstr(Expr.getLeafMask() & Helper->LeafsMayPoison) + << "; PoisonMaskSI: 0x" << Twine::utohexstr(PoisonMaskSI) << ";\n\n"; } void LogicCombiner::clear() { + LeafsMayPoison = 0; LogicalOpNodes.clear(); LeafValues.clear(); } @@ -112,10 +118,13 @@ else if (ConstVal->isAllOnesValue()) ExprVal = LogicalExpr::ExprAllOne; } - if (ExprVal != LogicalExpr::ExprAllOne && ExprVal != 0) + if (ExprVal != LogicalExpr::ExprAllOne && ExprVal != 0) { + if (!isGuaranteedNotToBeUndefOrPoison(Val)) + LeafsMayPoison |= ExprVal; LeafValues.insert(Val); + } LogicalOpNode *Node = new (Alloc.Allocate()) - LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0); + LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0, 0); LogicalOpNodes[Val] = Node; return Node; } @@ -146,12 +155,57 @@ NewExpr = LHS->getExpr() | RHS->getExpr(); else NewExpr = LHS->getExpr() ^ RHS->getExpr(); + + uint64_t PoisonMaskSI = 0; + uint64_t LPI = LHS->getExpr().getLeafMask() ^ LHS->getPoisonMaskSI(); + uint64_t RPI = RHS->getExpr().getLeafMask() ^ RHS->getPoisonMaskSI(); + PoisonMaskSI = + ((~LPI) & RHS->getPoisonMaskSI()) | ((~RPI) & LHS->getPoisonMaskSI()); + PoisonMaskSI &= NewExpr.getLeafMask() & LeafsMayPoison; + LogicalOpNode *Node = new (Alloc.Allocate()) - LogicalOpNode(this, BO, NewExpr, Weight, OneUseWeight); + LogicalOpNode(this, BO, NewExpr, Weight, OneUseWeight, PoisonMaskSI); LogicalOpNodes[BO] = Node; return Node; } +LogicalOpNode *LogicCombiner::visitSelect(SelectInst *SI, unsigned Depth) { + if (!SI->getType()->isIntOrIntVectorTy(1)) + return nullptr; + + LogicalOpNode *Cond = getLogicalOpNode(SI->getCondition(), Depth + 1); + if (Cond == nullptr) + return nullptr; + + LogicalOpNode *TrueVal = getLogicalOpNode(SI->getTrueValue(), Depth + 1); + if (TrueVal == nullptr) + return nullptr; + + LogicalOpNode *FalseVal = getLogicalOpNode(SI->getFalseValue(), Depth + 1); + if (FalseVal == nullptr) + return nullptr; + + LogicalExpr NewExpr = (Cond->getExpr() & TrueVal->getExpr()) ^ + ((~Cond->getExpr()) & FalseVal->getExpr()); + // TODO: We can reduce the weight if th node can be simplified even if + // it is not the root node. + unsigned Weight = + Cond->getWeight() + TrueVal->getWeight() + FalseVal->getWeight() + 1; + unsigned OneUseWeight = Weight; + if (SI->hasOneUse()) + OneUseWeight = Cond->getOneUseWeight() + TrueVal->getOneUseWeight() + + FalseVal->getOneUseWeight(); + + uint64_t PoisonMaskSI = NewExpr.getLeafMask() & LeafsMayPoison; + PoisonMaskSI &= + TrueVal->getExpr().getLeafMask() | FalseVal->getExpr().getLeafMask(); + + LogicalOpNode *Node = new (Alloc.Allocate()) + LogicalOpNode(this, SI, NewExpr, Weight, OneUseWeight, PoisonMaskSI); + LogicalOpNodes[SI] = Node; + return Node; +} + LogicalOpNode *LogicCombiner::getLogicalOpNode(Value *Val, unsigned Depth) { if (Depth == MaxDepthLogicOpsToScan) return nullptr; @@ -162,6 +216,8 @@ // TODO: add select instruction support if (auto *BO = dyn_cast(Val)) Node = visitBinOp(BO, Depth); + else if (auto *SI = dyn_cast(Val)) + Node = visitSelect(SI, Depth); else Node = visitLeafNode(Val, Depth); @@ -178,6 +234,9 @@ if (Expr.size() == 0) return Constant::getNullValue(Node->getValue()->getType()); + if (Node->getPoisonMaskSI() != 0) + return nullptr; + Instruction *I = cast(Node->getValue()); Type *Ty = I->getType(); if (Expr.size() == 1) { Index: llvm/test/Transforms/AggressiveInstCombine/logic-combine.ll =================================================================== --- llvm/test/Transforms/AggressiveInstCombine/logic-combine.ll +++ llvm/test/Transforms/AggressiveInstCombine/logic-combine.ll @@ -161,6 +161,59 @@ ret i32 %abbc } +define i1 @leaf3_select_ret_and(i1 %a, i1 %b, i1 noundef %c) { +; CHECK-LABEL: @leaf3_select_ret_and( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[A:%.*]], [[C:%.*]] +; CHECK-NEXT: ret i1 [[TMP1]] +; + %ab = and i1 %a, %b + %si = select i1 %a, i1 %c, i1 %b + %xor2 = xor i1 %si, %b + %cond = xor i1 %xor2, %ab + ret i1 %cond +} + +define i1 @leaf3_select_ret_and2(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @leaf3_select_ret_and2( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i1 [[TMP1]] +; + %ac = and i1 %a, %c + %si = select i1 %a, i1 %c, i1 %b + %xor2 = xor i1 %si, %b + %cond = xor i1 %xor2, %ac + ret i1 %cond +} + +define i1 @leaf3_select_ret_leaf(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @leaf3_select_ret_leaf( +; CHECK-NEXT: ret i1 [[B:%.*]] +; + %ab = and i1 %a, %b + %ac = and i1 %a, %c + %si = select i1 %a, i1 %c, i1 %b + %xor2 = xor i1 %si, %ab + %cond = xor i1 %xor2, %ac + ret i1 %cond +} + +; negative test, may have poison + +define i1 @leaf3_select_undef_ret_and(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @leaf3_select_undef_ret_and( +; CHECK-NEXT: [[AB:%.*]] = and i1 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[SI:%.*]] = select i1 [[A]], i1 [[C:%.*]], i1 [[B]] +; CHECK-NEXT: [[XOR2:%.*]] = xor i1 [[SI]], [[B]] +; CHECK-NEXT: [[COND:%.*]] = xor i1 [[XOR2]], [[AB]] +; CHECK-NEXT: ret i1 [[COND]] +; + %ab = and i1 %a, %b + %si = select i1 %a, i1 %c, i1 %b + %xor2 = xor i1 %si, %b + %cond = xor i1 %xor2, %ab + ret i1 %cond +} + define i8 @leaf4_ret_const_true(i8 %a, i8 %b, i8 %c, i8 %d) { ; CHECK-LABEL: @leaf4_ret_const_true( ; CHECK-NEXT: ret i8 -1 @@ -329,5 +382,67 @@ ret i32 %and2 } +define i1 @leaf4_select_noundef_complex_ret_leaf(i1 noundef %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @leaf4_select_noundef_complex_ret_leaf( +; CHECK-NEXT: ret i1 [[A:%.*]] +; + %bd = select i1 %d, i1 %b, i1 false + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %a + ret i1 %and +} + +define i1 @leaf4_select_noundef_complex_ret_and(i1 %a, i1 %b, i1 %c, i1 noundef %d) { +; CHECK-LABEL: @leaf4_select_noundef_complex_ret_and( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: ret i1 [[TMP1]] +; + %ab = and i1 %a, %b + %bc = and i1 %b, %c + %xor.ac = xor i1 %a, %c + %or = or i1 %ab, %xor.ac + %not.bc = xor i1 %bc, true + %and = and i1 %not.bc, %a + %xor = xor i1 %and, %or + %si = select i1 %b, i1 %d, i1 %xor + %xor2 = xor i1 %si, %c + %cond = xor i1 %xor2, %bc + ret i1 %cond +} + +define i1 @leaf4_select_poison_masked_complex_ret_leaf(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @leaf4_select_poison_masked_complex_ret_leaf( +; CHECK-NEXT: ret i1 [[A:%.*]] +; + %bd = select i1 %d, i1 %b, i1 false + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %a + ret i1 %and +} + +define i1 @leaf4_select_poison_masked_complex_ret_leaf2(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @leaf4_select_poison_masked_complex_ret_leaf2( +; CHECK-NEXT: ret i1 [[A:%.*]] +; + %not.a = xor i1 %a, true + %or1 = or i1 %not.a, %b + %or2 = or i1 %or1, %c + %xor = xor i1 %or2, %a + %or3 = or i1 %not.a, %c + %or4 = select i1 %or3, i1 true, i1 %b + %cond = xor i1 %xor, %or4 + ret i1 %cond +} + declare void @use8(i8) declare void @use32(i32)