diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3682,11 +3682,11 @@ /// Return if the N is a constant or constant vector equal to the true value /// from getBooleanContents(). - bool isConstTrueVal(const SDNode *N) const; + bool isConstTrueVal(SDValue N) const; /// Return if the N is a constant or constant vector equal to the false value /// from getBooleanContents(). - bool isConstFalseVal(const SDNode *N) const; + bool isConstFalseVal(SDValue N) const; /// Return if \p N is a True value when extended to \p VT. bool isExtendedTrueVal(const ConstantSDNode *N, EVT VT, bool SExt) const; diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -906,9 +906,8 @@ return true; } - if (N.getOpcode() != ISD::SELECT_CC || - !TLI.isConstTrueVal(N.getOperand(2).getNode()) || - !TLI.isConstFalseVal(N.getOperand(3).getNode())) + if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) || + !TLI.isConstFalseVal(N.getOperand(3))) return false; if (TLI.getBooleanContents(N.getValueType()) == @@ -8035,8 +8034,8 @@ // fold !(x cc y) -> (x !cc y) unsigned N0Opcode = N0.getOpcode(); SDValue LHS, RHS, CC; - if (TLI.isConstTrueVal(N1.getNode()) && - isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/true)) { + if (TLI.isConstTrueVal(N1) && + isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) { ISD::CondCode NotCC = ISD::getSetCCInverse(cast(CC)->get(), LHS.getValueType()); if (!LegalOperations || diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -3194,29 +3194,25 @@ // FIXME: Ideally, this would use ISD::isConstantSplatVector(), but that must // work with truncating build vectors and vectors with elements of less than // 8 bits. -bool TargetLowering::isConstTrueVal(const SDNode *N) const { +bool TargetLowering::isConstTrueVal(SDValue N) const { if (!N) return false; + unsigned EltWidth; APInt CVal; - if (auto *CN = dyn_cast(N)) { + if (ConstantSDNode *CN = isConstOrConstSplat(N, /*AllowUndefs=*/false, + /*AllowTruncation=*/true)) { CVal = CN->getAPIntValue(); - } else if (auto *BV = dyn_cast(N)) { - auto *CN = BV->getConstantSplatNode(); - if (!CN) - return false; - - // If this is a truncating build vector, truncate the splat value. - // Otherwise, we may fail to match the expected values below. - unsigned BVEltWidth = BV->getValueType(0).getScalarSizeInBits(); - CVal = CN->getAPIntValue(); - if (BVEltWidth < CVal.getBitWidth()) - CVal = CVal.trunc(BVEltWidth); - } else { + EltWidth = N.getValueType().getScalarSizeInBits(); + } else return false; - } - switch (getBooleanContents(N->getValueType(0))) { + // If this is a truncating splat, truncate the splat value. + // Otherwise, we may fail to match the expected values below. + if (EltWidth < CVal.getBitWidth()) + CVal = CVal.trunc(EltWidth); + + switch (getBooleanContents(N.getValueType())) { case UndefinedBooleanContent: return CVal[0]; case ZeroOrOneBooleanContent: @@ -3228,7 +3224,7 @@ llvm_unreachable("Invalid boolean contents"); } -bool TargetLowering::isConstFalseVal(const SDNode *N) const { +bool TargetLowering::isConstFalseVal(SDValue N) const { if (!N) return false; @@ -3763,7 +3759,7 @@ if (TopSetCC.getValueType() == MVT::i1 && VT == MVT::i1 && TopSetCC.getOpcode() == ISD::SETCC && (N0Opc == ISD::ZERO_EXTEND || N0Opc == ISD::SIGN_EXTEND) && - (isConstFalseVal(N1C) || + (isConstFalseVal(N1) || isExtendedTrueVal(N1C, N0->getValueType(0), SExt))) { bool Inverse = (N1C->isZero() && Cond == ISD::SETEQ) || diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -14527,7 +14527,7 @@ SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); const TargetLowering *TLI = Subtarget->getTargetLowering(); - if (TLI->isConstTrueVal(N1.getNode()) && + if (TLI->isConstTrueVal(N1) && (N0->getOpcode() == ARMISD::VCMP || N0->getOpcode() == ARMISD::VCMPZ)) { if (CanInvertMVEVCMP(N0)) { SDLoc DL(N0); diff --git a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll @@ -0,0 +1,54 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve -o - < %s | FileCheck %s + +define @not_icmp_sle_nxv8i16( %a, %b) { +; CHECK-LABEL: not_icmp_sle_nxv8i16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: cmpgt p0.h, p0/z, z0.h, z1.h +; CHECK-NEXT: ret + %icmp = icmp sle %a, %b + %tmp = insertelement undef, i1 true, i32 0 + %ones = shufflevector %tmp, undef, zeroinitializer + %not = xor %ones, %icmp + ret %not +} + +define @not_icmp_sgt_nxv4i32( %a, %b) { +; CHECK-LABEL: not_icmp_sgt_nxv4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmpge p0.s, p0/z, z1.s, z0.s +; CHECK-NEXT: ret + %icmp = icmp sgt %a, %b + %tmp = insertelement undef, i1 true, i32 0 + %ones = shufflevector %tmp, undef, zeroinitializer + %not = xor %icmp, %ones + ret %not +} + +define @not_fcmp_une_nxv2f64( %a, %b) { +; CHECK-LABEL: not_fcmp_une_nxv2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fcmeq p0.d, p0/z, z0.d, z1.d +; CHECK-NEXT: ret + %icmp = fcmp une %a, %b + %tmp = insertelement undef, i1 true, i32 0 + %ones = shufflevector %tmp, undef, zeroinitializer + %not = xor %icmp, %ones + ret %not +} + +define @not_fcmp_uge_nxv4f32( %a, %b) { +; CHECK-LABEL: not_fcmp_uge_nxv4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmgt p0.s, p0/z, z1.s, z0.s +; CHECK-NEXT: ret + %icmp = fcmp uge %a, %b + %tmp = insertelement undef, i1 true, i32 0 + %ones = shufflevector %tmp, undef, zeroinitializer + %not = xor %icmp, %ones + ret %not +} diff --git a/llvm/test/CodeGen/RISCV/rvv/cmp-folds.ll b/llvm/test/CodeGen/RISCV/rvv/cmp-folds.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/cmp-folds.ll @@ -0,0 +1,55 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v -verify-machineinstrs < %s | FileCheck %s + +define @not_icmp_sle_nxv8i16( %a, %b) { +; CHECK-LABEL: not_icmp_sle_nxv8i16: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, mu +; CHECK-NEXT: vmslt.vv v0, v10, v8 +; CHECK-NEXT: ret + %icmp = icmp sle %a, %b + %tmp = insertelement undef, i1 true, i32 0 + %ones = shufflevector %tmp, undef, zeroinitializer + %not = xor %ones, %icmp + ret %not +} + +define @not_icmp_sgt_nxv4i32( %a, %b) { +; CHECK-LABEL: not_icmp_sgt_nxv4i32: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, mu +; CHECK-NEXT: vmsle.vv v0, v8, v10 +; CHECK-NEXT: ret + %icmp = icmp sgt %a, %b + %tmp = insertelement undef, i1 true, i32 0 + %ones = shufflevector %tmp, undef, zeroinitializer + %not = xor %icmp, %ones + ret %not +} + +define @not_fcmp_une_nxv2f64( %a, %b) { +; CHECK-LABEL: not_fcmp_une_nxv2f64: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, mu +; CHECK-NEXT: vmfeq.vv v0, v8, v10 +; CHECK-NEXT: ret + %icmp = fcmp une %a, %b + %tmp = insertelement undef, i1 true, i32 0 + %ones = shufflevector %tmp, undef, zeroinitializer + %not = xor %icmp, %ones + ret %not +} + +define @not_fcmp_uge_nxv4f32( %a, %b) { +; CHECK-LABEL: not_fcmp_uge_nxv4f32: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, mu +; CHECK-NEXT: vmflt.vv v0, v8, v10 +; CHECK-NEXT: ret + %icmp = fcmp uge %a, %b + %tmp = insertelement undef, i1 true, i32 0 + %ones = shufflevector %tmp, undef, zeroinitializer + %not = xor %icmp, %ones + ret %not +}