diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14696,6 +14696,56 @@ Dot.getOperand(2)); } +// Try to fold (sub Y, (csel X, -X, pl)) -> (add Y, (csel -X, X, pl)) when +// condition came from (subs X, 0). This matches the CSEL expansion of +// abs node lowered by lowerABS. By swapping the operands, we convert +// abs to nabs. Note that (csel X, -X, pl) will be matched +// to csneg by the CondSelectOp pattern. +static SDValue performCombineSubABS(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // Is abs node has other uses, don't do the combine + if (N1.getOpcode() != AArch64ISD::CSEL || !N1.hasOneUse()) + return SDValue(); + + ConstantSDNode *CCNode = cast(N1->getOperand(2)); + AArch64CC::CondCode CC = + static_cast(CCNode->getSExtValue()); + + if (CC != AArch64CC::PL && CC != AArch64CC::MI) + return SDValue(); + + // Condition should come from SUBS + SDValue Cmp = N1.getOperand(3); + if (Cmp.getOpcode() != AArch64ISD::SUBS || !isNullConstant(Cmp.getOperand(1))) + return SDValue(); + assert(Cmp.getResNo() == 1 && "Unexpected result number"); + + // Get the X + SDValue X = Cmp.getOperand(0); + + SDValue FalseOp = N1.getOperand(0); + SDValue TrueOp = N1.getOperand(1); + + // CSEL operands should be X and NegX. Order doesn't matter. + auto IsNeg = [](SDValue Value, SDValue X) { + return Value.getOpcode() == ISD::SUB && + isNullConstant(Value.getOperand(0)) && X == Value.getOperand(1); + }; + if (!(IsNeg(FalseOp, X) && TrueOp == X) && + !(IsNeg(TrueOp, X) && FalseOp == X)) + return SDValue(); + + // Build a new CSEL with the operands swapped. + SDLoc DL(N); + MVT VT = N->getSimpleValueType(0); + SDValue Csel = DAG.getNode(AArch64ISD::CSEL, DL, VT, TrueOp, FalseOp, + N1.getOperand(2), Cmp); + // Convert sub to add. + return DAG.getNode(ISD::ADD, DL, VT, N0, Csel); +} + // The basic add/sub long vector instructions have variants with "2" on the end // which act on the high-half of their inputs. They are normally matched by // patterns like: @@ -17617,8 +17667,12 @@ default: LLVM_DEBUG(dbgs() << "Custom combining: skipping\n"); break; - case ISD::ADD: case ISD::SUB: + if (SDValue Val = performCombineSubABS(N, DAG)) { + return Val; + } + LLVM_FALLTHROUGH; + case ISD::ADD: return performAddSubCombine(N, DCI, DAG); case ISD::XOR: return performXorCombine(N, DAG, DCI, Subtarget); diff --git a/llvm/test/CodeGen/AArch64/neg-abs.ll b/llvm/test/CodeGen/AArch64/neg-abs.ll --- a/llvm/test/CodeGen/AArch64/neg-abs.ll +++ b/llvm/test/CodeGen/AArch64/neg-abs.ll @@ -8,8 +8,7 @@ ; CHECK-LABEL: neg_abs64: ; CHECK: // %bb.0: ; CHECK-NEXT: cmp x0, #0 -; CHECK-NEXT: cneg x8, x0, mi -; CHECK-NEXT: neg x0, x8 +; CHECK-NEXT: cneg x0, x0, pl ; CHECK-NEXT: ret %abs = tail call i64 @llvm.abs.i64(i64 %x, i1 true) %neg = sub nsw i64 0, %abs @@ -22,8 +21,7 @@ ; CHECK-LABEL: neg_abs32: ; CHECK: // %bb.0: ; CHECK-NEXT: cmp w0, #0 -; CHECK-NEXT: cneg w8, w0, mi -; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: cneg w0, w0, pl ; CHECK-NEXT: ret %abs = tail call i32 @llvm.abs.i32(i32 %x, i1 true) %neg = sub nsw i32 0, %abs