Skip to content

Commit 914d4ba

Browse files
committedMar 20, 2018
[NVPTX] Make tensor load/store intrinsics overloaded.
This way we can support address-space specific variants without explicitly encoding the space in the name of the intrinsic. Less intrinsics to deal with -> less boilerplate. Added a bit of tablegen magic to match/replace an intrinsics with a pointer argument in particular address space with the space-specific instruction variant. Updated tests to use non-default address spaces. Differential Revision: https://reviews.llvm.org/D43268 llvm-svn: 328006
1 parent 3a99893 commit 914d4ba

File tree

5 files changed

+174
-157
lines changed

5 files changed

+174
-157
lines changed
 

‎clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10527,8 +10527,7 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
1052710527
llvm_unreachable("Unexpected builtin ID.");
1052810528
}
1052910529
Value *Result =
10530-
Builder.CreateCall(CGM.getIntrinsic(IID),
10531-
{Builder.CreatePointerCast(Src, VoidPtrTy), Ldm});
10530+
Builder.CreateCall(CGM.getIntrinsic(IID, Src->getType()), {Src, Ldm});
1053210531

1053310532
// Save returned values.
1053410533
for (unsigned i = 0; i < NumResults; ++i) {
@@ -10567,10 +10566,9 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
1056710566
default:
1056810567
llvm_unreachable("Unexpected builtin ID.");
1056910568
}
10570-
Function *Intrinsic = CGM.getIntrinsic(IID);
10569+
Function *Intrinsic = CGM.getIntrinsic(IID, Dst->getType());
1057110570
llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1);
10572-
SmallVector<Value *, 10> Values;
10573-
Values.push_back(Builder.CreatePointerCast(Dst, VoidPtrTy));
10571+
SmallVector<Value *, 10> Values = {Dst};
1057410572
for (unsigned i = 0; i < NumResults; ++i) {
1057510573
Value *V = Builder.CreateAlignedLoad(
1057610574
Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)),

‎llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3884,90 +3884,65 @@ def int_nvvm_match_all_sync_i64p :
38843884
//
38853885

38863886
// WMMA.LOAD
3887-
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
3888-
string Type, LLVMType regty, int WithStride>
3887+
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type,
3888+
LLVMType regty, int WithStride>
38893889
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
38903890
[regty, regty, regty, regty],
38913891
[regty, regty, regty, regty,
38923892
regty, regty, regty, regty]),
3893-
!if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
3894-
[], // Properties must be set during instantiation.
3893+
!if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
3894+
[IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
38953895
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
3896-
#Space
38973896
#!if(WithStride,".stride","")
38983897
#"."#Type>;
38993898

3900-
multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
3901-
string Type, LLVMType regty> {
3902-
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
3903-
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
3904-
}
3905-
3906-
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
3907-
string Type, LLVMType regty> {
3908-
defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
3909-
defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
3910-
defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
3899+
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type,
3900+
LLVMType regty> {
3901+
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>;
3902+
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>;
39113903
}
39123904

39133905
multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
39143906
defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
39153907
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
39163908
}
39173909

3918-
// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
3919-
// passed to Intrinsic<> form inside of a multiclass. Setting them globally
3920-
// outside of the multiclass works.
3921-
let IntrProperties = [IntrReadMem, IntrArgMemOnly,
3922-
ReadOnly<0>, NoCapture<0>] in {
3923-
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
3924-
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
3925-
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
3926-
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
3927-
}
3910+
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
3911+
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
3912+
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
3913+
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
39283914

39293915
// WMMA.STORE.D
3930-
class NVVM_WMMA_STD_LSTS<string Layout, string Space,
3931-
string Type, LLVMType regty, int WithStride,
3916+
class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride,
39323917
// This is only used to create a typed empty array we
39333918
// need to pass to !if below.
39343919
list<LLVMType>Empty=[]>
39353920
: Intrinsic<[],
39363921
!listconcat(
3937-
[llvm_ptr_ty],
3922+
[llvm_anyptr_ty],
39383923
!if(!eq(Type,"f16"),
39393924
[regty, regty, regty, regty],
39403925
[regty, regty, regty, regty,
39413926
regty, regty, regty, regty]),
39423927
!if(WithStride, [llvm_i32_ty], Empty)),
3943-
[], // Properties must be set during instantiation.
3928+
[IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
39443929
"llvm.nvvm.wmma.store.d.sync."#Layout
3945-
#".m16n16k16"#Space
3930+
#".m16n16k16"
39463931
#!if(WithStride,".stride","")
39473932
#"."#Type>;
39483933

3949-
multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
3950-
string Type, LLVMType regty> {
3951-
def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
3952-
def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
3953-
}
3954-
39553934
multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
3956-
defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
3957-
defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
3958-
defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
3935+
def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>;
3936+
def NAME: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>;
39593937
}
39603938

39613939
multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
39623940
defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
39633941
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
39643942
}
39653943

