Index: ../lib/Target/X86/X86ISelLowering.cpp =================================================================== --- ../lib/Target/X86/X86ISelLowering.cpp +++ ../lib/Target/X86/X86ISelLowering.cpp @@ -30872,6 +30872,59 @@ return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones); } +/// Check if truncation with saturation form type \p SrcVT to \p DstVT +/// is valid for the given \p Subtarget. +static bool +isSATValidOnSubtarget(EVT SrcVT, EVT DstVT, const X86Subtarget &Subtarget) { + if (!Subtarget.hasAVX512()) + return false; + EVT SrcElVT = SrcVT.getScalarType(); + EVT DstElVT = DstVT.getScalarType(); + if (SrcElVT.getSizeInBits() < 16 || SrcElVT.getSizeInBits() > 64) + return false; + if (DstElVT.getSizeInBits() < 8 || DstElVT.getSizeInBits() > 32) + return false; + if (SrcVT.is512BitVector() || Subtarget.hasVLX()) + return SrcElVT.getSizeInBits() >= 32 || Subtarget.hasBWI(); + return false; +} + +/// Detect a pattern of truncation with saturation: +/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type). +/// Return the source value to be truncated or SDValue() if the pattern was not +/// matched or the unsupported on the current target. +static SDValue +detectUSatPattern(SDValue In, EVT VT, const X86Subtarget &Subtarget) { + if (In.getOpcode() != ISD::UMIN) + return SDValue(); + + EVT InVT = In.getValueType(); + // FIXME: Scalar type may be supported if we move it to vector register. + if (!InVT.isVector() || !InVT.isSimple()) + return SDValue(); + + if (!isSATValidOnSubtarget(InVT, VT, Subtarget)) + return SDValue(); + + //Saturation with truncation. We truncate from InVT to VT. + assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() && + "Unexpected types for truncate operation"); + + SDValue SrcVal; + APInt C; + if (ISD::isConstantSplatVector(In.getOperand(0).getNode(), C)) + SrcVal = In.getOperand(1); + else if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C)) + SrcVal = In.getOperand(0); + else + return SDValue(); + + // C should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according + // the element size of the destination type. + return (C == ((uint64_t)1 << VT.getScalarSizeInBits()) - 1) ? + SrcVal : SDValue(); +} + /// This function detects the AVG pattern between vectors of unsigned i8/i16, /// which is c = (a + b + 1) / 2, and replace this operation with the efficient /// X86ISD::AVG instruction. @@ -31438,6 +31491,12 @@ St->getPointerInfo(), St->getAlignment(), St->getMemOperand()->getFlags()); + if (SDValue Val = + detectUSatPattern(St->getValue(), St->getMemoryVT(), Subtarget)) + return EmitTruncSStore(false /* Unsigned saturation */, St->getChain(), + dl, Val, St->getBasePtr(), + St->getMemoryVT(), St->getMemOperand(), DAG); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); unsigned NumElems = VT.getVectorNumElements(); assert(StVT != VT && "Cannot truncate to the same type"); @@ -31974,6 +32033,10 @@ if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL)) return Avg; + // Try truncate with unsigned saturation. + if (SDValue Val = detectUSatPattern(Src, VT, Subtarget)) + return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, Val); + // The bitcast source is a direct mmx result. // Detect bitcasts between i32 to x86mmx if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) { Index: ../test/CodeGen/X86/avx512-trunc.ll =================================================================== --- ../test/CodeGen/X86/avx512-trunc.ll +++ ../test/CodeGen/X86/avx512-trunc.ll @@ -500,3 +500,98 @@ store <8 x i8> %x, <8 x i8>* %res ret void } + + +define void @usat_trunc_wb_256_mem(<16 x i16> %i, <16 x i8>* %res) { +; KNL-LABEL: usat_trunc_wb_256_mem: +; KNL: ## BB#0: +; KNL-NEXT: vpminuw {{.*}}(%rip), %ymm0, %ymm0 +; KNL-NEXT: vpmovsxwd %ymm0, %zmm0 +; KNL-NEXT: vpmovdb %zmm0, %xmm0 +; KNL-NEXT: vmovdqu %xmm0, (%rdi) +; KNL-NEXT: retq +; +; SKX-LABEL: usat_trunc_wb_256_mem: +; SKX: ## BB#0: +; SKX-NEXT: vpmovuswb %ymm0, (%rdi) +; SKX-NEXT: retq + %x3 = icmp ult <16 x i16> %i, + %x5 = select <16 x i1> %x3, <16 x i16> %i, <16 x i16> + %x6 = trunc <16 x i16> %x5 to <16 x i8> + store <16 x i8> %x6, <16 x i8>* %res, align 1 + ret void +} + +define <16 x i8> @usat_trunc_wb_256(<16 x i16> %i) { +; KNL-LABEL: usat_trunc_wb_256: +; KNL: ## BB#0: +; KNL-NEXT: vpminuw {{.*}}(%rip), %ymm0, %ymm0 +; KNL-NEXT: vpmovsxwd %ymm0, %zmm0 +; KNL-NEXT: vpmovdb %zmm0, %xmm0 +; KNL-NEXT: retq +; +; SKX-LABEL: usat_trunc_wb_256: +; SKX: ## BB#0: +; SKX-NEXT: vpmovuswb %ymm0, %xmm0 +; SKX-NEXT: retq + %x3 = icmp ult <16 x i16> %i, + %x5 = select <16 x i1> %x3, <16 x i16> %i, <16 x i16> + %x6 = trunc <16 x i16> %x5 to <16 x i8> + ret <16 x i8> %x6 +} + +define void @usat_trunc_wb_128_mem(<8 x i16> %i, <8 x i8>* %res) { +; KNL-LABEL: usat_trunc_wb_128_mem: +; KNL: ## BB#0: +; KNL-NEXT: vpminuw {{.*}}(%rip), %xmm0, %xmm0 +; KNL-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,2,4,6,8,10,12,14,u,u,u,u,u,u,u,u] +; KNL-NEXT: vmovq %xmm0, (%rdi) +; KNL-NEXT: retq +; +; SKX-LABEL: usat_trunc_wb_128_mem: +; SKX: ## BB#0: +; SKX-NEXT: vpmovuswb %xmm0, (%rdi) +; SKX-NEXT: retq + %x3 = icmp ult <8 x i16> %i, + %x5 = select <8 x i1> %x3, <8 x i16> %i, <8 x i16> + %x6 = trunc <8 x i16> %x5 to <8 x i8> + store <8 x i8> %x6, <8 x i8>* %res, align 1 + ret void +} + +define void @usat_trunc_db_512_mem(<16 x i32> %i, <16 x i8>* %res) { +; ALL-LABEL: usat_trunc_db_512_mem: +; ALL: ## BB#0: +; ALL-NEXT: vpmovusdb %zmm0, (%rdi) +; ALL-NEXT: retq + %x3 = icmp ult <16 x i32> %i, + %x5 = select <16 x i1> %x3, <16 x i32> %i, <16 x i32> + %x6 = trunc <16 x i32> %x5 to <16 x i8> + store <16 x i8> %x6, <16 x i8>* %res, align 1 + ret void +} + +define void @usat_trunc_qb_512_mem(<8 x i64> %i, <8 x i8>* %res) { +; ALL-LABEL: usat_trunc_qb_512_mem: +; ALL: ## BB#0: +; ALL-NEXT: vpmovusqb %zmm0, (%rdi) +; ALL-NEXT: retq + %x3 = icmp ult <8 x i64> %i, + %x5 = select <8 x i1> %x3, <8 x i64> %i, <8 x i64> + %x6 = trunc <8 x i64> %x5 to <8 x i8> + store <8 x i8> %x6, <8 x i8>* %res, align 1 + ret void +} + +define void @usat_trunc_qd_512_mem(<8 x i64> %i, <8 x i32>* %res) { +; ALL-LABEL: usat_trunc_qd_512_mem: +; ALL: ## BB#0: +; ALL-NEXT: vpmovusqd %zmm0, (%rdi) +; ALL-NEXT: retq + %x3 = icmp ult <8 x i64> %i, + %x5 = select <8 x i1> %x3, <8 x i64> %i, <8 x i64> + %x6 = trunc <8 x i64> %x5 to <8 x i32> + store <8 x i32> %x6, <8 x i32>* %res, align 1 + ret void +} +