Index: llvm/include/llvm/Analysis/ComplexLogicCombine.h =================================================================== --- llvm/include/llvm/Analysis/ComplexLogicCombine.h +++ llvm/include/llvm/Analysis/ComplexLogicCombine.h @@ -11,8 +11,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" +#include "llvm/IR/IRBuilder.h" namespace llvm { @@ -23,19 +22,28 @@ LogicalOpsHelper *Helper; Value *Val; LogicalExpr Expr; - // TODO: Add weight to measure cost for more than one use value + unsigned Weight; + unsigned OneUseWeight; void printValue(raw_ostream &OS, Value *Val) const; void printAndChain(raw_ostream &OS, uint64_t Mask) const; public: - LogicalOpNode(LogicalOpsHelper *OpsHelper, Value *SrcVal, - const LogicalExpr &SrcExpr) - : Helper(OpsHelper), Val(SrcVal), Expr(SrcExpr) {} + LogicalOpNode(LogicalOpsHelper *Helper, Value *Val, + const LogicalExpr &SrcExpr, unsigned Weight, + unsigned OneUseWeight) + : Helper(Helper), Val(Val), Expr(SrcExpr), Weight(Weight), + OneUseWeight(OneUseWeight) {} ~LogicalOpNode() {} Value *getValue() const { return Val; } const LogicalExpr &getExpr() const { return Expr; } + unsigned getWeight() const { return Weight; } + unsigned getOneUseWeight() const { return OneUseWeight; } + + bool worthToCombine(unsigned InstCnt) const { + return (OneUseWeight + InstCnt) < Weight; + } void print(raw_ostream &OS) const; }; @@ -59,6 +67,7 @@ LogicalOpNode *visitBinOp(BinaryOperator *BO, unsigned Depth); LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); Value *logicalOpToValue(LogicalOpNode *Node); + Value *buildMask(IRBuilder<> &Builder, Type *Ty, uint64_t Mask); }; inline raw_ostream &operator<<(raw_ostream &OS, const LogicalOpNode &I) { Index: llvm/lib/Analysis/ComplexLogicCombine.cpp =================================================================== --- llvm/lib/Analysis/ComplexLogicCombine.cpp +++ llvm/lib/Analysis/ComplexLogicCombine.cpp @@ -95,7 +95,7 @@ } printAndChain(OS, *Expr.begin()); - OS << "\n"; + OS << "\nWeight: " << Weight << "; OneUseWeight: " << OneUseWeight << ";\n\n"; } void LogicalOpsHelper::clear() { @@ -124,7 +124,8 @@ if (ExprVal != LogicalExpr::ExprAllOne && ExprVal != 0 && LeafSet.insert(Val).second) LeafValues.push_back(Val); - LogicalOpNode *Node = new LogicalOpNode(this, Val, LogicalExpr(ExprVal)); + LogicalOpNode *Node = + new LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0); LogicalOpNodes[Val] = Node; return Node; } @@ -144,13 +145,22 @@ if (RHS == nullptr) return nullptr; - LogicalOpNode *Node; + // TODO: We can reduce the weight if th node can be simplified even if + // it is not the root node. + unsigned Weight = LHS->getWeight() + RHS->getWeight() + 1; + unsigned OneUseWeight = Weight; + if (BO->hasOneUse()) + OneUseWeight = LHS->getOneUseWeight() + RHS->getOneUseWeight(); + + LogicalExpr NewExpr; if (BO->getOpcode() == Instruction::And) - Node = new LogicalOpNode(this, BO, LHS->getExpr() & RHS->getExpr()); + NewExpr = LHS->getExpr() & RHS->getExpr(); else if (BO->getOpcode() == Instruction::Or) - Node = new LogicalOpNode(this, BO, LHS->getExpr() | RHS->getExpr()); + NewExpr = LHS->getExpr() | RHS->getExpr(); else - Node = new LogicalOpNode(this, BO, LHS->getExpr() ^ RHS->getExpr()); + NewExpr = LHS->getExpr() ^ RHS->getExpr(); + LogicalOpNode *Node = + new LogicalOpNode(this, BO, NewExpr, Weight, OneUseWeight); LogicalOpNodes[BO] = Node; return Node; } @@ -181,16 +191,11 @@ if (Expr.size() == 0) return Constant::getNullValue(Node->getValue()->getType()); + Instruction *I = cast(Node->getValue()); + Type *Ty = I->getType(); if (Expr.size() == 1) { - uint64_t ExprMask = *Expr.begin(); - if (ExprMask == 0) - return Constant::getNullValue(Node->getValue()->getType()); - // ExprAllOne is not in the LeafValues - if (ExprMask == LogicalExpr::ExprAllOne) - return Constant::getAllOnesValue(Node->getValue()->getType()); - - if (llvm::popcount(ExprMask) == 1) - return LeafValues[llvm::Log2_64(ExprMask)]; + IRBuilder<> Builder(I); + return buildMask(Builder, Ty, *Expr.begin()); } // TODO: complex pattern simpilify @@ -198,6 +203,29 @@ return nullptr; } +Value *LogicalOpsHelper::buildMask(IRBuilder<> &Builder, Type *Ty, + uint64_t Mask) { + // ExprZero/ExprAllOne is not in the LeafValues + if (Mask == 0) + return Constant::getNullValue(Ty); + if (Mask == LogicalExpr::ExprAllOne) + return Constant::getAllOnesValue(Ty); + + unsigned ElementCnt = llvm::popcount(Mask); + if (ElementCnt == 1) + return LeafValues[llvm::Log2_64(Mask)]; + + unsigned MaskIdx = llvm::countr_zero(Mask); + Value *AndChain = LeafValues[MaskIdx]; + Mask -= (1ULL << MaskIdx); + for (unsigned I = 1; I < ElementCnt; I++) { + MaskIdx = llvm::countr_zero(Mask); + AndChain = Builder.CreateAnd(AndChain, LeafValues[MaskIdx]); + Mask -= (1ULL << MaskIdx); + } + return AndChain; +} + Value *LogicalOpsHelper::simplify(Value *Root) { assert(MaxLogicOpLeafsToScan <= 63 && "Logical leaf node can't larger than 63."); Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic.ll =================================================================== --- llvm/test/Transforms/AggressiveInstCombine/complex-logic.ll +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic.ll @@ -96,6 +96,16 @@ ret i4 %or } +define i1 @leaf2_ret_and(i1 %a, i1 %b) { +; CHECK-LABEL: @leaf2_ret_and( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i1 [[TMP1]] +; + %ab = and i1 %a, %b + %and.ab.a = and i1 %ab, %a + ret i1 %and.ab.a +} + define i1 @leaf3_complex_ret_const_false(i1 %a, i1 %b, i1 %c) { ; CHECK-LABEL: @leaf3_complex_ret_const_false( ; CHECK-NEXT: ret i1 false @@ -121,6 +131,36 @@ ret i1 %cond } +define i1 @leaf3_ret_and_chain(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @leaf3_ret_and_chain( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i1 [[TMP1]], [[C:%.*]] +; CHECK-NEXT: ret i1 [[TMP2]] +; + %ab = and i1 %a, %b + %abc = and i1 %ab, %c + %aabc = and i1 %abc, %a + %aabcc = and i1 %aabc, %c + ret i1 %aabcc +} + +; negative test, extra use cost is equal than it can save + +define i1 @leaf3_ret_and_chain_extra_use(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @leaf3_ret_and_chain_extra_use( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B:%.*]], [[C:%.*]] +; CHECK-NEXT: call void @use1(i1 [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = and i1 [[A:%.*]], [[B]] +; CHECK-NEXT: [[TMP3:%.*]] = and i1 [[TMP2]], [[C]] +; CHECK-NEXT: ret i1 [[TMP3]] +; + %ab = and i1 %a, %b + %bc = and i1 %b, %c + call void @use1(i1 %bc) + %abbc = and i1 %ab, %bc + ret i1 %abbc +} + define i1 @leaf4_ret_const_true(i1 %a, i1 %b, i1 %c, i1 %d) { ; CHECK-LABEL: @leaf4_ret_const_true( ; CHECK-NEXT: ret i1 true @@ -163,3 +203,47 @@ %and = and i1 %or3, %b ret i1 %and } + +define i1 @leaf4_complex_ret_and_chain(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @leaf4_complex_ret_and_chain( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: ret i1 [[TMP1]] +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + %and2 = and i1 %and, %d + ret i1 %and2 +} + +define i1 @leaf4_complex_ret_and_chain_extra_use(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @leaf4_complex_ret_and_chain_extra_use( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[TMP1]], [[C:%.*]] +; CHECK-NEXT: [[NOT_BD:%.*]] = xor i1 [[XOR]], true +; CHECK-NEXT: [[XOR_AB:%.*]] = xor i1 [[A:%.*]], [[B]] +; CHECK-NEXT: [[OR1:%.*]] = or i1 [[XOR_AB]], [[C]] +; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[NOT_BD]] +; CHECK-NEXT: [[TMP2:%.*]] = and i1 [[B]], [[D]] +; CHECK-NEXT: call void @use1(i1 [[OR2]]) +; CHECK-NEXT: ret i1 [[TMP2]] +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + %and2 = and i1 %and, %d + call void @use1(i1 %or2) + ret i1 %and2 +} + +declare void @use1(i1)