Index: llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h +++ llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h @@ -11,8 +11,8 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" namespace llvm { @@ -25,7 +25,7 @@ LogicalExpr Expr; unsigned Weight; unsigned OneUseWeight; - // TODO: Add weight to measure cost for more than one use value + uint64_t PoisonMaskSI; // Record Poison from Select Inst True/False Value void printValue(raw_ostream &OS, Value *Val) const; void printAndChain(raw_ostream &OS, uint64_t Mask) const; @@ -33,15 +33,16 @@ public: LogicalOpNode(LogicalOpsHelper *Helper, Value *Val, const LogicalExpr &SrcExpr, unsigned Weight, - unsigned OneUseWeight) + unsigned OneUseWeight, uint64_t PoisonMaskSI) : Helper(Helper), Val(Val), Expr(SrcExpr), Weight(Weight), - OneUseWeight(OneUseWeight) {} + OneUseWeight(OneUseWeight), PoisonMaskSI(PoisonMaskSI) {} ~LogicalOpNode() {} Value *getValue() const { return Val; } const LogicalExpr &getExpr() const { return Expr; } unsigned getWeight() const { return Weight; } unsigned getOneUseWeight() const { return OneUseWeight; } + uint64_t getPoisonMaskSI() const { return PoisonMaskSI; } void print(raw_ostream &OS) const; }; @@ -59,11 +60,14 @@ SmallDenseMap LogicalOpNodes; SmallPtrSet LeafSet; SmallVector LeafValues; + bool NeedPoisonCheck; + uint64_t PoisonLeafMask; void clear(); LogicalOpNode *visitLeafNode(Value *Val, unsigned Depth); LogicalOpNode *visitBinOp(BinaryOperator *BO, unsigned Depth); + LogicalOpNode *visitSelect(SelectInst *SI, unsigned Depth); LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); Value *logicalOpToValue(LogicalOpNode *Node); Index: llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp @@ -32,6 +32,8 @@ #include "ComplexLogicalOpsCombine.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/CommandLine.h" @@ -99,12 +101,17 @@ printAndChain(OS, *Expr.begin()); OS << "\n"; - OS << "Weight: " << Weight << "; OneUseWeight: " << OneUseWeight << ";\n\n"; + OS << "Weight: " << Weight << "; OneUseWeight: " << OneUseWeight << ";\n"; + OS << "PoisonSrc: 0x" + << Twine::utohexstr(Expr.getLeafMask() & Helper->PoisonLeafMask) + << "; PoisonMaskSI: 0x" << Twine::utohexstr(PoisonMaskSI) << ";\n\n"; } void LogicalOpsHelper::clear() { for (auto node : LogicalOpNodes) delete node.second; + NeedPoisonCheck = false; + PoisonLeafMask = 0; LogicalOpNodes.clear(); LeafSet.clear(); LeafValues.clear(); @@ -129,8 +136,9 @@ LeafSet.insert(Val).second) LeafValues.push_back(Val); LogicalOpNode *Node = - new LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0); + new LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0, 0); LogicalOpNodes[Val] = Node; + PoisonLeafMask |= isGuaranteedNotToBeUndefOrPoison(Val) ? 0 : ExprVal; return Node; } @@ -149,13 +157,6 @@ if (RHS == nullptr) return nullptr; - // TODO: We can reduce the weight if th node can be simplified even if - // it is not the root node. - unsigned Weight = LHS->getWeight() + RHS->getWeight() + 1; - unsigned OneUseWeight = Weight; - if (BO->hasOneUse()) - OneUseWeight = LHS->getOneUseWeight() + RHS->getOneUseWeight(); - LogicalExpr NewExpr; if (BO->getOpcode() == Instruction::And) NewExpr = LHS->getExpr() & RHS->getExpr(); @@ -163,12 +164,67 @@ NewExpr = LHS->getExpr() | RHS->getExpr(); else NewExpr = LHS->getExpr() ^ RHS->getExpr(); + + // TODO: We can reduce the weight if th node can be simplified even if + // it is not the root node. + unsigned Weight = LHS->getWeight() + RHS->getWeight() + 1; + unsigned OneUseWeight = Weight; + if (BO->hasOneUse()) + OneUseWeight = LHS->getOneUseWeight() + RHS->getOneUseWeight(); + + uint64_t PoisonMaskSI = 0; + if (NeedPoisonCheck) { + uint64_t LPI = LHS->getExpr().getLeafMask() ^ LHS->getPoisonMaskSI(); + uint64_t RPI = RHS->getExpr().getLeafMask() ^ RHS->getPoisonMaskSI(); + PoisonMaskSI = + ((~LPI) & RHS->getPoisonMaskSI()) | ((~RPI) & LHS->getPoisonMaskSI()); + PoisonMaskSI &= NewExpr.getLeafMask() & PoisonLeafMask; + } + LogicalOpNode *Node = - new LogicalOpNode(this, BO, NewExpr, Weight, OneUseWeight); + new LogicalOpNode(this, BO, NewExpr, Weight, OneUseWeight, PoisonMaskSI); LogicalOpNodes[BO] = Node; return Node; } +LogicalOpNode *LogicalOpsHelper::visitSelect(SelectInst *SI, unsigned Depth) { + if (!SI->getType()->isIntOrIntVectorTy(1)) + return nullptr; + + LogicalOpNode *Cond = getLogicalOpNode(SI->getCondition(), Depth + 1); + if (Cond == nullptr) + return nullptr; + + LogicalOpNode *TrueVal = getLogicalOpNode(SI->getTrueValue(), Depth + 1); + if (TrueVal == nullptr) + return nullptr; + + LogicalOpNode *FalseVal = getLogicalOpNode(SI->getFalseValue(), Depth + 1); + if (TrueVal == nullptr) + return nullptr; + + LogicalExpr NewExpr = (Cond->getExpr() & TrueVal->getExpr()) ^ + ((~Cond->getExpr()) & FalseVal->getExpr()); + // TODO: We can reduce the weight if th node can be simplified even if + // it is not the root node. + unsigned Weight = + Cond->getWeight() + TrueVal->getWeight() + FalseVal->getWeight() + 1; + unsigned OneUseWeight = Weight; + if (SI->hasOneUse()) + OneUseWeight = Cond->getOneUseWeight() + TrueVal->getOneUseWeight() + + FalseVal->getOneUseWeight(); + + uint64_t PoisonMaskSI = NewExpr.getLeafMask() & PoisonLeafMask; + PoisonMaskSI &= + TrueVal->getExpr().getLeafMask() | FalseVal->getExpr().getLeafMask(); + + LogicalOpNode *Node = + new LogicalOpNode(this, SI, NewExpr, Weight, OneUseWeight, PoisonMaskSI); + LogicalOpNodes[SI] = Node; + NeedPoisonCheck = true; + return Node; +} + LogicalOpNode *LogicalOpsHelper::getLogicalOpNode(Value *Val, unsigned Depth) { if (Depth == MaxDepthLogicOpsToScan) return nullptr; @@ -179,6 +235,8 @@ // TODO: add select instruction support if (auto *BO = dyn_cast(Val)) Node = visitBinOp(BO, Depth); + else if (auto *SI = dyn_cast(Val)) + Node = visitSelect(SI, Depth); else Node = visitLeafNode(Val, Depth); @@ -205,6 +263,9 @@ if (ExprMask == LogicalExpr::ExprAllOne) return Constant::getAllOnesValue(Node->getValue()->getType()); + if (NeedPoisonCheck && Node->getPoisonMaskSI() != 0) + return nullptr; + unsigned ElementCnt = llvm::popcount(ExprMask); if (ElementCnt == 1) return LeafValues[llvm::Log2_64(ExprMask)]; @@ -237,6 +298,7 @@ } Value *LogicalOpsHelper::simplify(Value *Root) { + NeedPoisonCheck = false; LogicalOpNode *RootNode = getLogicalOpNode(Root); Root = logicalOpToValue(RootNode); if (Root != nullptr) Index: llvm/lib/Transforms/AggressiveInstCombine/LogicalExpr.h =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/LogicalExpr.h +++ llvm/lib/Transforms/AggressiveInstCombine/LogicalExpr.h @@ -53,6 +53,8 @@ } unsigned size() const { return AddChain.size(); } + uint64_t getLeafMask() const { return LeafMask; } + ExprAddChain::iterator begin() { return AddChain.begin(); } ExprAddChain::iterator end() { return AddChain.end(); } ExprAddChain::const_iterator begin() const { return AddChain.begin(); } Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll =================================================================== --- llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll @@ -146,5 +146,85 @@ ret void } +define void @test6(i1 noundef %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test6( +; CHECK-NEXT: br i1 [[A:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = select i1 %d, i1 %b, i1 false + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %a + br i1 %and, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +define void @test7(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test7( +; CHECK-NEXT: br i1 [[A:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = select i1 %d, i1 %b, i1 false + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %a + br i1 %and, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +define void @test8(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @test8( +; CHECK-NEXT: br i1 [[A:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %not.a = xor i1 %a, true + %or1 = or i1 %not.a, %b + %or2 = or i1 %or1, %c + %xor = xor i1 %or2, %a + %or3 = or i1 %not.a, %c + %or4 = select i1 %or3, i1 true, i1 %b + %cond = xor i1 %xor, %or4 + br i1 %cond, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + declare void @usev() declare void @use1(i1)