Index: llvm/include/llvm/IR/Constants.h =================================================================== --- llvm/include/llvm/IR/Constants.h +++ llvm/include/llvm/IR/Constants.h @@ -289,7 +289,8 @@ APInt *Payload = nullptr); static Constant *getSNaN(Type *Ty, bool Negative = false, APInt *Payload = nullptr); - static Constant *getNegativeZero(Type *Ty); + static Constant *getZero(Type *Ty, bool Negative = false); + static Constant *getNegativeZero(Type *Ty) { return getZero(Ty, true); } static Constant *getInfinity(Type *Ty, bool Negative = false); /// Return true if Ty is big enough to represent V. @@ -1120,9 +1121,12 @@ /// commutative, callers can acquire the operand 1 identity constant by /// setting AllowRHSConstant to true. For example, any shift has a zero /// identity constant for operand 1: X shift 0 = X. + /// If this is a fadd/fsub operation and we don't care about signed zeros, + /// then setting NSZ to true returns the identity +0.0 instead of -0.0. /// Return nullptr if the operator does not have an identity constant. static Constant *getBinOpIdentity(unsigned Opcode, Type *Ty, - bool AllowRHSConstant = false); + bool AllowRHSConstant = false, + bool NSZ = false); /// Return the absorbing element for the given binary /// operation, i.e. a constant C such that X op C = C and C op X = C for Index: llvm/lib/IR/Constants.cpp =================================================================== --- llvm/lib/IR/Constants.cpp +++ llvm/lib/IR/Constants.cpp @@ -1037,9 +1037,9 @@ return C; } -Constant *ConstantFP::getNegativeZero(Type *Ty) { +Constant *ConstantFP::getZero(Type *Ty, bool Negative) { const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics(); - APFloat NegZero = APFloat::getZero(Semantics, /*Negative=*/true); + APFloat NegZero = APFloat::getZero(Semantics, Negative); Constant *C = get(Ty->getContext(), NegZero); if (VectorType *VTy = dyn_cast<VectorType>(Ty)) @@ -1048,7 +1048,6 @@ return C; } - Constant *ConstantFP::getZeroValueForNegation(Type *Ty) { if (Ty->isFPOrFPVectorTy()) return getNegativeZero(Ty); @@ -2835,7 +2834,7 @@ } Constant *ConstantExpr::getBinOpIdentity(unsigned Opcode, Type *Ty, - bool AllowRHSConstant) { + bool AllowRHSConstant, bool NSZ) { assert(Instruction::isBinaryOp(Opcode) && "Only binops allowed"); // Commutative opcodes: it does not matter if AllowRHSConstant is set. @@ -2850,8 +2849,7 @@ case Instruction::And: // X & -1 = X return Constant::getAllOnesValue(Ty); case Instruction::FAdd: // X + -0.0 = X - // TODO: If the fadd has 'nsz', should we return +0.0? - return ConstantFP::getNegativeZero(Ty); + return ConstantFP::getZero(Ty, !NSZ); case Instruction::FMul: // X * 1.0 = X return ConstantFP::get(Ty, 1.0); default: Index: llvm/lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- llvm/lib/Target/ARM/ARMISelLowering.cpp +++ llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -16711,9 +16711,10 @@ if (Op.getOpcode() != ISD::BITCAST || Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM) return false; - if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664) + uint64_t ImmVal = Op.getOperand(0).getConstantOperandVal(0); + if (VT == MVT::v4f32 && (ImmVal == 1664 || !ImmVal)) return true; - if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688) + if (VT == MVT::v8f16 && (ImmVal == 2688 || !ImmVal)) return true; return false; }; Index: llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -450,8 +450,11 @@ } if (OpToFold) { - Constant *C = ConstantExpr::getBinOpIdentity(TVI->getOpcode(), - TVI->getType(), true); + FastMathFlags FMF; + if (isa<FPMathOperator>(&SI)) + FMF = SI.getFastMathFlags(); + Constant *C = ConstantExpr::getBinOpIdentity( + TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); Value *OOp = TVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. @@ -460,6 +463,8 @@ if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); + if (isa<FPMathOperator>(&SI)) + cast<Instruction>(NewSel)->setFastMathFlags(FMF); NewSel->takeName(TVI); BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); @@ -482,8 +487,11 @@ } if (OpToFold) { - Constant *C = ConstantExpr::getBinOpIdentity(FVI->getOpcode(), - FVI->getType(), true); + FastMathFlags FMF; + if (isa<FPMathOperator>(&SI)) + FMF = SI.getFastMathFlags(); + Constant *C = ConstantExpr::getBinOpIdentity( + FVI->getOpcode(), FVI->getType(), true, FMF.noSignedZeros()); Value *OOp = FVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. @@ -492,6 +500,8 @@ if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); + if (isa<FPMathOperator>(&SI)) + cast<Instruction>(NewSel)->setFastMathFlags(FMF); NewSel->takeName(FVI); BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(), TrueVal, NewSel); Index: llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll =================================================================== --- llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll +++ llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll @@ -363,6 +363,20 @@ ret <4 x float> %b } +define arm_aapcs_vfpcc <4 x float> @fadd_v4f32_x2(<4 x float> %x, <4 x float> %y, i32 %n) { +; CHECK-LABEL: fadd_v4f32_x2: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vctp.32 r0 +; CHECK-NEXT: vpst +; CHECK-NEXT: vaddt.f32 q0, q0, q1 +; CHECK-NEXT: bx lr +entry: + %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n) + %a = select <4 x i1> %c, <4 x float> %y, <4 x float> <float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00> + %b = fadd <4 x float> %a, %x + ret <4 x float> %b +} + define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x(<8 x half> %x, <8 x half> %y, i32 %n) { ; CHECK-LABEL: fadd_v8f16_x: ; CHECK: @ %bb.0: @ %entry @@ -377,6 +391,20 @@ ret <8 x half> %b } +define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x2(<8 x half> %x, <8 x half> %y, i32 %n) { +; CHECK-LABEL: fadd_v8f16_x2: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vctp.16 r0 +; CHECK-NEXT: vpst +; CHECK-NEXT: vaddt.f16 q0, q0, q1 +; CHECK-NEXT: bx lr +entry: + %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n) + %a = select <8 x i1> %c, <8 x half> %y, <8 x half> <half 0x0000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000, half 0x00000> + %b = fadd <8 x half> %a, %x + ret <8 x half> %b +} + define arm_aapcs_vfpcc <4 x float> @fsub_v4f32_x(<4 x float> %x, <4 x float> %y, i32 %n) { ; CHECK-LABEL: fsub_v4f32_x: ; CHECK: @ %bb.0: @ %entry Index: llvm/test/Transforms/InstCombine/select-binop-foldable-floating-point.ll =================================================================== --- llvm/test/Transforms/InstCombine/select-binop-foldable-floating-point.ll +++ llvm/test/Transforms/InstCombine/select-binop-foldable-floating-point.ll @@ -45,6 +45,39 @@ ret float %D } +define <4 x float> @select_nsz_fadd_v4f32(<4 x i1> %cond, <4 x float> %A, <4 x float> %B) { +; CHECK-LABEL: @select_nsz_fadd_v4f32( +; CHECK-NEXT: [[C:%.*]] = select nnan nsz <4 x i1> [[COND:%.*]], <4 x float> [[B:%.*]], <4 x float> zeroinitializer +; CHECK-NEXT: [[D:%.*]] = fadd nnan nsz <4 x float> [[C]], [[A:%.*]] +; CHECK-NEXT: ret <4 x float> [[D]] +; + %C = fadd nsz nnan <4 x float> %A, %B + %D = select nsz nnan <4 x i1> %cond, <4 x float> %C, <4 x float> %A + ret <4 x float> %D +} + +define <vscale x 4 x float> @select_nsz_fadd_nxv4f32(<vscale x 4 x i1> %cond, <vscale x 4 x float> %A, <vscale x 4 x float> %B) { +; CHECK-LABEL: @select_nsz_fadd_nxv4f32( +; CHECK-NEXT: [[C:%.*]] = select nnan nsz <vscale x 4 x i1> [[COND:%.*]], <vscale x 4 x float> [[B:%.*]], <vscale x 4 x float> zeroinitializer +; CHECK-NEXT: [[D:%.*]] = fadd nnan nsz <vscale x 4 x float> [[C]], [[A:%.*]] +; CHECK-NEXT: ret <vscale x 4 x float> [[D]] +; + %C = fadd nsz nnan <vscale x 4 x float> %A, %B + %D = select nsz nnan <vscale x 4 x i1> %cond, <vscale x 4 x float> %C, <vscale x 4 x float> %A + ret <vscale x 4 x float> %D +} + +define <vscale x 4 x float> @select_nsz_fadd_nxv4f32_swapops(<vscale x 4 x i1> %cond, <vscale x 4 x float> %A, <vscale x 4 x float> %B) { +; CHECK-LABEL: @select_nsz_fadd_nxv4f32_swapops( +; CHECK-NEXT: [[C:%.*]] = select fast <vscale x 4 x i1> [[COND:%.*]], <vscale x 4 x float> zeroinitializer, <vscale x 4 x float> [[B:%.*]] +; CHECK-NEXT: [[D:%.*]] = fadd fast <vscale x 4 x float> [[C]], [[A:%.*]] +; CHECK-NEXT: ret <vscale x 4 x float> [[D]] +; + %C = fadd fast <vscale x 4 x float> %A, %B + %D = select fast <vscale x 4 x i1> %cond, <vscale x 4 x float> %A, <vscale x 4 x float> %C + ret <vscale x 4 x float> %D +} + define float @select_fmul(i1 %cond, float %A, float %B) { ; CHECK-LABEL: @select_fmul( ; CHECK-NEXT: [[C:%.*]] = select i1 [[COND:%.*]], float [[B:%.*]], float 1.000000e+00 @@ -135,7 +168,7 @@ define <4 x float> @select_nsz_fsub_v4f32(<4 x i1> %cond, <4 x float> %A, <4 x float> %B) { ; CHECK-LABEL: @select_nsz_fsub_v4f32( -; CHECK-NEXT: [[C:%.*]] = select <4 x i1> [[COND:%.*]], <4 x float> [[B:%.*]], <4 x float> zeroinitializer +; CHECK-NEXT: [[C:%.*]] = select nsz <4 x i1> [[COND:%.*]], <4 x float> [[B:%.*]], <4 x float> zeroinitializer ; CHECK-NEXT: [[D:%.*]] = fsub <4 x float> [[A:%.*]], [[C]] ; CHECK-NEXT: ret <4 x float> [[D]] ; @@ -146,7 +179,7 @@ define <vscale x 4 x float> @select_nsz_fsub_nxv4f32(<vscale x 4 x i1> %cond, <vscale x 4 x float> %A, <vscale x 4 x float> %B) { ; CHECK-LABEL: @select_nsz_fsub_nxv4f32( -; CHECK-NEXT: [[C:%.*]] = select <vscale x 4 x i1> [[COND:%.*]], <vscale x 4 x float> [[B:%.*]], <vscale x 4 x float> zeroinitializer +; CHECK-NEXT: [[C:%.*]] = select nsz <vscale x 4 x i1> [[COND:%.*]], <vscale x 4 x float> [[B:%.*]], <vscale x 4 x float> zeroinitializer ; CHECK-NEXT: [[D:%.*]] = fsub <vscale x 4 x float> [[A:%.*]], [[C]] ; CHECK-NEXT: ret <vscale x 4 x float> [[D]] ;