Index: ../include/llvm/Target/TargetLowering.h =================================================================== --- ../include/llvm/Target/TargetLowering.h +++ ../include/llvm/Target/TargetLowering.h @@ -3075,6 +3075,17 @@ /// possibly more for vectors. SDValue expandUnalignedStore(StoreSDNode *ST, SelectionDAG &DAG) const; + /// Increments memory address \p Addr according to the type of the value + /// \p DataVT that should be stored. If the data is stored in compressed + /// form, the memory address should be incremented according to the number of + /// the stored elements. This number is equal to the number of '1's bits + /// in the \p Mask. + /// \p DataVT is a vector type. \p Mask is a vector value. + /// \p DataVT and \p Mask have the same number of vector elements. + SDValue IncrementMemoryAddress(SDValue Addr, SDValue Mask, const SDLoc &DL, + EVT DataVT, SelectionDAG &DAG, + bool IsCompressedMemory) const; + //===--------------------------------------------------------------------===// // Instruction Emitting Hooks // Index: ../lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- ../lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ ../lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -5542,6 +5542,7 @@ MaskedStoreSDNode *MST = dyn_cast(N); SDValue Mask = MST->getMask(); SDValue Data = MST->getValue(); + EVT VT = Data.getValueType(); SDLoc DL(N); // If the MSTORE data type requires splitting and the mask is provided by a @@ -5551,16 +5552,13 @@ if (Mask.getOpcode() == ISD::SETCC) { // Check if any splitting is required. - if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) != + if (TLI.getTypeAction(*DAG.getContext(), VT) != TargetLowering::TypeSplitVector) return SDValue(); SDValue MaskLo, MaskHi, Lo, Hi; std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); - EVT LoVT, HiVT; - std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MST->getValueType(0)); - SDValue Chain = MST->getChain(); SDValue Ptr = MST->getBasePtr(); @@ -5570,8 +5568,7 @@ // if Alignment is equal to the vector size, // take the half of it for the second part unsigned SecondHalfAlignment = - (Alignment == Data->getValueType(0).getSizeInBits()/8) ? - Alignment/2 : Alignment; + (Alignment == VT.getSizeInBits() / 8) ? Alignment / 2 : Alignment; EVT LoMemVT, HiMemVT; std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); @@ -5585,11 +5582,11 @@ Alignment, MST->getAAInfo(), MST->getRanges()); Lo = DAG.getMaskedStore(Chain, DL, DataLo, Ptr, MaskLo, LoMemVT, MMO, - MST->isTruncatingStore(), MST->isCompressingStore()); + MST->isTruncatingStore(), + MST->isCompressingStore()); - unsigned IncrementSize = LoMemVT.getSizeInBits()/8; - Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, - DAG.getConstant(IncrementSize, DL, Ptr.getValueType())); + Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG, + MST->isCompressingStore()); MMO = DAG.getMachineFunction(). getMachineMemOperand(MST->getPointerInfo(), @@ -5598,7 +5595,8 @@ MST->getRanges()); Hi = DAG.getMaskedStore(Chain, DL, DataHi, Ptr, MaskHi, HiMemVT, MMO, - MST->isTruncatingStore(), MST->isCompressingStore()); + MST->isTruncatingStore(), + MST->isCompressingStore()); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); @@ -5737,11 +5735,10 @@ Alignment, MLD->getAAInfo(), MLD->getRanges()); Lo = DAG.getMaskedLoad(LoVT, DL, Chain, Ptr, MaskLo, Src0Lo, LoMemVT, MMO, - ISD::NON_EXTLOAD); + ISD::NON_EXTLOAD, MLD->isExpandingLoad()); - unsigned IncrementSize = LoMemVT.getSizeInBits()/8; - Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, - DAG.getConstant(IncrementSize, DL, Ptr.getValueType())); + Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG, + MLD->isExpandingLoad()); MMO = DAG.getMachineFunction(). getMachineMemOperand(MLD->getPointerInfo(), @@ -5749,7 +5746,7 @@ SecondHalfAlignment, MLD->getAAInfo(), MLD->getRanges()); Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, Src0Hi, HiMemVT, MMO, - ISD::NON_EXTLOAD); + ISD::NON_EXTLOAD, MLD->isExpandingLoad()); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); Index: ../lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- ../lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ ../lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1114,11 +1114,10 @@ Alignment, MLD->getAAInfo(), MLD->getRanges()); Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, MaskLo, Src0Lo, LoMemVT, MMO, - ExtType); + ExtType, MLD->isExpandingLoad()); - unsigned IncrementSize = LoMemVT.getSizeInBits()/8; - Ptr = DAG.getNode(ISD::ADD, dl, Ptr.getValueType(), Ptr, - DAG.getConstant(IncrementSize, dl, Ptr.getValueType())); + Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, dl, LoMemVT, DAG, + MLD->isExpandingLoad()); MMO = DAG.getMachineFunction(). getMachineMemOperand(MLD->getPointerInfo(), @@ -1126,7 +1125,7 @@ SecondHalfAlignment, MLD->getAAInfo(), MLD->getRanges()); Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, MaskHi, Src0Hi, HiMemVT, MMO, - ExtType); + ExtType, MLD->isExpandingLoad()); // Build a factor node to remember that this load is independent of the @@ -1769,19 +1768,18 @@ Alignment, N->getAAInfo(), N->getRanges()); Lo = DAG.getMaskedStore(Ch, DL, DataLo, Ptr, MaskLo, LoMemVT, MMO, - N->isTruncatingStore()); - - unsigned IncrementSize = LoMemVT.getSizeInBits()/8; - Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, - DAG.getConstant(IncrementSize, DL, Ptr.getValueType())); + N->isTruncatingStore(), + N->isCompressingStore()); + Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG, + N->isCompressingStore()); MMO = DAG.getMachineFunction(). getMachineMemOperand(N->getPointerInfo(), MachineMemOperand::MOStore, HiMemVT.getStoreSize(), SecondHalfAlignment, N->getAAInfo(), N->getRanges()); Hi = DAG.getMaskedStore(Ch, DL, DataHi, Ptr, MaskHi, HiMemVT, MMO, - N->isTruncatingStore()); + N->isTruncatingStore(), N->isCompressingStore()); // Build a factor node to remember that this store is independent of the // other one. @@ -2881,7 +2879,8 @@ SDValue Res = DAG.getMaskedLoad(WidenVT, dl, N->getChain(), N->getBasePtr(), Mask, Src0, N->getMemoryVT(), - N->getMemOperand(), ExtType); + N->getMemOperand(), ExtType, + N->isExpandingLoad()); // Legalize the chain result - switch anything that used the old chain to // use the new one. ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); @@ -3317,7 +3316,7 @@ "Mask and data vectors should have the same number of elements"); return DAG.getMaskedStore(MST->getChain(), dl, WideVal, MST->getBasePtr(), Mask, MST->getMemoryVT(), MST->getMemOperand(), - false); + false, MST->isCompressingStore()); } SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) { Index: ../lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- ../lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ ../lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -3610,6 +3610,38 @@ return Result; } +SDValue +TargetLowering::IncrementMemoryAddress(SDValue Addr, SDValue Mask, + const SDLoc &DL, EVT DataVT, + SelectionDAG &DAG, + bool IsCompressedMemory) const { + SDValue Increment; + EVT AddrVT = Addr.getValueType(); + EVT MaskVT = Mask.getValueType(); + assert(DataVT.getVectorNumElements() == MaskVT.getVectorNumElements() && + "Incompatible types of Data and Mask"); + if (IsCompressedMemory) { + // Incrementing the pointer according to number of '1's in the mask. + EVT MaskIntVT = EVT::getIntegerVT(*DAG.getContext(), MaskVT.getSizeInBits()); + SDValue MaskInIntReg = DAG.getBitcast(MaskIntVT, Mask); + if (MaskIntVT.getSizeInBits() < 32) { + MaskInIntReg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskInIntReg); + MaskIntVT = MVT::i32; + } + + // Count '1's with POPCNT. + Increment = DAG.getNode(ISD::CTPOP, DL, MaskIntVT, MaskInIntReg); + Increment = DAG.getZExtOrTrunc(Increment, DL, AddrVT); + // Scale is an element size in bytes. + SDValue Scale = DAG.getConstant(DataVT.getScalarSizeInBits() / 8, DL, + AddrVT); + Increment = DAG.getNode(ISD::MUL, DL, AddrVT, Increment, Scale); + } else + Increment = DAG.getConstant(DataVT.getSizeInBits() / 8, DL, AddrVT); + + return DAG.getNode(ISD::ADD, DL, AddrVT, Addr, Increment); +} + //===----------------------------------------------------------------------===// // Implementation of Emulated TLS Model //===----------------------------------------------------------------------===// Index: ../test/CodeGen/X86/compress_expand.ll =================================================================== --- ../test/CodeGen/X86/compress_expand.ll +++ ../test/CodeGen/X86/compress_expand.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mattr=+avx512vl,+avx512dq,+avx512bw < %s | FileCheck %s --check-prefix=ALL --check-prefix=SKX -; RUN: llc -mattr=+avx512f < %s | FileCheck %s --check-prefix=ALL --check-prefix=KNL +; RUN: llc -mcpu=skylake-avx512 < %s | FileCheck %s --check-prefix=ALL --check-prefix=SKX +; RUN: llc -mcpu=knl < %s | FileCheck %s --check-prefix=ALL --check-prefix=KNL target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-linux-gnu" @@ -235,6 +235,157 @@ ret void } +define <2 x float> @test13(float* %base, <2 x float> %src0, <2 x i32> %trigger) { +; SKX-LABEL: test13: +; SKX: # BB#0: +; SKX-NEXT: vpxord %xmm2, %xmm2, %xmm2 +; SKX-NEXT: vpblendd {{.*#+}} xmm1 = xmm1[0],xmm2[1],xmm1[2],xmm2[3] +; SKX-NEXT: vpcmpeqq %xmm2, %xmm1, %k0 +; SKX-NEXT: kshiftlb $6, %k0, %k0 +; SKX-NEXT: kshiftrb $6, %k0, %k1 +; SKX-NEXT: vexpandps (%rdi), %xmm0 {%k1} +; SKX-NEXT: retq +; +; KNL-LABEL: test13: +; KNL: # BB#0: +; KNL-NEXT: # kill: %XMM0 %XMM0 %ZMM0 +; KNL-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; KNL-NEXT: vpblendd {{.*#+}} xmm1 = xmm1[0],xmm2[1],xmm1[2],xmm2[3] +; KNL-NEXT: vpcmpeqq %xmm2, %xmm1, %xmm1 +; KNL-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; KNL-NEXT: vmovq {{.*#+}} xmm1 = xmm1[0],zero +; KNL-NEXT: vpxord %zmm2, %zmm2, %zmm2 +; KNL-NEXT: vinserti32x4 $0, %xmm1, %zmm2, %zmm1 +; KNL-NEXT: vpslld $31, %zmm1, %zmm1 +; KNL-NEXT: vptestmd %zmm1, %zmm1, %k1 +; KNL-NEXT: vexpandps (%rdi), %zmm0 {%k1} +; KNL-NEXT: # kill: %XMM0 %XMM0 %ZMM0 +; KNL-NEXT: retq + %mask = icmp eq <2 x i32> %trigger, zeroinitializer + %res = call <2 x float> @llvm.masked.expandload.v2f32(float* %base, <2 x i1> %mask, <2 x float> %src0) + ret <2 x float> %res +} + +define void @test14(float* %base, <2 x float> %V, <2 x i32> %trigger) { +; SKX-LABEL: test14: +; SKX: # BB#0: +; SKX-NEXT: vpxord %xmm2, %xmm2, %xmm2 +; SKX-NEXT: vpblendd {{.*#+}} xmm1 = xmm1[0],xmm2[1],xmm1[2],xmm2[3] +; SKX-NEXT: vpcmpeqq %xmm2, %xmm1, %k0 +; SKX-NEXT: kshiftlb $6, %k0, %k0 +; SKX-NEXT: kshiftrb $6, %k0, %k1 +; SKX-NEXT: vcompressps %xmm0, (%rdi) {%k1} +; SKX-NEXT: retq +; +; KNL-LABEL: test14: +; KNL: # BB#0: +; KNL-NEXT: # kill: %XMM0 %XMM0 %ZMM0 +; KNL-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; KNL-NEXT: vpblendd {{.*#+}} xmm1 = xmm1[0],xmm2[1],xmm1[2],xmm2[3] +; KNL-NEXT: vpcmpeqq %xmm2, %xmm1, %xmm1 +; KNL-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; KNL-NEXT: vmovq {{.*#+}} xmm1 = xmm1[0],zero +; KNL-NEXT: vpxord %zmm2, %zmm2, %zmm2 +; KNL-NEXT: vinserti32x4 $0, %xmm1, %zmm2, %zmm1 +; KNL-NEXT: vpslld $31, %zmm1, %zmm1 +; KNL-NEXT: vptestmd %zmm1, %zmm1, %k1 +; KNL-NEXT: vcompressps %zmm0, (%rdi) {%k1} +; KNL-NEXT: retq + %mask = icmp eq <2 x i32> %trigger, zeroinitializer + call void @llvm.masked.compressstore.v2f32(<2 x float> %V, float* %base, <2 x i1> %mask) + ret void +} + +define <32 x float> @test15(float* %base, <32 x float> %src0, <32 x i32> %trigger) { +; ALL-LABEL: test15: +; ALL: # BB#0: +; ALL-NEXT: vpxord %zmm4, %zmm4, %zmm4 +; ALL-NEXT: vpcmpeqd %zmm4, %zmm3, %k1 +; ALL-NEXT: vpcmpeqd %zmm4, %zmm2, %k2 +; ALL-NEXT: kmovw %k2, %eax +; ALL-NEXT: popcntl %eax, %eax +; ALL-NEXT: vexpandps (%rdi,%rax,4), %zmm1 {%k1} +; ALL-NEXT: vexpandps (%rdi), %zmm0 {%k2} +; ALL-NEXT: retq + %mask = icmp eq <32 x i32> %trigger, zeroinitializer + %res = call <32 x float> @llvm.masked.expandload.v32f32(float* %base, <32 x i1> %mask, <32 x float> %src0) + ret <32 x float> %res +} + +define <16 x double> @test16(double* %base, <16 x double> %src0, <16 x i32> %trigger) { +; SKX-LABEL: test16: +; SKX: # BB#0: +; SKX-NEXT: vextracti32x8 $1, %zmm2, %ymm3 +; SKX-NEXT: vpxord %ymm4, %ymm4, %ymm4 +; SKX-NEXT: vpcmpeqd %ymm4, %ymm3, %k1 +; SKX-NEXT: vpcmpeqd %ymm4, %ymm2, %k2 +; SKX-NEXT: kmovb %k2, %eax +; SKX-NEXT: popcntl %eax, %eax +; SKX-NEXT: vexpandpd (%rdi,%rax,8), %zmm1 {%k1} +; SKX-NEXT: vexpandpd (%rdi), %zmm0 {%k2} +; SKX-NEXT: retq +; +; KNL-LABEL: test16: +; KNL: # BB#0: +; KNL-NEXT: vpxor %ymm3, %ymm3, %ymm3 +; KNL-NEXT: vextracti64x4 $1, %zmm2, %ymm4 +; KNL-NEXT: vpcmpeqd %zmm3, %zmm4, %k1 +; KNL-NEXT: vpcmpeqd %zmm3, %zmm2, %k2 +; KNL-NEXT: vexpandpd (%rdi), %zmm0 {%k2} +; KNL-NEXT: kmovw %k2, %eax +; KNL-NEXT: movzbl %al, %eax +; KNL-NEXT: popcntl %eax, %eax +; KNL-NEXT: vexpandpd (%rdi,%rax,8), %zmm1 {%k1} +; KNL-NEXT: retq + %mask = icmp eq <16 x i32> %trigger, zeroinitializer + %res = call <16 x double> @llvm.masked.expandload.v16f64(double* %base, <16 x i1> %mask, <16 x double> %src0) + ret <16 x double> %res +} + +define void @test17(float* %base, <32 x float> %V, <32 x i32> %trigger) { +; ALL-LABEL: test17: +; ALL: # BB#0: +; ALL-NEXT: vpxord %zmm4, %zmm4, %zmm4 +; ALL-NEXT: vpcmpeqd %zmm4, %zmm3, %k1 +; ALL-NEXT: vpcmpeqd %zmm4, %zmm2, %k2 +; ALL-NEXT: kmovw %k2, %eax +; ALL-NEXT: popcntl %eax, %eax +; ALL-NEXT: vcompressps %zmm1, (%rdi,%rax,4) {%k1} +; ALL-NEXT: vcompressps %zmm0, (%rdi) {%k2} +; ALL-NEXT: retq + %mask = icmp eq <32 x i32> %trigger, zeroinitializer + call void @llvm.masked.compressstore.v32f32(<32 x float> %V, float* %base, <32 x i1> %mask) + ret void +} + +define void @test18(double* %base, <16 x double> %V, <16 x i1> %mask) { +; SKX-LABEL: test18: +; SKX: # BB#0: +; SKX-NEXT: vpsllw $7, %xmm2, %xmm2 +; SKX-NEXT: vpmovb2m %xmm2, %k1 +; SKX-NEXT: kshiftrw $8, %k1, %k2 +; SKX-NEXT: kmovb %k1, %eax +; SKX-NEXT: popcntl %eax, %eax +; SKX-NEXT: vcompresspd %zmm1, (%rdi,%rax,8) {%k2} +; SKX-NEXT: vcompresspd %zmm0, (%rdi) {%k1} +; SKX-NEXT: retq +; +; KNL-LABEL: test18: +; KNL: # BB#0: +; KNL-NEXT: vpmovsxbd %xmm2, %zmm2 +; KNL-NEXT: vpslld $31, %zmm2, %zmm2 +; KNL-NEXT: vptestmd %zmm2, %zmm2, %k1 +; KNL-NEXT: kshiftrw $8, %k1, %k2 +; KNL-NEXT: kmovw %k1, %eax +; KNL-NEXT: movzbl %al, %eax +; KNL-NEXT: popcntl %eax, %eax +; KNL-NEXT: vcompresspd %zmm1, (%rdi,%rax,8) {%k2} +; KNL-NEXT: vcompresspd %zmm0, (%rdi) {%k1} +; KNL-NEXT: retq + call void @llvm.masked.compressstore.v16f64(<16 x double> %V, double* %base, <16 x i1> %mask) + ret void +} + declare void @llvm.masked.compressstore.v16f32(<16 x float>, float* , <16 x i1>) declare void @llvm.masked.compressstore.v8f32(<8 x float>, float* , <8 x i1>) declare void @llvm.masked.compressstore.v8f64(<8 x double>, double* , <8 x i1>) @@ -245,3 +396,11 @@ declare void @llvm.masked.compressstore.v4f32(<4 x float>, float* , <4 x i1>) declare void @llvm.masked.compressstore.v4i64(<4 x i64>, i64* , <4 x i1>) declare void @llvm.masked.compressstore.v2i64(<2 x i64>, i64* , <2 x i1>) +declare void @llvm.masked.compressstore.v2f32(<2 x float>, float* , <2 x i1>) +declare void @llvm.masked.compressstore.v32f32(<32 x float>, float* , <32 x i1>) +declare void @llvm.masked.compressstore.v16f64(<16 x double>, double* , <16 x i1>) +declare void @llvm.masked.compressstore.v32f64(<32 x double>, double* , <32 x i1>) + +declare <2 x float> @llvm.masked.expandload.v2f32(float* , <2 x i1> , <2 x float> ) +declare <32 x float> @llvm.masked.expandload.v32f32(float* , <32 x i1> , <32 x float> ) +declare <16 x double> @llvm.masked.expandload.v16f64(double* , <16 x i1> , <16 x double> )