diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -39,6 +39,8 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include + using namespace llvm; #define DEBUG_TYPE "riscv-lower" @@ -8290,6 +8292,192 @@ return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false); } +// Verifies conditions to apply an optimization. +// Returns Reference comparison code and three operands A, B, C. +// Conditions for optimization: +// One operand of the compasions has to be common. +// This operand is written to C. +// Two others operands are differend. They are written to A and B. +// Comparisons has to be similar with respect to common operand C. +// e.g. A < C; C > B are similar +// but A < C; B > C are not. +// Reference comparison code is the comparison code if +// common operand is right placed. +// e.g. C > A will be swapped to A < C. +static Optional> +verifyCompareConds(SDNode *N, SelectionDAG &DAG) { + LLVM_DEBUG( + dbgs() << "Checking conditions for comparison operation combining.\n";); + + SDValue V0 = N->getOperand(0); + SDValue V1 = N->getOperand(1); + // Condition 1. Operations have to have same value type. + if (V0.getValueType() != V1.getValueType()) { + return None; + } + // Condition 2.1. Operations have to be used only in logic operation. + if (!V0.hasOneUse() || !V1.hasOneUse()) { + return None; + } + + SDNode *N0 = V0.getNode(); + SDNode *N1 = V1.getNode(); + // Condition 2.2. Operands have to be comparison operations. + if (N0->getOpcode() != ISD::SETCC || N1->getOpcode() != ISD::SETCC) { + return None; + } + // SETCC has three operands: op0, op1, cond. Checks only the first two. + // Returns Iterator to condition operand in comparation. + auto GetCmpIt = [](SDNode *N) { return N->op_begin() + 2; }; + // Finds position of the common operand in the first comparison operation. + auto CommonOpIt0 = std::find_first_of(N0->op_begin(), GetCmpIt(N0), + N1->op_begin(), GetCmpIt(N1)); + + // Condition 3. Common operand has to be in comparations. + if (CommonOpIt0 == GetCmpIt(N0)) { + return None; + } + // Finds common operand for the second comparison operation. + auto CommonOpIt1 = std::find(N1->op_begin(), GetCmpIt(N1), *CommonOpIt0); + assert(CommonOpIt1 != GetCmpIt(N1) && + "Cannot find common operand in the second comparison operation."); + + // Cast difference_type to int used for the future bit operations. + // pair.first is the position of the common operand for the first + // operation(N0). + // pair.second is the position of the common operand for the + // second operation(N1). + auto CommonPos = + std::make_pair(std::distance(N0->op_begin(), CommonOpIt0), + std::distance(N1->op_begin(), CommonOpIt1)); + auto CheckZeroOne = [](std::pair P) { + return (P.first == 0 || P.first == 1) && (P.second == 0 || P.second == 1); + }; + assert(CheckZeroOne(CommonPos) && + "Position of the operand has to be 0 or 1."); + + // Since Position can be only 0 or 1, inverse positions of the common operand + // to get positions of others operands. + auto DifferPos = + std::make_pair(CommonPos.first ^ 1, CommonPos.second ^ 1); + + LLVM_DEBUG(dbgs() << "Shared operands are on positions: " << CommonPos.first + << " and " << CommonPos.second << '\n';); + // If common operand at the first position then swap operation to convert to + // strict patternt. Common operand has to be right hand side. + ISD::CondCode RefCond = cast(*GetCmpIt(N0))->get(); + ISD::CondCode AssistCode = cast(*GetCmpIt(N1))->get(); + if (!CommonPos.first) + RefCond = ISD::getSetCCSwappedOperands(RefCond); + if (!CommonPos.second) + AssistCode = ISD::getSetCCSwappedOperands(AssistCode); + LLVM_DEBUG(dbgs() << "Reference condition is: " << RefCond << '\n';); + // If there are different comparison operations then do not perform an + // optimization. a < c; c < b -> will be changed to b > c. + if (RefCond != AssistCode) { + return None; + } + + SDValue A = N0->getOperand(DifferPos.first); + SDValue B = N1->getOperand(DifferPos.second); + SDValue C = *CommonOpIt0; + // Condition 4. Compare operands has to be the same type. + if (A.getValueType() != B.getValueType() || + A.getValueType() != C.getValueType()) { + return None; + } + + LLVM_DEBUG( + dbgs() << "The conditions for comparation combining are satisfied.\n";); + return std::make_tuple(RefCond, A, B, C); +} + +enum class OperationType : unsigned { Signed = 0, Unsigned, Float }; +static ISD::NodeType getSelectionCode(OperationType OpT, bool IsAnd, + bool IsGreaterOp) { + // Codes of selection operation. The first index selects signed, unsigned or + // float the second index selects MIN/MAX. + static constexpr ISD::NodeType SelectionCodes[3][2] = { + {ISD::SMIN, ISD::SMAX}, + {ISD::UMIN, ISD::UMAX}, + {ISD::FMINNUM, ISD::FMAXNUM}}; + const bool ChooseSelCode = IsAnd ^ IsGreaterOp; + // boolean values will be converted to 0 and 1 and used as indicies in matrix. + const unsigned TyIdx = static_cast(OpT); + assert(TyIdx < 3 && "Access out of boundaries"); + return SelectionCodes[TyIdx][ChooseSelCode]; +} + +// Checks necessary extensions to perform the combineCmp optimization. +// Returns false if some extensions are missing. +static bool checkNecessaryExtensions(SDNode *N, + const RISCVSubtarget &Subtarget) { + // This optimization suitable only with Zbb extension. + if (!Subtarget.hasStdExtZbb()) + return false; + // Early calculating type of operand to check necessary extensions. + const auto OpEVT = N->getOperand(0).getOperand(0).getValueType(); + if (OpEVT.isVector()) + // Do not perform optimization with vectors. + return false; + const auto OpMVT = OpEVT.getSimpleVT(); + if (OpMVT == MVT::SimpleValueType::f64 && !Subtarget.hasStdExtD()) + return false; + if (OpMVT == MVT::SimpleValueType::f32 && !Subtarget.hasStdExtF()) + return false; + return true; +} + +// Combines two comparison operation and logic operation to one selection +// operation(min, max) and logic operation. Returns new constructed Node if +// conditions for optimization are satisfied. +static SDValue combineCmpOp(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!checkNecessaryExtensions(N, Subtarget)) { + return SDValue(); + } + const auto Props = verifyCompareConds(N, DAG); + // If conditions are invalidated then do not perform an optimization. + if (!Props) { + return SDValue(); + } + + const auto [RefOpcode, A, B, C] = Props.value(); + const auto BitOpcode = N->getOpcode(); + assert((BitOpcode == ISD::AND || BitOpcode == ISD::OR) && + "This optimization can be used only with AND/OR operations"); + const auto CmpOpVT = A.getValueType(); + + // Collection of all Conditions which similar to less. + static const std::unordered_set SetLess = { + ISD::SETULT, ISD::SETULE, ISD::SETLT, ISD::SETLE}; + // Collection of all Conditions which similar to greater. + static const std::unordered_set SetGreater = { + ISD::SETUGT, ISD::SETUGE, ISD::SETGT, ISD::SETGE}; + + bool IsGreaterOp = SetGreater.count(RefOpcode); + if (!IsGreaterOp && !SetLess.count(RefOpcode)) { + return SDValue(); + } + + OperationType TypeOfCmp = OperationType::Signed; + if (CmpOpVT.isFloatingPoint()) + TypeOfCmp = OperationType::Float; + else if (ISD::isUnsignedIntSetCC(RefOpcode)) + TypeOfCmp = OperationType::Unsigned; + + const bool IsAnd = BitOpcode == ISD::AND; + const ISD::NodeType PickCode = + getSelectionCode(TypeOfCmp, IsAnd, IsGreaterOp); + + SDLoc DL(N); + SDValue Pick = DAG.getNode(PickCode, DL, CmpOpVT, A, B); + SDValue Cmp = + DAG.getSetCC(DL, N->getOperand(0).getValueType(), Pick, C, RefOpcode); + + return Cmp; +} + static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { SDValue N0 = N->getOperand(0); @@ -8311,6 +8499,9 @@ return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And); } + if (SDValue V = combineCmpOp(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) return V; @@ -8330,6 +8521,9 @@ return SHFL; } + if (SDValue V = combineCmpOp(N, DAG, Subtarget)) + return V; + if (SDValue V = combineBinOpToReduce(N, DAG)) return V; // fold (or (select cond, 0, y), x) -> diff --git a/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll @@ -0,0 +1,317 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -mattr=+zbb -O3 < %s \ +; RUN: | FileCheck %s --check-prefixes=CHECK,CHECK-RV64I +; RUN: llc -mtriple=riscv64 -mattr=+zbb,+f -target-abi=lp64f -O3 < %s \ +; RUN: | FileCheck %s --check-prefixes=CHECK,CHECK-RV64IF +; Tests aimed to check oprimization which combines +; two comparison operations and logic operation into +; one select(min/max) operation and one comparison +; operaion. + +; 4 patterns bellow will be converted to umin+less. +define i1 @ulo(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ulo: +; CHECK: # %bb.0: +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ult i64 %a, %c + %l1 = icmp ult i64 %b, %c + %res = or i1 %l0, %l1 + ret i1 %res +} + +define i1 @ulo_swap1(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ulo_swap1: +; CHECK: # %bb.0: +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ugt i64 %c, %a + %l1 = icmp ult i64 %b, %c + %res = or i1 %l0, %l1 + ret i1 %res +} + +define i1 @ulo_swap2(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ulo_swap2: +; CHECK: # %bb.0: +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ult i64 %a, %c + %l1 = icmp ugt i64 %c, %b + %res = or i1 %l0, %l1 + ret i1 %res +} + +define i1 @ulo_swap12(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ulo_swap12: +; CHECK: # %bb.0: +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ugt i64 %c, %a + %l1 = icmp ugt i64 %c, %b + %res = or i1 %l0, %l1 + ret i1 %res +} + +; 4 patterns bellow will be converted to umax+less. +define i1 @ula(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ula: +; CHECK: # %bb.0: +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ult i64 %a, %c + %l1 = icmp ult i64 %b, %c + %res = and i1 %l0, %l1 + ret i1 %res +} + +define i1 @ula_swap1(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ula_swap1: +; CHECK: # %bb.0: +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ugt i64 %c, %a + %l1 = icmp ult i64 %b, %c + %res = and i1 %l0, %l1 + ret i1 %res +} + +define i1 @ula_swap2(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ula_swap2: +; CHECK: # %bb.0: +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ult i64 %a, %c + %l1 = icmp ugt i64 %c, %b + %res = and i1 %l0, %l1 + ret i1 %res +} + +define i1 @ula_swap12(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ula_swap12: +; CHECK: # %bb.0: +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ugt i64 %c, %a + %l1 = icmp ugt i64 %c, %b + %res = and i1 %l0, %l1 + ret i1 %res +} + +; 4 patterns bellow will be converted to umax+greater +; (greater will be converted to setult somehow) +define i1 @ugo(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ugo: +; CHECK: # %bb.0: +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 +; CHECK-NEXT: ret + %l0 = icmp ugt i64 %a, %c + %l1 = icmp ugt i64 %b, %c + %res = or i1 %l0, %l1 + ret i1 %res +} + +define i1 @ugo_swap1(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ugo_swap1: +; CHECK: # %bb.0: +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 +; CHECK-NEXT: ret + %l0 = icmp ult i64 %c, %a + %l1 = icmp ugt i64 %b, %c + %res = or i1 %l0, %l1 + ret i1 %res +} + +define i1 @ugo_swap2(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ugo_swap2: +; CHECK: # %bb.0: +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 +; CHECK-NEXT: ret + %l0 = icmp ugt i64 %a, %c + %l1 = icmp ult i64 %c, %b + %res = or i1 %l0, %l1 + ret i1 %res +} + +define i1 @ugo_swap12(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ugo_swap12: +; CHECK: # %bb.0: +; CHECK-NEXT: maxu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 +; CHECK-NEXT: ret + %l0 = icmp ult i64 %c, %a + %l1 = icmp ult i64 %c, %b + %res = or i1 %l0, %l1 + ret i1 %res +} + +; Pattern bellow will be converted to umin+greater or equal +; (greater will be converted to setult somehow) +define i1 @ugea(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: ugea: +; CHECK: # %bb.0: +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a1, a0 +; CHECK-NEXT: xori a0, a0, 1 +; CHECK-NEXT: ret + %l0 = icmp uge i64 %a, %c + %l1 = icmp uge i64 %b, %c + %res = and i1 %l0, %l1 + ret i1 %res +} + +; Pattern bellow will be converted to umin+greater +; (greater will be converted to setult somehow) +define i1 @uga(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: uga: +; CHECK: # %bb.0: +; CHECK-NEXT: minu a1, a1, a2 +; CHECK-NEXT: sltu a0, a0, a1 +; CHECK-NEXT: ret + %l0 = icmp ugt i64 %a, %c + %l1 = icmp ugt i64 %b, %c + %res = and i1 %l0, %l1 + ret i1 %res +} + +; Patterns bellow will be converted to smax+less. +; Sign check. +define i1 @sla(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: sla: +; CHECK: # %bb.0: +; CHECK-NEXT: max a1, a1, a2 +; CHECK-NEXT: slt a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp slt i64 %a, %c + %l1 = icmp slt i64 %b, %c + %res = and i1 %l0, %l1 + ret i1 %res +} + +; Patterns bellow will be converted to fmin+less. +; Float check. +define i1 @flo(float %c, float %a, float %b) { +; CHECK-RV64I-LABEL: flo: +; CHECK-RV64I: # %bb.0: +; CHECK-RV64I-NEXT: addi sp, sp, -32 +; CHECK-RV64I-NEXT: .cfi_def_cfa_offset 32 +; CHECK-RV64I-NEXT: sd ra, 24(sp) # 8-byte Folded Spill +; CHECK-RV64I-NEXT: sd s0, 16(sp) # 8-byte Folded Spill +; CHECK-RV64I-NEXT: sd s1, 8(sp) # 8-byte Folded Spill +; CHECK-RV64I-NEXT: sd s2, 0(sp) # 8-byte Folded Spill +; CHECK-RV64I-NEXT: .cfi_offset ra, -8 +; CHECK-RV64I-NEXT: .cfi_offset s0, -16 +; CHECK-RV64I-NEXT: .cfi_offset s1, -24 +; CHECK-RV64I-NEXT: .cfi_offset s2, -32 +; CHECK-RV64I-NEXT: mv s0, a2 +; CHECK-RV64I-NEXT: mv s1, a0 +; CHECK-RV64I-NEXT: mv a0, a1 +; CHECK-RV64I-NEXT: mv a1, s1 +; CHECK-RV64I-NEXT: call __gesf2@plt +; CHECK-RV64I-NEXT: mv s2, a0 +; CHECK-RV64I-NEXT: mv a0, s0 +; CHECK-RV64I-NEXT: mv a1, s1 +; CHECK-RV64I-NEXT: call __gesf2@plt +; CHECK-RV64I-NEXT: or a0, s2, a0 +; CHECK-RV64I-NEXT: slti a0, a0, 0 +; CHECK-RV64I-NEXT: ld ra, 24(sp) # 8-byte Folded Reload +; CHECK-RV64I-NEXT: ld s0, 16(sp) # 8-byte Folded Reload +; CHECK-RV64I-NEXT: ld s1, 8(sp) # 8-byte Folded Reload +; CHECK-RV64I-NEXT: ld s2, 0(sp) # 8-byte Folded Reload +; CHECK-RV64I-NEXT: addi sp, sp, 32 +; CHECK-RV64I-NEXT: ret +; +; CHECK-RV64IF-LABEL: flo: +; CHECK-RV64IF: # %bb.0: +; CHECK-RV64IF-NEXT: fmin.s ft0, fa1, fa2 +; CHECK-RV64IF-NEXT: fle.s a0, fa0, ft0 +; CHECK-RV64IF-NEXT: xori a0, a0, 1 +; CHECK-RV64IF-NEXT: ret + %l0 = fcmp ult float %a, %c + %l1 = fcmp ult float %b, %c + %res = or i1 %l0, %l1 + ret i1 %res +} + +; Negative test +; Double check with/without +f extension. +define i1 @dlo(double %c, double %a, double %b) { +; CHECK-LABEL: dlo: +; CHECK: # %bb.0: +; CHECK-NEXT: addi sp, sp, -32 +; CHECK-NEXT: .cfi_def_cfa_offset 32 +; CHECK-NEXT: sd ra, 24(sp) # 8-byte Folded Spill +; CHECK-NEXT: sd s0, 16(sp) # 8-byte Folded Spill +; CHECK-NEXT: sd s1, 8(sp) # 8-byte Folded Spill +; CHECK-NEXT: sd s2, 0(sp) # 8-byte Folded Spill +; CHECK-NEXT: .cfi_offset ra, -8 +; CHECK-NEXT: .cfi_offset s0, -16 +; CHECK-NEXT: .cfi_offset s1, -24 +; CHECK-NEXT: .cfi_offset s2, -32 +; CHECK-NEXT: mv s0, a2 +; CHECK-NEXT: mv s1, a0 +; CHECK-NEXT: mv a0, a1 +; CHECK-NEXT: mv a1, s1 +; CHECK-NEXT: call __gedf2@plt +; CHECK-NEXT: mv s2, a0 +; CHECK-NEXT: mv a0, s0 +; CHECK-NEXT: mv a1, s1 +; CHECK-NEXT: call __gedf2@plt +; CHECK-NEXT: or a0, s2, a0 +; CHECK-NEXT: slti a0, a0, 0 +; CHECK-NEXT: ld ra, 24(sp) # 8-byte Folded Reload +; CHECK-NEXT: ld s0, 16(sp) # 8-byte Folded Reload +; CHECK-NEXT: ld s1, 8(sp) # 8-byte Folded Reload +; CHECK-NEXT: ld s2, 0(sp) # 8-byte Folded Reload +; CHECK-NEXT: addi sp, sp, 32 +; CHECK-NEXT: ret + %l0 = fcmp ult double %a, %c + %l1 = fcmp ult double %b, %c + %res = or i1 %l0, %l1 + ret i1 %res +} + +; Negative test +; More than one user +define i1 @multi_user(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: multi_user: +; CHECK: # %bb.0: +; CHECK-NEXT: sltu a1, a1, a0 +; CHECK-NEXT: sltu a0, a2, a0 +; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ugt i64 %c, %a + %l1 = icmp ult i64 %b, %c + %res = or i1 %l0, %l1 + + %out = and i1 %l0, %res + ret i1 %out +} + +; Negative test +; No same comparations +define i1 @no_same_ops(i64 %c, i64 %a, i64 %b) { +; CHECK-LABEL: no_same_ops: +; CHECK: # %bb.0: +; CHECK-NEXT: sltu a1, a0, a1 +; CHECK-NEXT: sltu a0, a2, a0 +; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: ret + %l0 = icmp ult i64 %c, %a + %l1 = icmp ugt i64 %c, %b + %res = or i1 %l0, %l1 + ret i1 %res +} +