Skip to content

Commit 3bafc2f

Browse files
committedOct 12, 2017
[NVPTX] Implemented wmma intrinsics and instructions.
WMMA = "Warp Level Matrix Multiply-Accumulate". These are the new instructions introduced in PTX6.0 and available on sm_70 GPUs. Differential Revision: https://reviews.llvm.org/D38645 llvm-svn: 315601
1 parent 1a7e387 commit 3bafc2f

File tree

6 files changed

+1192
-0
lines changed

6 files changed

+1192
-0
lines changed
 

‎llvm/include/llvm/IR/IntrinsicsNVVM.td

+146
Original file line numberDiff line numberDiff line change
@@ -3869,4 +3869,150 @@ def int_nvvm_match_all_sync_i64p :
38693869
Intrinsic<[llvm_i64_ty, llvm_i1_ty], [llvm_i32_ty, llvm_i64_ty],
38703870
[IntrNoMem, IntrConvergent], "llvm.nvvm.match.all.sync.i64p">;
38713871

3872+
//
3873+
// WMMA instructions
3874+
//
3875+
3876+
// WMMA.LOAD
3877+
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
3878+
string Type, LLVMType regty, int WithStride>
3879+
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
3880+
[regty, regty, regty, regty],
3881+
[regty, regty, regty, regty,
3882+
regty, regty, regty, regty]),
3883+
!if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
3884+
[], // Properties must be set during instantiation.
3885+
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
3886+
#Space
3887+
#!if(WithStride,".stride","")
3888+
#"."#Type>;
3889+
3890+
multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
3891+
string Type, LLVMType regty> {
3892+
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
3893+
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
3894+
}
3895+
3896+
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
3897+
string Type, LLVMType regty> {
3898+
defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
3899+
defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
3900+
defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
3901+
}
3902+
3903+
multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
3904+
defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
3905+
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
3906+
}
3907+
3908+
// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
3909+
// passed to Intrinsic<> form inside of a multiclass. Setting them globally
3910+
// outside of the multiclass works.
3911+
let IntrProperties = [IntrReadMem, IntrArgMemOnly,
3912+
ReadOnly<0>, NoCapture<0>] in {
3913+
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
3914+
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
3915+
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
3916+
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
3917+
}
3918+
3919+
// WMMA.STORE.D
3920+
class NVVM_WMMA_STD_LSTS<string Layout, string Space,
3921+
string Type, LLVMType regty, int WithStride,
3922+
// This is only used to create a typed empty array we
3923+
// need to pass to !if below.
3924+
list<LLVMType>Empty=[]>
3925+
: Intrinsic<[],
3926+
!listconcat(
3927+
[llvm_ptr_ty],
3928+
!if(!eq(Type,"f16"),
3929+
[regty, regty, regty, regty],
3930+
[regty, regty, regty, regty,
3931+
regty, regty, regty, regty]),
3932+
!if(WithStride, [llvm_i32_ty], Empty)),
3933+
[], // Properties must be set during instantiation.
3934+
"llvm.nvvm.wmma.store.d.sync."#Layout
3935+
#".m16n16k16"#Space
3936+
#!if(WithStride,".stride","")
3937+
#"."#Type>;
3938+
3939+
multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
3940+
string Type, LLVMType regty> {
3941+
def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
3942+
def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
3943+
}
3944+
3945+
multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
3946+
defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
3947+
defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
3948+
defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
3949+
}
3950+
3951+
multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
3952+
defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
3953+
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
3954+
}
3955+
3956+
let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
3957+
WriteOnly<0>, NoCapture<0>] in {
3958+
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
3959+
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
3960+
}
3961+
3962+
// WMMA.MMA
3963+
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
3964+
string DType, LLVMType d_regty,
3965+
string CType, LLVMType c_regty,
3966+
string Satfinite = "">
3967+
: Intrinsic<!if(!eq(DType,"f16"),
3968+
[d_regty, d_regty, d_regty, d_regty],
3969+
[d_regty, d_regty, d_regty, d_regty,
3970+
d_regty, d_regty, d_regty, d_regty]),
3971+
!listconcat(
3972+
[// A
3973+
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3974+
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3975+
// B
3976+
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3977+
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
3978+
!if(!eq(CType,"f16"),
3979+
[c_regty, c_regty, c_regty, c_regty],
3980+
[c_regty, c_regty, c_regty, c_regty,
3981+
c_regty, c_regty, c_regty, c_regty])),
3982+
[IntrNoMem],
3983+
"llvm.nvvm.wmma.mma.sync."#ALayout#"."#BLayout
3984+
#".m16n16k16."#DType#"."#CType#Satfinite>;
3985+
3986+
multiclass NVVM_WMMA_MMA_ABDC<string ALayout, string BLayout,
3987+
string DType, LLVMType d_regty,
3988+
string CType, LLVMType c_regty> {
3989+
def NAME : NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
3990+
DType, d_regty,
3991+
CType, c_regty>;
3992+
def _satfinite: NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
3993+
DType, d_regty,
3994+
CType, c_regty,".satfinite">;
3995+
}
3996+
3997+
multiclass NVVM_WMMA_MMA_ABD<string ALayout, string BLayout,
3998+
string DType, LLVMType d_regty> {
3999+
defm _f16: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
4000+
"f16", llvm_v2f16_ty>;
4001+
defm _f32: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
4002+
"f32", llvm_float_ty>;
4003+
}
4004+
4005+
multiclass NVVM_WMMA_MMA_AB<string ALayout, string BLayout> {
4006+
defm _f16: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f16", llvm_v2f16_ty>;
4007+
defm _f32: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f32", llvm_float_ty>;
4008+
}
4009+
4010+
multiclass NVVM_WMMA_MMA_A<string ALayout> {
4011+
defm _col: NVVM_WMMA_MMA_AB<ALayout, "col">;
4012+
defm _row: NVVM_WMMA_MMA_AB<ALayout, "row">;
4013+
}
4014+
4015+
defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">;
4016+
defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">;
4017+
38724018
} // let TargetPrefix = "nvvm"

0 commit comments

Comments
 (0)
Please sign in to comment.