Skip to content

Commit 914d4ba

Browse files
committedMar 20, 2018
[NVPTX] Make tensor load/store intrinsics overloaded.
This way we can support address-space specific variants without explicitly encoding the space in the name of the intrinsic. Less intrinsics to deal with -> less boilerplate. Added a bit of tablegen magic to match/replace an intrinsics with a pointer argument in particular address space with the space-specific instruction variant. Updated tests to use non-default address spaces. Differential Revision: https://reviews.llvm.org/D43268 llvm-svn: 328006
1 parent 3a99893 commit 914d4ba

File tree

5 files changed

+174
-157
lines changed

5 files changed

+174
-157
lines changed
 

‎clang/lib/CodeGen/CGBuiltin.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -10527,8 +10527,7 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
1052710527
llvm_unreachable("Unexpected builtin ID.");
1052810528
}
1052910529
Value *Result =
10530-
Builder.CreateCall(CGM.getIntrinsic(IID),
10531-
{Builder.CreatePointerCast(Src, VoidPtrTy), Ldm});
10530+
Builder.CreateCall(CGM.getIntrinsic(IID, Src->getType()), {Src, Ldm});
1053210531

1053310532
// Save returned values.
1053410533
for (unsigned i = 0; i < NumResults; ++i) {
@@ -10567,10 +10566,9 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
1056710566
default:
1056810567
llvm_unreachable("Unexpected builtin ID.");
1056910568
}
10570-
Function *Intrinsic = CGM.getIntrinsic(IID);
10569+
Function *Intrinsic = CGM.getIntrinsic(IID, Dst->getType());
1057110570
llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1);
10572-
SmallVector<Value *, 10> Values;
10573-
Values.push_back(Builder.CreatePointerCast(Dst, VoidPtrTy));
10571+
SmallVector<Value *, 10> Values = {Dst};
1057410572
for (unsigned i = 0; i < NumResults; ++i) {
1057510573
Value *V = Builder.CreateAlignedLoad(
1057610574
Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)),

‎llvm/include/llvm/IR/IntrinsicsNVVM.td

+20-45
Original file line numberDiff line numberDiff line change
@@ -3884,90 +3884,65 @@ def int_nvvm_match_all_sync_i64p :
38843884
//
38853885

38863886
// WMMA.LOAD
3887-
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
3888-
string Type, LLVMType regty, int WithStride>
3887+
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type,
3888+
LLVMType regty, int WithStride>
38893889
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
38903890
[regty, regty, regty, regty],
38913891
[regty, regty, regty, regty,
38923892
regty, regty, regty, regty]),
3893-
!if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
3894-
[], // Properties must be set during instantiation.
3893+
!if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
3894+
[IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
38953895
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
3896-
#Space
38973896
#!if(WithStride,".stride","")
38983897
#"."#Type>;
38993898

3900-
multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
3901-
string Type, LLVMType regty> {
3902-
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
3903-
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
3904-
}
3905-
3906-
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
3907-
string Type, LLVMType regty> {
3908-
defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
3909-
defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
3910-
defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
3899+
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type,
3900+
LLVMType regty> {
3901+
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>;
3902+
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>;
39113903
}
39123904

39133905
multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
39143906
defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
39153907
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
39163908
}
39173909

3918-
// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
3919-
// passed to Intrinsic<> form inside of a multiclass. Setting them globally
3920-
// outside of the multiclass works.
3921-
let IntrProperties = [IntrReadMem, IntrArgMemOnly,
3922-
ReadOnly<0>, NoCapture<0>] in {
3923-
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
3924-
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
3925-
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
3926-
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
3927-
}
3910+
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
3911+
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
3912+
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
3913+
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
39283914

39293915
// WMMA.STORE.D
3930-
class NVVM_WMMA_STD_LSTS<string Layout, string Space,
3931-
string Type, LLVMType regty, int WithStride,
3916+
class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride,
39323917
// This is only used to create a typed empty array we
39333918
// need to pass to !if below.
39343919
list<LLVMType>Empty=[]>
39353920
: Intrinsic<[],
39363921
!listconcat(
3937-
[llvm_ptr_ty],
3922+
[llvm_anyptr_ty],
39383923
!if(!eq(Type,"f16"),
39393924
[regty, regty, regty, regty],
39403925
[regty, regty, regty, regty,
39413926
regty, regty, regty, regty]),
39423927
!if(WithStride, [llvm_i32_ty], Empty)),
3943-
[], // Properties must be set during instantiation.
3928+
[IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
39443929
"llvm.nvvm.wmma.store.d.sync."#Layout
3945-
#".m16n16k16"#Space
3930+
#".m16n16k16"
39463931
#!if(WithStride,".stride","")
39473932
#"."#Type>;
39483933

3949-
multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
3950-
string Type, LLVMType regty> {
3951-
def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
3952-
def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
3953-
}
3954-
39553934
multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
3956-
defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
3957-
defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
3958-
defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
3935+
def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>;
3936+
def NAME: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>;
39593937
}
39603938

39613939
multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
39623940
defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
39633941
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
39643942
}
39653943

