diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8239,6 +8239,234 @@ return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT)); } +// Helper class contains information about comparison operation. +// The first two operands of this operation are compared values and the +// last one is the operation. +// Compared values are stored in Ops. +// Comparison operation is stored in CCode. +class CmpOpInfo { + static unsigned constexpr Size = 2u; + + // Type for storing operands of compare operation. + using OpsArray = std::array; + OpsArray Ops; + + using const_iterator = OpsArray::const_iterator; + const_iterator begin() const { return Ops.begin(); } + const_iterator end() const { return Ops.end(); } + + ISD::CondCode CCode; + + unsigned CommonPos{Size}; + unsigned DifferPos{Size}; + + // Sets CommonPos and DifferPos based on incoming possition + // of common operand CPos. + void setPositions(const_iterator CPos) { + assert(CPos != Ops.end() && "Common operand has to be in OpsArray.\n"); + CommonPos = CPos == Ops.begin() ? 0 : 1; + DifferPos = 1 - CommonPos; + assert((DifferPos == 0 || DifferPos == 1) && + "Positions can be only 0 or 1."); + } + + // Private constructor of comparison info based on comparison operator. + // It is private because CmpOpInfo only reasonable relative to other + // comparison operator. Therefore, infos about comparison operation + // have to be collected simultaneously via CmpOpInfo::getInfoAbout(). + CmpOpInfo(const SDValue &CmpOp) + : Ops{CmpOp->getOperand(0), CmpOp->getOperand(1)}, + CCode{cast(CmpOp->getOperand(2))->get()} {} + + // Finds common operand of Op1 and Op2 and finishes filling CmpOpInfos. + // Returns true if common operand is found. Otherwise - false. + static bool establishCorrespondence(CmpOpInfo &Op1, CmpOpInfo &Op2) { + const auto CommonOpIt1 = + std::find_first_of(Op1.begin(), Op1.end(), Op2.begin(), Op2.end()); + if (CommonOpIt1 == Op1.end()) + return false; + + const auto CommonOpIt2 = std::find(Op2.begin(), Op2.end(), *CommonOpIt1); + assert(CommonOpIt2 != Op2.end() && + "Cannot find common operand in the second comparison operation."); + + Op1.setPositions(CommonOpIt1); + Op2.setPositions(CommonOpIt2); + + return true; + } + +public: + CmpOpInfo(const CmpOpInfo &) = default; + CmpOpInfo(CmpOpInfo &&) = default; + + SDValue const &operator[](unsigned Pos) const { + assert(Pos < Size && "Out of range\n"); + return Ops.at(Pos); + } + + // Creates infos about comparison operations CmpOp0 and CmpOp1. + // If there is no common operand returns None. Otherwise, returns + // correspondence info about comparison operations. + static llvm::Optional> + getInfoAbout(SDValue const &CmpOp0, SDValue const &CmpOp1) { + CmpOpInfo Op0{CmpOp0}; + CmpOpInfo Op1{CmpOp1}; + if (!establishCorrespondence(Op0, Op1)) + return None; + return std::make_pair(Op0, Op1); + } + + // Returns position of common operand. + unsigned getCPos() const { return CommonPos; } + + // Returns position of differ operand. + unsigned getDPos() const { return DifferPos; } + + // Returns common operand. + SDValue const &getCOp() const { return operator[](CommonPos); } + + // Returns differ operand. + SDValue const &getDOp() const { return operator[](DifferPos); } + + // Returns consition code of comparison operation. + ISD::CondCode getCondCode() const { return CCode; } +}; + +// Verifies conditions to apply an optimization. +// Returns Reference comparison code and three operands A, B, C. +// Conditions for optimization: +// One operand of the compasions has to be common. +// This operand is written to C. +// Two others operands are differend. They are written to A and B. +// Comparisons has to be similar with respect to common operand C. +// e.g. A < C; C > B are similar +// but A < C; B > C are not. +// Reference comparison code is the comparison code if +// common operand is right placed. +// e.g. C > A will be swapped to A < C. +static Optional> +verifyCompareConds(SDNode *N, SelectionDAG &DAG) { + LLVM_DEBUG( + dbgs() << "Checking conditions for comparison operation combining.\n";); + + SDValue V0 = N->getOperand(0); + SDValue V1 = N->getOperand(1); + assert(V0.getValueType() == V1.getValueType() && + "Operations must have the same value type."); + + // Condition 1. Operations have to be used only in logic operation. + if (!V0.hasOneUse() || !V1.hasOneUse()) + return None; + + // Condition 2. Operands have to be comparison operations. + if (V0->getOpcode() != ISD::SETCC || V1->getOpcode() != ISD::SETCC) + return None; + + // Condition 3.1. Operations only with integers. + if (!V0->getOperand(0).getValueType().isInteger()) + return None; + + auto ComparisonInfo = CmpOpInfo::getInfoAbout(V0, V1); + // Condition 3.2. Common operand has to be in comparison. + if (!ComparisonInfo) + return None; + + auto &&[Op0, Op1] = std::move(ComparisonInfo.value()); + + LLVM_DEBUG(dbgs() << "Shared operands are on positions: " << Op0.getCPos() + << " and " << Op1.getCPos() << '\n';); + // If common operand at the first position then swap operation to convert to + // strict pattern. Common operand has to be right hand side. + ISD::CondCode RefCond = Op0.getCondCode(); + ISD::CondCode AssistCode = Op1.getCondCode(); + if (!Op0.getCPos()) + RefCond = ISD::getSetCCSwappedOperands(RefCond); + if (!Op1.getCPos()) + AssistCode = ISD::getSetCCSwappedOperands(AssistCode); + LLVM_DEBUG(dbgs() << "Reference condition is: " << RefCond << '\n';); + // If there are different comparison operations then do not perform an + // optimization. a < c; c < b -> will be changed to b > c. + if (RefCond != AssistCode) + return None; + + // Conditions can be only similar to Less or Greater. (>, >=, <, <=) + // Applying this mask to the operation will determine Less and Greater + // operations. + const unsigned CmpMask = 0b110; + const unsigned MaskedOpcode = CmpMask & RefCond; + // If masking gave 0b110, then this is an operation NE, O or TRUE. + if (MaskedOpcode == CmpMask) + return None; + // If masking gave 00000, then this is an operation E, O or FALSE. + if (MaskedOpcode == 0) + return None; + // Everything else is similar to Less or Greater. + + SDValue A = Op0.getDOp(); + SDValue B = Op1.getDOp(); + SDValue C = Op0.getCOp(); + + LLVM_DEBUG( + dbgs() << "The conditions for combining comparisons are satisfied.\n";); + return std::make_tuple(RefCond, A, B, C); +} + +enum class OperationType : unsigned { Signed = 0, Unsigned, None }; +static ISD::NodeType getSelectionCode(OperationType OpT, bool IsAnd, + bool IsGreaterOp) { + // Codes of selection operation. The first index selects signed or unsigned, + // the second index selects MIN/MAX. + static constexpr ISD::NodeType SelectionCodes[2][2] = { + {ISD::SMIN, ISD::SMAX}, {ISD::UMIN, ISD::UMAX}}; + const bool ChooseSelCode = IsAnd ^ IsGreaterOp; + // boolean values will be converted to 0 and 1 and used as indicies in matrix. + const unsigned TyIdx = static_cast(OpT); + assert(TyIdx < 2 && "Access out of boundaries"); + return SelectionCodes[TyIdx][ChooseSelCode]; +} + +// Combines two comparison operation and logic operation to one selection +// operation(min, max) and logic operation. Returns new constructed Node if +// conditions for optimization are satisfied. +static SDValue combineCmpOp(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!Subtarget.hasStdExtZbb()) + return SDValue(); + + const auto BitOpcode = N->getOpcode(); + assert((BitOpcode == ISD::AND || BitOpcode == ISD::OR) && + "This optimization can be used only with AND/OR operations"); + + const auto Props = verifyCompareConds(N, DAG); + // If conditions are invalidated then do not perform an optimization. + if (!Props) + return SDValue(); + + const auto [RefOpcode, A, B, C] = Props.value(); + const auto CmpOpVT = A.getValueType(); + + const bool IsGreaterOp = RefOpcode & 0b10; + OperationType TypeOfCmp = OperationType::None; + if (ISD::isUnsignedIntSetCC(RefOpcode)) + TypeOfCmp = OperationType::Unsigned; + else if (ISD::isSignedIntSetCC(RefOpcode)) + TypeOfCmp = OperationType::Signed; + assert(TypeOfCmp != OperationType::None && + "Operation neither with signed or unsigned integers."); + + const bool IsAnd = BitOpcode == ISD::AND; + const ISD::NodeType PickCode = + getSelectionCode(TypeOfCmp, IsAnd, IsGreaterOp); + + SDLoc DL(N); + SDValue Pick = DAG.getNode(PickCode, DL, CmpOpVT, A, B); + SDValue Cmp = + DAG.getSetCC(DL, N->getOperand(0).getValueType(), Pick, C, RefOpcode); + + return Cmp; +} + static SDValue performANDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { @@ -8263,6 +8491,9 @@ return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And); } + if (SDValue V = combineCmpOp(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) return V; @@ -8279,6 +8510,9 @@ const RISCVSubtarget &Subtarget) { SelectionDAG &DAG = DCI.DAG; + if (SDValue V = combineCmpOp(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) return V; diff --git a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll --- a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll +++ b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll @@ -12,9 +12,8 @@ define i1 @ulo(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ult i64 %b, %c @@ -25,9 +24,8 @@ define i1 @ulo_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ult i64 %b, %c @@ -38,9 +36,8 @@ define i1 @ulo_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ugt i64 %c, %b @@ -51,9 +48,8 @@ define i1 @ulo_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ulo_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ugt i64 %c, %b @@ -65,9 +61,8 @@ define i1 @ula(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ult i64 %b, %c @@ -78,9 +73,8 @@ define i1 @ula_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ult i64 %b, %c @@ -91,9 +85,8 @@ define i1 @ula_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ult i64 %a, %c %l1 = icmp ugt i64 %c, %b @@ -104,9 +97,8 @@ define i1 @ula_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ula_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %c, %a %l1 = icmp ugt i64 %c, %b @@ -119,9 +111,8 @@ define i1 @ugo(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ugt i64 %b, %c @@ -132,9 +123,8 @@ define i1 @ugo_swap1(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap1: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ult i64 %c, %a %l1 = icmp ugt i64 %b, %c @@ -145,9 +135,8 @@ define i1 @ugo_swap2(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap2: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ult i64 %c, %b @@ -158,9 +147,8 @@ define i1 @ugo_swap12(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugo_swap12: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ult i64 %c, %a %l1 = icmp ult i64 %c, %b @@ -173,9 +161,8 @@ define i1 @ugea(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: ugea: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: sltu a0, a2, a0 -; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 ; CHECK-NEXT: xori a0, a0, 1 ; CHECK-NEXT: ret %l0 = icmp uge i64 %a, %c @@ -189,9 +176,8 @@ define i1 @uga(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: uga: ; CHECK: # %bb.0: -; CHECK-NEXT: sltu a1, a0, a1 -; CHECK-NEXT: sltu a0, a0, a2 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 ; CHECK-NEXT: ret %l0 = icmp ugt i64 %a, %c %l1 = icmp ugt i64 %b, %c @@ -204,9 +190,8 @@ define i1 @sla(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: sla: ; CHECK: # %bb.0: -; CHECK-NEXT: slt a1, a1, a0 -; CHECK-NEXT: slt a0, a2, a0 -; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: max a1, a1, a2 +; CHECK-NEXT: slt a0, a1, a0 ; CHECK-NEXT: ret %l0 = icmp slt i64 %a, %c %l1 = icmp slt i64 %b, %c @@ -214,6 +199,7 @@ ret i1 %res } +; Negative test ; Float check. define i1 @flo(float %c, float %a, float %b) { ; CHECK-RV64I-LABEL: flo: @@ -259,6 +245,7 @@ ret i1 %res } +; Negative test ; Double check. define i1 @dlo(double %c, double %a, double %b) { ; CHECK-LABEL: dlo: @@ -296,6 +283,7 @@ ret i1 %res } +; Negative test ; More than one user define i1 @multi_user(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: multi_user: @@ -313,6 +301,7 @@ ret i1 %out } +; Negative test ; No same comparations define i1 @no_same_ops(i64 %c, i64 %a, i64 %b) { ; CHECK-LABEL: no_same_ops: