Index: llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -46,6 +46,201 @@ "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, cl::desc("Max number of instructions to scan for aggressive instcombine.")); +static cl::opt MaxLogicOpLeafsToScan( + "aggressive-instcombine-max-logic-op-leafs", cl::init(8), cl::Hidden, + cl::desc("Max depth of logic ops to scan for aggressive instcombine.")); + +static cl::opt MaxDepthLogicOpsToScan( + "aggressive-instcombine-max-depth-logic-ops", cl::init(8), cl::Hidden, + cl::desc("Max depth of logic ops to scan for aggressive instcombine.")); + +/// This class help to find the simplest expression for a complex logic +/// operation chain. We canonicalize all other ops to and/xor. +/// For example: +/// a | b --> (a & b) ^ a ^ b +/// c ? a : b --> (c & a) ^ ((c ^ true) & b) +/// We use a unsigned set to represent the expression. Every value that is not +/// comes from logic operation should be the leaf node. Leaf node is 1 bit in +/// the unsigned value. For example, we have source a, b, c. The mask for a is +/// 1, b is 2 ,c is 4. +/// a & b & c --> {7} +/// a & b ^ c & a --> {3, 5} +/// a & b ^ c & a ^ b --> {3, 5, 2} +/// Every unsigned value is an and chain. The unsigned set is an xor chain. +/// After that, any logic value can be represented by a unsigned set. +/// For example: +/// r1 = (a | b) & c -> r1 = (a & b & c) ^ (a & c) ^ (b & c) -> {7, 5, 3} +/// Final we need to rebuild the simplest pattern from the expression. For now, +/// we only simplify the code when the expression is leaf or null. +class LogicalOpsHelper { +private: + static const unsigned ExprZero = 0x40000000; + static const unsigned ExprNegOne = 0x80000000; + + typedef SmallDenseSet LogicalExpr; + + static LogicalExpr exprAnd(const LogicalExpr &LHS, const LogicalExpr &RHS) { + LogicalExpr Ret; + for (auto LHSAndArray : LHS) { + // a & 0 -> 0 + if (LHSAndArray & ExprZero) + continue; + for (auto RHSAndArray : RHS) { + // a & 0 -> 0 + if (RHSAndArray & ExprZero) + continue; + unsigned NewAndArray = LHSAndArray | RHSAndArray; + // a & 1 -> a + if (NewAndArray & ExprNegOne) + NewAndArray &= ~ExprNegOne; + // a ^ a -> 0 + if (!Ret.insert(NewAndArray).second) + Ret.erase(NewAndArray); + } + } + return Ret; + } + + static LogicalExpr exprXor(const LogicalExpr &LHS, const LogicalExpr &RHS) { + LogicalExpr Ret = LHS; + for (auto RHSAndArray : RHS) { + // a ^ a -> 0 + if (!Ret.insert(RHSAndArray).second) + Ret.erase(RHSAndArray); + } + return Ret; + } + + static LogicalExpr exprOr(const LogicalExpr &LHS, const LogicalExpr &RHS) { + // a | b --> (a & b) ^ a ^ b + return exprXor(exprXor(exprAnd(LHS, RHS), LHS), RHS); + } + + class LogicalOpNode { + private: + LogicalOpsHelper *Helper; + Value *Val; + + public: + LogicalOpNode(LogicalOpsHelper *OpsHelper, Value *OrigVal) + : Helper(OpsHelper), Val(OrigVal) {} + ~LogicalOpNode() {} + + Value *getValue() const { return Val; } + + LogicalExpr Expr; + // TODO: Add weight to measure cost for more than one use value + }; + +public: + LogicalOpsHelper() {} + ~LogicalOpsHelper() { clear(); } + + Value *simplify(Value *Root) { + clear(); + LogicalOpNode *RootNode = getLogicalOpNode(Root); + return logicalOpToValue(RootNode); + } + +private: + friend class LogicalOpNode; + + SmallDenseMap LogicalOpNodes; + SmallPtrSet LeafSet; + SmallVector LeafValues; + + void clear() { + for (auto node : LogicalOpNodes) + delete node.second; + LogicalOpNodes.clear(); + LeafSet.clear(); + LeafValues.clear(); + } + + bool visitLeafNode(LogicalOpNode *Node, Value *Val, unsigned Depth) { + if (Depth == 0 || LeafSet.size() > MaxLogicOpLeafsToScan) + return false; + + unsigned ExprVal = 1 << LeafSet.size(); + if (auto ConstVal = dyn_cast(Val)) { + if (ConstVal->isZero()) + ExprVal = ExprZero; + else if (ConstVal->isAllOnesValue()) + ExprVal = ExprNegOne; + } + if (LeafSet.insert(Val).second) + LeafValues.push_back(Val); + Node->Expr.insert(ExprVal); + LogicalOpNodes[Val] = Node; + return true; + } + + bool visitBinOp(LogicalOpNode *Node, BinaryOperator *BO, unsigned Depth) { + if (BO->getOpcode() != Instruction::And && + BO->getOpcode() != Instruction::Or && + BO->getOpcode() != Instruction::Xor) + return visitLeafNode(Node, BO, Depth); + + LogicalOpNode *LHS = getLogicalOpNode(BO->getOperand(0), Depth + 1); + if (LHS == nullptr) + return false; + + LogicalOpNode *RHS = getLogicalOpNode(BO->getOperand(1), Depth + 1); + if (RHS == nullptr) + return false; + + if (BO->getOpcode() == Instruction::And) + Node->Expr = exprAnd(LHS->Expr, RHS->Expr); + else if (BO->getOpcode() == Instruction::Or) + Node->Expr = exprOr(LHS->Expr, RHS->Expr); + else + Node->Expr = exprXor(LHS->Expr, RHS->Expr); + LogicalOpNodes[BO] = Node; + return true; + } + + LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0) { + if (Depth == MaxDepthLogicOpsToScan) + return nullptr; + + if (LogicalOpNodes.find(Val) == LogicalOpNodes.end()) { + LogicalOpNode *Node = new LogicalOpNode(this, Val); + bool Succeed; + + // TODO: add select instruction support + if (auto *BO = dyn_cast(Val)) + Succeed = visitBinOp(Node, BO, Depth); + else + Succeed = visitLeafNode(Node, Val, Depth); + + if (!Succeed) + delete Node; + } + return LogicalOpNodes[Val]; + } + + Value *logicalOpToValue(LogicalOpNode *Node) { + if (Node == nullptr) + return nullptr; + + if (Node->Expr.empty()) + return Constant::getNullValue(Node->getValue()->getType()); + + if (Node->Expr.size() == 1) { + unsigned ExprMask = *Node->Expr.begin(); + if (ExprMask == ExprNegOne) + return Constant::getAllOnesValue(Node->getValue()->getType()); + + if (llvm::popcount(ExprMask) == 1) + return LeafValues[llvm::Log2_32(ExprMask)]; + } + + // TODO: complex pattern simpilify + + return nullptr; + } +}; + /// Match a pattern for a bitwise funnel/rotate operation that partially guards /// against undefined behavior by branching around the funnel-shift/rotation /// when the shift amount is 0. @@ -855,6 +1050,19 @@ // bugs. MadeChange |= foldSqrt(I, TTI, TLI); } + + // Simplify complex logic ops. + if (auto *BI = dyn_cast(BB.getTerminator())) { + if (BI->isConditional()) { + LogicalOpsHelper Helper; + Value *Cond = BI->getCondition(); + Value *NewCond = Helper.simplify(Cond); + if (NewCond) { + Cond->replaceAllUsesWith(NewCond); + MadeChange = true; + } + } + } } // We're done with transforms, so remove dead instructions. Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll @@ -0,0 +1,56 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=aggressive-instcombine -S | FileCheck %s + +define void @test1(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test1( +; CHECK-NEXT: br i1 true, label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = and i1 %b, %d + %not.bd = xor i1 %bd, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + br i1 %or3, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +define void @test2(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: br i1 [[B:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + br i1 %and, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +declare void @usev()