diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -886,6 +886,9 @@ setTargetDAGCombine(ISD::FP_TO_UINT); setTargetDAGCombine(ISD::FDIV); + // Try and combine setcc with csel + setTargetDAGCombine(ISD::SETCC); + setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); setTargetDAGCombine(ISD::ANY_EXTEND); @@ -15316,6 +15319,35 @@ return performCONDCombine(N, DCI, DAG, 2, 3); } +static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::SETCC && "Unexpected opcode!"); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + ISD::CondCode Cond = cast(N->getOperand(2))->get(); + + // setcc (csel 0, 1, cond, X), 1, ne ==> csel 0, 1, !cond, X + if (Cond == ISD::SETNE && isOneConstant(RHS) && + LHS->getOpcode() == AArch64ISD::CSEL && + isNullConstant(LHS->getOperand(0)) && isOneConstant(LHS->getOperand(1)) && + LHS->hasOneUse()) { + SDLoc DL(N); + + // Invert CSEL's condition. + auto *OpCC = cast(LHS.getOperand(2)); + auto OldCond = static_cast(OpCC->getZExtValue()); + auto NewCond = getInvertedCondCode(OldCond); + + // csel 0, 1, !cond, X + SDValue CSEL = + DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), LHS.getOperand(0), + LHS.getOperand(1), DAG.getConstant(NewCond, DL, MVT::i32), + LHS.getOperand(3)); + return DAG.getZExtOrTrunc(CSEL, DL, N->getValueType(0)); + } + + return SDValue(); +} + // Optimize some simple tbz/tbnz cases. Returns the new operand and bit to test // as well as whether the test should be inverted. This code is required to // catch these cases (as opposed to standard dag combines) because @@ -16153,6 +16185,8 @@ return performSelectCombine(N, DCI); case ISD::VSELECT: return performVSelectCombine(N, DCI.DAG); + case ISD::SETCC: + return performSETCCCombine(N, DAG); case ISD::LOAD: if (performTBISimplification(N->getOperand(1), DCI, DAG)) return SDValue(N, 0); diff --git a/llvm/test/CodeGen/AArch64/sve-setcc.ll b/llvm/test/CodeGen/AArch64/sve-setcc.ll --- a/llvm/test/CodeGen/AArch64/sve-setcc.ll +++ b/llvm/test/CodeGen/AArch64/sve-setcc.ll @@ -1,5 +1,23 @@ ; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s +; Ensure we use the CC result of SVE compare instructions when branching. +define void @sve_cmplt_setcc(* %out, %in, %pg) { +; CHECK-LABEL: @sve_cmplt_setcc +; CHECK: cmplt p1.h, p0/z, z0.h, #0 +; CHECK-NEXT: b.eq +entry: + %0 = tail call @llvm.aarch64.sve.cmplt.wide.nxv8i16( %pg, %in, zeroinitializer) + %1 = tail call i1 @llvm.aarch64.sve.ptest.any.nxv8i1( %pg, %0) + br i1 %1, label %if.then, label %if.end + +if.then: + tail call void @llvm.masked.store.nxv8i16.p0nxv8i16( %in, * %out, i32 2, %pg) + br label %if.end + +if.end: + ret void +} + ; Ensure we use the inverted CC result of SVE compare instructions when branching. define void @sve_cmplt_setcc_inverted(* %out, %in, %pg) { ; CHECK-LABEL: @sve_cmplt_setcc_inverted @@ -18,7 +36,26 @@ ret void } +; Ensure we combine setcc and csel so as to not end up with an extra compare +define void @sve_cmplt_setcc_hslo(* %out, %in, %pg) { +; CHECK-LABEL: @sve_cmplt_setcc_hslo +; CHECK: cmplt p1.h, p0/z, z0.h, #0 +; CHECK-NEXT: b.hs +entry: + %0 = tail call @llvm.aarch64.sve.cmplt.wide.nxv8i16( %pg, %in, zeroinitializer) + %1 = tail call i1 @llvm.aarch64.sve.ptest.last.nxv8i1( %pg, %0) + br i1 %1, label %if.then, label %if.end + +if.then: + tail call void @llvm.masked.store.nxv8i16.p0nxv8i16( %in, * %out, i32 2, %pg) + br label %if.end + +if.end: + ret void +} + declare i1 @llvm.aarch64.sve.ptest.any.nxv8i1(, ) +declare i1 @llvm.aarch64.sve.ptest.last.nxv8i1(, ) declare @llvm.aarch64.sve.cmplt.wide.nxv8i16(, , )