3966-
let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
3967-
WriteOnly<0>, NoCapture<0>] in {
3968-
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
3969-
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
3970-
}
3944+
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
3945+
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
39713946

39723947
// WMMA.MMA
39733948
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,

‎llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,26 +3327,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
33273327
case Intrinsic::nvvm_wmma_load_a_f16_row:
33283328
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
33293329
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
3330-
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
3331-
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
3332-
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
3333-
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
3334-
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
3335-
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
3336-
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
3337-
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
33383330
case Intrinsic::nvvm_wmma_load_b_f16_col:
33393331
case Intrinsic::nvvm_wmma_load_b_f16_row:
33403332
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
3341-
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
3342-
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
3343-
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
3344-
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
3345-
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
3346-
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
3347-
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
3348-
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
3349-
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
3333+
case Intrinsic::nvvm_wmma_load_b_f16_row_stride: {
33503334
Info.opc = ISD::INTRINSIC_W_CHAIN;
33513335
Info.memVT = MVT::v8f16;
33523336
Info.ptrVal = I.getArgOperand(0);
@@ -3359,15 +3343,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
33593343
case Intrinsic::nvvm_wmma_load_c_f16_col:
33603344
case Intrinsic::nvvm_wmma_load_c_f16_row:
33613345
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
3362-
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
3363-
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
3364-
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
3365-
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
3366-
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
3367-
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
3368-
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
3369-
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
3370-
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
3346+
case Intrinsic::nvvm_wmma_load_c_f16_row_stride: {
33713347
Info.opc = ISD::INTRINSIC_W_CHAIN;
33723348
Info.memVT = MVT::v4f16;
33733349
Info.ptrVal = I.getArgOperand(0);
@@ -3380,15 +3356,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
33803356
case Intrinsic::nvvm_wmma_load_c_f32_col:
33813357
case Intrinsic::nvvm_wmma_load_c_f32_row:
33823358
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
3383-
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
3384-
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
3385-
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
3386-
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
3387-
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
3388-
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
3389-
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
3390-
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
3391-
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
3359+
case Intrinsic::nvvm_wmma_load_c_f32_row_stride: {
33923360
Info.opc = ISD::INTRINSIC_W_CHAIN;
33933361
Info.memVT = MVT::v8f32;
33943362
Info.ptrVal = I.getArgOperand(0);
@@ -3401,15 +3369,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
34013369
case Intrinsic::nvvm_wmma_store_d_f16_col:
34023370
case Intrinsic::nvvm_wmma_store_d_f16_row:
34033371
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
3404-
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
3405-
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
3406-
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
3407-
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
3408-
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
3409-
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
3410-
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
3411-
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
3412-
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
3372+
case Intrinsic::nvvm_wmma_store_d_f16_row_stride: {
34133373
Info.opc = ISD::INTRINSIC_VOID;
34143374
Info.memVT = MVT::v4f16;
34153375
Info.ptrVal = I.getArgOperand(0);
@@ -3422,15 +3382,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
34223382
case Intrinsic::nvvm_wmma_store_d_f32_col:
34233383
case Intrinsic::nvvm_wmma_store_d_f32_row:
34243384
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
3425-
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
3426-
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
3427-
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
3428-
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
3429-
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
3430-
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
3431-
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
3432-
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
3433-
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
3385+
case Intrinsic::nvvm_wmma_store_d_f32_row_stride: {
34343386
Info.opc = ISD::INTRINSIC_VOID;
34353387
Info.memVT = MVT::v8f32;
34363388
Info.ptrVal = I.getArgOperand(0);

‎llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7379,13 +7379,16 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
73797379
string Type, NVPTXRegClass regclass,
73807380
DAGOperand SrcOp, bit WithStride>
73817381
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
7382-
// Intrinsic that matches this instruction.
7383-
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
7384-
# Abc
7385-
# "_" # Type
7386-
# "_" # Layout
7387-
# !subst(".","_",Space)
7388-
# !if(WithStride,"_stride", ""));
7382+
// Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
7383+
// for this function.
7384+
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_LOAD_"
7385+
# !subst("a", "A",
7386+
!subst("b", "B",
7387+
!subst("c", "C_" # Type, Abc)))
7388+
# "_" # Layout
7389+
# !subst(".", "_", Space)
7390+
# !if(WithStride,"_stride", "")
7391+
# "_Intr");
73897392
dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
73907393
dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
73917394
dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47));
@@ -7410,7 +7413,7 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
74107413
!subst(imem, ADDRvar,
74117414
!subst(MEMri64, ADDRri64,
74127415
!subst(MEMri, ADDRri,
7413-
!subst(ins, Intr, tmp)))));
7416+
!subst(ins, IntrMatcher, tmp)))));
74147417
// Finally, consatenate both parts together. !con() requires both dags to have
74157418
// the same operator, so we wrap PatArgs in a (set ...) dag.
74167419
let Pattern = [!con(PatOuts, (set PatArgs))];
@@ -7425,20 +7428,52 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
74257428
#";";
74267429
}
74277430

