Skip to content

Commit 56b039e

Browse files
author
Igor Breger
committedFeb 1, 2016
AVX512: fix mask handling for gather/scatter/prefetch intrinsics.
Differential Revision: http://reviews.llvm.org/D16755 llvm-svn: 259346
1 parent f8d0f18 commit 56b039e

File tree

2 files changed

+77
-43
lines changed

2 files changed

+77
-43
lines changed
 

‎llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16492,6 +16492,11 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
1649216492
const X86Subtarget &Subtarget,
1649316493
SelectionDAG &DAG, SDLoc dl) {
1649416494

16495+
if (isAllOnesConstant(Mask))
16496+
return DAG.getTargetConstant(1, dl, MaskVT);
16497+
if (X86::isZeroNode(Mask))
16498+
return DAG.getTargetConstant(0, dl, MaskVT);
16499+
1649516500
if (MaskVT.bitsGT(Mask.getSimpleValueType())) {
1649616501
// Mask should be extended
1649716502
Mask = DAG.getNode(ISD::ANY_EXTEND, dl,
@@ -17409,79 +17414,52 @@ static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
1740917414
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
1741017415
MVT MaskVT = MVT::getVectorVT(MVT::i1,
1741117416
Index.getSimpleValueType().getVectorNumElements());
17412-
SDValue MaskInReg;
17413-
ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
17414-
if (MaskC)
17415-
MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
17416-
else {
17417-
MVT BitcastVT = MVT::getVectorVT(MVT::i1,
17418-
Mask.getSimpleValueType().getSizeInBits());
1741917417

17420-
// In case when MaskVT equals v2i1 or v4i1, low 2 or 4 elements
17421-
// are extracted by EXTRACT_SUBVECTOR.
17422-
MaskInReg = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MaskVT,
17423-
DAG.getBitcast(BitcastVT, Mask),
17424-
DAG.getIntPtrConstant(0, dl));
17425-
}
17418+
SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
1742617419
SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other);
1742717420
SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32);
1742817421
SDValue Segment = DAG.getRegister(0, MVT::i32);
1742917422
if (Src.getOpcode() == ISD::UNDEF)
1743017423
Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl);
17431-
SDValue Ops[] = {Src, MaskInReg, Base, Scale, Index, Disp, Segment, Chain};
17424+
SDValue Ops[] = {Src, VMask, Base, Scale, Index, Disp, Segment, Chain};
1743217425
SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops);
1743317426
SDValue RetOps[] = { SDValue(Res, 0), SDValue(Res, 2) };
1743417427
return DAG.getMergeValues(RetOps, dl);
1743517428
}
1743617429

1743717430
static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
1743817431
SDValue Src, SDValue Mask, SDValue Base,
17439-
SDValue Index, SDValue ScaleOp, SDValue Chain) {
17432+
SDValue Index, SDValue ScaleOp, SDValue Chain,
17433+
const X86Subtarget &Subtarget) {
1744017434
SDLoc dl(Op);
1744117435
auto *C = cast<ConstantSDNode>(ScaleOp);
1744217436
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
1744317437
SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32);
1744417438
SDValue Segment = DAG.getRegister(0, MVT::i32);
1744517439
MVT MaskVT = MVT::getVectorVT(MVT::i1,
1744617440
Index.getSimpleValueType().getVectorNumElements());
17447-
SDValue MaskInReg;
17448-
ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
17449-
if (MaskC)
17450-
MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
17451-
else {
17452-
MVT BitcastVT = MVT::getVectorVT(MVT::i1,
17453-
Mask.getSimpleValueType().getSizeInBits());
1745417441

17455-
// In case when MaskVT equals v2i1 or v4i1, low 2 or 4 elements
17456-
// are extracted by EXTRACT_SUBVECTOR.
17457-
MaskInReg = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MaskVT,
17458-
DAG.getBitcast(BitcastVT, Mask),
17459-
DAG.getIntPtrConstant(0, dl));
17460-
}
17442+
SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
1746117443
SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);
17462-
SDValue Ops[] = {Base, Scale, Index, Disp, Segment, MaskInReg, Src, Chain};
17444+
SDValue Ops[] = {Base, Scale, Index, Disp, Segment, VMask, Src, Chain};
1746317445
SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops);
1746417446
return SDValue(Res, 1);
1746517447
}
1746617448

1746717449
static SDValue getPrefetchNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
1746817450
SDValue Mask, SDValue Base, SDValue Index,
17469-
SDValue ScaleOp, SDValue Chain) {
17451+
SDValue ScaleOp, SDValue Chain,
17452+
const X86Subtarget &Subtarget) {
1747017453
SDLoc dl(Op);
1747117454
auto *C = cast<ConstantSDNode>(ScaleOp);
1747217455
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
1747317456
SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32);
1747417457
SDValue Segment = DAG.getRegister(0, MVT::i32);
1747517458
MVT MaskVT =
1747617459
MVT::getVectorVT(MVT::i1, Index.getSimpleValueType().getVectorNumElements());
17477-
SDValue MaskInReg;
17478-
ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
17479-
if (MaskC)
17480-
MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
17481-
else
17482-
MaskInReg = DAG.getBitcast(MaskVT, Mask);
17460+
SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
1748317461
//SDVTList VTs = DAG.getVTList(MVT::Other);
17484-
SDValue Ops[] = {MaskInReg, Base, Scale, Index, Disp, Segment, Chain};
17462+
SDValue Ops[] = {VMask, Base, Scale, Index, Disp, Segment, Chain};
1748517463
SDNode *Res = DAG.getMachineNode(Opc, dl, MVT::Other, Ops);
1748617464
return SDValue(Res, 0);
1748717465
}
@@ -17678,7 +17656,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
1767817656
SDValue Src = Op.getOperand(5);
1767917657
SDValue Scale = Op.getOperand(6);
1768017658
return getScatterNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index,
17681-
Scale, Chain);
17659+
Scale, Chain, Subtarget);
1768217660
}
1768317661
case PREFETCH: {
1768417662
SDValue Hint = Op.getOperand(6);
@@ -17690,7 +17668,8 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
1769017668
SDValue Index = Op.getOperand(3);
1769117669
SDValue Base = Op.getOperand(4);
1769217670
SDValue Scale = Op.getOperand(5);
17693-
return getPrefetchNode(Opcode, Op, DAG, Mask, Base, Index, Scale, Chain);
17671+
return getPrefetchNode(Opcode, Op, DAG, Mask, Base, Index, Scale, Chain,
17672+
Subtarget);
1769417673
}
1769517674
// Read Time Stamp Counter (RDTSC) and Processor ID (RDTSCP).
1769617675
case RDTSC: {

‎llvm/test/CodeGen/X86/avx512-gather-scatter-intrin.ll

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,18 +259,22 @@ define void @prefetch(<8 x i64> %ind, i8* %base) {
259259
; CHECK: ## BB#0:
260260
; CHECK-NEXT: kxnorw %k0, %k0, %k1
261261
; CHECK-NEXT: vgatherpf0qps (%rdi,%zmm0,4) {%k1}
262+
; CHECK-NEXT: kxorw %k0, %k0, %k1
262263
; CHECK-NEXT: vgatherpf1qps (%rdi,%zmm0,4) {%k1}
264+
; CHECK-NEXT: movb $1, %al
265+
; CHECK-NEXT: kmovb %eax, %k1
263266
; CHECK-NEXT: vscatterpf0qps (%rdi,%zmm0,2) {%k1}
267+
; CHECK-NEXT: movb $120, %al
268+
; CHECK-NEXT: kmovb %eax, %k1
264269
; CHECK-NEXT: vscatterpf1qps (%rdi,%zmm0,2) {%k1}
265270
; CHECK-NEXT: retq
266271
call void @llvm.x86.avx512.gatherpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 4, i32 0)
267-
call void @llvm.x86.avx512.gatherpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 4, i32 1)
268-
call void @llvm.x86.avx512.scatterpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 2, i32 0)
269-
call void @llvm.x86.avx512.scatterpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 2, i32 1)
272+
call void @llvm.x86.avx512.gatherpf.qps.512(i8 0, <8 x i64> %ind, i8* %base, i32 4, i32 1)
273+
call void @llvm.x86.avx512.scatterpf.qps.512(i8 1, <8 x i64> %ind, i8* %base, i32 2, i32 0)
274+
call void @llvm.x86.avx512.scatterpf.qps.512(i8 120, <8 x i64> %ind, i8* %base, i32 2, i32 1)
270275
ret void
271276
}
272277

