@@ -3869,4 +3869,150 @@ def int_nvvm_match_all_sync_i64p :
3869
3869
Intrinsic<[llvm_i64_ty, llvm_i1_ty], [llvm_i32_ty, llvm_i64_ty],
3870
3870
[IntrNoMem, IntrConvergent], "llvm.nvvm.match.all.sync.i64p">;
3871
3871
3872
+ //
3873
+ // WMMA instructions
3874
+ //
3875
+
3876
+ // WMMA.LOAD
3877
+ class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
3878
+ string Type, LLVMType regty, int WithStride>
3879
+ : Intrinsic<!if(!eq(Abc#Type,"cf16"),
3880
+ [regty, regty, regty, regty],
3881
+ [regty, regty, regty, regty,
3882
+ regty, regty, regty, regty]),
3883
+ !if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
3884
+ [], // Properties must be set during instantiation.
3885
+ "llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
3886
+ #Space
3887
+ #!if(WithStride,".stride","")
3888
+ #"."#Type>;
3889
+
3890
+ multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
3891
+ string Type, LLVMType regty> {
3892
+ def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
3893
+ def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
3894
+ }
3895
+
3896
+ multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
3897
+ string Type, LLVMType regty> {
3898
+ defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
3899
+ defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
3900
+ defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
3901
+ }
3902
+
3903
+ multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
3904
+ defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
3905
+ defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
3906
+ }
3907
+
3908
+ // For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
3909
+ // passed to Intrinsic<> form inside of a multiclass. Setting them globally
3910
+ // outside of the multiclass works.
3911
+ let IntrProperties = [IntrReadMem, IntrArgMemOnly,
3912
+ ReadOnly<0>, NoCapture<0>] in {
3913
+ defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
3914
+ defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
3915
+ defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
3916
+ defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
3917
+ }
3918
+
3919
+ // WMMA.STORE.D
3920
+ class NVVM_WMMA_STD_LSTS<string Layout, string Space,
3921
+ string Type, LLVMType regty, int WithStride,
3922
+ // This is only used to create a typed empty array we
3923
+ // need to pass to !if below.
3924
+ list<LLVMType>Empty=[]>
3925
+ : Intrinsic<[],
3926
+ !listconcat(
3927
+ [llvm_ptr_ty],
3928
+ !if(!eq(Type,"f16"),
3929
+ [regty, regty, regty, regty],
3930
+ [regty, regty, regty, regty,
3931
+ regty, regty, regty, regty]),
3932
+ !if(WithStride, [llvm_i32_ty], Empty)),
3933
+ [], // Properties must be set during instantiation.
3934
+ "llvm.nvvm.wmma.store.d.sync."#Layout
3935
+ #".m16n16k16"#Space
3936
+ #!if(WithStride,".stride","")
3937
+ #"."#Type>;
3938
+
3939
+ multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
3940
+ string Type, LLVMType regty> {
3941
+ def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
3942
+ def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
3943
+ }
3944
+
3945
+ multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
3946
+ defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
3947
+ defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
3948
+ defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
3949
+ }
3950
+
3951
+ multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
3952
+ defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
3953
+ defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
3954
+ }
3955
+
3956
+ let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
3957
+ WriteOnly<0>, NoCapture<0>] in {
3958
+ defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
3959
+ defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
3960
+ }
3961
+
3962
+ // WMMA.MMA
3963
+ class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
3964
+ string DType, LLVMType d_regty,
3965
+ string CType, LLVMType c_regty,
3966
+ string Satfinite = "">
3967
+ : Intrinsic<!if(!eq(DType,"f16"),
3968
+ [d_regty, d_regty, d_regty, d_regty],
3969
+ [d_regty, d_regty, d_regty, d_regty,
3970
+ d_regty, d_regty, d_regty, d_regty]),
3971
+ !listconcat(
3972
+ [// A
3973
+ llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3974
+ llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3975
+ // B
3976
+ llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3977
+ llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
3978
+ !if(!eq(CType,"f16"),
3979
+ [c_regty, c_regty, c_regty, c_regty],
3980
+ [c_regty, c_regty, c_regty, c_regty,
3981
+ c_regty, c_regty, c_regty, c_regty])),
3982
+ [IntrNoMem],
3983
+ "llvm.nvvm.wmma.mma.sync."#ALayout#"."#BLayout
3984
+ #".m16n16k16."#DType#"."#CType#Satfinite>;
3985
+
3986
+ multiclass NVVM_WMMA_MMA_ABDC<string ALayout, string BLayout,
3987
+ string DType, LLVMType d_regty,
3988
+ string CType, LLVMType c_regty> {
3989
+ def NAME : NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
3990
+ DType, d_regty,
3991
+ CType, c_regty>;
3992
+ def _satfinite: NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
3993
+ DType, d_regty,
3994
+ CType, c_regty,".satfinite">;
3995
+ }
3996
+
3997
+ multiclass NVVM_WMMA_MMA_ABD<string ALayout, string BLayout,
3998
+ string DType, LLVMType d_regty> {
3999
+ defm _f16: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
4000
+ "f16", llvm_v2f16_ty>;
4001
+ defm _f32: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
4002
+ "f32", llvm_float_ty>;
4003
+ }
4004
+
4005
+ multiclass NVVM_WMMA_MMA_AB<string ALayout, string BLayout> {
4006
+ defm _f16: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f16", llvm_v2f16_ty>;
4007
+ defm _f32: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f32", llvm_float_ty>;
4008
+ }
4009
+
4010
+ multiclass NVVM_WMMA_MMA_A<string ALayout> {
4011
+ defm _col: NVVM_WMMA_MMA_AB<ALayout, "col">;
4012
+ defm _row: NVVM_WMMA_MMA_AB<ALayout, "row">;
4013
+ }
4014
+
4015
+ defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">;
4016
+ defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">;
4017
+
3872
4018
} // let TargetPrefix = "nvvm"
0 commit comments