Index: llvm/include/llvm/Analysis/ComplexLogicCombine.h =================================================================== --- /dev/null +++ llvm/include/llvm/Analysis/ComplexLogicCombine.h @@ -0,0 +1,69 @@ +//===----------- ComplexLogicCombine.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LogicalExpr.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" + +namespace llvm { + +class LogicalOpsHelper; + +class LogicalOpNode { +private: + LogicalOpsHelper *Helper; + Value *Val; + LogicalExpr Expr; + // TODO: Add weight to measure cost for more than one use value + + void printValue(raw_ostream &OS, Value *Val) const; + void printAndChain(raw_ostream &OS, uint64_t Mask) const; + +public: + LogicalOpNode(LogicalOpsHelper *OpsHelper, Value *SrcVal, + const LogicalExpr &SrcExpr) + : Helper(OpsHelper), Val(SrcVal), Expr(SrcExpr) {} + ~LogicalOpNode() {} + + Value *getValue() const { return Val; } + const LogicalExpr &getExpr() const { return Expr; } + void print(raw_ostream &OS) const; +}; + +class LogicalOpsHelper { +public: + LogicalOpsHelper() {} + ~LogicalOpsHelper() { clear(); } + + Value *simplify(Value *Root); + +private: + friend class LogicalOpNode; + + SmallDenseMap LogicalOpNodes; + SmallPtrSet LeafSet; + SmallVector LeafValues; + + void clear(); + + LogicalOpNode *visitLeafNode(Value *Val, unsigned Depth); + LogicalOpNode *visitBinOp(BinaryOperator *BO, unsigned Depth); + LogicalOpNode *getLogicalOpNode(Value *Val, unsigned Depth = 0); + Value *logicalOpToValue(LogicalOpNode *Node); +}; + +inline raw_ostream &operator<<(raw_ostream &OS, const LogicalOpNode &I) { + I.print(OS); + return OS; +} + +} // namespace llvm Index: llvm/include/llvm/Analysis/LogicalExpr.h =================================================================== --- /dev/null +++ llvm/include/llvm/Analysis/LogicalExpr.h @@ -0,0 +1,132 @@ +//===------------------- LogicalExpr.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A example for the LogicalExpr caculation: For source values {a,b,c,d}, we can +// represent them as a bitmask with 'a' as the least-significant-bit: {dcba}. +// LHS is (a * b * c * d + a * d + b + a * c * d), RHS is (a + a * c). +// Use bit mask to represent the expression: +// {0b1111, 0b1001, 0b0010 , 0b1101} * {0b0001, 0b0101} +// --> +// (0b1111 + 0b1001 + 0b0010 + 0b1101) * (0b0001 + 0b0101) +// --> +// (0b1111 + 0b1001 + 0b0010 + 0b1101) * 0b0001+ (0b1111 + 0b1001 + 0b0010 + +// 0b1101) * 0b0101 +// --> +// (0b1111 | 0b0001) + (0b1001 | 0b0001) + (0b0010 | 0b0001) + (0b1101 | 0b0001) +// + (0b1111 | 0b0101) + (0b1001 | 0b0101) + (0b0010 | 0b0101) + (0b1101 | +// 0b0101) +// --> +// 0b1111 + 0b1001 + 0b0010 + 0b1101 + 0b1111 + 0b1101 + 0b0111 + 0b1101 +// --> +// 0b1001 + 0b0010 + 0b1101 + 0b0111 +// --> +// {0b1001, 0b0010, 0b1101, 0b0111} +// --> +// a * d + b + a * c * d + a * b * c +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DenseSet.h" + +namespace llvm { +typedef SmallDenseSet ExprAddChain; + +class LogicalExpr { +private: + ExprAddChain AddChain; + uint64_t LeafMask; + + inline void updateLeafMask() { + LeafMask = 0; + for (auto Mask : AddChain) + LeafMask |= Mask; + } + +public: + static const uint64_t ExprAllOne = 0x8000000000000000; + static const uint64_t ExprZero = 0x4000000000000000; + + LogicalExpr() {} + LogicalExpr(uint64_t Mask) { + AddChain.insert(Mask); + LeafMask = Mask; + } + LogicalExpr(const ExprAddChain &SrcAddChain) : AddChain(SrcAddChain) { + updateLeafMask(); + } + + unsigned size() const { return AddChain.size(); } + ExprAddChain::iterator begin() { return AddChain.begin(); } + ExprAddChain::iterator end() { return AddChain.end(); } + ExprAddChain::const_iterator begin() const { return AddChain.begin(); } + ExprAddChain::const_iterator end() const { return AddChain.end(); } + + LogicalExpr &operator*=(const LogicalExpr &RHS) { + ExprAddChain NewChain; + for (auto LHSMask : AddChain) { + // a & 0 -> 0 + if (LHSMask & ExprZero) + continue; + for (auto RHSMask : RHS.AddChain) { + // 0 & a -> 0 + if (RHSMask & ExprZero) + continue; + uint64_t NewMask = LHSMask | RHSMask; + // a & 1 -> a + if (NewMask != ExprAllOne && ((NewMask & ExprAllOne) != 0)) + NewMask &= ~ExprAllOne; + // a ^ a -> 0 + if (!NewChain.insert(NewMask).second) + NewChain.erase(NewMask); + } + } + + AddChain = NewChain; + updateLeafMask(); + return *this; + } + + LogicalExpr &operator+=(const LogicalExpr &RHS) { + for (auto RHSMask : RHS.AddChain) { + // a ^ a -> 0 + if (!AddChain.insert(RHSMask).second) + AddChain.erase(RHSMask); + } + updateLeafMask(); + return *this; + } +}; + +inline LogicalExpr operator*(LogicalExpr a, const LogicalExpr &b) { + a *= b; + return a; +} + +inline LogicalExpr operator+(LogicalExpr a, const LogicalExpr &b) { + a += b; + return a; +} + +inline LogicalExpr operator&(const LogicalExpr &a, const LogicalExpr &b) { + return a * b; +} + +inline LogicalExpr operator^(const LogicalExpr &a, const LogicalExpr &b) { + return a + b; +} + +inline LogicalExpr operator|(const LogicalExpr &a, const LogicalExpr &b) { + return a * b + a + b; +} + +inline LogicalExpr operator~(const LogicalExpr &a) { + LogicalExpr AllOneExpr(LogicalExpr::ExprAllOne); + return a + AllOneExpr; +} + +} // namespace llvm Index: llvm/lib/Analysis/CMakeLists.txt =================================================================== --- llvm/lib/Analysis/CMakeLists.txt +++ llvm/lib/Analysis/CMakeLists.txt @@ -46,6 +46,7 @@ CmpInstAnalysis.cpp CostModel.cpp CodeMetrics.cpp + ComplexLogicCombine.cpp ConstantFolding.cpp CycleAnalysis.cpp DDG.cpp Index: llvm/lib/Analysis/ComplexLogicCombine.cpp =================================================================== --- /dev/null +++ llvm/lib/Analysis/ComplexLogicCombine.cpp @@ -0,0 +1,216 @@ +//===-------- ComplexLogicalOpsCombine.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file help to find the simplest expression for a complex logic +// operation chain. We canonicalize all other ops to and/xor. +// For example: +// a | b --> (a & b) ^ a ^ b +// c ? a : b --> (c & a) ^ ((c ^ true) & b) +// We use a mask set to represent the expression. Any value that is not a logic +// operation is a leaf node. Leaf node is 1 bit in the mask. For example, we +// have source a, b, c. The mask for a is 1, b is 2 ,c is 4. +// a & b & c --> {7} +// a & b ^ c & a --> {3, 5} +// a & b ^ c & a ^ b --> {3, 5, 2} +// Every mask is an and chain. The set of masks is a xor chain. +// Based on boolean ring, We can treat & as ring multiplication and ^ as ring +// addition. After that, any logic value can be represented by a unsigned set. +// For example: +// r1 = (a | b) & c -> r1 = (a * b * c) + (a * c) + (b * c) -> {7, 5, 6} +// Final we need to rebuild the simplest pattern from the expression. For now, +// we only simplify the code when the expression is leaf or null. +// +// Reference: https://en.wikipedia.org/wiki/Boolean_ring +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ComplexLogicCombine.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "complex-logic-combine" + +STATISTIC(NumComplexLogicalOpsSimplified, + "Number of complex logical operations simplified"); + +static cl::opt MaxLogicOpLeafsToScan( + "clc-max-logic-leafs", cl::init(8), cl::Hidden, + cl::desc("Max leafs of logic ops to scan for complex logical combine.")); + +static cl::opt MaxDepthLogicOpsToScan( + "clc-max-depth", cl::init(8), cl::Hidden, + cl::desc("Max depth of logic ops to scan for complex logical combine.")); + +void LogicalOpNode::printValue(raw_ostream &OS, Value *Val) const { + if (auto *ConstVal = dyn_cast(Val)) + OS << *ConstVal; + else + OS << Val->getName(); +} + +void LogicalOpNode::printAndChain(raw_ostream &OS, uint64_t Mask) const { + if (Mask == LogicalExpr::ExprAllOne) { + OS << "-1"; + return; + } + + unsigned MulElementCnt = llvm::popcount(Mask); + if (((Mask & LogicalExpr::ExprZero) != 0) || MulElementCnt == 0) + return; + + if (MulElementCnt == 1) { + printValue(OS, Helper->LeafValues[llvm::Log2_64(Mask)]); + return; + } + + unsigned MaskIdx; + for (unsigned I = 1; I < MulElementCnt; I++) { + MaskIdx = llvm::countr_zero(Mask); + printValue(OS, Helper->LeafValues[MaskIdx]); + OS << " * "; + Mask -= (1ULL << MaskIdx); + } + MaskIdx = llvm::countr_zero(Mask); + printValue(OS, Helper->LeafValues[MaskIdx]); +} + +void LogicalOpNode::print(raw_ostream &OS) const { + OS << *Val << " --> "; + if (Expr.size() == 0) { + OS << "0\n"; + return; + } + + for (auto I = ++Expr.begin(); I != Expr.end(); I++) { + printAndChain(OS, *I); + OS << " + "; + } + + printAndChain(OS, *Expr.begin()); + OS << "\n"; +} + +void LogicalOpsHelper::clear() { + for (auto node : LogicalOpNodes) + delete node.second; + LogicalOpNodes.clear(); + LeafSet.clear(); + LeafValues.clear(); +} + +LogicalOpNode *LogicalOpsHelper::visitLeafNode(Value *Val, unsigned Depth) { + // Depth is 0 means the root is not logical operation. We can't + // do anything for that. + if (Depth == 0 || LeafSet.size() > MaxLogicOpLeafsToScan) + return nullptr; + + uint64_t ExprVal = 1ULL << LeafSet.size(); + // Constant Zero,AllOne are special leaf nodes. They involve + // LogicalExpr's calculation so we must detect them at first. + if (auto ConstVal = dyn_cast(Val)) { + if (ConstVal->isZero()) + ExprVal = LogicalExpr::ExprZero; + else if (ConstVal->isAllOnesValue()) + ExprVal = LogicalExpr::ExprAllOne; + } + if (ExprVal != LogicalExpr::ExprAllOne && ExprVal != LogicalExpr::ExprZero && + LeafSet.insert(Val).second) + LeafValues.push_back(Val); + LogicalOpNode *Node = new LogicalOpNode(this, Val, LogicalExpr(ExprVal)); + LogicalOpNodes[Val] = Node; + return Node; +} + +LogicalOpNode *LogicalOpsHelper::visitBinOp(BinaryOperator *BO, + unsigned Depth) { + // We can only to simpilfy and, or , xor in the binary operator + if (BO->getOpcode() != Instruction::And && + BO->getOpcode() != Instruction::Or && BO->getOpcode() != Instruction::Xor) + return visitLeafNode(BO, Depth); + + LogicalOpNode *LHS = getLogicalOpNode(BO->getOperand(0), Depth + 1); + if (LHS == nullptr) + return nullptr; + + LogicalOpNode *RHS = getLogicalOpNode(BO->getOperand(1), Depth + 1); + if (RHS == nullptr) + return nullptr; + + LogicalOpNode *Node; + if (BO->getOpcode() == Instruction::And) + Node = new LogicalOpNode(this, BO, LHS->getExpr() & RHS->getExpr()); + else if (BO->getOpcode() == Instruction::Or) + Node = new LogicalOpNode(this, BO, LHS->getExpr() | RHS->getExpr()); + else + Node = new LogicalOpNode(this, BO, LHS->getExpr() ^ RHS->getExpr()); + LogicalOpNodes[BO] = Node; + return Node; +} + +LogicalOpNode *LogicalOpsHelper::getLogicalOpNode(Value *Val, unsigned Depth) { + if (Depth == MaxDepthLogicOpsToScan) + return nullptr; + + if (LogicalOpNodes.find(Val) == LogicalOpNodes.end()) { + LogicalOpNode *Node; + + // TODO: add select instruction support + if (auto *BO = dyn_cast(Val)) + Node = visitBinOp(BO, Depth); + else + Node = visitLeafNode(Val, Depth); + + if (!Node) + return nullptr; + LLVM_DEBUG(dbgs() << *Node); + } + return LogicalOpNodes[Val]; +} + +Value *LogicalOpsHelper::logicalOpToValue(LogicalOpNode *Node) { + const LogicalExpr &Expr = Node->getExpr(); + // Empty happen when all masks are earsed from the set because of a ^ a = 0. + if (Expr.size() == 0) + return Constant::getNullValue(Node->getValue()->getType()); + + if (Expr.size() == 1) { + uint64_t ExprMask = *Expr.begin(); + // ExprZero/ExprAllOne is not in the LeafValues + if (ExprMask == LogicalExpr::ExprZero) + return Constant::getNullValue(Node->getValue()->getType()); + if (ExprMask == LogicalExpr::ExprAllOne) + return Constant::getAllOnesValue(Node->getValue()->getType()); + + if (llvm::popcount(ExprMask) == 1) + return LeafValues[llvm::Log2_64(ExprMask)]; + } + + // TODO: complex pattern simpilify + + return nullptr; +} + +Value *LogicalOpsHelper::simplify(Value *Root) { + assert(MaxLogicOpLeafsToScan <= 62 && + "Logical leaf node can't larger than 62."); + LogicalOpNode *RootNode = getLogicalOpNode(Root); + if (RootNode == nullptr) + return nullptr; + + Value *NewRoot = logicalOpToValue(RootNode); + if (NewRoot == Root) + return nullptr; + + if (NewRoot != nullptr) + NumComplexLogicalOpsSimplified++; + return NewRoot; +} Index: llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -18,6 +18,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/ComplexLogicCombine.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -46,6 +47,10 @@ "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, cl::desc("Max number of instructions to scan for aggressive instcombine.")); +static cl::opt EnableClcForAll( + "enable-clc-for-all", cl::Hidden, cl::init(false), + cl::desc("Enable complex logical combine for every logical operation")); + /// Match a pattern for a bitwise funnel/rotate operation that partially guards /// against undefined behavior by branching around the funnel-shift/rotation /// when the shift amount is 0. @@ -838,6 +843,14 @@ const DataLayout &DL = F.getParent()->getDataLayout(); + LogicalOpsHelper Helper; + auto ComplexLogicalSimplify = [](LogicalOpsHelper &Helper, Value *V) { + Value *NewV = Helper.simplify(V); + if (NewV) + V->replaceAllUsesWith(NewV); + return NewV != nullptr; + }; + // Walk the block backwards for efficiency. We're matching a chain of // use->defs, so we're more likely to succeed by starting from the bottom. // Also, we want to avoid matching partial patterns. @@ -854,6 +867,20 @@ // needs to be called at the end of this sequence, otherwise we may make // bugs. MadeChange |= foldSqrt(I, TTI, TLI); + + if (EnableClcForAll) { + if (I.getOpcode() == Instruction::And || + I.getOpcode() == Instruction::Or || + I.getOpcode() == Instruction::Xor) + MadeChange |= ComplexLogicalSimplify(Helper, &I); + } + } + + if (!EnableClcForAll) { + if (auto *BI = dyn_cast(BB.getTerminator())) { + if (BI->isConditional()) + MadeChange |= ComplexLogicalSimplify(Helper, BI->getCondition()); + } } } Index: llvm/test/Transforms/AggressiveInstCombine/complex-logic.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/AggressiveInstCombine/complex-logic.ll @@ -0,0 +1,58 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=aggressive-instcombine -enable-clc-for-all -S | FileCheck %s + +define i1 @test1(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test1( +; CHECK-NEXT: ret i1 true +; + %bd = and i1 %b, %d + %not.bd = xor i1 %bd, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + ret i1 %or3 +} + +define i1 @test2(i1 %a, i1 %b, i1 %c, i1 %d) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: ret i1 [[B:%.*]] +; + %bd = and i1 %b, %d + %xor = xor i1 %bd, %c + %not.bd = xor i1 %xor, true + %xor.ab = xor i1 %a, %b + %or1 = or i1 %xor.ab, %c + %or2 = or i1 %or1, %not.bd + %or3 = or i1 %or2, %a + %and = and i1 %or3, %b + ret i1 %and +} + +define i1 @test3(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @test3( +; CHECK-NEXT: ret i1 [[A:%.*]] +; + %not.a = xor i1 %a, true + %or1 = or i1 %not.a, %b + %or2 = or i1 %or1, %c + %or3 = or i1 %c, %not.a + %or4 = or i1 %or3, %b + %xor = xor i1 %or2, %or4 + %cond = xor i1 %xor, %a + ret i1 %cond +} + +define i1 @test4(i1 %a, i1 %b, i1 %c) { +; CHECK-LABEL: @test4( +; CHECK-NEXT: ret i1 [[C:%.*]] +; + %ab = and i1 %a, %b + %bc = and i1 %b, %c + %xor.ac = xor i1 %a, %c + %or = or i1 %ab, %xor.ac + %not.bc = xor i1 %bc, true + %and = and i1 %not.bc, %a + %cond = xor i1 %and, %or + ret i1 %cond +}