273-
274278
declare <2 x double> @llvm.x86.avx512.gather3div2.df(<2 x double>, i8*, <2 x i64>, i8, i32)
275279

276280
define <2 x double>@test_int_x86_avx512_gather3div2_df(<2 x double> %x0, i8* %x1, <2 x i64> %x2, i8 %x3) {
@@ -790,3 +794,54 @@ define void@test_int_x86_avx512_scattersiv8_si(i8* %x0, i8 %x1, <8 x i32> %x2, <
790794
ret void
791795
}
792796

797+
define void @scatter_mask_test(i8* %x0, <8 x i32> %x2, <8 x i32> %x3) {
798+
; CHECK-LABEL: scatter_mask_test:
799+
; CHECK: ## BB#0:
800+
; CHECK-NEXT: kxnorw %k0, %k0, %k1
801+
; CHECK-NEXT: vpscatterdd %ymm1, (%rdi,%ymm0,2) {%k1}
802+
; CHECK-NEXT: kxorw %k0, %k0, %k1
803+
; CHECK-NEXT: vpscatterdd %ymm1, (%rdi,%ymm0,4) {%k1}
804+
; CHECK-NEXT: movb $1, %al
805+
; CHECK-NEXT: kmovb %eax, %k1
806+
; CHECK-NEXT: vpscatterdd %ymm1, (%rdi,%ymm0,2) {%k1}
807+
; CHECK-NEXT: movb $96, %al
808+
; CHECK-NEXT: kmovb %eax, %k1
809+
; CHECK-NEXT: vpscatterdd %ymm1, (%rdi,%ymm0,4) {%k1}
810+
; CHECK-NEXT: retq
811+
call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 -1, <8 x i32> %x2, <8 x i32> %x3, i32 2)
812+
call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 0, <8 x i32> %x2, <8 x i32> %x3, i32 4)
813+
call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 1, <8 x i32> %x2, <8 x i32> %x3, i32 2)
814+
call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 96, <8 x i32> %x2, <8 x i32> %x3, i32 4)
815+
ret void
816+
}
817+
818+
define <16 x float> @gather_mask_test(<16 x i32> %ind, <16 x float> %src, i8* %base) {
819+
; CHECK-LABEL: gather_mask_test:
820+
; CHECK: ## BB#0:
821+
; CHECK-NEXT: kxnorw %k0, %k0, %k1
822+
; CHECK-NEXT: vmovaps %zmm1, %zmm2
823+
; CHECK-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm2 {%k1}
824+
; CHECK-NEXT: kxorw %k0, %k0, %k1
825+
; CHECK-NEXT: vmovaps %zmm1, %zmm3
826+
; CHECK-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm3 {%k1}
827+
; CHECK-NEXT: movw $1, %ax
828+
; CHECK-NEXT: kmovw %eax, %k1
829+
; CHECK-NEXT: vmovaps %zmm1, %zmm4
830+
; CHECK-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm4 {%k1}
831+
; CHECK-NEXT: movw $220, %ax
832+
; CHECK-NEXT: kmovw %eax, %k1
833+
; CHECK-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
834+
; CHECK-NEXT: vaddps %zmm3, %zmm2, %zmm0
835+
; CHECK-NEXT: vaddps %zmm4, %zmm1, %zmm1
836+
; CHECK-NEXT: vaddps %zmm0, %zmm1, %zmm0
837+
; CHECK-NEXT: retq
838+
%res = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 -1, i32 4)
839+
%res1 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 0, i32 4)
840+
%res2 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 1, i32 4)
841+
%res3 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 220, i32 4)
842+
843+
%res4 = fadd <16 x float> %res, %res1
844+
%res5 = fadd <16 x float> %res3, %res2
845+
%res6 = fadd <16 x float> %res5, %res4
846+
ret <16 x float> %res6
847+
}

0 commit comments

Comments
 (0)