Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14839,41 +14839,62 @@ } /// Perform the scalar expression combine in the form of: -/// CSEL (c, 1, cc) + b => CSINC(b+c, b, cc) +/// CSEL(c, 1, cc) + b => CSINC(b+c, b, cc) +/// CSNEG(c, -1, cc) + b => CSINC(b+c, b, cc) static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) { EVT VT = N->getValueType(0); if (!VT.isScalarInteger() || N->getOpcode() != ISD::ADD) return SDValue(); - SDValue CSel = N->getOperand(0); + SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); // Handle commutivity. - if (CSel.getOpcode() != AArch64ISD::CSEL) { - std::swap(CSel, RHS); - if (CSel.getOpcode() != AArch64ISD::CSEL) { + if (LHS.getOpcode() != AArch64ISD::CSEL && + LHS.getOpcode() != AArch64ISD::CSNEG) { + std::swap(LHS, RHS); + if (LHS.getOpcode() != AArch64ISD::CSEL && + LHS.getOpcode() != AArch64ISD::CSNEG) { return SDValue(); } } - if (!CSel.hasOneUse()) + if (!LHS.hasOneUse()) return SDValue(); AArch64CC::CondCode AArch64CC = - static_cast(CSel.getConstantOperandVal(2)); + static_cast(LHS.getConstantOperandVal(2)); - // The CSEL should include a const one operand. - ConstantSDNode *CTVal = dyn_cast(CSel.getOperand(0)); - ConstantSDNode *CFVal = dyn_cast(CSel.getOperand(1)); - if (!CTVal || !CFVal || (!CTVal->isOne() && !CFVal->isOne())) + // The CSEL should include a const one operand, and the CSNEG should include + // One or NegOne operand. + ConstantSDNode *CTVal = dyn_cast(LHS.getOperand(0)); + ConstantSDNode *CFVal = dyn_cast(LHS.getOperand(1)); + if (!CTVal || !CFVal) return SDValue(); - // switch CSEL (1, c, cc) to CSEL (c, 1, !cc) - if (CTVal->isOne() && !CFVal->isOne()) { + if (!(LHS.getOpcode() == AArch64ISD::CSEL && + (CTVal->isOne() || CFVal->isOne())) && + !(LHS.getOpcode() == AArch64ISD::CSNEG && + (CTVal->isOne() || CFVal->isAllOnes()))) + return SDValue(); + + // Switch CSEL(1, c, cc) to CSEL(c, 1, !cc) + if (LHS.getOpcode() == AArch64ISD::CSEL && + CTVal->isOne() && !CFVal->isOne()) { std::swap(CTVal, CFVal); AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC); } + SDLoc DL(N); + // Switch CSNEG(1, c, cc) to CSNEG(-c, -1, !cc) + if (LHS.getOpcode() == AArch64ISD::CSNEG && + CTVal->isOne() && !CFVal->isAllOnes()) { + int64_t C = -1 * CFVal->getSExtValue(); + CTVal = cast(DAG.getConstant(C, DL, VT)); + CFVal = cast(DAG.getConstant(-1, DL, VT)); + AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC); + } + // It might be neutral for larger constants, as the immediate need to be // materialized in a register. APInt ADDC = CTVal->getAPIntValue(); @@ -14881,12 +14902,16 @@ if (!TLI.isLegalAddImmediate(ADDC.getSExtValue())) return SDValue(); - assert(CFVal->isOne() && "Unexpected constant value"); + assert(((LHS.getOpcode() == AArch64ISD::CSEL && CFVal->isOne()) || + (LHS.getOpcode() == AArch64ISD::CSNEG && CFVal->isAllOnes())) && + "Unexpected constant value"); - SDLoc DL(N); - SDValue NewNode = DAG.getNode(ISD::ADD, DL, VT, RHS, SDValue(CTVal, 0)); + unsigned Opcode = ISD::ADD; + if (LHS.getOpcode() == AArch64ISD::CSNEG) + Opcode = ISD::SUB; + SDValue NewNode = DAG.getNode(Opcode, DL, VT, RHS, SDValue(CTVal, 0)); SDValue CCVal = DAG.getConstant(AArch64CC, DL, MVT::i32); - SDValue Cmp = CSel.getOperand(3); + SDValue Cmp = LHS.getOperand(3); return DAG.getNode(AArch64ISD::CSINC, DL, VT, NewNode, RHS, CCVal, Cmp); } Index: llvm/test/CodeGen/AArch64/aarch64-isel-csinc.ll =================================================================== --- llvm/test/CodeGen/AArch64/aarch64-isel-csinc.ll +++ llvm/test/CodeGen/AArch64/aarch64-isel-csinc.ll @@ -112,3 +112,33 @@ %cond = add nsw i32 %cond.v, %b ret i32 %cond } + +; int csinc8 (int a, int b) { return a ? b-1 : b+1; } +define dso_local i32 @csinc8(i32 %a, i32 %b) { +; CHECK-LABEL: csinc8: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: cmp w0, #0 +; CHECK-NEXT: add w8, w1, #1 +; CHECK-NEXT: csinc w0, w8, w1, ne +; CHECK-NEXT: ret +entry: + %tobool.not = icmp eq i32 %a, 0 + %cond.v = select i1 %tobool.not, i32 1, i32 -1 + %cond = add nsw i32 %cond.v, %b + ret i32 %cond +} + +; int csinc9 (int a, int b) { return a ? b+1 : b-1; } +define dso_local i32 @csinc9(i32 %a, i32 %b) { +; CHECK-LABEL: csinc9: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: cmp w0, #0 +; CHECK-NEXT: add w8, w1, #1 +; CHECK-NEXT: csinc w0, w8, w1, eq +; CHECK-NEXT: ret +entry: + %tobool.not = icmp eq i32 %a, 0 + %cond.v = select i1 %tobool.not, i32 -1, i32 1 + %cond = add nsw i32 %cond.v, %b + ret i32 %cond +}