Index: llvm/include/llvm/Analysis/LogicCombine.h =================================================================== --- llvm/include/llvm/Analysis/LogicCombine.h +++ llvm/include/llvm/Analysis/LogicCombine.h @@ -13,8 +13,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Support/Allocator.h" namespace llvm { @@ -26,18 +25,29 @@ LogicCombiner *Helper; Value *Val; LogicalExpr Expr; - // TODO: Add weight to measure cost for more than one use value + unsigned Weight; + unsigned OneUseWeight; + unsigned ProfitRebuildInstCnt; void printAndChain(raw_ostream &OS, uint64_t LeafBits) const; public: - LogicalOpNode(LogicCombiner *OpsHelper, Value *SrcVal, - const LogicalExpr &SrcExpr) - : Helper(OpsHelper), Val(SrcVal), Expr(SrcExpr) {} + LogicalOpNode(LogicCombiner *Helper, Value *Val, const LogicalExpr &SrcExpr, + unsigned Weight, unsigned OneUseWeight, + unsigned ProfitRebuildInstCnt) + : Helper(Helper), Val(Val), Expr(SrcExpr), Weight(Weight), + OneUseWeight(OneUseWeight), ProfitRebuildInstCnt(ProfitRebuildInstCnt) { + } ~LogicalOpNode() {} Value *getValue() const { return Val; } const LogicalExpr &getExpr() const { return Expr; } + unsigned getWeight() const { return Weight; } + unsigned getOneUseWeight() const { return OneUseWeight; } + + bool worthToCombine(unsigned InstCnt) const { + return InstCnt < ProfitRebuildInstCnt; + } void print(raw_ostream &OS) const; }; @@ -61,6 +71,7 @@ LogicalOpNode *visitBinOp(BinaryOperator *BO, unsigned Depth); LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); Value *logicalOpToValue(LogicalOpNode *Node); + Value *buildAndChain(IRBuilder<> &Builder, Type *Ty, uint64_t LeafBits); }; inline raw_ostream &operator<<(raw_ostream &OS, const LogicalOpNode &I) { Index: llvm/lib/Analysis/LogicCombine.cpp =================================================================== --- llvm/lib/Analysis/LogicCombine.cpp +++ llvm/lib/Analysis/LogicCombine.cpp @@ -89,7 +89,8 @@ printAndChain(OS, *I); } - OS << "\n"; + OS << "\nWeight: " << Weight << "; OneUseWeight: " << OneUseWeight + << "ProfitRebuildInstCnt: " << ProfitRebuildInstCnt << ";\n\n"; } void LogicCombiner::clear() { @@ -114,8 +115,8 @@ } if (ExprVal != LogicalExpr::ExprAllOne && ExprVal != 0) LeafValues.insert(Val); - LogicalOpNode *Node = - new (Alloc.Allocate()) LogicalOpNode(this, Val, LogicalExpr(ExprVal)); + LogicalOpNode *Node = new (Alloc.Allocate()) + LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0, 0); LogicalOpNodes[Val] = Node; return Node; } @@ -132,16 +133,21 @@ if (RHS == nullptr) return nullptr; - LogicalOpNode *Node; + unsigned Weight = LHS->getWeight() + RHS->getWeight() + 1; + unsigned OneUseWeight = LHS->getOneUseWeight() + RHS->getOneUseWeight(); + unsigned ProfitRebuildInstCnt = Weight - OneUseWeight; + if (!BO->hasOneUse()) + OneUseWeight = Weight; + + LogicalExpr NewExpr; if (BO->getOpcode() == Instruction::And) - Node = new (Alloc.Allocate()) - LogicalOpNode(this, BO, LHS->getExpr() & RHS->getExpr()); + NewExpr = LHS->getExpr() & RHS->getExpr(); else if (BO->getOpcode() == Instruction::Or) - Node = new (Alloc.Allocate()) - LogicalOpNode(this, BO, LHS->getExpr() | RHS->getExpr()); + NewExpr = LHS->getExpr() | RHS->getExpr(); else - Node = new (Alloc.Allocate()) - LogicalOpNode(this, BO, LHS->getExpr() ^ RHS->getExpr()); + NewExpr = LHS->getExpr() ^ RHS->getExpr(); + LogicalOpNode *Node = new (Alloc.Allocate()) LogicalOpNode( + this, BO, NewExpr, Weight, OneUseWeight, ProfitRebuildInstCnt); LogicalOpNodes[BO] = Node; return Node; } @@ -172,16 +178,17 @@ if (Expr.size() == 0) return Constant::getNullValue(Node->getValue()->getType()); + Instruction *I = cast(Node->getValue()); + Type *Ty = I->getType(); if (Expr.size() == 1) { uint64_t LeafBits = *Expr.begin(); - if (LeafBits == 0) - return Constant::getNullValue(Node->getValue()->getType()); - // ExprAllOne is not in the LeafValues - if (LeafBits == LogicalExpr::ExprAllOne) - return Constant::getAllOnesValue(Node->getValue()->getType()); - - if (popcount(LeafBits) == 1) - return LeafValues[Log2_64(LeafBits)]; + unsigned InstCnt = popcount(LeafBits) - 1; + // TODO: For now we assume we can't reuse any node from old instruction. + // Later we can search if we can reuse the node is not one use. + if (Node->worthToCombine(InstCnt)) { + IRBuilder<> Builder(I); + return buildAndChain(Builder, Ty, LeafBits); + } } // TODO: find the simplest form from logical expression when it is not @@ -190,6 +197,30 @@ return nullptr; } +Value *LogicCombiner::buildAndChain(IRBuilder<> &Builder, Type *Ty, + uint64_t LeafBits) { + if (LeafBits == 0) + return Constant::getNullValue(Ty); + + // ExprAllOne is not in the LeafValues + if (LeafBits == LogicalExpr::ExprAllOne) + return Constant::getAllOnesValue(Ty); + + unsigned LeafCnt = popcount(LeafBits); + if (LeafCnt == 1) + return LeafValues[Log2_64(LeafBits)]; + + unsigned LeafIdx = countr_zero(LeafBits); + Value *AndChain = LeafValues[LeafIdx]; + LeafBits -= (1ULL << LeafIdx); + for (unsigned I = 1; I < LeafCnt; I++) { + LeafIdx = countr_zero(LeafBits); + AndChain = Builder.CreateAnd(AndChain, LeafValues[LeafIdx]); + LeafBits -= (1ULL << LeafIdx); + } + return AndChain; +} + Value *LogicCombiner::simplify(Value *Root) { assert(MaxLogicOpLeafsToScan <= 63 && "Logical leaf node can't be larger than 63."); Index: llvm/test/Transforms/AggressiveInstCombine/logic-combine.ll =================================================================== --- llvm/test/Transforms/AggressiveInstCombine/logic-combine.ll +++ llvm/test/Transforms/AggressiveInstCombine/logic-combine.ll @@ -96,6 +96,16 @@ ret i1 %or } +define i8 @leaf2_ret_and(i8 %a, i8 %b) { +; CHECK-LABEL: @leaf2_ret_and( +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i8 [[TMP1]] +; + %ab = and i8 %a, %b + %and.ab.a = and i8 %ab, %a + ret i8 %and.ab.a +} + define i8 @leaf3_complex_ret_const_false(i8 %a, i8 %b, i8 %c) { ; CHECK-LABEL: @leaf3_complex_ret_const_false( ; CHECK-NEXT: ret i8 0 @@ -121,6 +131,36 @@ ret i8 %cond } +define i32 @leaf3_ret_and_chain(i32 %a, i32 %b, i32 %c) { +; CHECK-LABEL: @leaf3_ret_and_chain( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[C:%.*]] +; CHECK-NEXT: ret i32 [[TMP2]] +; + %ab = and i32 %a, %b + %abc = and i32 %ab, %c + %aabc = and i32 %abc, %a + %aabcc = and i32 %aabc, %c + ret i32 %aabcc +} + +; negative test, extra use cost is equal than it can save + +define i32 @leaf3_ret_and_chain_extra_use(i32 %a, i32 %b, i32 %c) { +; CHECK-LABEL: @leaf3_ret_and_chain_extra_use( +; CHECK-NEXT: [[AB:%.*]] = and i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[BC:%.*]] = and i32 [[B]], [[C:%.*]] +; CHECK-NEXT: call void @use32(i32 [[BC]]) +; CHECK-NEXT: [[ABBC:%.*]] = and i32 [[AB]], [[BC]] +; CHECK-NEXT: ret i32 [[ABBC]] +; + %ab = and i32 %a, %b + %bc = and i32 %b, %c + call void @use32(i32 %bc) + %abbc = and i32 %ab, %bc + ret i32 %abbc +} + define i8 @leaf4_ret_const_true(i8 %a, i8 %b, i8 %c, i8 %d) { ; CHECK-LABEL: @leaf4_ret_const_true( ; CHECK-NEXT: ret i8 -1 @@ -247,4 +287,47 @@ ret i8 %r } +define i32 @leaf4_complex_ret_and_chain(i32 %a, i32 %b, i32 %c, i32 %d) { +; CHECK-LABEL: @leaf4_complex_ret_and_chain( +; CHECK-NEXT: [[AND2:%.*]] = and i32 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: ret i32 [[AND2]] +; + %bd = and i32 %b, %d + %xor = xor i32 %bd, %c + %not.bd = xor i32 %xor, -1 + %xor.ab = xor i32 %a, %b + %or1 = or i32 %xor.ab, %c + %or2 = or i32 %or1, %not.bd + %or3 = or i32 %or2, %a + %and = and i32 %or3, %b + %and2 = and i32 %and, %d + ret i32 %and2 +} + +define i32 @leaf4_complex_ret_and_chain_extra_use(i32 %a, i32 %b, i32 %c, i32 %d) { +; CHECK-LABEL: @leaf4_complex_ret_and_chain_extra_use( +; CHECK-NEXT: [[BD:%.*]] = and i32 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[BD]], [[C:%.*]] +; CHECK-NEXT: [[NOT_BD:%.*]] = xor i32 [[XOR]], -1 +; CHECK-NEXT: [[XOR_AB:%.*]] = xor i32 [[A:%.*]], [[B]] +; CHECK-NEXT: [[OR1:%.*]] = or i32 [[XOR_AB]], [[C]] +; CHECK-NEXT: [[OR2:%.*]] = or i32 [[OR1]], [[NOT_BD]] +; CHECK-NEXT: [[AND2:%.*]] = and i32 [[B]], [[D]] +; CHECK-NEXT: call void @use32(i32 [[OR2]]) +; CHECK-NEXT: ret i32 [[AND2]] +; + %bd = and i32 %b, %d + %xor = xor i32 %bd, %c + %not.bd = xor i32 %xor, -1 + %xor.ab = xor i32 %a, %b + %or1 = or i32 %xor.ab, %c + %or2 = or i32 %or1, %not.bd + %or3 = or i32 %or2, %a + %and = and i32 %or3, %b + %and2 = and i32 %and, %d + call void @use32(i32 %or2) + ret i32 %and2 +} + declare void @use8(i8) +declare void @use32(i32)