diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -39,6 +39,8 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include + using namespace llvm; #define DEBUG_TYPE "riscv-lower" @@ -8290,6 +8292,177 @@ return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false); } +// Verifies conditions to apply an optimization. +// Returns Reference comparison code and three operands A, B, C. +// Conditions for optimization: +// One operand of the compasions has to be common. +// This operand is written to C. +// Two others operands are differend. They are written to A and B. +// Comparisons has to be similar with respect to common operand C. +// e.g. A < C; C > B are similar +// but A < C; B > C are not. +// Reference comparison code is the comparison code if +// common operand is right placed. +// e.g. C > A will be swapped to A < C. +static Optional> +verifyCompareConds(SDNode *N, SelectionDAG &DAG) { + LLVM_DEBUG( + dbgs() << "Checking conditions for comparison operation combining.\n";); + + SDValue V0 = N->getOperand(0); + SDValue V1 = N->getOperand(1); + // Condition 1. Operations have to have same value type. + if (V0.getValueType() != V1.getValueType()) { + return None; + } + // Condition 2.1. Operations have to be used only in logic operation. + if (!V0.hasOneUse() || !V1.hasOneUse()) { + return None; + } + + SDNode *N0 = V0.getNode(); + SDNode *N1 = V1.getNode(); + // Condition 2.2. Operands have to be comparison operations. + if (N0->getOpcode() != ISD::SETCC || N1->getOpcode() != ISD::SETCC) { + return None; + } + // SETCC has three operands: op0, op1, cond. Checks only the first two. + // Returns Iterator to condition operand in comparison. + auto GetCmpIt = [](SDNode *N) { return N->op_begin() + 2; }; + // Finds position of the common operand in the first comparison operation. + auto CommonOpIt0 = std::find_first_of(N0->op_begin(), GetCmpIt(N0), + N1->op_begin(), GetCmpIt(N1)); + + // Condition 3.1. Common operand has to be in comparison. + if (CommonOpIt0 == GetCmpIt(N0)) { + return None; + } + // Condition 3.2. Operations only with integers. + if (!CommonOpIt0->getValueType().isInteger()) { + return None; + } + + // Finds common operand for the second comparison operation. + auto CommonOpIt1 = std::find(N1->op_begin(), GetCmpIt(N1), *CommonOpIt0); + assert(CommonOpIt1 != GetCmpIt(N1) && + "Cannot find common operand in the second comparison operation."); + + // Cast difference_type to int used for the future bit operations. + // pair.first is the position of the common operand for the first + // operation(N0). + // pair.second is the position of the common operand for the + // second operation(N1). + auto CommonPos = + std::make_pair(std::distance(N0->op_begin(), CommonOpIt0), + std::distance(N1->op_begin(), CommonOpIt1)); + auto CheckZeroOne = [](std::pair P) { + return (P.first == 0 || P.first == 1) && (P.second == 0 || P.second == 1); + }; + assert(CheckZeroOne(CommonPos) && + "Position of the operand has to be 0 or 1."); + + // Since Position can be only 0 or 1, inverse positions of the common operand + // to get positions of others operands. + auto DifferPos = + std::make_pair(CommonPos.first ^ 1, CommonPos.second ^ 1); + + LLVM_DEBUG(dbgs() << "Shared operands are on positions: " << CommonPos.first + << " and " << CommonPos.second << '\n';); + // If common operand at the first position then swap operation to convert to + // strict patternt. Common operand has to be right hand side. + ISD::CondCode RefCond = cast(*GetCmpIt(N0))->get(); + ISD::CondCode AssistCode = cast(*GetCmpIt(N1))->get(); + if (!CommonPos.first) + RefCond = ISD::getSetCCSwappedOperands(RefCond); + if (!CommonPos.second) + AssistCode = ISD::getSetCCSwappedOperands(AssistCode); + LLVM_DEBUG(dbgs() << "Reference condition is: " << RefCond << '\n';); + // If there are different comparison operations then do not perform an + // optimization. a < c; c < b -> will be changed to b > c. + if (RefCond != AssistCode) { + return None; + } + + SDValue A = N0->getOperand(DifferPos.first); + SDValue B = N1->getOperand(DifferPos.second); + SDValue C = *CommonOpIt0; + // Condition 4. Compare operands has to be the same type. + if (A.getValueType() != B.getValueType() || + A.getValueType() != C.getValueType()) { + return None; + } + + LLVM_DEBUG( + dbgs() << "The conditions for combining comparisons are satisfied.\n";); + return std::make_tuple(RefCond, A, B, C); +} + +enum class OperationType : unsigned { Signed = 0, Unsigned, None }; +static ISD::NodeType getSelectionCode(OperationType OpT, bool IsAnd, + bool IsGreaterOp) { + // Codes of selection operation. The first index selects signed or unsigned, + // the second index selects MIN/MAX. + static constexpr ISD::NodeType SelectionCodes[2][2] = { + {ISD::SMIN, ISD::SMAX}, {ISD::UMIN, ISD::UMAX}}; + const bool ChooseSelCode = IsAnd ^ IsGreaterOp; + // boolean values will be converted to 0 and 1 and used as indicies in matrix. + const unsigned TyIdx = static_cast(OpT); + assert(TyIdx < 2 && "Access out of boundaries"); + return SelectionCodes[TyIdx][ChooseSelCode]; +} + +// Combines two comparison operation and logic operation to one selection +// operation(min, max) and logic operation. Returns new constructed Node if +// conditions for optimization are satisfied. +static SDValue combineCmpOp(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!Subtarget.hasStdExtZbb()) { + return SDValue(); + } + const auto Props = verifyCompareConds(N, DAG); + // If conditions are invalidated then do not perform an optimization. + if (!Props) { + return SDValue(); + } + + const auto [RefOpcode, A, B, C] = Props.value(); + const auto BitOpcode = N->getOpcode(); + assert((BitOpcode == ISD::AND || BitOpcode == ISD::OR) && + "This optimization can be used only with AND/OR operations"); + const auto CmpOpVT = A.getValueType(); + + // Collection of all Conditions which similar to less. + static const std::unordered_set SetLess = { + ISD::SETULT, ISD::SETULE, ISD::SETLT, ISD::SETLE}; + // Collection of all Conditions which similar to greater. + static const std::unordered_set SetGreater = { + ISD::SETUGT, ISD::SETUGE, ISD::SETGT, ISD::SETGE}; + + bool IsGreaterOp = SetGreater.count(RefOpcode); + if (!IsGreaterOp && !SetLess.count(RefOpcode)) { + return SDValue(); + } + + OperationType TypeOfCmp = OperationType::None; + if (ISD::isUnsignedIntSetCC(RefOpcode)) + TypeOfCmp = OperationType::Unsigned; + else if (ISD::isSignedIntSetCC(RefOpcode)) + TypeOfCmp = OperationType::Signed; + assert(TypeOfCmp != OperationType::None && + "Operation neither with signed or unsigned integers."); + + const bool IsAnd = BitOpcode == ISD::AND; + const ISD::NodeType PickCode = + getSelectionCode(TypeOfCmp, IsAnd, IsGreaterOp); + + SDLoc DL(N); + SDValue Pick = DAG.getNode(PickCode, DL, CmpOpVT, A, B); + SDValue Cmp = + DAG.getSetCC(DL, N->getOperand(0).getValueType(), Pick, C, RefOpcode); + + return Cmp; +} + static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { SDValue N0 = N->getOperand(0); @@ -8311,6 +8484,9 @@ return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And); } + if (SDValue V = combineCmpOp(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) return V; @@ -8330,6 +8506,9 @@ return SHFL; } + if (SDValue V = combineCmpOp(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) return V; // fold (or (select cond, 0, y), x) -> diff --git a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll --- a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll +++ b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll @@ -12,9 +12,8 @@ define i1 @ulo(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ult i64 %b, %c @@ -25,9 +24,8 @@ define i1 @ulo_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ult i64 %b, %c @@ -38,9 +36,8 @@ define i1 @ulo_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ugt i64 %c, %b @@ -51,9 +48,8 @@ define i1 @ulo_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ugt i64 %c, %b @@ -65,9 +61,8 @@ define i1 @ula(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ult i64 %b, %c @@ -78,9 +73,8 @@ define i1 @ula_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ult i64 %b, %c @@ -91,9 +85,8 @@ define i1 @ula_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ugt i64 %c, %b @@ -104,9 +97,8 @@ define i1 @ula_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ugt i64 %c, %b @@ -119,9 +111,8 @@ define i1 @ugo(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ugt i64 %b, %c @@ -132,9 +123,8 @@ define i1 @ugo_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ult i64 %c, %a %l1 = icmp ugt i64 %b, %c @@ -145,9 +135,8 @@ define i1 @ugo_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ult i64 %c, %b @@ -158,9 +147,8 @@ define i1 @ugo_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ult i64 %c, %a %l1 = icmp ult i64 %c, %b @@ -173,11 +161,9 @@ define i1 @ugea(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugea: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: xori a1, a1, 1 -; CHECK-NEXT: sltu a0, a2, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: xori a0, a0, 1 -; CHECK-NEXT: and a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp uge i64 %a, %c %l1 = icmp uge i64 %b, %c @@ -190,9 +176,8 @@ define i1 @uga(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: uga: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ugt i64 %b, %c @@ -205,9 +190,8 @@ define i1 @sla(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: sla: ; CHECK: # %bb.0: -; CHECK-NEXT: slt a1, a1, a0 -; CHECK-NEXT: slt a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: max a1, a1, a2 +; CHECK-NEXT: slt a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp slt i64 %a, %c %l1 = icmp slt i64 %b, %c @@ -215,6 +199,7 @@ ret i1 %res } +; Negative test ; Float check. define i1 @flo(float %c, float %a, float %b) { ; CHECK-RV64I-LABEL: flo: @@ -261,6 +246,7 @@ ret i1 %res } +; Negative test ; Double check. define i1 @dlo(double %c, double %a, double %b) { ; CHECK-LABEL: dlo: @@ -298,6 +284,7 @@ ret i1 %res } +; Negative test ; More than one user define i1 @multi_user(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: multi_user: @@ -315,6 +302,7 @@ ret i1 %out } +; Negative test ; No same comparations define i1 @no_same_ops(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: no_same_ops: