Index: llvm/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.cpp +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -30480,6 +30480,37 @@ return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1); } +// For SSE4.1 and AVX512, we may want to combine VRNDSCALES from +// vector_shuffle<{0,3}|{0,5,6,7}> (fceil|ffloor A), B +// patterns. +static SDValue combineShuffleFloorCeil(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!isa(N)) + return SDValue(); + EVT VT = N->getValueType(0); + unsigned Num = VT.getVectorNumElements(); + if (Num * VT.getScalarSizeInBits() != 128 || !Subtarget.hasSSE41()) + return SDValue(); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + int Op = N0.getOpcode(); + if ((Num != 2 && Num != 4) || (Op != ISD::FCEIL && Op != ISD::FFLOOR)) + return SDValue(); + + // The mask being matched here is equivalent to a 0...01 select mask. + ShuffleVectorSDNode *SVOp = cast(N); + if (SVOp->getMaskElt(0) != 0) + return SDValue(); + for (unsigned i = 1; i < Num; ++i) + if (SVOp->getMaskElt(i) != Num + i) + return SDValue(); + + int Imm = (Op == ISD::FCEIL) ? 2 : 1; + SDLoc DL(N); + return DAG.getNode(X86ISD::VRNDSCALES, DL, VT, N1, N0.getOperand(0), + DAG.getConstant(Imm, DL, MVT::i32)); +} + // We are looking for a shuffle where both sources are concatenated with undef // and have a width that is half of the output's width. AVX2 has VPERMD/Q, so // if we can express this as a single-source shuffle, that's preferable. @@ -30653,6 +30684,9 @@ EltsFromConsecutiveLoads(VT, Elts, dl, DAG, Subtarget, true)) return LD; + if (SDValue RndScale = combineShuffleFloorCeil(N, DAG, Subtarget)) + return RndScale; + // For AVX2, we sometimes want to combine // (vector_shuffle (concat_vectors t1, undef) // (concat_vectors t2, undef)) Index: llvm/lib/Target/X86/X86InstrAVX512.td =================================================================== --- llvm/lib/Target/X86/X86InstrAVX512.td +++ llvm/lib/Target/X86/X86InstrAVX512.td @@ -9370,10 +9370,14 @@ let Predicates = [HasAVX512] in { def : Pat<(v16f32 (ffloor VR512:$src)), (VRNDSCALEPSZrri VR512:$src, (i32 0x9))>; +def : Pat<(v16f32 (vselect v16f32_info.KRCWM:$mask, (ffloor VR512:$src), VR512:$dst)), + (VRNDSCALEPSZrrik VR512:$dst, v16f32_info.KRCWM:$mask, VR512:$src, (i32 0x9))>; def : Pat<(v16f32 (fnearbyint VR512:$src)), (VRNDSCALEPSZrri VR512:$src, (i32 0xC))>; def : Pat<(v16f32 (fceil VR512:$src)), (VRNDSCALEPSZrri VR512:$src, (i32 0xA))>; +def : Pat<(v16f32 (vselect v16f32_info.KRCWM:$mask, (fceil VR512:$src), VR512:$dst)), + (VRNDSCALEPSZrrik VR512:$dst, v16f32_info.KRCWM:$mask, VR512:$src, (i32 0xA))>; def : Pat<(v16f32 (frint VR512:$src)), (VRNDSCALEPSZrri VR512:$src, (i32 0x4))>; def : Pat<(v16f32 (ftrunc VR512:$src)), @@ -9381,10 +9385,14 @@ def : Pat<(v8f64 (ffloor VR512:$src)), (VRNDSCALEPDZrri VR512:$src, (i32 0x9))>; +def : Pat<(v8f64 (vselect v8f64_info.KRCWM:$mask, (ffloor VR512:$src), VR512:$dst)), + (VRNDSCALEPDZrrik VR512:$dst, v8f64_info.KRCWM:$mask, VR512:$src, (i32 0x9))>; def : Pat<(v8f64 (fnearbyint VR512:$src)), (VRNDSCALEPDZrri VR512:$src, (i32 0xC))>; def : Pat<(v8f64 (fceil VR512:$src)), (VRNDSCALEPDZrri VR512:$src, (i32 0xA))>; +def : Pat<(v8f64 (vselect v8f64_info.KRCWM:$mask, (fceil VR512:$src), VR512:$dst)), + (VRNDSCALEPDZrrik VR512:$dst, v8f64_info.KRCWM:$mask, VR512:$src, (i32 0xA))>; def : Pat<(v8f64 (frint VR512:$src)), (VRNDSCALEPDZrri VR512:$src, (i32 0x4))>; def : Pat<(v8f64 (ftrunc VR512:$src)),