Index: llvm/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.cpp +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -48952,6 +48952,83 @@ return DAG.getBitcast(VT, CFmul); } +/// This inverts a canonicalization in IR that replaces a variable select arm +/// with an identity constant. Codegen improves if we re-use the variable +/// operand rather than load a constant. This can also be converted into a +/// masked vector operation if the target supports it. +static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG, + bool ShouldCommuteOperands) { + // Match a select as operand 1. The identity constant that we are looking for + // is only valid as operand 1 of a non-commutative binop. + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + if (ShouldCommuteOperands) + std::swap(N0, N1); + + // TODO: Should this apply to scalar select too? + if (!N1.hasOneUse() || N1.getOpcode() != ISD::VSELECT) + return SDValue(); + + unsigned Opcode = N->getOpcode(); + EVT VT = N->getValueType(0); + SDValue Cond = N1.getOperand(0); + SDValue TVal = N1.getOperand(1); + SDValue FVal = N1.getOperand(2); + + // TODO: This (and possibly the entire function) belongs in a + // target-independent location with target hooks. + // TODO: The cases should match with IR's ConstantExpr::getBinOpIdentity(). + // TODO: With fast-math (NSZ), allow the opposite-sign form of zero? + auto isIdentityConstantForOpcode = [](unsigned Opcode, SDValue V) { + if (ConstantFPSDNode *C = isConstOrConstSplatFP(V)) { + switch (Opcode) { + case ISD::FADD: // X + -0.0 --> X + return C->isZero() && C->isNegative(); + case ISD::FSUB: // X - 0.0 --> X + return C->isZero() && !C->isNegative(); + } + } + return false; + }; + + // This transform increases uses of N0, so freeze it to be safe. + // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal) + if (isIdentityConstantForOpcode(Opcode, TVal)) { + SDValue F0 = DAG.getFreeze(N0); + SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags()); + return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO); + } + // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0 + if (isIdentityConstantForOpcode(Opcode, FVal)) { + SDValue F0 = DAG.getFreeze(N0); + SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags()); + return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0); + } + + return SDValue(); +} + +static SDValue combineBinopWithSelect(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // TODO: This is too general. There are cases where pre-AVX512 codegen would + // benefit. The transform may also be profitable for scalar code. + if (!Subtarget.hasAVX512()) + return SDValue(); + + if (!Subtarget.hasVLX() && !N->getValueType(0).is512BitVector()) + return SDValue(); + + if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, false)) + return Sel; + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (TLI.isCommutativeBinOp(N->getOpcode())) + if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, true)) + return Sel; + + return SDValue(); +} + /// Do target-specific dag combines on floating-point adds/subs. static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { @@ -48961,6 +49038,9 @@ if (SDValue COp = combineFaddCFmul(N, DAG, Subtarget)) return COp; + if (SDValue Sel = combineBinopWithSelect(N, DAG, Subtarget)) + return Sel; + return SDValue(); } Index: llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll =================================================================== --- llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll +++ llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll @@ -83,8 +83,8 @@ ; CHECK: # %bb.0: ; CHECK-NEXT: kmovd %edi, %k1 ; CHECK-NEXT: vsubph %zmm2, %zmm1, %zmm0 {%k1} {z} -; CHECK-NEXT: vsubph (%rsi), %zmm1, %zmm1 {%k1} {z} -; CHECK-NEXT: vsubph %zmm1, %zmm0, %zmm0 +; CHECK-NEXT: vsubph (%rsi), %zmm1, %zmm1 +; CHECK-NEXT: vsubph %zmm1, %zmm0, %zmm0 {%k1} ; CHECK-NEXT: retq %mask = bitcast i32 %msk to <32 x i1> %val = load <32 x half>, <32 x half>* %ptr Index: llvm/test/CodeGen/X86/vector-bo-select.ll =================================================================== --- llvm/test/CodeGen/X86/vector-bo-select.ll +++ llvm/test/CodeGen/X86/vector-bo-select.ll @@ -27,9 +27,8 @@ ; AVX512VL: # %bb.0: ; AVX512VL-NEXT: vpslld $31, %xmm0, %xmm0 ; AVX512VL-NEXT: vptestmd %xmm0, %xmm0, %k1 -; AVX512VL-NEXT: vbroadcastss {{.*#+}} xmm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0] -; AVX512VL-NEXT: vmovaps %xmm2, %xmm0 {%k1} -; AVX512VL-NEXT: vaddps %xmm0, %xmm1, %xmm0 +; AVX512VL-NEXT: vaddps %xmm2, %xmm1, %xmm1 {%k1} +; AVX512VL-NEXT: vmovaps %xmm1, %xmm0 ; AVX512VL-NEXT: retq %s = select <4 x i1> %b, <4 x float> %y, <4 x float> %r = fadd <4 x float> %x, %s @@ -62,9 +61,8 @@ ; AVX512VL-NEXT: vpmovsxwd %xmm0, %ymm0 ; AVX512VL-NEXT: vpslld $31, %ymm0, %ymm0 ; AVX512VL-NEXT: vptestmd %ymm0, %ymm0, %k1 -; AVX512VL-NEXT: vbroadcastss {{.*#+}} ymm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0] -; AVX512VL-NEXT: vmovaps %ymm2, %ymm0 {%k1} -; AVX512VL-NEXT: vaddps %ymm1, %ymm0, %ymm0 +; AVX512VL-NEXT: vaddps %ymm2, %ymm1, %ymm1 {%k1} +; AVX512VL-NEXT: vmovaps %ymm1, %ymm0 ; AVX512VL-NEXT: retq %s = select <8 x i1> %b, <8 x float> %y, <8 x float> %r = fadd <8 x float> %s, %x @@ -92,8 +90,8 @@ ; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0 ; AVX512-NEXT: vpslld $31, %zmm0, %zmm0 ; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1 -; AVX512-NEXT: vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1} ; AVX512-NEXT: vaddps %zmm2, %zmm1, %zmm0 +; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1} ; AVX512-NEXT: retq %s = select <16 x i1> %b, <16 x float> , <16 x float> %y %r = fadd <16 x float> %x, %s @@ -121,8 +119,8 @@ ; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0 ; AVX512-NEXT: vpslld $31, %zmm0, %zmm0 ; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1 -; AVX512-NEXT: vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1} -; AVX512-NEXT: vaddps %zmm1, %zmm2, %zmm0 +; AVX512-NEXT: vaddps %zmm2, %zmm1, %zmm0 +; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1} ; AVX512-NEXT: retq %s = select <16 x i1> %b, <16 x float> , <16 x float> %y %r = fadd <16 x float> %s, %x @@ -152,14 +150,16 @@ ; AVX512VL: # %bb.0: ; AVX512VL-NEXT: vpslld $31, %xmm0, %xmm0 ; AVX512VL-NEXT: vptestmd %xmm0, %xmm0, %k1 -; AVX512VL-NEXT: vmovaps %xmm2, %xmm0 {%k1} {z} -; AVX512VL-NEXT: vsubps %xmm0, %xmm1, %xmm0 +; AVX512VL-NEXT: vsubps %xmm2, %xmm1, %xmm1 {%k1} +; AVX512VL-NEXT: vmovaps %xmm1, %xmm0 ; AVX512VL-NEXT: retq %s = select <4 x i1> %b, <4 x float> %y, <4 x float> zeroinitializer %r = fsub <4 x float> %x, %s ret <4 x float> %r } +; negative test - fsub is not commutative; there is no identity constant for operand 0 + define <8 x float> @fsub_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x float> noundef %y) { ; AVX2-LABEL: fsub_v8f32_commute: ; AVX2: # %bb.0: @@ -214,15 +214,17 @@ ; AVX512: # %bb.0: ; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0 ; AVX512-NEXT: vpslld $31, %zmm0, %zmm0 -; AVX512-NEXT: vptestnmd %zmm0, %zmm0, %k1 -; AVX512-NEXT: vmovaps %zmm2, %zmm0 {%k1} {z} -; AVX512-NEXT: vsubps %zmm0, %zmm1, %zmm0 +; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1 +; AVX512-NEXT: vsubps %zmm2, %zmm1, %zmm0 +; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1} ; AVX512-NEXT: retq %s = select <16 x i1> %b, <16 x float> zeroinitializer, <16 x float> %y %r = fsub <16 x float> %x, %s ret <16 x float> %r } +; negative test - fsub is not commutative; there is no identity constant for operand 0 + define <16 x float> @fsub_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef %x, <16 x float> noundef %y) { ; AVX2-LABEL: fsub_v16f32_commute_swap: ; AVX2: # %bb.0: @@ -318,9 +320,7 @@ ; AVX512VL-LABEL: fadd_v8f32_cast_cond: ; AVX512VL: # %bb.0: ; AVX512VL-NEXT: kmovw %edi, %k1 -; AVX512VL-NEXT: vbroadcastss {{.*#+}} ymm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0] -; AVX512VL-NEXT: vmovaps %ymm1, %ymm2 {%k1} -; AVX512VL-NEXT: vaddps %ymm2, %ymm0, %ymm0 +; AVX512VL-NEXT: vaddps %ymm1, %ymm0, %ymm0 {%k1} ; AVX512VL-NEXT: retq %b = bitcast i8 %pb to <8 x i1> %s = select <8 x i1> %b, <8 x float> %y, <8 x float> @@ -384,9 +384,7 @@ ; AVX512-LABEL: fadd_v8f64_cast_cond: ; AVX512: # %bb.0: ; AVX512-NEXT: kmovw %edi, %k1 -; AVX512-NEXT: vbroadcastsd {{.*#+}} zmm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0] -; AVX512-NEXT: vmovapd %zmm1, %zmm2 {%k1} -; AVX512-NEXT: vaddpd %zmm2, %zmm0, %zmm0 +; AVX512-NEXT: vaddpd %zmm1, %zmm0, %zmm0 {%k1} ; AVX512-NEXT: retq %b = bitcast i8 %pb to <8 x i1> %s = select <8 x i1> %b, <8 x double> %y, <8 x double> @@ -457,8 +455,7 @@ ; AVX512VL-LABEL: fsub_v8f32_cast_cond: ; AVX512VL: # %bb.0: ; AVX512VL-NEXT: kmovw %edi, %k1 -; AVX512VL-NEXT: vmovaps %ymm1, %ymm1 {%k1} {z} -; AVX512VL-NEXT: vsubps %ymm1, %ymm0, %ymm0 +; AVX512VL-NEXT: vsubps %ymm1, %ymm0, %ymm0 {%k1} ; AVX512VL-NEXT: retq %b = bitcast i8 %pb to <8 x i1> %s = select <8 x i1> %b, <8 x float> %y, <8 x float> zeroinitializer @@ -523,8 +520,7 @@ ; AVX512-LABEL: fsub_v8f64_cast_cond: ; AVX512: # %bb.0: ; AVX512-NEXT: kmovw %edi, %k1 -; AVX512-NEXT: vmovapd %zmm1, %zmm1 {%k1} {z} -; AVX512-NEXT: vsubpd %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vsubpd %zmm1, %zmm0, %zmm0 {%k1} ; AVX512-NEXT: retq %b = bitcast i8 %pb to <8 x i1> %s = select <8 x i1> %b, <8 x double> %y, <8 x double> zeroinitializer