diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3308,23 +3308,29 @@ Op.getOperand(2)); } -// Sets 'C' bit of NZCV to 0 if value is 0, else sets 'C' bit to 1 -static SDValue valueToCarryFlag(SDValue Value, SelectionDAG &DAG) { +// If Invert is false, sets 'C' bit of NZCV to 0 if value is 0, else sets 'C' +// bit to 1. If Invert is true, sets 'C' bit of NZCV to 1 if value is 0, else +// sets 'C' bit to 0. +static SDValue valueToCarryFlag(SDValue Value, SelectionDAG &DAG, bool Invert) { SDLoc DL(Value); - SDValue One = DAG.getConstant(1, DL, Value.getValueType()); + EVT VT = Value.getValueType(); + SDValue Op0 = Invert ? DAG.getConstant(0, DL, VT) : Value; + SDValue Op1 = Invert ? Value : DAG.getConstant(1, DL, VT); SDValue Cmp = - DAG.getNode(AArch64ISD::SUBS, DL, - DAG.getVTList(Value.getValueType(), MVT::Glue), Value, One); + DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::Glue), Op0, Op1); return Cmp.getValue(1); } -// Value is 1 if 'C' bit of NZCV is 1, else 0 -static SDValue carryFlagToValue(SDValue Flag, EVT VT, SelectionDAG &DAG) { +// If Invert is false, value is 1 if 'C' bit of NZCV is 1, else 0. +// If Invert is true, value is 0 if 'C' bit of NZCV is 1, else 1. +static SDValue carryFlagToValue(SDValue Flag, EVT VT, SelectionDAG &DAG, + bool Invert) { assert(Flag.getResNo() == 1); SDLoc DL(Flag); SDValue Zero = DAG.getConstant(0, DL, VT); SDValue One = DAG.getConstant(1, DL, VT); - SDValue CC = DAG.getConstant(AArch64CC::HS, DL, MVT::i32); + unsigned Cond = Invert ? AArch64CC::LO : AArch64CC::HS; + SDValue CC = DAG.getConstant(Cond, DL, MVT::i32); return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Flag); } @@ -3348,9 +3354,10 @@ if (VT0 != MVT::i32 && VT0 != MVT::i64) return SDValue(); + bool InvertCarry = Opcode == AArch64ISD::SBCS; SDValue OpLHS = Op.getOperand(0); SDValue OpRHS = Op.getOperand(1); - SDValue OpCarryIn = valueToCarryFlag(Op.getOperand(2), DAG); + SDValue OpCarryIn = valueToCarryFlag(Op.getOperand(2), DAG, InvertCarry); SDLoc DL(Op); SDVTList VTs = DAG.getVTList(VT0, VT1); @@ -3358,8 +3365,9 @@ SDValue Sum = DAG.getNode(Opcode, DL, DAG.getVTList(VT0, MVT::Glue), OpLHS, OpRHS, OpCarryIn); - SDValue OutFlag = IsSigned ? overflowFlagToValue(Sum.getValue(1), VT1, DAG) - : carryFlagToValue(Sum.getValue(1), VT1, DAG); + SDValue OutFlag = + IsSigned ? overflowFlagToValue(Sum.getValue(1), VT1, DAG) + : carryFlagToValue(Sum.getValue(1), VT1, DAG, InvertCarry); return DAG.getNode(ISD::MERGE_VALUES, DL, VTs, Sum, OutFlag); } @@ -15517,13 +15525,21 @@ } // (ADC{S} l r (CMP (CSET HS carry) 1)) => (ADC{S} l r carry) -// (SBC{S} l r (CMP (CSET LO carry) 1)) => (SBC{S} l r carry) +// (SBC{S} l r (CMP 0 (CSET LO carry))) => (SBC{S} l r carry) static SDValue foldOverflowCheck(SDNode *Op, SelectionDAG &DAG, bool IsAdd) { SDValue CmpOp = Op->getOperand(2); - if (!(isCMP(CmpOp) && isOneConstant(CmpOp.getOperand(1)))) + if (!isCMP(CmpOp)) return SDValue(); - SDValue CsetOp = CmpOp->getOperand(0); + if (IsAdd) { + if (!isOneConstant(CmpOp.getOperand(1))) + return SDValue(); + } else { + if (!isNullConstant(CmpOp.getOperand(0))) + return SDValue(); + } + + SDValue CsetOp = CmpOp->getOperand(IsAdd ? 0 : 1); auto CC = getCSETCondCode(CsetOp); if (CC != (IsAdd ? AArch64CC::HS : AArch64CC::LO)) return SDValue(); diff --git a/llvm/test/CodeGen/AArch64/i128-math.ll b/llvm/test/CodeGen/AArch64/i128-math.ll --- a/llvm/test/CodeGen/AArch64/i128-math.ll +++ b/llvm/test/CodeGen/AArch64/i128-math.ll @@ -92,7 +92,7 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: subs x0, x0, x2 ; CHECK-NEXT: sbcs x1, x1, x3 -; CHECK-NEXT: cset w8, hs +; CHECK-NEXT: cset w8, lo ; CHECK-NEXT: eor w2, w8, #0x1 ; CHECK-NEXT: ret %1 = tail call { i128, i1 } @llvm.usub.with.overflow.i128(i128 %x, i128 %y) @@ -110,7 +110,7 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: subs x0, x0, x2 ; CHECK-NEXT: sbcs x1, x1, x3 -; CHECK-NEXT: cset w2, hs +; CHECK-NEXT: cset w2, lo ; CHECK-NEXT: ret %1 = tail call { i128, i1 } @llvm.usub.with.overflow.i128(i128 %x, i128 %y) %2 = extractvalue { i128, i1 } %1, 0 @@ -126,7 +126,7 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: subs x8, x0, x2 ; CHECK-NEXT: sbcs x9, x1, x3 -; CHECK-NEXT: cset w10, hs +; CHECK-NEXT: cset w10, lo ; CHECK-NEXT: cmp w10, #0 ; CHECK-NEXT: csel x0, xzr, x8, ne ; CHECK-NEXT: csel x1, xzr, x9, ne diff --git a/llvm/test/CodeGen/AArch64/usub_sat_vec.ll b/llvm/test/CodeGen/AArch64/usub_sat_vec.ll --- a/llvm/test/CodeGen/AArch64/usub_sat_vec.ll +++ b/llvm/test/CodeGen/AArch64/usub_sat_vec.ll @@ -346,13 +346,13 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: subs x8, x2, x6 ; CHECK-NEXT: sbcs x9, x3, x7 -; CHECK-NEXT: cset w10, hs +; CHECK-NEXT: cset w10, lo ; CHECK-NEXT: cmp w10, #0 ; CHECK-NEXT: csel x2, xzr, x8, ne ; CHECK-NEXT: csel x3, xzr, x9, ne ; CHECK-NEXT: subs x8, x0, x4 ; CHECK-NEXT: sbcs x9, x1, x5 -; CHECK-NEXT: cset w10, hs +; CHECK-NEXT: cset w10, lo ; CHECK-NEXT: cmp w10, #0 ; CHECK-NEXT: csel x8, xzr, x8, ne ; CHECK-NEXT: csel x1, xzr, x9, ne