Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -26784,13 +26784,86 @@ return DCI.CombineTo(N, NewVec, WideLd.getValue(1), true); } -/// PerformMSTORECombine - Resolve truncating stores + +/// If exactly one element of the mask is set for a non-truncating masked store, +/// it is a vector extract and scalar store. +/// Note: It is expected that the degenerate cases of an all-zeros or all-ones +/// mask have already been optimized in IR, so we don't bother with those here. +static SDValue reduceMaskedStoreToScalarStore(MaskedStoreSDNode *MS, + SelectionDAG &DAG) { + // TODO: This is not x86-specific, so it could be lifted to DAGCombiner. + // However, some target hooks may need to be added to know when the transform + // is profitable. Endianness would also have to be considered. + + // If V is a build vector of boolean constants and exactly one of those + // constants is true, return the operand index of that true element. + // Otherwise, return -1. + auto getOneTrueElt = [](SDValue V) { + // This needs to be a build vector of booleans. + // TODO: Checking for the i1 type matches the IR definition for the mask, + // but the mask check could be loosened to i8 or other types. That might + // also require checking more than 'allOnesValue'; eg, the x86 HW + // instructions only require that the MSB is set for each mask element. + // The ISD::MSTORE comments/definition do not specify how the mask operand + // is formatted. + auto *BV = dyn_cast(V); + if (!BV || BV->getValueType(0).getVectorElementType() != MVT::i1) + return -1; + + int TrueIndex = -1; + unsigned NumElts = BV->getValueType(0).getVectorNumElements(); + for (unsigned i = 0; i < NumElts; ++i) { + const SDValue &Op = BV->getOperand(i); + if (Op.getOpcode() == ISD::UNDEF) + continue; + auto *ConstNode = dyn_cast(Op); + if (!ConstNode) + return -1; + if (ConstNode->getAPIntValue().isAllOnesValue()) { + // If we already found a one, this is too many. + if (TrueIndex >= 0) + return -1; + TrueIndex = i; + } + } + return TrueIndex; + }; + + int TrueMaskElt = getOneTrueElt(MS->getMask()); + if (TrueMaskElt < 0) + return SDValue(); + + SDLoc DL(MS); + EVT VT = MS->getValue().getValueType(); + EVT EltVT = VT.getVectorElementType(); + + // Extract the one scalar element that is actually being stored. + SDValue ExtractIndex = DAG.getIntPtrConstant(TrueMaskElt, DL); + SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, + MS->getValue(), ExtractIndex); + + // Store that element at the appropriate offset from the base pointer. + SDValue StoreAddr = MS->getBasePtr(); + unsigned EltSize = EltVT.getStoreSize(); + if (TrueMaskElt != 0) { + unsigned StoreOffset = TrueMaskElt * EltSize; + SDValue StoreOffsetVal = DAG.getIntPtrConstant(StoreOffset, DL); + StoreAddr = DAG.getNode(ISD::ADD, DL, StoreAddr.getValueType(), StoreAddr, + StoreOffsetVal); + } + unsigned Alignment = MinAlign(MS->getAlignment(), EltSize); + return DAG.getStore(MS->getChain(), DL, Extract, StoreAddr, + MS->getPointerInfo(), MS->isVolatile(), + MS->isNonTemporal(), Alignment); +} + static SDValue PerformMSTORECombine(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { MaskedStoreSDNode *Mst = cast(N); if (!Mst->isTruncatingStore()) - return SDValue(); + return reduceMaskedStoreToScalarStore(Mst, DAG); + // Resolve truncating stores. EVT VT = Mst->getValue().getValueType(); unsigned NumElems = VT.getVectorNumElements(); EVT StVT = Mst->getMemoryVT(); Index: llvm/trunk/test/CodeGen/X86/masked_memop.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/masked_memop.ll +++ llvm/trunk/test/CodeGen/X86/masked_memop.ll @@ -991,36 +991,92 @@ ret void } -define void @test22(<4 x i32> %trigger, <4 x i32>* %addr, <4 x i32> %val) { -; AVX1-LABEL: test22: -; AVX1: ## BB#0: -; AVX1-NEXT: movl $-1, %eax -; AVX1-NEXT: vmovd %eax, %xmm0 -; AVX1-NEXT: vmaskmovps %xmm1, %xmm0, (%rdi) -; AVX1-NEXT: retq +; When only one element of the mask is set, reduce to a scalar store. + +define void @one_mask_bit_set1(<4 x i32>* %addr, <4 x i32> %val) { +; AVX-LABEL: one_mask_bit_set1: +; AVX: ## BB#0: +; AVX-NEXT: vmovd %xmm0, (%rdi) +; AVX-NEXT: retq ; -; AVX2-LABEL: test22: -; AVX2: ## BB#0: -; AVX2-NEXT: movl $-1, %eax -; AVX2-NEXT: vmovd %eax, %xmm0 -; AVX2-NEXT: vpmaskmovd %xmm1, %xmm0, (%rdi) -; AVX2-NEXT: retq +; AVX512-LABEL: one_mask_bit_set1: +; AVX512: ## BB#0: +; AVX512-NEXT: vmovd %xmm0, (%rdi) +; AVX512-NEXT: retq + call void @llvm.masked.store.v4i32(<4 x i32> %val, <4 x i32>* %addr, i32 4, <4 x i1>) + ret void +} + +; Choose a different element to show that the correct address offset is produced. + +define void @one_mask_bit_set2(<4 x float>* %addr, <4 x float> %val) { +; AVX-LABEL: one_mask_bit_set2: +; AVX: ## BB#0: +; AVX-NEXT: vextractps $2, %xmm0, 8(%rdi) +; AVX-NEXT: retq ; -; AVX512F-LABEL: test22: -; AVX512F: ## BB#0: -; AVX512F-NEXT: movl $-1, %eax -; AVX512F-NEXT: vmovd %eax, %xmm0 -; AVX512F-NEXT: vpmaskmovd %xmm1, %xmm0, (%rdi) -; AVX512F-NEXT: retq +; AVX512-LABEL: one_mask_bit_set2: +; AVX512: ## BB#0: +; AVX512-NEXT: vextractps $2, %xmm0, 8(%rdi) +; AVX512-NEXT: retq + call void @llvm.masked.store.v4f32(<4 x float> %val, <4 x float>* %addr, i32 4, <4 x i1>) + ret void +} + +; Choose a different scalar type and a high element of a 256-bit vector because AVX doesn't support those evenly. + +define void @one_mask_bit_set3(<4 x i64>* %addr, <4 x i64> %val) { +; AVX-LABEL: one_mask_bit_set3: +; AVX: ## BB#0: +; AVX-NEXT: vextractf128 $1, %ymm0, %xmm0 +; AVX-NEXT: vmovlps %xmm0, 16(%rdi) +; AVX-NEXT: vzeroupper +; AVX-NEXT: retq ; -; SKX-LABEL: test22: -; SKX: ## BB#0: -; SKX-NEXT: movb $1, %al -; SKX-NEXT: kmovw %eax, %k1 -; SKX-NEXT: vmovdqu32 %xmm1, (%rdi) {%k1} -; SKX-NEXT: retq - %mask = icmp eq <4 x i32> %trigger, zeroinitializer - call void @llvm.masked.store.v4i32(<4 x i32>%val, <4 x i32>* %addr, i32 4, <4 x i1>) +; AVX512-LABEL: one_mask_bit_set3: +; AVX512: ## BB#0: +; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm0 +; AVX512-NEXT: vmovq %xmm0, 16(%rdi) +; AVX512-NEXT: retq + call void @llvm.masked.store.v4i64(<4 x i64> %val, <4 x i64>* %addr, i32 4, <4 x i1>) + ret void +} + +; Choose a different scalar type and a high element of a 256-bit vector because AVX doesn't support those evenly. + +define void @one_mask_bit_set4(<4 x double>* %addr, <4 x double> %val) { +; AVX-LABEL: one_mask_bit_set4: +; AVX: ## BB#0: +; AVX-NEXT: vextractf128 $1, %ymm0, %xmm0 +; AVX-NEXT: vmovhpd %xmm0, 24(%rdi) +; AVX-NEXT: vzeroupper +; AVX-NEXT: retq +; +; AVX512-LABEL: one_mask_bit_set4: +; AVX512: ## BB#0: +; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm0 +; AVX512-NEXT: vmovhpd %xmm0, 24(%rdi) +; AVX512-NEXT: retq + call void @llvm.masked.store.v4f64(<4 x double> %val, <4 x double>* %addr, i32 4, <4 x i1>) + ret void +} + +; Try a 512-bit vector to make sure AVX doesn't die and AVX512 works as expected. + +define void @one_mask_bit_set5(<8 x double>* %addr, <8 x double> %val) { +; AVX-LABEL: one_mask_bit_set5: +; AVX: ## BB#0: +; AVX-NEXT: vextractf128 $1, %ymm1, %xmm0 +; AVX-NEXT: vmovlps %xmm0, 48(%rdi) +; AVX-NEXT: vzeroupper +; AVX-NEXT: retq +; +; AVX512-LABEL: one_mask_bit_set5: +; AVX512: ## BB#0: +; AVX512-NEXT: vextractf32x4 $3, %zmm0, %xmm0 +; AVX512-NEXT: vmovlpd %xmm0, 48(%rdi) +; AVX512-NEXT: retq + call void @llvm.masked.store.v8f64(<8 x double> %val, <8 x double>* %addr, i32 4, <8 x i1>) ret void } @@ -1030,8 +1086,10 @@ declare void @llvm.masked.store.v16i32(<16 x i32>, <16 x i32>*, i32, <16 x i1>) declare void @llvm.masked.store.v8i32(<8 x i32>, <8 x i32>*, i32, <8 x i1>) declare void @llvm.masked.store.v4i32(<4 x i32>, <4 x i32>*, i32, <4 x i1>) +declare void @llvm.masked.store.v4i64(<4 x i64>, <4 x i64>*, i32, <4 x i1>) declare void @llvm.masked.store.v2f32(<2 x float>, <2 x float>*, i32, <2 x i1>) declare void @llvm.masked.store.v2i32(<2 x i32>, <2 x i32>*, i32, <2 x i1>) +declare void @llvm.masked.store.v4f32(<4 x float>, <4 x float>*, i32, <4 x i1>) declare void @llvm.masked.store.v16f32(<16 x float>, <16 x float>*, i32, <16 x i1>) declare void @llvm.masked.store.v16f32p(<16 x float>*, <16 x float>**, i32, <16 x i1>) declare <16 x float> @llvm.masked.load.v16f32(<16 x float>*, i32, <16 x i1>, <16 x float>) @@ -1043,6 +1101,7 @@ declare <4 x double> @llvm.masked.load.v4f64(<4 x double>*, i32, <4 x i1>, <4 x double>) declare <2 x double> @llvm.masked.load.v2f64(<2 x double>*, i32, <2 x i1>, <2 x double>) declare void @llvm.masked.store.v8f64(<8 x double>, <8 x double>*, i32, <8 x i1>) +declare void @llvm.masked.store.v4f64(<4 x double>, <4 x double>*, i32, <4 x i1>) declare void @llvm.masked.store.v2f64(<2 x double>, <2 x double>*, i32, <2 x i1>) declare void @llvm.masked.store.v2i64(<2 x i64>, <2 x i64>*, i32, <2 x i1>)