7428-
multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
7429-
string Type, NVPTXRegClass regclass,
7430-
DAGOperand SrcOp> {
7431-
def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 1>;
7432-
def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 0>;
7431+
class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space,
7432+
string Type, bit WithStride>
7433+
: PatFrag <(ops),(ops)> {
7434+
// Intrinsic that matches this instruction.
7435+
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
7436+
# Abc
7437+
# "_" # Type
7438+
# "_" # Layout
7439+
# !if(WithStride,"_stride", ""));
7440+
code match_generic = [{
7441+
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
7442+
}];
7443+
code match_shared = [{
7444+
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
7445+
}];
7446+
code match_global = [{
7447+
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
7448+
}];
7449+
7450+
let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
7451+
let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp));
7452+
let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
7453+
!if(!eq(Space, ".global"), match_global, match_generic));
7454+
}
7455+
7456+
multiclass WMMA_LOAD_ALSTS<string Abc, string Layout, string Space,
7457+
string Type, NVPTXRegClass regclass, bit WithStride> {
7458+
def _avar: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, imem, WithStride>;
7459+
def _areg: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int32Regs, WithStride>;
7460+
def _areg64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int64Regs, WithStride>;
7461+
def _ari: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri, WithStride>;
7462+
def _ari64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri64, WithStride>;
7463+
}
7464+
7465+
multiclass WMMA_LOAD_ALSTSh<string Abc, string Layout, string Space,
7466+
string Type, NVPTXRegClass regclass, bit WithStride> {
7467+
// Define a PatFrag that matches appropriate intrinsic that loads from the
7468+
// given address space.
7469+
def _Intr : WMMA_LOAD_INTR_HELPER<Abc, Layout, Space, Type, WithStride>;
7470+
defm NAME: WMMA_LOAD_ALSTS<Abc, Layout, Space, Type, regclass, WithStride>;
74337471
}
74347472

