Index: llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h +++ llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h @@ -23,19 +23,26 @@ LogicalOpsHelper *Helper; Value *Val; LogicalExpr Expr; + unsigned Weight; + unsigned OneUseWeight; // TODO: Add weight to measure cost for more than one use value void printValue(raw_ostream &OS, Value *Val) const; void printAndChain(raw_ostream &OS, uint64_t Mask) const; public: - LogicalOpNode(LogicalOpsHelper *OpsHelper, Value *SrcVal, - const LogicalExpr &SrcExpr) - : Helper(OpsHelper), Val(SrcVal), Expr(SrcExpr) {} + LogicalOpNode(LogicalOpsHelper *Helper, Value *Val, + const LogicalExpr &SrcExpr, unsigned Weight, + unsigned OneUseWeight) + : Helper(Helper), Val(Val), Expr(SrcExpr), Weight(Weight), + OneUseWeight(OneUseWeight) {} ~LogicalOpNode() {} Value *getValue() const { return Val; } const LogicalExpr &getExpr() const { return Expr; } + unsigned getWeight() const { return Weight; } + unsigned getOneUseWeight() const { return OneUseWeight; } + void print(raw_ostream &OS) const; }; @@ -58,7 +65,9 @@ LogicalOpNode *visitLeafNode(Value *Val, unsigned Depth); LogicalOpNode *visitBinOp(BinaryOperator *BO, unsigned Depth); LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); + Value *logicalOpToValue(LogicalOpNode *Node); + Value *buildAndChain(Instruction *I, uint64_t Mask); }; inline raw_ostream &operator<<(raw_ostream &OS, const LogicalOpNode &I) { Index: llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp @@ -33,6 +33,7 @@ #include "ComplexLogicalOpsCombine.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -98,6 +99,7 @@ printAndChain(OS, *Expr.begin()); OS << "\n"; + OS << "Weight: " << Weight << "; OneUseWeight: " << OneUseWeight << ";\n\n"; } void LogicalOpsHelper::clear() { @@ -126,7 +128,8 @@ if (ExprVal != LogicalExpr::ExprAllOne && ExprVal != LogicalExpr::ExprZero && LeafSet.insert(Val).second) LeafValues.push_back(Val); - LogicalOpNode *Node = new LogicalOpNode(this, Val, LogicalExpr(ExprVal)); + LogicalOpNode *Node = + new LogicalOpNode(this, Val, LogicalExpr(ExprVal), 0, 0); LogicalOpNodes[Val] = Node; return Node; } @@ -146,13 +149,22 @@ if (RHS == nullptr) return nullptr; - LogicalOpNode *Node; + // TODO: We can reduce the weight if th node can be simplified even if + // it is not the root node. + unsigned Weight = LHS->getWeight() + RHS->getWeight() + 1; + unsigned OneUseWeight = Weight; + if (BO->hasOneUse()) + OneUseWeight = LHS->getOneUseWeight() + RHS->getOneUseWeight(); + + LogicalExpr NewExpr; if (BO->getOpcode() == Instruction::And) - Node = new LogicalOpNode(this, BO, LHS->getExpr() & RHS->getExpr()); + NewExpr = LHS->getExpr() & RHS->getExpr(); else if (BO->getOpcode() == Instruction::Or) - Node = new LogicalOpNode(this, BO, LHS->getExpr() | RHS->getExpr()); + NewExpr = LHS->getExpr() | RHS->getExpr(); else - Node = new LogicalOpNode(this, BO, LHS->getExpr() ^ RHS->getExpr()); + NewExpr = LHS->getExpr() ^ RHS->getExpr(); + LogicalOpNode *Node = + new LogicalOpNode(this, BO, NewExpr, Weight, OneUseWeight); LogicalOpNodes[BO] = Node; return Node; } @@ -193,8 +205,16 @@ if (ExprMask == LogicalExpr::ExprAllOne) return Constant::getAllOnesValue(Node->getValue()->getType()); - if (llvm::popcount(ExprMask) == 1) + unsigned ElementCnt = llvm::popcount(ExprMask); + if (ElementCnt == 1) return LeafValues[llvm::Log2_64(ExprMask)]; + + unsigned InstCnt = ElementCnt - 1; + // TODO: For now we assume we can't reuse any node from old instruction. + // Later we can search if we can reuse the node is not one use. + if ((Node->getOneUseWeight() + InstCnt) < Node->getWeight()) + return buildAndChain(cast(Node->getValue()), ExprMask); + return nullptr; } // TODO: complex pattern simpilify @@ -202,6 +222,20 @@ return nullptr; } +Value *LogicalOpsHelper::buildAndChain(Instruction *I, uint64_t Mask) { + IRBuilder<> Builder(I); + unsigned ElementCnt = llvm::popcount(Mask); + unsigned MaskIdx = llvm::countTrailingZeros(Mask); + Value *AndChain = LeafValues[MaskIdx]; + Mask -= (1 << MaskIdx); + for (unsigned I = 1; I < ElementCnt; I++) { + MaskIdx = llvm::countTrailingZeros(Mask); + AndChain = Builder.CreateAnd(AndChain, LeafValues[MaskIdx]); + Mask -= (1 << MaskIdx); + } + return AndChain; +} + Value *LogicalOpsHelper::simplify(Value *Root) { LogicalOpNode *RootNode = getLogicalOpNode(Root); Root = logicalOpToValue(RootNode); Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll =================================================================== --- llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll @@ -105,4 +105,71 @@ ret void } +define void @test5(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test5( +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: br i1 [[TMP1]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + %and2 = and i1 %and, %d + br i1 %and2, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +define void @test6(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test6( +; CHECK-NEXT: [[BD:%.*]] = and i1 [[B:%.*]], [[D:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[BD]], [[C:%.*]] +; CHECK-NEXT: [[NOT_BD:%.*]] = xor i1 [[XOR]], true +; CHECK-NEXT: [[XOR_AB:%.*]] = xor i1 [[A:%.*]], [[B]] +; CHECK-NEXT: [[OR1:%.*]] = or i1 [[XOR_AB]], [[C]] +; CHECK-NEXT: [[OR2:%.*]] = or i1 [[OR1]], [[NOT_BD]] +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[B]], [[D]] +; CHECK-NEXT: call void @use1(i1 [[OR2]]) +; CHECK-NEXT: br i1 [[TMP1]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + %and2 = and i1 %and, %d + call void @use1(i1 %or2) + br i1 %and2, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + declare void @usev() +declare void @use1(i1)