3966-
let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
3967-
WriteOnly<0>, NoCapture<0>] in {
3968-
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
3969-
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
3970-
}
3944+
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
3945+
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
39713946

39723947
// WMMA.MMA
39733948
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,

‎llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

+5-53
Original file line numberDiff line numberDiff line change
@@ -3327,26 +3327,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
33273327
case Intrinsic::nvvm_wmma_load_a_f16_row:
33283328
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
33293329
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
3330-
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
3331-
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
3332-
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
3333-
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
3334-
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
3335-
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
3336-
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
3337-
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
33383330
case Intrinsic::nvvm_wmma_load_b_f16_col:
33393331
case Intrinsic::nvvm_wmma_load_b_f16_row:
33403332
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
3341-
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
3342-
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
3343-
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
3344-
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
3345-
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
3346-
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
3347-
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
3348-
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
3349-
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
3333+
case Intrinsic::nvvm_wmma_load_b_f16_row_stride: {
33503334
Info.opc = ISD::INTRINSIC_W_CHAIN;
33513335
Info.memVT = MVT::v8f16;
33523336
Info.ptrVal = I.getArgOperand(0);
@@ -3359,15 +3343,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
33593343
case Intrinsic::nvvm_wmma_load_c_f16_col:
33603344
case Intrinsic::nvvm_wmma_load_c_f16_row:
33613345
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
3362-
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
3363-
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
3364-
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
3365-
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
3366-
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
3367-
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
3368-
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
3369-
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
3370-
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
3346+
case Intrinsic::nvvm_wmma_load_c_f16_row_stride: {
33713347
Info.opc = ISD::INTRINSIC_W_CHAIN;
33723348
Info.memVT = MVT::v4f16;
33733349
Info.ptrVal = I.getArgOperand(0);
@@ -3380,15 +3356,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
33803356
case Intrinsic::nvvm_wmma_load_c_f32_col:
33813357
case Intrinsic::nvvm_wmma_load_c_f32_row:
33823358
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
3383-
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
3384-
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
3385-
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
3386-
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
3387-
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
3388-
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
3389-
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
3390-
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
3391-
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
3359+
case Intrinsic::nvvm_wmma_load_c_f32_row_stride: {
33923360
Info.opc = ISD::INTRINSIC_W_CHAIN;
33933361
Info.memVT = MVT::v8f32;
33943362
Info.ptrVal = I.getArgOperand(0);
@@ -3401,15 +3369,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
34013369
case Intrinsic::nvvm_wmma_store_d_f16_col:
34023370
case Intrinsic::nvvm_wmma_store_d_f16_row:
34033371
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
3404-
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
3405-
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
3406-
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
3407-
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
3408-
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
3409-
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
3410-
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
3411-
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
3412-
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
3372+
case Intrinsic::nvvm_wmma_store_d_f16_row_stride: {
34133373
Info.opc = ISD::INTRINSIC_VOID;
34143374
Info.memVT = MVT::v4f16;
34153375
Info.ptrVal = I.getArgOperand(0);
@@ -3422,15 +3382,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
34223382
case Intrinsic::nvvm_wmma_store_d_f32_col:
34233383
case Intrinsic::nvvm_wmma_store_d_f32_row:
34243384
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
3425-
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
3426-
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
3427-
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
3428-
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
3429-
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
3430-
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
3431-
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
3432-
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
3433-
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
3385+
case Intrinsic::nvvm_wmma_store_d_f32_row_stride: {
34343386
Info.opc = ISD::INTRINSIC_VOID;
34353387
Info.memVT = MVT::v8f32;
34363388
Info.ptrVal = I.getArgOperand(0);

0 commit comments

Comments
 (0)
Please sign in to comment.