Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -32052,6 +32052,29 @@ return DAG.getNode(Opcode, DL, N->getValueType(0), LHS, RHS); } + // Some mask scalar intrinsics rely on checking if only one bit is set + // and implement it in C code like this: + // A[0] = (U & 1) ? A[0] : W[0]; + // This creates some redundant instructions that break pattern matching. + // fold (select (setcc (and (X, 1), 0, seteq), Y, Z)) -> select(X, Z, Y) + if (Subtarget.hasAVX512() && N->getOpcode() == ISD::SELECT && + Cond.getOpcode() == ISD::SETCC && (VT == MVT::f32 || VT == MVT::f64)) { + ISD::CondCode CC = cast(Cond.getOperand(2))->get(); + SDValue AndNode = Cond.getOperand(0); + if (AndNode.getOpcode() == ISD::AND && CC == ISD::SETEQ && + isNullConstant(Cond.getOperand(1)) && + AndNode.getOperand(1).getSimpleValueType().isScalarInteger() && + isa(AndNode.getOperand(1)) && + dyn_cast(AndNode.getOperand(1))->getAPIntValue() == 1) { + // LHS and RHS swapped due to + // setcc outputting 1 when AND resulted in 0 and vice versa. + SDValue Mask = AndNode.getOperand(0); + if (Mask.getValueType() != MVT::i8) + Mask = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Mask); + return DAG.getNode(ISD::SELECT, DL, VT, Mask, RHS, LHS); + } + } + // v16i8 (select v16i1, v16i8, v16i8) does not have a proper // lowering on KNL. In this case we convert it to // v16i8 (select v16i8, v16i8, v16i8) and use AVX instruction. Index: lib/Target/X86/X86InstrAVX512.td =================================================================== --- lib/Target/X86/X86InstrAVX512.td +++ lib/Target/X86/X86InstrAVX512.td @@ -6633,6 +6633,56 @@ defm VFNMSUB : avx512_fma3s<0xAF, 0xBF, 0x9F, "vfnmsub", X86Fnmsub, X86Fnmsubs1, X86FnmsubRnds1, X86Fnmsubs3, X86FnmsubRnds3>; +multiclass avx512_scalar_fma_patterns { + let Predicates = [HasFMA, HasAVX512] in { + def : Pat<(VT (Move (VT VR128:$src2), (VT (scalar_to_vector + (X86selects VK1WM:$mask, + (Op (EltVT (extractelt (VT VR128:$src1), (iPTR 0))), + (EltVT (extractelt (VT VR128:$src2), (iPTR 0))), + (EltVT (extractelt (VT VR128:$src3), (iPTR 0)))), + (EltVT (extractelt (VT VR128:$src2), (iPTR 0)))))))), + (!cast(Prefix#"213"#Suffix#"Zr_Intk") + VR128:$src2, VK1WM:$mask, VR128:$src1, VR128:$src3)>; + + def : Pat<(VT (Move (VT VR128:$src2), (VT (scalar_to_vector + (X86selects VK1WM:$mask, + (Op (EltVT (extractelt (VT VR128:$src1), (iPTR 0))), + (EltVT (extractelt (VT VR128:$src2), (iPTR 0))), + (EltVT (extractelt (VT VR128:$src3), (iPTR 0)))), + (EltVT (extractelt (VT VR128:$src3), (iPTR 0)))))))), + (!cast(Prefix#"231"#Suffix#"Zr_Intk") + VR128:$src3, VK1WM:$mask, VR128:$src2, VR128:$src1)>; + + def : Pat<(VT (Move (VT VR128:$src2), (VT (scalar_to_vector + (X86selects VK1WM:$mask, + (Op (EltVT (extractelt (VT VR128:$src1), (iPTR 0))), + (EltVT (extractelt (VT VR128:$src2), (iPTR 0))), + (EltVT (extractelt (VT VR128:$src3), (iPTR 0)))), + (EltVT ZeroFP)))))), + (!cast(Prefix#"213"#Suffix#"Zr_Intkz") + VR128:$src2, VK1WM:$mask, VR128:$src1, VR128:$src3)>; + } +} + +defm : avx512_scalar_fma_patterns; +defm : avx512_scalar_fma_patterns; +defm : avx512_scalar_fma_patterns; +defm : avx512_scalar_fma_patterns; + +defm : avx512_scalar_fma_patterns; +defm : avx512_scalar_fma_patterns; +defm : avx512_scalar_fma_patterns; +defm : avx512_scalar_fma_patterns; + //===----------------------------------------------------------------------===// // AVX-512 Packed Multiply of Unsigned 52-bit Integers and Add the Low 52-bit IFMA //===----------------------------------------------------------------------===// @@ -8435,6 +8485,42 @@ VEX_W, AVX512AIi8Base, EVEX_4V, EVEX_CD8<64, CD8VT1>; +multiclass avx512_masked_scalar { + let Predicates = [BasePredicate] in { + def : Pat<(Move _.VT:$src1, (scalar_to_vector (X86selects Mask, + (OpNode (extractelt _.VT:$src2, (iPTR 0))), + (extractelt _.VT:$dst, (iPTR 0))))), + (!cast("V"#OpcPrefix#r_Intk) + _.VT:$dst, OutMask, _.VT:$src2, _.VT:$src1)>; + + def : Pat<(Move _.VT:$src1, (scalar_to_vector (X86selects Mask, + (OpNode (extractelt _.VT:$src2, (iPTR 0))), + ZeroFP))), + (!cast("V"#OpcPrefix#r_Intkz) + OutMask, _.VT:$src2, _.VT:$src1)>; + } +} + +multiclass avx512_masked_scalar_imm ImmV, dag OutMask, + Predicate BasePredicate> { + let Predicates = [BasePredicate] in { + def : Pat<(Move _.VT:$src1, (scalar_to_vector (X86selects Mask, + (OpNode (extractelt _.VT:$src2, (iPTR 0))), + (extractelt _.VT:$dst, (iPTR 0))))), + (!cast("V"#OpcPrefix#r_Intk) + _.VT:$dst, OutMask, _.VT:$src1, _.VT:$src2, (i32 ImmV))>; + + def : Pat<(Move _.VT:$src1, (scalar_to_vector (X86selects Mask, + (OpNode (extractelt _.VT:$src2, (iPTR 0))), ZeroFP))), + (!cast("V"#OpcPrefix#r_Intkz) + OutMask, _.VT:$src1, _.VT:$src2, (i32 ImmV))>; + } +} + //------------------------------------------------- // Integer truncate and extend operations //------------------------------------------------- @@ -10783,69 +10869,54 @@ // TODO: Some canonicalization in lowering would simplify the number of // patterns we have to try to match. -multiclass AVX512_scalar_math_f32_patterns { +multiclass AVX512_scalar_math_fp_patterns { let Predicates = [HasAVX512] in { // extracted scalar math op with insert via movss - def : Pat<(v4f32 (X86Movss (v4f32 VR128X:$dst), (v4f32 (scalar_to_vector - (Op (f32 (extractelt (v4f32 VR128X:$dst), (iPTR 0))), - FR32X:$src))))), - (!cast("V"#OpcPrefix#SSZrr_Int) v4f32:$dst, - (COPY_TO_REGCLASS FR32X:$src, VR128X))>; + def : Pat<(_.VT (MoveNode (_.VT VR128X:$dst), (_.VT (scalar_to_vector + (Op (_.EltVT (extractelt (_.VT VR128X:$dst), (iPTR 0))), + _.FRC:$src))))), + (!cast("V"#OpcPrefix#Zrr_Int) _.VT:$dst, + (COPY_TO_REGCLASS _.FRC:$src, VR128X))>; // vector math op with insert via movss - def : Pat<(v4f32 (X86Movss (v4f32 VR128X:$dst), - (Op (v4f32 VR128X:$dst), (v4f32 VR128X:$src)))), - (!cast("V"#OpcPrefix#SSZrr_Int) v4f32:$dst, v4f32:$src)>; + def : Pat<(_.VT (MoveNode (_.VT VR128X:$dst), + (Op (_.VT VR128X:$dst), (_.VT VR128X:$src)))), + (!cast("V"#OpcPrefix#Zrr_Int) _.VT:$dst, _.VT:$src)>; // extracted masked scalar math op with insert via movss - def : Pat<(X86Movss (v4f32 VR128X:$src1), + def : Pat<(MoveNode (_.VT VR128X:$src1), (scalar_to_vector (X86selects VK1WM:$mask, - (Op (f32 (extractelt (v4f32 VR128X:$src1), (iPTR 0))), - FR32X:$src2), - FR32X:$src0))), - (!cast("V"#OpcPrefix#SSZrr_Intk) (COPY_TO_REGCLASS FR32X:$src0, VR128X), - VK1WM:$mask, v4f32:$src1, - (COPY_TO_REGCLASS FR32X:$src2, VR128X))>; - } -} - -defm : AVX512_scalar_math_f32_patterns; -defm : AVX512_scalar_math_f32_patterns; -defm : AVX512_scalar_math_f32_patterns; -defm : AVX512_scalar_math_f32_patterns; - -multiclass AVX512_scalar_math_f64_patterns { - let Predicates = [HasAVX512] in { - // extracted scalar math op with insert via movsd - def : Pat<(v2f64 (X86Movsd (v2f64 VR128X:$dst), (v2f64 (scalar_to_vector - (Op (f64 (extractelt (v2f64 VR128X:$dst), (iPTR 0))), - FR64X:$src))))), - (!cast("V"#OpcPrefix#SDZrr_Int) v2f64:$dst, - (COPY_TO_REGCLASS FR64X:$src, VR128X))>; - - // vector math op with insert via movsd - def : Pat<(v2f64 (X86Movsd (v2f64 VR128X:$dst), - (Op (v2f64 VR128X:$dst), (v2f64 VR128X:$src)))), - (!cast("V"#OpcPrefix#SDZrr_Int) v2f64:$dst, v2f64:$src)>; - + (Op (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), + _.FRC:$src2), + _.FRC:$src0))), + (!cast("V"#OpcPrefix#Zrr_Intk) (COPY_TO_REGCLASS _.FRC:$src0, VR128X), + VK1WM:$mask, _.VT:$src1, + (COPY_TO_REGCLASS _.FRC:$src2, VR128X))>; + // extracted masked scalar math op with insert via movss - def : Pat<(X86Movsd (v2f64 VR128X:$src1), + def : Pat<(MoveNode (_.VT VR128X:$src1), (scalar_to_vector (X86selects VK1WM:$mask, - (Op (f64 (extractelt (v2f64 VR128X:$src1), (iPTR 0))), - FR64X:$src2), - FR64X:$src0))), - (!cast("V"#OpcPrefix#SDZrr_Intk) (COPY_TO_REGCLASS FR64X:$src0, VR128X), - VK1WM:$mask, v2f64:$src1, - (COPY_TO_REGCLASS FR64X:$src2, VR128X))>; + (Op (_.EltVT (extractelt (_.VT VR128X:$src1), (iPTR 0))), + _.FRC:$src2), (_.EltVT ZeroFP)))), + (!cast("V"#OpcPrefix#Zrr_Intkz) + VK1WM:$mask, _.VT:$src1, + (COPY_TO_REGCLASS _.FRC:$src2, VR128X))>; } } -defm : AVX512_scalar_math_f64_patterns; -defm : AVX512_scalar_math_f64_patterns; -defm : AVX512_scalar_math_f64_patterns; -defm : AVX512_scalar_math_f64_patterns; +defm : AVX512_scalar_math_fp_patterns; +defm : AVX512_scalar_math_fp_patterns; +defm : AVX512_scalar_math_fp_patterns; +defm : AVX512_scalar_math_fp_patterns; + +defm : AVX512_scalar_math_fp_patterns; +defm : AVX512_scalar_math_fp_patterns; +defm : AVX512_scalar_math_fp_patterns; +defm : AVX512_scalar_math_fp_patterns; + //===----------------------------------------------------------------------===// // AES instructions Index: lib/Target/X86/X86InstrFMA.td =================================================================== --- lib/Target/X86/X86InstrFMA.td +++ lib/Target/X86/X86InstrFMA.td @@ -364,6 +364,28 @@ defm VFNMSUB : fma3s<0x9F, 0xAF, 0xBF, "vfnmsub", X86Fnmsubs1, X86Fnmsub, SchedWriteFMA.Scl>, VEX_LIG; +multiclass scalar_fma_patterns { + let Predicates = [HasFMA] in { + def : Pat<(VT (Move (VT VR128:$src2), (VT (scalar_to_vector + (Op (EltVT (extractelt (VT VR128:$src1), (iPTR 0))), + (EltVT (extractelt (VT VR128:$src2), (iPTR 0))), + (EltVT (extractelt (VT VR128:$src3), (iPTR 0)))))))), + (!cast(Prefix#"213"#Suffix#"r_Int") + VR128:$src2, VR128:$src1, VR128:$src3)>; + } +} + +defm : scalar_fma_patterns; +defm : scalar_fma_patterns; +defm : scalar_fma_patterns; +defm : scalar_fma_patterns; + +defm : scalar_fma_patterns; +defm : scalar_fma_patterns; +defm : scalar_fma_patterns; +defm : scalar_fma_patterns; + //===----------------------------------------------------------------------===// // FMA4 - AMD 4 operand Fused Multiply-Add instructions //===----------------------------------------------------------------------===// Index: lib/Target/X86/X86InstrSSE.td =================================================================== --- lib/Target/X86/X86InstrSSE.td +++ lib/Target/X86/X86InstrSSE.td @@ -2682,78 +2682,49 @@ // TODO: Some canonicalization in lowering would simplify the number of // patterns we have to try to match. -multiclass scalar_math_f32_patterns { - let Predicates = [UseSSE1] in { - // extracted scalar math op with insert via movss - def : Pat<(v4f32 (X86Movss (v4f32 VR128:$dst), (v4f32 (scalar_to_vector - (Op (f32 (extractelt (v4f32 VR128:$dst), (iPTR 0))), - FR32:$src))))), - (!cast(OpcPrefix#SSrr_Int) v4f32:$dst, - (COPY_TO_REGCLASS FR32:$src, VR128))>; - - // vector math op with insert via movss - def : Pat<(v4f32 (X86Movss (v4f32 VR128:$dst), - (Op (v4f32 VR128:$dst), (v4f32 VR128:$src)))), - (!cast(OpcPrefix#SSrr_Int) v4f32:$dst, v4f32:$src)>; - } - - // Repeat everything for AVX. - let Predicates = [UseAVX] in { - // extracted scalar math op with insert via movss - def : Pat<(v4f32 (X86Movss (v4f32 VR128:$dst), (v4f32 (scalar_to_vector - (Op (f32 (extractelt (v4f32 VR128:$dst), (iPTR 0))), - FR32:$src))))), - (!cast("V"#OpcPrefix#SSrr_Int) v4f32:$dst, - (COPY_TO_REGCLASS FR32:$src, VR128))>; - - // vector math op with insert via movss - def : Pat<(v4f32 (X86Movss (v4f32 VR128:$dst), - (Op (v4f32 VR128:$dst), (v4f32 VR128:$src)))), - (!cast("V"#OpcPrefix#SSrr_Int) v4f32:$dst, v4f32:$src)>; - } -} - -defm : scalar_math_f32_patterns; -defm : scalar_math_f32_patterns; -defm : scalar_math_f32_patterns; -defm : scalar_math_f32_patterns; - -multiclass scalar_math_f64_patterns { - let Predicates = [UseSSE2] in { - // extracted scalar math op with insert via movsd - def : Pat<(v2f64 (X86Movsd (v2f64 VR128:$dst), (v2f64 (scalar_to_vector - (Op (f64 (extractelt (v2f64 VR128:$dst), (iPTR 0))), - FR64:$src))))), - (!cast(OpcPrefix#SDrr_Int) v2f64:$dst, - (COPY_TO_REGCLASS FR64:$src, VR128))>; - - // vector math op with insert via movsd - def : Pat<(v2f64 (X86Movsd (v2f64 VR128:$dst), - (Op (v2f64 VR128:$dst), (v2f64 VR128:$src)))), - (!cast(OpcPrefix#SDrr_Int) v2f64:$dst, v2f64:$src)>; - } - - // Repeat everything for AVX. - let Predicates = [UseAVX] in { - // extracted scalar math op with insert via movsd - def : Pat<(v2f64 (X86Movsd (v2f64 VR128:$dst), (v2f64 (scalar_to_vector - (Op (f64 (extractelt (v2f64 VR128:$dst), (iPTR 0))), - FR64:$src))))), - (!cast("V"#OpcPrefix#SDrr_Int) v2f64:$dst, - (COPY_TO_REGCLASS FR64:$src, VR128))>; - - // vector math op with insert via movsd - def : Pat<(v2f64 (X86Movsd (v2f64 VR128:$dst), - (Op (v2f64 VR128:$dst), (v2f64 VR128:$src)))), - (!cast("V"#OpcPrefix#SDrr_Int) v2f64:$dst, v2f64:$src)>; - } -} - -defm : scalar_math_f64_patterns; -defm : scalar_math_f64_patterns; -defm : scalar_math_f64_patterns; -defm : scalar_math_f64_patterns; - +multiclass scalar_math_patterns { + let Predicates = [BasePredicate] in { + // extracted scalar math op with insert via movss/movsd + def : Pat<(VT (Move (VT VR128:$dst), (VT (scalar_to_vector + (Op (EltTy (extractelt (VT VR128:$dst), (iPTR 0))), + RC:$src))))), + (!cast(OpcPrefix#rr_Int) VT:$dst, + (COPY_TO_REGCLASS RC:$src, VR128))>; + + // vector math op with insert via movss/movsd + def : Pat<(VT (Move (VT VR128:$dst), + (Op (VT VR128:$dst), (VT VR128:$src)))), + (!cast(OpcPrefix#rr_Int) VT:$dst, VT:$src)>; + } + + // Repeat for AVX versions of the instructions. + let Predicates = [UseAVX] in { + // extracted scalar math op with insert via movss/movsd + def : Pat<(VT (Move (VT VR128:$dst), (VT (scalar_to_vector + (Op (EltTy (extractelt (VT VR128:$dst), (iPTR 0))), + RC:$src))))), + (!cast("V"#OpcPrefix#rr_Int) VT:$dst, + (COPY_TO_REGCLASS RC:$src, VR128))>; + + // vector math op with insert via movss/movsd + def : Pat<(VT (Move (VT VR128:$dst), + (Op (VT VR128:$dst), (VT VR128:$src)))), + (!cast("V"#OpcPrefix#rr_Int) VT:$dst, VT:$src)>; + } +} + +defm : scalar_math_patterns; +defm : scalar_math_patterns; +defm : scalar_math_patterns; +defm : scalar_math_patterns; + +defm : scalar_math_patterns; +defm : scalar_math_patterns; +defm : scalar_math_patterns; +defm : scalar_math_patterns; + /// Unop Arithmetic /// In addition, we also have a special variant of the scalar form here to /// represent the associated intrinsic operation. This form is unlike the @@ -2982,13 +2953,42 @@ // There is no f64 version of the reciprocal approximation instructions. -// TODO: We should add *scalar* op patterns for these just like we have for -// the binops above. If the binop and unop patterns could all be unified -// that would be even better. +multiclass scalar_unary_math_patterns { + let Predicates = [BasePredicate] in { + def : Pat<(VT (Move VT:$dst, (scalar_to_vector + (OpNode (extractelt VT:$src, 0))))), + (!cast(OpcPrefix#r_Int) VT:$dst, VT:$src)>; + } + + // Repeat for AVX versions of the instructions. + let Predicates = [HasAVX] in { + def : Pat<(VT (Move VT:$dst, (scalar_to_vector + (OpNode (extractelt VT:$src, 0))))), + (!cast("V"#OpcPrefix#r_Int) VT:$dst, VT:$src)>; + } +} + +multiclass scalar_unary_math_imm_patterns ImmV, + Predicate BasePredicate> { + let Predicates = [BasePredicate] in { + def : Pat<(VT (Move VT:$dst, (scalar_to_vector + (OpNode (extractelt VT:$src, 0))))), + (!cast(OpcPrefix#r_Int) VT:$dst, VT:$src, (i32 ImmV))>; + } + + // Repeat for AVX versions of the instructions. + let Predicates = [HasAVX] in { + def : Pat<(VT (Move VT:$dst, (scalar_to_vector + (OpNode (extractelt VT:$src, 0))))), + (!cast("V"#OpcPrefix#r_Int) VT:$dst, VT:$src, (i32 ImmV))>; + } +} -multiclass scalar_unary_math_patterns { +multiclass scalar_unary_math_intr_patterns { let Predicates = [BasePredicate] in { def : Pat<(VT (Move VT:$dst, (Intr VT:$src))), (!cast(OpcPrefix#r_Int) VT:$dst, VT:$src)>; @@ -3001,14 +3001,14 @@ } } -defm : scalar_unary_math_patterns; -defm : scalar_unary_math_patterns; -defm : scalar_unary_math_patterns; -defm : scalar_unary_math_patterns; +defm : scalar_unary_math_intr_patterns; +defm : scalar_unary_math_intr_patterns; +defm : scalar_unary_math_intr_patterns; +defm : scalar_unary_math_intr_patterns; //===----------------------------------------------------------------------===//