diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8238,6 +8238,167 @@ return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT)); } +// Verifies conditions to apply an optimization. +// Returns Reference comparison code and three operands A, B, C. +// Conditions for optimization: +// One operand of the compasions has to be common. +// This operand is written to C. +// Two others operands are differend. They are written to A and B. +// Comparisons has to be similar with respect to common operand C. +// e.g. A < C; C > B are similar +// but A < C; B > C are not. +// Reference comparison code is the comparison code if +// common operand is right placed. +// e.g. C > A will be swapped to A < C. +static Optional> +verifyCompareConds(SDNode *N, SelectionDAG &DAG) { + LLVM_DEBUG( + dbgs() << "Checking conditions for comparison operation combining.\n";); + + SDValue V0 = N->getOperand(0); + SDValue V1 = N->getOperand(1); + assert(V0.getValueType() == V1.getValueType() && + "Operations must have the same value type."); + + // Condition 1. Operations have to be used only in logic operation. + if (!V0.hasOneUse() || !V1.hasOneUse()) + return None; + + // Condition 2. Operands have to be comparison operations. + if (V0->getOpcode() != ISD::SETCC || V1->getOpcode() != ISD::SETCC) + return None; + + // SETCC has three operands: op0, op1, cond. Checks only the first two. + // Returns Iterator to condition operand in comparison. + auto GetCmpIt = [](SDValue V) { return V->op_begin() + 2; }; + // Finds position of the common operand in the first comparison operation. + auto CommonOpIt0 = std::find_first_of(V0->op_begin(), GetCmpIt(V0), + V1->op_begin(), GetCmpIt(V1)); + + // Condition 3.1. Common operand has to be in comparison. + if (CommonOpIt0 == GetCmpIt(V0)) + return None; + + // Condition 3.2. Operations only with integers. + if (!CommonOpIt0->getValueType().isInteger()) + return None; + + // Finds common operand for the second comparison operation. + auto CommonOpIt1 = std::find(V1->op_begin(), GetCmpIt(V1), *CommonOpIt0); + assert(CommonOpIt1 != GetCmpIt(V1) && + "Cannot find common operand in the second comparison operation."); + + // Cast difference_type to int used for the future bit operations. + // pair.first is the position of the common operand for the first + // operation(V0). + // pair.second is the position of the common operand for the + // second operation(V1). + auto CommonPos = + std::make_pair(std::distance(V0->op_begin(), CommonOpIt0), + std::distance(V1->op_begin(), CommonOpIt1)); + assert((CommonPos.first == 0 || CommonPos.first == 1) && + "Position of the operand has to be 0 or 1."); + assert((CommonPos.second == 0 || CommonPos.second == 1) && + "Position of the operand has to be 0 or 1."); + + // Since Position can be only 0 or 1, inverse positions of the common operand + // to get positions of others operands. + auto DifferPos = + std::make_pair(CommonPos.first ^ 1, CommonPos.second ^ 1); + + LLVM_DEBUG(dbgs() << "Shared operands are on positions: " << CommonPos.first + << " and " << CommonPos.second << '\n';); + // If common operand at the first position then swap operation to convert to + // strict pattern. Common operand has to be right hand side. + ISD::CondCode RefCond = cast(*GetCmpIt(V0))->get(); + ISD::CondCode AssistCode = cast(*GetCmpIt(V1))->get(); + if (!CommonPos.first) + RefCond = ISD::getSetCCSwappedOperands(RefCond); + if (!CommonPos.second) + AssistCode = ISD::getSetCCSwappedOperands(AssistCode); + LLVM_DEBUG(dbgs() << "Reference condition is: " << RefCond << '\n';); + // If there are different comparison operations then do not perform an + // optimization. a < c; c < b -> will be changed to b > c. + if (RefCond != AssistCode) + return None; + + // Conditions can be only similar to Less or Greater. (>, >=, <, <=) + // Applying this mask to the operation will determine Less and Greater + // operations. + const unsigned CmpMask = 0b110; + const unsigned MaskedOpcode = CmpMask & RefCond; + // If masking gave 0b110, then this is an operation NE, O or TRUE. + if (MaskedOpcode == CmpMask) + return None; + // If masking gave 00000, then this is an operation E, O or FALSE. + if (MaskedOpcode == 0) + return None; + // Everything else is similar to Less or Greater. + + SDValue A = V0->getOperand(DifferPos.first); + SDValue B = V1->getOperand(DifferPos.second); + SDValue C = *CommonOpIt0; + + LLVM_DEBUG( + dbgs() << "The conditions for combining comparisons are satisfied.\n";); + return std::make_tuple(RefCond, A, B, C); +} + +enum class OperationType : unsigned { Signed = 0, Unsigned, None }; +static ISD::NodeType getSelectionCode(OperationType OpT, bool IsAnd, + bool IsGreaterOp) { + // Codes of selection operation. The first index selects signed or unsigned, + // the second index selects MIN/MAX. + static constexpr ISD::NodeType SelectionCodes[2][2] = { + {ISD::SMIN, ISD::SMAX}, {ISD::UMIN, ISD::UMAX}}; + const bool ChooseSelCode = IsAnd ^ IsGreaterOp; + // boolean values will be converted to 0 and 1 and used as indicies in matrix. + const unsigned TyIdx = static_cast(OpT); + assert(TyIdx < 2 && "Access out of boundaries"); + return SelectionCodes[TyIdx][ChooseSelCode]; +} + +// Combines two comparison operation and logic operation to one selection +// operation(min, max) and logic operation. Returns new constructed Node if +// conditions for optimization are satisfied. +static SDValue combineCmpOp(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!Subtarget.hasStdExtZbb()) + return SDValue(); + + const auto BitOpcode = N->getOpcode(); + assert((BitOpcode == ISD::AND || BitOpcode == ISD::OR) && + "This optimization can be used only with AND/OR operations"); + + const auto Props = verifyCompareConds(N, DAG); + // If conditions are invalidated then do not perform an optimization. + if (!Props) + return SDValue(); + + const auto [RefOpcode, A, B, C] = Props.value(); + const auto CmpOpVT = A.getValueType(); + + const bool IsGreaterOp = RefOpcode & 0b10; + OperationType TypeOfCmp = OperationType::None; + if (ISD::isUnsignedIntSetCC(RefOpcode)) + TypeOfCmp = OperationType::Unsigned; + else if (ISD::isSignedIntSetCC(RefOpcode)) + TypeOfCmp = OperationType::Signed; + assert(TypeOfCmp != OperationType::None && + "Operation neither with signed or unsigned integers."); + + const bool IsAnd = BitOpcode == ISD::AND; + const ISD::NodeType PickCode = + getSelectionCode(TypeOfCmp, IsAnd, IsGreaterOp); + + SDLoc DL(N); + SDValue Pick = DAG.getNode(PickCode, DL, CmpOpVT, A, B); + SDValue Cmp = + DAG.getSetCC(DL, N->getOperand(0).getValueType(), Pick, C, RefOpcode); + + return Cmp; +} + static SDValue performANDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { @@ -8262,6 +8423,9 @@ return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And); } + if (SDValue V = combineCmpOp(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) return V; @@ -8278,6 +8442,9 @@ const RISCVSubtarget &Subtarget) { SelectionDAG &DAG = DCI.DAG; + if (SDValue V = combineCmpOp(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) return V; diff --git a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll --- a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll +++ b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll @@ -12,9 +12,8 @@ define i1 @ulo(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ult i64 %b, %c @@ -25,9 +24,8 @@ define i1 @ulo_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ult i64 %b, %c @@ -38,9 +36,8 @@ define i1 @ulo_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ugt i64 %c, %b @@ -51,9 +48,8 @@ define i1 @ulo_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ugt i64 %c, %b @@ -65,9 +61,8 @@ define i1 @ula(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ult i64 %b, %c @@ -78,9 +73,8 @@ define i1 @ula_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ult i64 %b, %c @@ -91,9 +85,8 @@ define i1 @ula_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ugt i64 %c, %b @@ -104,9 +97,8 @@ define i1 @ula_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ugt i64 %c, %b @@ -119,9 +111,8 @@ define i1 @ugo(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ugt i64 %b, %c @@ -132,9 +123,8 @@ define i1 @ugo_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ult i64 %c, %a %l1 = icmp ugt i64 %b, %c @@ -145,9 +135,8 @@ define i1 @ugo_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ult i64 %c, %b @@ -158,9 +147,8 @@ define i1 @ugo_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ult i64 %c, %a %l1 = icmp ult i64 %c, %b @@ -173,9 +161,8 @@ define i1 @ugea(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugea: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: xori a0, a0, 1 ; CHECK-NEXT: ret %l0 = icmp uge i64 %a, %c @@ -189,9 +176,8 @@ define i1 @uga(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: uga: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ugt i64 %b, %c @@ -204,9 +190,8 @@ define i1 @sla(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: sla: ; CHECK: # %bb.0: -; CHECK-NEXT: slt a1, a1, a0 -; CHECK-NEXT: slt a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: max a1, a1, a2 +; CHECK-NEXT: slt a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp slt i64 %a, %c %l1 = icmp slt i64 %b, %c @@ -214,6 +199,7 @@ ret i1 %res } +; Negative test ; Float check. define i1 @flo(float %c, float %a, float %b) { ; CHECK-RV64I-LABEL: flo: @@ -259,6 +245,7 @@ ret i1 %res } +; Negative test ; Double check. define i1 @dlo(double %c, double %a, double %b) { ; CHECK-LABEL: dlo: @@ -296,6 +283,7 @@ ret i1 %res } +; Negative test ; More than one user define i1 @multi_user(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: multi_user: @@ -313,6 +301,7 @@ ret i1 %out } +; Negative test ; No same comparations define i1 @no_same_ops(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: no_same_ops: