Index: llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "AggressiveInstCombineInternal.h" +#include "ComplexLogicalOpsCombine.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" @@ -855,6 +856,20 @@ // bugs. MadeChange |= foldSqrt(I, TTI, TLI); } + + // TODO: enable it for all logical instructions. + // Simplify complex logic ops. + if (auto *BI = dyn_cast(BB.getTerminator())) { + if (BI->isConditional()) { + LogicalOpsHelper Helper; + Value *Cond = BI->getCondition(); + Value *NewCond = Helper.simplify(Cond); + if (NewCond) { + Cond->replaceAllUsesWith(NewCond); + MadeChange = true; + } + } + } } // We're done with transforms, so remove dead instructions. Index: llvm/lib/Transforms/AggressiveInstCombine/CMakeLists.txt =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/CMakeLists.txt +++ llvm/lib/Transforms/AggressiveInstCombine/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_component_library(LLVMAggressiveInstCombine AggressiveInstCombine.cpp TruncInstCombine.cpp + ComplexLogicalOpsCombine.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms Index: llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h =================================================================== --- /dev/null +++ llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.h @@ -0,0 +1,70 @@ +//===----- ComplexLogicalOpsCombine.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" + +namespace llvm { + +class LogicalOpsHelper; + +typedef SmallDenseSet LogicalExpr; + +class LogicalOpNode { +private: + LogicalOpsHelper *Helper; + Value *Val; + // TODO: Add weight to measure cost for more than one use value + + void printValue(raw_ostream &OS, Value *Val) const; + void printAndChain(raw_ostream &OS, unsigned AndChain) const; + +public: + LogicalOpNode(LogicalOpsHelper *OpsHelper, Value *OrigVal) + : Helper(OpsHelper), Val(OrigVal) {} + ~LogicalOpNode() {} + + Value *getValue() const { return Val; } + void print(raw_ostream &OS) const; + + LogicalExpr Expr; +}; + +class LogicalOpsHelper { +public: + LogicalOpsHelper() {} + ~LogicalOpsHelper() { clear(); } + + Value *simplify(Value *Root); + +private: + friend class LogicalOpNode; + + SmallDenseMap LogicalOpNodes; + SmallPtrSet LeafSet; + SmallVector LeafValues; + + void clear(); + + bool visitLeafNode(LogicalOpNode *Node, Value *Val, unsigned Depth); + bool visitBinOp(LogicalOpNode *Node, BinaryOperator *BO, unsigned Depth); + LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); + Value *logicalOpToValue(LogicalOpNode *Node); +}; + +inline raw_ostream &operator<<(raw_ostream &OS, const LogicalOpNode &I) { + I.print(OS); + return OS; +} + +} // namespace llvm \ No newline at end of file Index: llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp =================================================================== --- /dev/null +++ llvm/lib/Transforms/AggressiveInstCombine/ComplexLogicalOpsCombine.cpp @@ -0,0 +1,241 @@ +//===-------- ComplexLogicalOpsCombine.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file help to find the simplest expression for a complex logic +// operation chain. We canonicalize all other ops to and/xor. +// For example: +// a | b --> (a & b) ^ a ^ b +// c ? a : b --> (c & a) ^ ((c ^ true) & b) +// We use a unsigned set to represent the expression. Every value that is not +// comes from logic operation should be the leaf node. Leaf node is 1 bit in +// the unsigned value. For example, we have source a, b, c. The mask for a is +// 1, b is 2 ,c is 4. +// a & b & c --> {7} +// a & b ^ c & a --> {3, 5} +// a & b ^ c & a ^ b --> {3, 5, 2} +// Every unsigned value is an and chain. The unsigned set is an xor chain. +// After that, any logic value can be represented by a unsigned set. +// For example: +// r1 = (a | b) & c -> r1 = (a & b & c) ^ (a & c) ^ (b & c) -> {7, 5, 3} +// Final we need to rebuild the simplest pattern from the expression. For now, +// we only simplify the code when the expression is leaf or null. +// +//===----------------------------------------------------------------------===// + +#include "ComplexLogicalOpsCombine.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "aggressive-instcombine" + +STATISTIC(NumComplexLogicalOpsSimplified, + "Number of complex logical operations simplified"); + +static cl::opt MaxLogicOpLeafsToScan( + "aggressive-instcombine-max-logic-op-leafs", cl::init(8), cl::Hidden, + cl::desc("Max depth of logic ops to scan for aggressive instcombine.")); + +static cl::opt MaxDepthLogicOpsToScan( + "aggressive-instcombine-max-depth-logic-ops", cl::init(8), cl::Hidden, + cl::desc("Max depth of logic ops to scan for aggressive instcombine.")); + +static const unsigned ExprZero = 0x40000000; +static const unsigned ExprNegOne = 0x80000000; + +static LogicalExpr exprAnd(const LogicalExpr &LHS, const LogicalExpr &RHS) { + LogicalExpr Ret; + for (auto LHSAndArray : LHS) { + // a & 0 -> 0 + if (LHSAndArray & ExprZero) + continue; + for (auto RHSAndArray : RHS) { + // a & 0 -> 0 + if (RHSAndArray & ExprZero) + continue; + unsigned NewAndArray = LHSAndArray | RHSAndArray; + // a & 1 -> a + if (NewAndArray & ExprNegOne) + NewAndArray &= ~ExprNegOne; + // a ^ a -> 0 + if (!Ret.insert(NewAndArray).second) + Ret.erase(NewAndArray); + } + } + return Ret; +} + +static LogicalExpr exprXor(const LogicalExpr &LHS, const LogicalExpr &RHS) { + LogicalExpr Ret = LHS; + for (auto RHSAndArray : RHS) { + // a ^ a -> 0 + if (!Ret.insert(RHSAndArray).second) + Ret.erase(RHSAndArray); + } + return Ret; +} + +static LogicalExpr exprOr(const LogicalExpr &LHS, const LogicalExpr &RHS) { + // a | b --> (a & b) ^ a ^ b + return exprXor(exprXor(exprAnd(LHS, RHS), LHS), RHS); +} + +void LogicalOpNode::printValue(raw_ostream &OS, Value *Val) const { + if (auto *ConstVal = dyn_cast(Val)) + OS << *ConstVal; + else + OS << Val->getName(); +} + +void LogicalOpNode::printAndChain(raw_ostream &OS, unsigned AndChain) const { + if (AndChain == ExprNegOne) { + OS << "-1"; + return; + } + + unsigned ChainLength = llvm::popcount(AndChain); + if (((AndChain & ExprZero) != 0) || ChainLength == 0) + return; + + if (ChainLength == 1) { + printValue(OS, Helper->LeafValues[llvm::Log2_32(AndChain)]); + return; + } + + unsigned MaskIdx; + for (unsigned I = 1; I < ChainLength; I++) { + MaskIdx = llvm::countTrailingZeros(AndChain); + printValue(OS, Helper->LeafValues[MaskIdx]); + OS << " * "; + AndChain -= (1 << MaskIdx); + } + MaskIdx = llvm::countTrailingZeros(AndChain); + printValue(OS, Helper->LeafValues[MaskIdx]); +} + +void LogicalOpNode::print(raw_ostream &OS) const { + OS << *Val << " --> "; + if (Expr.size() == 0) { + OS << "0\n"; + return; + } + + for (auto I = ++Expr.begin(); I != Expr.end(); I++) { + printAndChain(OS, *I); + OS << " + "; + } + + printAndChain(OS, *Expr.begin()); + OS << "\n"; +} + +void LogicalOpsHelper::clear() { + for (auto node : LogicalOpNodes) + delete node.second; + LogicalOpNodes.clear(); + LeafSet.clear(); + LeafValues.clear(); +} + +bool LogicalOpsHelper::visitLeafNode(LogicalOpNode *Node, Value *Val, + unsigned Depth) { + if (Depth == 0 || LeafSet.size() > MaxLogicOpLeafsToScan) + return false; + + unsigned ExprVal = 1 << LeafSet.size(); + if (auto ConstVal = dyn_cast(Val)) { + if (ConstVal->isZero()) + ExprVal = ExprZero; + else if (ConstVal->isAllOnesValue()) + ExprVal = ExprNegOne; + } + if (LeafSet.insert(Val).second) + LeafValues.push_back(Val); + Node->Expr.insert(ExprVal); + LogicalOpNodes[Val] = Node; + return true; +} + +bool LogicalOpsHelper::visitBinOp(LogicalOpNode *Node, BinaryOperator *BO, + unsigned Depth) { + // We can only to simpilfy and, or , xor in the binary operator + if (BO->getOpcode() != Instruction::And && + BO->getOpcode() != Instruction::Or && BO->getOpcode() != Instruction::Xor) + return visitLeafNode(Node, BO, Depth); + + LogicalOpNode *LHS = getLogicalOpNode(BO->getOperand(0), Depth + 1); + if (LHS == nullptr) + return false; + + LogicalOpNode *RHS = getLogicalOpNode(BO->getOperand(1), Depth + 1); + if (RHS == nullptr) + return false; + + if (BO->getOpcode() == Instruction::And) + Node->Expr = exprAnd(LHS->Expr, RHS->Expr); + else if (BO->getOpcode() == Instruction::Or) + Node->Expr = exprOr(LHS->Expr, RHS->Expr); + else + Node->Expr = exprXor(LHS->Expr, RHS->Expr); + LogicalOpNodes[BO] = Node; + return true; +} + +LogicalOpNode *LogicalOpsHelper::getLogicalOpNode(Value *Val, unsigned Depth) { + if (Depth == MaxDepthLogicOpsToScan) + return nullptr; + + if (LogicalOpNodes.find(Val) == LogicalOpNodes.end()) { + LogicalOpNode *Node = new LogicalOpNode(this, Val); + bool Succeed; + + // TODO: add select instruction support + if (auto *BO = dyn_cast(Val)) + Succeed = visitBinOp(Node, BO, Depth); + else + Succeed = visitLeafNode(Node, Val, Depth); + + if (!Succeed) + delete Node; + + LLVM_DEBUG(dbgs() << *Node); + } + return LogicalOpNodes[Val]; +} + +Value *LogicalOpsHelper::logicalOpToValue(LogicalOpNode *Node) { + if (Node == nullptr) + return nullptr; + + if (Node->Expr.empty()) + return Constant::getNullValue(Node->getValue()->getType()); + + if (Node->Expr.size() == 1) { + unsigned ExprMask = *Node->Expr.begin(); + if (ExprMask == ExprNegOne) + return Constant::getAllOnesValue(Node->getValue()->getType()); + + if (llvm::popcount(ExprMask) == 1) + return LeafValues[llvm::Log2_32(ExprMask)]; + } + + // TODO: complex pattern simpilify + + return nullptr; +} + +Value *LogicalOpsHelper::simplify(Value *Root) { + LogicalOpNode *RootNode = getLogicalOpNode(Root); + Root = logicalOpToValue(RootNode); + if (Root != nullptr) + NumComplexLogicalOpsSimplified++; + return Root; +} Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic-ops.ll @@ -0,0 +1,56 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=aggressive-instcombine -S | FileCheck %s + +define void @test1(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test1( +; CHECK-NEXT: br i1 true, label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = and i1 %b, %d + %not.bd = xor i1 %bd, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + br i1 %or3, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +define void @test2(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: br i1 [[B:%.*]], label [[IF_END:%.*]], label [[IF_THEN:%.*]] +; CHECK: if.then: +; CHECK-NEXT: call void @usev() +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: ret void +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + br i1 %and, label %if.end, label %if.then + +if.then: + call void @usev() + br label %if.end + +if.end: + ret void +} + +declare void @usev()