74357473
multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
7436-
string Type, NVPTXRegClass regclass> {
7437-
defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imem>;
7438-
defm _areg: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int32Regs>;
7439-
defm _areg64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int64Regs>;
7440-
defm _ari: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri>;
7441-
defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri64>;
7474+
string Type, NVPTXRegClass regclass> {
7475+
defm _stride: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 1>;
7476+
defm NAME: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 0>;
74427477
}
74437478

74447479
multiclass WMMA_LOAD_ALT<string Abc, string Layout,
@@ -7461,15 +7496,16 @@ defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
74617496
//
74627497
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
74637498
//
7464-
class WMMA_STORE_D_LSTOS<string Layout, string Space,
7499+
class WMMA_STORE_D_LSTSO<string Layout, string Space,
74657500
string Type, NVPTXRegClass regclass,
7466-
DAGOperand DstOp, bit WithStride>
7501+
bit WithStride, DAGOperand DstOp>
74677502
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
7468-
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d_"
7469-
# Type
7470-
# "_" # Layout
7471-
# !subst(".","_",Space)
7472-
# !if(WithStride,"_stride", ""));
7503+
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_STORE_D"
7504+
# "_" # Type
7505+
# "_" # Layout
7506+
# !subst(".", "_", Space)
7507+
# !if(WithStride,"_stride", "")
7508+
# "_Intr");
74737509

74747510
dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
74757511
dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
@@ -7483,7 +7519,7 @@ class WMMA_STORE_D_LSTOS<string Layout, string Space,
74837519
!subst(imem, ADDRvar,
74847520
!subst(MEMri64, ADDRri64,
74857521
!subst(MEMri, ADDRri,
7486-
!subst(ins, Intr, tmp)))));
7522+
!subst(ins, IntrMatcher, tmp)))));
74877523
let Pattern = [PatArgs];
74887524
let OutOperandList = (outs);
74897525
let InOperandList = Ins;
@@ -7501,20 +7537,56 @@ class WMMA_STORE_D_LSTOS<string Layout, string Space,
75017537

75027538
}
75037539

7504-
multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
7505-
string Type, NVPTXRegClass regclass,
7506-
DAGOperand DstOp> {
7507-
def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 1>;
7508-
def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 0>;
7540+
class WMMA_STORE_INTR_HELPER<string Layout, string Space,
7541+
string Type, bit WithStride>
7542+
: PatFrag <(ops),(ops)> {
7543+
// Intrinsic that matches this instruction.
7544+
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d"
7545+
# "_" # Type
7546+
# "_" # Layout
7547+
# !if(WithStride, "_stride", ""));
7548+
code match_generic = [{
7549+
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
7550+
}];
7551+
code match_shared = [{
7552+
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
7553+
}];
7554+
code match_global = [{
7555+
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
7556+
}];
7557+
7558+
dag Args = !if(!eq(Type,"f16"),
7559+
(ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3),
7560+
(ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3,
7561+
node:$r4, node:$r5, node:$r6, node:$r7));
7562+
dag StrideArg = !if(WithStride, (ops node:$ldm), (ops));
7563+
let Operands = !con(Args, StrideArg);
7564+
let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp));
7565+
let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
7566+
!if(!eq(Space, ".global"), match_global, match_generic));
7567+
}
7568+
7569+
multiclass WMMA_STORE_D_LSTS<string Layout, string Space,
7570+
string Type, NVPTXRegClass regclass, bit WithStride> {
7571+
def _avar: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, imem>;
7572+
def _areg: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int32Regs>;
7573+
def _areg64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int64Regs>;
7574+
def _ari: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri>;
7575+
def _ari64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri64>;
7576+
}
7577+
7578+
multiclass WMMA_STORE_D_LSTSh<string Layout, string Space,
7579+
string Type, NVPTXRegClass regclass, bit WithStride> {
7580+
// Define a PatFrag that matches appropriate intrinsic that loads from the
7581+
// given address space.
7582+
def _Intr: WMMA_STORE_INTR_HELPER<Layout, Space, Type, WithStride>;
7583+
defm NAME: WMMA_STORE_D_LSTS<Layout, Space, Type, regclass, WithStride>;
75097584
}
75107585

