Index: llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -46,6 +46,175 @@ "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, cl::desc("Max number of instructions to scan for aggressive instcombine.")); +static cl::opt MaxLogicOpLeafsToScan( + "aggressive-instcombine-max-logic-op-leafs", cl::init(8), cl::Hidden, + cl::desc("Max depth of logic ops to scan for aggressive instcombine.")); + +static cl::opt MaxDepthLogicOpsToScan( + "aggressive-instcombine-max-depth-logic-ops", cl::init(8), cl::Hidden, + cl::desc("Max depth of logic ops to scan for aggressive instcombine.")); + +static const unsigned ExprZero = 0x40000000; +static const unsigned ExprNegOne = 0x80000000; + +typedef SmallDenseSet LogicExpression; + +static LogicExpression exprAnd(const LogicExpression &LHS, + const LogicExpression &RHS) { + LogicExpression Ret; + for (auto LHSAndArray : LHS) { + // a & 0 -> 0 + if (LHSAndArray & ExprZero) + continue; + for (auto RHSAndArray : RHS) { + // a & 0 -> 0 + if (RHSAndArray & ExprZero) + continue; + unsigned NewAndArray = LHSAndArray | RHSAndArray; + // a & 1 -> a + if (NewAndArray & ExprNegOne) + NewAndArray &= ~ExprNegOne; + // a ^ a -> 0 + if (!Ret.insert(NewAndArray).second) + Ret.erase(NewAndArray); + } + } + return Ret; +} + +static LogicExpression exprXor(LogicExpression LHS, + const LogicExpression &RHS) { + for (auto RHSAndArray : RHS) { + // a ^ a -> 0 + if (!LHS.insert(RHSAndArray).second) + LHS.erase(RHSAndArray); + } + return LHS; +} + +static LogicExpression exprOr(LogicExpression LHS, const LogicExpression &RHS) { + // a | b --> (a & b) ^ a ^ b + return exprXor(exprXor(exprAnd(LHS, RHS), LHS), RHS); +} + +struct LogicOpNode { + Value *Val; + LogicExpression Expression; + // TODO: Add weight to measure cost for more than one use value +}; + +/// This class help to find the simplest expression for a complex logic +/// operation chain. We canonicalize all other ops to and/xor. +/// For example: +/// a | b --> (a & b) ^ a ^ b +/// c ? a : b --> (c & a) ^ ((c ^ true) & b) +/// We use a unsigned set to represent the expression. Every value that is not +/// comes from logic operation should be the leaf node. Leaf node is 1 bit in +/// the unsigned value. For example, we have source a, b, c. The mask for a is +/// 1, b is 2 ,c is 4. +/// a & b & c --> {7} +/// a & b ^ c & a --> {3, 5} +/// a & b ^ c & a ^ b --> {3, 5, 2} +/// Every unsigned value is an and chain. The unsigned set is an xor chain. +/// After that, any logic value can be represented by a unsigned set. +/// For example: +/// r1 = (a | b) & c -> r1 = (a & b & c) ^ (a & c) ^ (b & c) -> {7, 5, 3} +/// Final we need to rebuild the simplest pattern from the expression. For now, +/// we only simplify the code when the expression is leaf or null. +class LogicOpsHelper { +public: + LogicOpsHelper() {} + ~LogicOpsHelper() { clear(); } + + Value *simplify(Value *Root) { + clear(); + LogicOpNode *RootNode = getLogicOpNode(Root); + return logicOpToValue(RootNode); + } + +private: + SmallDenseMap LogicOpNodes; + SmallPtrSet LeafSet; + SmallVector LeafValues; + + void clear() { + for (auto node : LogicOpNodes) + delete node.second; + LogicOpNodes.clear(); + LeafSet.clear(); + LeafValues.clear(); + } + + LogicOpNode *getLogicOpNode(Value *Val, unsigned Depth = 0) { + if (Depth == MaxDepthLogicOpsToScan) + return nullptr; + + if (LogicOpNodes.find(Val) == LogicOpNodes.end()) { + LogicOpNode *Node = new LogicOpNode(); + Node->Val = Val; + + // TODO: add select instruction support + if (auto *BO = dyn_cast(Val)) { + if (BO->getOpcode() == Instruction::And || + BO->getOpcode() == Instruction::Or || + BO->getOpcode() == Instruction::Xor) { + LogicOpNode *LHS = getLogicOpNode(BO->getOperand(0), Depth + 1); + if (LHS == nullptr) + return nullptr; + + LogicOpNode *RHS = getLogicOpNode(BO->getOperand(1), Depth + 1); + if (RHS == nullptr) + return nullptr; + + if (BO->getOpcode() == Instruction::And) + Node->Expression = exprAnd(LHS->Expression, RHS->Expression); + else if (BO->getOpcode() == Instruction::Or) + Node->Expression = exprOr(LHS->Expression, RHS->Expression); + else + Node->Expression = exprXor(LHS->Expression, RHS->Expression); + } + } else { + if (Depth == 0 || LeafSet.size() > MaxLogicOpLeafsToScan) + return nullptr; + + unsigned ExprVal = 1 << LeafSet.size(); + if (auto ConstVal = dyn_cast(Val)) { + if (ConstVal->isZero()) + ExprVal = ExprZero; + else if (ConstVal->isAllOnesValue()) + ExprVal = ExprNegOne; + } + if (LeafSet.insert(Val).second) + LeafValues.push_back(Val); + Node->Expression.insert(ExprVal); + } + LogicOpNodes[Val] = Node; + } + return LogicOpNodes[Val]; + } + + Value *logicOpToValue(LogicOpNode *Node) { + if (Node == nullptr) + return nullptr; + + if (Node->Expression.empty()) + return Constant::getNullValue(Node->Val->getType()); + + if (Node->Expression.size() == 1) { + unsigned Expr = *Node->Expression.begin(); + if (Expr == ExprNegOne) + return Constant::getAllOnesValue(Node->Val->getType()); + + if (llvm::popcount(Expr) == 1) + return LeafValues[llvm::Log2_32(Expr)]; + } + + // TODO: complex pattern simpilify + + return nullptr; + } +}; + /// Match a pattern for a bitwise funnel/rotate operation that partially guards /// against undefined behavior by branching around the funnel-shift/rotation /// when the shift amount is 0. @@ -855,6 +1024,19 @@ // bugs. MadeChange |= foldSqrt(I, TTI, TLI); } + + // Simplify complex logic ops. + if (auto *BI = dyn_cast(BB.getTerminator())) { + if (BI->isConditional()) { + LogicOpsHelper Helper; + Value *Cond = BI->getCondition(); + Value *NewCond = Helper.simplify(Cond); + if (NewCond) { + Cond->replaceAllUsesWith(NewCond); + MadeChange = true; + } + } + } } // We're done with transforms, so remove dead instructions. Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll @@ -0,0 +1,56 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=aggressive-instcombine -S | FileCheck %s + +define void @test1(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test1( +; CHECK-NEXT: br i1 true, label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = and i1 %b, %d + %not.bd = xor i1 %bd, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + br i1 %or3, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +define void @test2(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: br i1 [[B:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + br i1 %and, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +declare void @usev()