Index: llvm/include/llvm/Analysis/ComplexLogicCombine.h =================================================================== --- llvm/include/llvm/Analysis/ComplexLogicCombine.h +++ llvm/include/llvm/Analysis/ComplexLogicCombine.h @@ -23,19 +23,28 @@ LogicalOpsHelper *Helper; Value *Val; LogicalExpr Expr; - // TODO: Add weight to measure cost for more than one use value + unsigned Weight; + unsigned OneUseWeight; void printValue(raw_ostream &OS, Value *Val) const; void printAndChain(raw_ostream &OS, uint64_t Mask) const; public: - LogicalOpNode(LogicalOpsHelper *OpsHelper, Value *SrcVal, - const LogicalExpr &SrcExpr) - : Helper(OpsHelper), Val(SrcVal), Expr(SrcExpr) {} + LogicalOpNode(LogicalOpsHelper *Helper, Value *Val, + const LogicalExpr &SrcExpr, unsigned Weight, + unsigned OneUseWeight) + : Helper(Helper), Val(Val), Expr(SrcExpr), Weight(Weight), + OneUseWeight(OneUseWeight) {} ~LogicalOpNode() {} Value *getValue() const { return Val; } const LogicalExpr &getExpr() const { return Expr; } + unsigned getWeight() const { return Weight; } + unsigned getOneUseWeight() const { return OneUseWeight; } + + bool worthToCombine(unsigned InstCnt) const { + return (OneUseWeight + InstCnt) < Weight; + } void print(raw_ostream &OS) const; }; @@ -59,6 +68,7 @@ LogicalOpNode *visitBinOp(BinaryOperator *BO, unsigned Depth); LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); Value *logicalOpToValue(LogicalOpNode *Node); + Value *buildAndChain(Instruction *I, uint64_t Mask); }; inline raw_ostream &operator<<(raw_ostream &OS, const LogicalOpNode &I) { Index: llvm/lib/Analysis/ComplexLogicCombine.cpp =================================================================== --- llvm/lib/Analysis/ComplexLogicCombine.cpp +++ llvm/lib/Analysis/ComplexLogicCombine.cpp @@ -32,6 +32,7 @@ #include "llvm/Analysis/ComplexLogicCombine.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -96,7 +97,7 @@ } printAndChain(OS, *Expr.begin()); - OS << "\n"; + OS << "\nWeight: " << Weight << "; OneUseWeight: " << OneUseWeight << ";\n\n"; } void LogicalOpsHelper::clear() { @@ -125,7 +126,8 @@ if (ExprVal != LogicalExpr::ExprAllOne && ExprVal != LogicalExpr::ExprZero && LeafSet.insert(Val).second) LeafValues.push_back(Val); - LogicalOpNode *Node = new LogicalOpNode(this, Val, LogicalExpr(ExprVal)); + LogicalOpNode *Node = + new LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0); LogicalOpNodes[Val] = Node; return Node; } @@ -145,13 +147,22 @@ if (RHS == nullptr) return nullptr; - LogicalOpNode *Node; + // TODO: We can reduce the weight if th node can be simplified even if + // it is not the root node. + unsigned Weight = LHS->getWeight() + RHS->getWeight() + 1; + unsigned OneUseWeight = Weight; + if (BO->hasOneUse()) + OneUseWeight = LHS->getOneUseWeight() + RHS->getOneUseWeight(); + + LogicalExpr NewExpr; if (BO->getOpcode() == Instruction::And) - Node = new LogicalOpNode(this, BO, LHS->getExpr() & RHS->getExpr()); + NewExpr = LHS->getExpr() & RHS->getExpr(); else if (BO->getOpcode() == Instruction::Or) - Node = new LogicalOpNode(this, BO, LHS->getExpr() | RHS->getExpr()); + NewExpr = LHS->getExpr() | RHS->getExpr(); else - Node = new LogicalOpNode(this, BO, LHS->getExpr() ^ RHS->getExpr()); + NewExpr = LHS->getExpr() ^ RHS->getExpr(); + LogicalOpNode *Node = + new LogicalOpNode(this, BO, NewExpr, Weight, OneUseWeight); LogicalOpNodes[BO] = Node; return Node; } @@ -190,8 +201,16 @@ if (ExprMask == LogicalExpr::ExprAllOne) return Constant::getAllOnesValue(Node->getValue()->getType()); - if (llvm::popcount(ExprMask) == 1) + unsigned ElementCnt = llvm::popcount(ExprMask); + if (ElementCnt == 1) return LeafValues[llvm::Log2_64(ExprMask)]; + + unsigned InstCnt = ElementCnt - 1; + // TODO: For now we assume we can't reuse any node from old instruction. + // Later we can search if we can reuse the node is not one use. + if (Node->worthToCombine(InstCnt)) + return buildAndChain(cast(Node->getValue()), ExprMask); + return nullptr; } // TODO: complex pattern simpilify @@ -199,6 +218,20 @@ return nullptr; } +Value *LogicalOpsHelper::buildAndChain(Instruction *I, uint64_t Mask) { + IRBuilder<> Builder(I); + unsigned ElementCnt = llvm::popcount(Mask); + unsigned MaskIdx = llvm::countr_zero(Mask); + Value *AndChain = LeafValues[MaskIdx]; + Mask -= (1ULL << MaskIdx); + for (unsigned I = 1; I < ElementCnt; I++) { + MaskIdx = llvm::countr_zero(Mask); + AndChain = Builder.CreateAnd(AndChain, LeafValues[MaskIdx]); + Mask -= (1ULL << MaskIdx); + } + return AndChain; +} + Value *LogicalOpsHelper::simplify(Value *Root) { assert(MaxLogicOpLeafsToScan <= 62 && "Logical leaf node can't larger than 62."); Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic.ll =================================================================== --- llvm/test/Transforms/AggressiveInstCombine/complex-logic.ll +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic.ll @@ -56,3 +56,47 @@ %cond = xor i1 %and, %or ret i1 %cond } + +define i1 @test5(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test5( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: ret i1 [[TMP1]] +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + %and2 = and i1 %and, %d + ret i1 %and2 +} + +define i1 @test6(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test6( +; CHECK-NEXT: [[BD:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[BD]], [[C:%.*]] +; CHECK-NEXT: [[NOT_BD:%.*]] = xor i1 [[XOR]], true +; CHECK-NEXT: [[XOR_AB:%.*]] = xor i1 [[A:%.*]], [[B]] +; CHECK-NEXT: [[OR1:%.*]] = or i1 [[XOR_AB]], [[C]] +; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[NOT_BD]] +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B]], [[D]] +; CHECK-NEXT: call void @use1(i1 [[OR2]]) +; CHECK-NEXT: ret i1 [[TMP1]] +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + %and2 = and i1 %and, %d + call void @use1(i1 %or2) + ret i1 %and2 +} + +declare void @use1(i1)