Index: lib/CodeGen/SelectionDAG/LegalizeTypes.h =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -798,6 +798,7 @@ SDValue WidenVecOp_EXTRACT_SUBVECTOR(SDNode *N); SDValue WidenVecOp_STORE(SDNode* N); SDValue WidenVecOp_MSTORE(SDNode* N, unsigned OpNo); + SDValue WidenVecOp_MGATHER(SDNode* N, unsigned OpNo); SDValue WidenVecOp_MSCATTER(SDNode* N, unsigned OpNo); SDValue WidenVecOp_SETCC(SDNode* N); Index: lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -3633,6 +3633,7 @@ case ISD::EXTRACT_VECTOR_ELT: Res = WidenVecOp_EXTRACT_VECTOR_ELT(N); break; case ISD::STORE: Res = WidenVecOp_STORE(N); break; case ISD::MSTORE: Res = WidenVecOp_MSTORE(N, OpNo); break; + case ISD::MGATHER: Res = WidenVecOp_MGATHER(N, OpNo); break; case ISD::MSCATTER: Res = WidenVecOp_MSCATTER(N, OpNo); break; case ISD::SETCC: Res = WidenVecOp_SETCC(N); break; case ISD::FCOPYSIGN: Res = WidenVecOp_FCOPYSIGN(N); break; @@ -3898,36 +3899,86 @@ false, MST->isCompressingStore()); } -SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) { - assert(OpNo == 1 && "Can widen only data operand of mscatter"); - MaskedScatterSDNode *MSC = cast(N); - SDValue DataOp = MSC->getValue(); - SDValue Mask = MSC->getMask(); +SDValue DAGTypeLegalizer::WidenVecOp_MGATHER(SDNode *N, unsigned OpNo) { + assert(OpNo == 4 && "Can widen only the index of mgather"); + auto *MG = cast(N); + SDValue DataOp = MG->getPassThru(); + SDValue Mask = MG->getMask(); + SDValue Scale = MG->getScale(); EVT MaskVT = Mask.getValueType(); - SDValue Scale = MSC->getScale(); + EVT DataVT = DataOp.getValueType(); + + // Widen index. + SDValue Index = GetWidenedVector(MG->getIndex()); + unsigned NumElts = Index.getValueType().getVectorNumElements(); // Widen the value. - SDValue WideVal = GetWidenedVector(DataOp); - EVT WideVT = WideVal.getValueType(); - unsigned NumElts = WideVT.getVectorNumElements(); - SDLoc dl(N); + EVT WideDataVT = EVT::getVectorVT(*DAG.getContext(), + DataVT.getVectorElementType(), + NumElts); + DataOp = ModifyToType(DataOp, WideDataVT); // The mask should be widened as well. EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(), NumElts); Mask = ModifyToType(Mask, WideMaskVT, true); - // Widen index. + SDLoc dl(N); + SDValue Ops[] = {MG->getChain(), DataOp, Mask, MG->getBasePtr(), Index, + Scale}; + SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideDataVT, MVT::Other), + MG->getMemoryVT(), dl, Ops, + MG->getMemOperand()); + ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); + + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, N->getValueType(0), Res, + DAG.getConstant(0, dl, + TLI.getVectorIdxTy(DAG.getDataLayout()))); + + ReplaceValueWith(SDValue(N, 0), Res.getValue(0)); + return SDValue(); +} + +SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) { + MaskedScatterSDNode *MSC = cast(N); + SDValue DataOp = MSC->getValue(); + SDValue Mask = MSC->getMask(); SDValue Index = MSC->getIndex(); - EVT WideIndexVT = EVT::getVectorVT(*DAG.getContext(), - Index.getValueType().getScalarType(), - NumElts); - Index = ModifyToType(Index, WideIndexVT); + SDValue Scale = MSC->getScale(); - SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index, + unsigned NumElts; + if (OpNo == 1) { + DataOp = GetWidenedVector(DataOp); + NumElts = DataOp.getValueType().getVectorNumElements(); + + // Widen index. + EVT IndexVT = Index.getValueType(); + EVT WideIndexVT = EVT::getVectorVT(*DAG.getContext(), + IndexVT.getVectorElementType(), NumElts); + + Index = ModifyToType(Index, WideIndexVT); + } else if (OpNo == 4) { + Index = GetWidenedVector(Index); + NumElts = Index.getValueType().getVectorNumElements(); + + // Widen the data. + EVT DataVT = DataOp.getValueType(); + EVT WideDataVT = EVT::getVectorVT(*DAG.getContext(), + DataVT.getVectorElementType(), NumElts); + DataOp = ModifyToType(DataOp, WideDataVT); + } else + llvm_unreachable("Can't widen this operand of mscatter"); + + // The mask should be widened as well. + EVT MaskVT = Mask.getValueType(); + EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(), + MaskVT.getVectorElementType(), NumElts); + + Mask = ModifyToType(Mask, WideMaskVT, true); + SDValue Ops[] = {MSC->getChain(), DataOp, Mask, MSC->getBasePtr(), Index, Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), - MSC->getMemoryVT(), dl, Ops, + MSC->getMemoryVT(), SDLoc(N), Ops, MSC->getMemOperand()); } Index: test/CodeGen/X86/masked_gather_scatter_widen.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/masked_gather_scatter_widen.ll @@ -0,0 +1,140 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512vl -mattr=+avx512dq -x86-experimental-vector-widening-legalization < %s | FileCheck %s --check-prefix=CHECK --check-prefix=WIDEN +; RUN: llc -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512vl -mattr=+avx512dq < %s | FileCheck %s --check-prefix=CHECK --check-prefix=PROMOTE + +define <2 x double> @test_gather_v2i32_index(double* %base, <2 x i32> %ind, <2 x i1> %mask, <2 x double> %src0) { +; WIDEN-LABEL: test_gather_v2i32_index: +; WIDEN: # %bb.0: +; WIDEN-NEXT: # kill: def $xmm2 killed $xmm2 def $ymm2 +; WIDEN-NEXT: vpsllq $63, %xmm1, %xmm1 +; WIDEN-NEXT: vpmovq2m %xmm1, %k1 +; WIDEN-NEXT: vgatherdpd (%rdi,%xmm0,8), %ymm2 {%k1} +; WIDEN-NEXT: vmovapd %xmm2, %xmm0 +; WIDEN-NEXT: vzeroupper +; WIDEN-NEXT: retq +; +; PROMOTE-LABEL: test_gather_v2i32_index: +; PROMOTE: # %bb.0: +; PROMOTE-NEXT: vpsllq $32, %xmm0, %xmm0 +; PROMOTE-NEXT: vpsraq $32, %xmm0, %xmm0 +; PROMOTE-NEXT: vpsllq $63, %xmm1, %xmm1 +; PROMOTE-NEXT: vpmovq2m %xmm1, %k1 +; PROMOTE-NEXT: vgatherqpd (%rdi,%xmm0,8), %xmm2 {%k1} +; PROMOTE-NEXT: vmovapd %xmm2, %xmm0 +; PROMOTE-NEXT: retq + %gep.random = getelementptr double, double* %base, <2 x i32> %ind + %res = call <2 x double> @llvm.masked.gather.v2f64.v2p0f64(<2 x double*> %gep.random, i32 4, <2 x i1> %mask, <2 x double> %src0) + ret <2 x double> %res +} + +define void @test_scatter_v2i32_index(<2 x double> %a1, double* %base, <2 x i32> %ind, <2 x i1> %mask) { +; WIDEN-LABEL: test_scatter_v2i32_index: +; WIDEN: # %bb.0: +; WIDEN-NEXT: # kill: def $xmm0 killed $xmm0 def $ymm0 +; WIDEN-NEXT: vpsllq $63, %xmm2, %xmm2 +; WIDEN-NEXT: vpmovq2m %xmm2, %k1 +; WIDEN-NEXT: vscatterdpd %ymm0, (%rdi,%xmm1,8) {%k1} +; WIDEN-NEXT: vzeroupper +; WIDEN-NEXT: retq +; +; PROMOTE-LABEL: test_scatter_v2i32_index: +; PROMOTE: # %bb.0: +; PROMOTE-NEXT: vpsllq $63, %xmm2, %xmm2 +; PROMOTE-NEXT: vpmovq2m %xmm2, %k1 +; PROMOTE-NEXT: vpsllq $32, %xmm1, %xmm1 +; PROMOTE-NEXT: vpsraq $32, %xmm1, %xmm1 +; PROMOTE-NEXT: vscatterqpd %xmm0, (%rdi,%xmm1,8) {%k1} +; PROMOTE-NEXT: retq + %gep = getelementptr double, double *%base, <2 x i32> %ind + call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %a1, <2 x double*> %gep, i32 4, <2 x i1> %mask) + ret void +} + +define <2 x i32> @test_gather_v2i32_data(<2 x i32*> %ptr, <2 x i1> %mask, <2 x i32> %src0) { +; WIDEN-LABEL: test_gather_v2i32_data: +; WIDEN: # %bb.0: +; WIDEN-NEXT: vpsllq $63, %xmm1, %xmm1 +; WIDEN-NEXT: vpmovq2m %xmm1, %k1 +; WIDEN-NEXT: vpgatherqd (,%xmm0), %xmm2 {%k1} +; WIDEN-NEXT: vmovdqa %xmm2, %xmm0 +; WIDEN-NEXT: retq +; +; PROMOTE-LABEL: test_gather_v2i32_data: +; PROMOTE: # %bb.0: +; PROMOTE-NEXT: vpsllq $63, %xmm1, %xmm1 +; PROMOTE-NEXT: vpmovq2m %xmm1, %k1 +; PROMOTE-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[0,2,2,3] +; PROMOTE-NEXT: vpgatherqd (,%xmm0), %xmm1 {%k1} +; PROMOTE-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm1[0],zero,xmm1[1],zero +; PROMOTE-NEXT: retq + %res = call <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*> %ptr, i32 4, <2 x i1> %mask, <2 x i32> %src0) + ret <2 x i32>%res +} + +define void @test_scatter_v2i32_data(<2 x i32>%a1, <2 x i32*> %ptr, <2 x i1>%mask) { +; WIDEN-LABEL: test_scatter_v2i32_data: +; WIDEN: # %bb.0: +; WIDEN-NEXT: vpsllq $63, %xmm2, %xmm2 +; WIDEN-NEXT: vpmovq2m %xmm2, %k1 +; WIDEN-NEXT: vpscatterqd %xmm0, (,%xmm1) {%k1} +; WIDEN-NEXT: retq +; +; PROMOTE-LABEL: test_scatter_v2i32_data: +; PROMOTE: # %bb.0: +; PROMOTE-NEXT: vpsllq $63, %xmm2, %xmm2 +; PROMOTE-NEXT: vpmovq2m %xmm2, %k1 +; PROMOTE-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; PROMOTE-NEXT: vpscatterqd %xmm0, (,%xmm1) {%k1} +; PROMOTE-NEXT: retq + call void @llvm.masked.scatter.v2i32.v2p0i32(<2 x i32> %a1, <2 x i32*> %ptr, i32 4, <2 x i1> %mask) + ret void +} + +define <2 x i32> @test_gather_v2i32_data_index(i32* %base, <2 x i32> %ind, <2 x i1> %mask, <2 x i32> %src0) { +; WIDEN-LABEL: test_gather_v2i32_data_index: +; WIDEN: # %bb.0: +; WIDEN-NEXT: vpsllq $63, %xmm1, %xmm1 +; WIDEN-NEXT: vpmovq2m %xmm1, %k1 +; WIDEN-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm2 {%k1} +; WIDEN-NEXT: vmovdqa %xmm2, %xmm0 +; WIDEN-NEXT: retq +; +; PROMOTE-LABEL: test_gather_v2i32_data_index: +; PROMOTE: # %bb.0: +; PROMOTE-NEXT: vpsllq $63, %xmm1, %xmm1 +; PROMOTE-NEXT: vpmovq2m %xmm1, %k1 +; PROMOTE-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; PROMOTE-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[0,2,2,3] +; PROMOTE-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1} +; PROMOTE-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm1[0],zero,xmm1[1],zero +; PROMOTE-NEXT: retq + %gep.random = getelementptr i32, i32* %base, <2 x i32> %ind + %res = call <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*> %gep.random, i32 4, <2 x i1> %mask, <2 x i32> %src0) + ret <2 x i32> %res +} + +define void @test_scatter_v2i32_data_index(<2 x i32> %a1, i32* %base, <2 x i32> %ind, <2 x i1> %mask) { +; WIDEN-LABEL: test_scatter_v2i32_data_index: +; WIDEN: # %bb.0: +; WIDEN-NEXT: vpsllq $63, %xmm2, %xmm2 +; WIDEN-NEXT: vpmovq2m %xmm2, %k1 +; WIDEN-NEXT: vpscatterdd %xmm0, (%rdi,%xmm1,4) {%k1} +; WIDEN-NEXT: retq +; +; PROMOTE-LABEL: test_scatter_v2i32_data_index: +; PROMOTE: # %bb.0: +; PROMOTE-NEXT: vpsllq $63, %xmm2, %xmm2 +; PROMOTE-NEXT: vpmovq2m %xmm2, %k1 +; PROMOTE-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; PROMOTE-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; PROMOTE-NEXT: vpscatterdd %xmm0, (%rdi,%xmm1,4) {%k1} +; PROMOTE-NEXT: retq + %gep = getelementptr i32, i32 *%base, <2 x i32> %ind + call void @llvm.masked.scatter.v2i32.v2p0i32(<2 x i32> %a1, <2 x i32*> %gep, i32 4, <2 x i1> %mask) + ret void +} + +declare <2 x double> @llvm.masked.gather.v2f64.v2p0f64(<2 x double*>, i32, <2 x i1>, <2 x double>) +declare void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double>, <2 x double*>, i32, <2 x i1>) +declare <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*>, i32, <2 x i1>, <2 x i32>) +declare void @llvm.masked.scatter.v2i32.v2p0i32(<2 x i32> , <2 x i32*> , i32 , <2 x i1>)