Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -36504,31 +36504,31 @@ } static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { MaskedStoreSDNode *Mst = cast(N); - if (Mst->isCompressingStore()) return SDValue(); + EVT VT = Mst->getValue().getValueType(); if (!Mst->isTruncatingStore()) { if (SDValue ScalarStore = reduceMaskedStoreToScalarStore(Mst, DAG)) return ScalarStore; - // If the mask is checking (0 > X), we're creating a vector with all-zeros - // or all-ones elements based on the sign bits of X. AVX1 masked store only - // cares about the sign bit of each mask element, so eliminate the compare: - // mstore val, ptr, (pcmpgt 0, X) --> mstore val, ptr, X - // Note that by waiting to match an x86-specific PCMPGT node, we're - // eliminating potentially more complex matching of a setcc node which has - // a full range of predicates. + // If the mask value has been legalized to a non-boolean vector, try to + // simplify ops leading up to it. We only demand the MSB of each lane. SDValue Mask = Mst->getMask(); - if (Mask.getOpcode() == X86ISD::PCMPGT && - ISD::isBuildVectorAllZeros(Mask.getOperand(0).getNode())) { - assert(Mask.getValueType() == Mask.getOperand(1).getValueType() && - "Unexpected type for PCMPGT"); - return DAG.getMaskedStore( - Mst->getChain(), SDLoc(N), Mst->getValue(), Mst->getBasePtr(), - Mask.getOperand(1), Mst->getMemoryVT(), Mst->getMemOperand()); + if (Mask.getScalarValueSizeInBits() != 1) { + TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), + !DCI.isBeforeLegalizeOps()); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + APInt DemandedMask(APInt::getSignMask(VT.getScalarSizeInBits())); + KnownBits Known; + if (TLI.SimplifyDemandedBits(Mask, DemandedMask, Known, TLO)) { + DCI.AddToWorklist(Mask.getNode()); + DCI.CommitTargetLoweringOpt(TLO); + return SDValue(N, 0); + } } // TODO: AVX512 targets should also be able to simplify something like the @@ -36539,7 +36539,6 @@ } // Resolve truncating stores. - EVT VT = Mst->getValue().getValueType(); unsigned NumElems = VT.getVectorNumElements(); EVT StVT = Mst->getMemoryVT(); SDLoc dl(Mst); @@ -40367,7 +40366,7 @@ case ISD::LOAD: return combineLoad(N, DAG, DCI, Subtarget); case ISD::MLOAD: return combineMaskedLoad(N, DAG, DCI, Subtarget); case ISD::STORE: return combineStore(N, DAG, Subtarget); - case ISD::MSTORE: return combineMaskedStore(N, DAG, Subtarget); + case ISD::MSTORE: return combineMaskedStore(N, DAG, DCI, Subtarget); case ISD::SINT_TO_FP: return combineSIntToFP(N, DAG, Subtarget); case ISD::UINT_TO_FP: return combineUIntToFP(N, DAG, Subtarget); case ISD::FADD: Index: test/CodeGen/X86/masked_memop.ll =================================================================== --- test/CodeGen/X86/masked_memop.ll +++ test/CodeGen/X86/masked_memop.ll @@ -1279,6 +1279,7 @@ } ; TODO: SimplifyDemandedBits should eliminate an ashr here. +; It works for AVX2, but not the more complicated pattern for AVX1. define void @masked_store_bool_mask_demand_trunc_sext(<4 x double> %x, <4 x double>* %p, <4 x i32> %masksrc) { ; AVX1-LABEL: masked_store_bool_mask_demand_trunc_sext: @@ -1296,7 +1297,6 @@ ; AVX2-LABEL: masked_store_bool_mask_demand_trunc_sext: ; AVX2: ## %bb.0: ; AVX2-NEXT: vpslld $31, %xmm1, %xmm1 -; AVX2-NEXT: vpsrad $31, %xmm1, %xmm1 ; AVX2-NEXT: vpmovsxdq %xmm1, %ymm1 ; AVX2-NEXT: vmaskmovpd %ymm0, %ymm1, (%rdi) ; AVX2-NEXT: vzeroupper @@ -1338,7 +1338,6 @@ ; AVX1-NEXT: vmovd %ecx, %xmm2 ; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0] ; AVX1-NEXT: vpslld $31, %xmm1, %xmm1 -; AVX1-NEXT: vpsrad $31, %xmm1, %xmm1 ; AVX1-NEXT: vmaskmovps %xmm0, %xmm1, (%rdi) ; AVX1-NEXT: retq ; @@ -1350,7 +1349,6 @@ ; AVX2-NEXT: vmovd %ecx, %xmm2 ; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0] ; AVX2-NEXT: vpslld $31, %xmm1, %xmm1 -; AVX2-NEXT: vpsrad $31, %xmm1, %xmm1 ; AVX2-NEXT: vpmaskmovd %xmm0, %xmm1, (%rdi) ; AVX2-NEXT: retq ;