75117586
multiclass WMMA_STORE_D_LST<string Layout, string Space,
7512-
string Type, NVPTXRegClass regclass> {
7513-
defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imem>;
7514-
defm _areg: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int32Regs>;
7515-
defm _areg64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int64Regs>;
7516-
defm _ari: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri>;
7517-
defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri64>;
7587+
string Type, NVPTXRegClass regclass > {
7588+
defm _stride: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 1>;
7589+
defm NAME: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 0>;
75187590
}
75197591

75207592
multiclass WMMA_STORE_D_LT<string Layout,

‎llvm/test/CodeGen/NVPTX/wmma.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,51 @@ def make_wmma_slice_ty(abcd, itype):
1515
def make_wmma_ld_ret_ty(abc, itype):
1616
return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
1717

18+
# returns address space
19+
def get_aspace(space):
20+
space_map = {
21+
".global" : 1,
22+
".shared" : 3,
23+
".const" : 4,
24+
".local" : 5,
25+
".param" : 101,
26+
"" : 0,
27+
".generic": 0
28+
}
29+
return space_map[space];
30+
31+
def get_pspace(space):
32+
return "p%di8" % get_aspace(space);
33+
1834
# Convenient test patterns.
1935
check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
2036
check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
2137
check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
2238

2339
def gen_wmma_load_tests():
2440
load_template = """
25-
declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args});
41+
declare ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
2642
2743
; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
28-
define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) {
44+
define ${ret_ty} @test_wmma_load_${function_suffix}(i8 ${as}* %src ${extra_args}) {
2945
; CHECK wmma.load.${intrinsic_suffix}
3046
; CHECK: {${check_result}}
3147
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
32-
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args});
48+
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
3349
ret ${ret_ty} %v0;
3450
}
3551
3652
; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
37-
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
53+
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_args}) {
3854
; CHECK wmma.load.${intrinsic_suffix}
3955
; CHECK: {${check_result}}
4056
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
41-
%src1 = getelementptr i8, i8* %src, i32 128;
42-
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args});
57+
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
58+
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src1 ${extra_args});
4359
ret ${ret_ty} %v0;
4460
}
4561
"""
46-
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
62+
suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
4763
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
4864

4965
for abc, layout, space, stride, itype in product(
@@ -58,7 +74,9 @@ def gen_wmma_load_tests():
5874
"layout" : layout,
5975
"space" : space,
6076
"stride" : stride,
61-
"itype" : itype
77+
"itype" : itype,
78+
"pspace" : get_pspace(space),
79+
"as" : "addrspace(%d)" % get_aspace(space)
6280
}
6381

6482
if itype == "f32" and abc != "c":
@@ -89,28 +107,28 @@ def make_wmma_slice_args(itype, abcd, prefix="v"):
89107

90108
def gen_wmma_store_tests():
91109
store_template = """
92-
declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args});
110+
declare void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args}${extra_args});
93111
94112
; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
95-
define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) {
113+
define void @test_wmma_store_${function_suffix}(i8 ${as}* %src, ${args}${extra_args}) {
96114
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
97115
; CHECK: {${check_args}}
98116
; CHECK: ${stride_pattern}
99-
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args});
117+
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args} ${extra_args});
100118
ret void
101119
}
102120
103121
; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
104-
define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) {
122+
define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra_args}) {
105123
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
106124
; CHECK: ${check_args}
107125
; CHECK: ${stride_pattern}
108-
%src1 = getelementptr i8, i8* %src, i32 128;
109-
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args});
126+
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
127+
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src1, ${args}${extra_args});
110128
ret void
111129
}
112130
"""
113-
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
131+
suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
114132
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
115133

116134
for abc, layout, space, stride, itype in product(
@@ -125,7 +143,9 @@ def gen_wmma_store_tests():
125143
"layout" : layout,
126144
"space" : space,
127145
"stride" : stride,
128-
"itype" : itype
146+
"itype" : itype,
147+
"pspace" : get_pspace(space),
148+
"as" : "addrspace(%d)" % get_aspace(space)
129149
}
130150

131151
test_params = params

0 commit comments

Comments
 (0)
Please sign in to comment.