Index: clang/lib/CodeGen/CGBuiltin.cpp =================================================================== --- clang/lib/CodeGen/CGBuiltin.cpp +++ clang/lib/CodeGen/CGBuiltin.cpp @@ -10504,23 +10504,23 @@ unsigned NumResults; switch (BuiltinID) { case NVPTX::BI__hmma_m16n16k16_ld_a: - IID = isColMajor ? Intrinsic::nvvm_wmma_load_a_f16_col_stride - : Intrinsic::nvvm_wmma_load_a_f16_row_stride; + IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride + : Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride; NumResults = 8; break; case NVPTX::BI__hmma_m16n16k16_ld_b: - IID = isColMajor ? Intrinsic::nvvm_wmma_load_b_f16_col_stride - : Intrinsic::nvvm_wmma_load_b_f16_row_stride; + IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride + : Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride; NumResults = 8; break; case NVPTX::BI__hmma_m16n16k16_ld_c_f16: - IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f16_col_stride - : Intrinsic::nvvm_wmma_load_c_f16_row_stride; + IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride + : Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride; NumResults = 4; break; case NVPTX::BI__hmma_m16n16k16_ld_c_f32: - IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f32_col_stride - : Intrinsic::nvvm_wmma_load_c_f32_row_stride; + IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride + : Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride; NumResults = 8; break; default: @@ -10555,13 +10555,13 @@ // for some reason nvcc builtins use _c_. switch (BuiltinID) { case NVPTX::BI__hmma_m16n16k16_st_c_f16: - IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f16_col_stride - : Intrinsic::nvvm_wmma_store_d_f16_row_stride; + IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride + : Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride; NumResults = 4; break; case NVPTX::BI__hmma_m16n16k16_st_c_f32: - IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f32_col_stride - : Intrinsic::nvvm_wmma_store_d_f32_row_stride; + IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride + : Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride; break; default: llvm_unreachable("Unexpected builtin ID."); @@ -10580,8 +10580,8 @@ return Result; } - // BI__hmma_m16n16k16_mma_(d, a, b, c, layout, satf) - // --> Intrinsic::nvvm_wmma_mma_sync + // BI__hmma_m16n16k16_mma_(d, a, b, c, layout, satf) --> + // Intrinsic::nvvm_wmma_m16n16k16_mma_sync case NVPTX::BI__hmma_m16n16k16_mma_f16f16: case NVPTX::BI__hmma_m16n16k16_mma_f32f16: case NVPTX::BI__hmma_m16n16k16_mma_f32f32: @@ -10602,15 +10602,15 @@ bool Satf = SatfArg.getSExtValue(); // clang-format off -#define MMA_VARIANTS(type) {{ \ - Intrinsic::nvvm_wmma_mma_sync_row_row_##type, \ - Intrinsic::nvvm_wmma_mma_sync_row_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_mma_sync_row_col_##type, \ - Intrinsic::nvvm_wmma_mma_sync_row_col_##type##_satfinite, \ - Intrinsic::nvvm_wmma_mma_sync_col_row_##type, \ - Intrinsic::nvvm_wmma_mma_sync_col_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_mma_sync_col_col_##type, \ - Intrinsic::nvvm_wmma_mma_sync_col_col_##type##_satfinite \ +#define MMA_VARIANTS(type) {{ \ + Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type, \ + Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type, \ + Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type##_satfinite, \ + Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type, \ + Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type, \ + Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type##_satfinite \ }} // clang-format on Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -3884,39 +3884,53 @@ // // WMMA.LOAD -class NVVM_WMMA_LD_ALSTS +class NVVM_WMMA_LD_GALSTS : Intrinsic, NoCapture<0>], - "llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16" - #!if(WithStride,".stride","") - #"."#Type>; + "llvm.nvvm.wmma." + # Geometry + # ".load" + # "." # Abc + # "." # Layout + # !if(WithStride, ".stride", "") + # "." # Type>; -multiclass NVVM_WMMA_LD_ALT { - def _stride: NVVM_WMMA_LD_ALSTS; - def NAME : NVVM_WMMA_LD_ALSTS; +multiclass NVVM_WMMA_LD_GALT { + def _stride: NVVM_WMMA_LD_GALSTS; + def NAME : NVVM_WMMA_LD_GALSTS; } -multiclass NVVM_WMMA_LD_AT { - defm _row: NVVM_WMMA_LD_ALT; - defm _col: NVVM_WMMA_LD_ALT; +multiclass NVVM_WMMA_LD_GAT { + defm _row: NVVM_WMMA_LD_GALT; + defm _col: NVVM_WMMA_LD_GALT; } -defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>; -defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>; -defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>; -defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>; +multiclass NVVM_WMMA_LD_G { + defm _a_f16: NVVM_WMMA_LD_GAT; + defm _b_f16: NVVM_WMMA_LD_GAT; + defm _c_f16: NVVM_WMMA_LD_GAT; + defm _c_f32: NVVM_WMMA_LD_GAT; +} + +multiclass NVVM_WMMA_LD { + defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">; +} + +defm int_nvvm_wmma: NVVM_WMMA_LD; // WMMA.STORE.D -class NVVM_WMMA_STD_LSTSEmpty=[]> +class NVVM_WMMA_STD_GLSTSEmpty=[]> : Intrinsic<[], !listconcat( [llvm_anyptr_ty], @@ -3926,29 +3940,40 @@ regty, regty, regty, regty]), !if(WithStride, [llvm_i32_ty], Empty)), [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>], - "llvm.nvvm.wmma.store.d.sync."#Layout - #".m16n16k16" - #!if(WithStride,".stride","") - #"."#Type>; + "llvm.nvvm.wmma." + # Geometry + # ".store.d" + # "." # Layout + # !if(WithStride, ".stride", "") + # "." # Type>; -multiclass NVVM_WMMA_STD_LT { - def _stride: NVVM_WMMA_STD_LSTS; - def NAME: NVVM_WMMA_STD_LSTS; +multiclass NVVM_WMMA_STD_GLT { + def _stride: NVVM_WMMA_STD_GLSTS; + def NAME: NVVM_WMMA_STD_GLSTS; } -multiclass NVVM_WMMA_STD_T { - defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>; - defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>; +multiclass NVVM_WMMA_STD_GT { + defm _row: NVVM_WMMA_STD_GLT; + defm _col: NVVM_WMMA_STD_GLT; +} +multiclass NVVM_WMMA_STD_G { + defm _d_f16: NVVM_WMMA_STD_GT; + defm _d_f32: NVVM_WMMA_STD_GT; +} + +multiclass NVVM_WMMA_STD { + defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">; } -defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>; -defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>; +defm int_nvvm_wmma: NVVM_WMMA_STD; // WMMA.MMA -class NVVM_WMMA_MMA_ABDCS +class NVVM_WMMA_MMA_GABDCS : Intrinsic; + "llvm.nvvm.wmma." + # Geometry + # ".mma" + # "." # ALayout + # "." # BLayout + # "." # DType + # "." # CType + # Satfinite> { +} -multiclass NVVM_WMMA_MMA_ABDC { - def NAME : NVVM_WMMA_MMA_ABDCS; - def _satfinite: NVVM_WMMA_MMA_ABDCS; +multiclass NVVM_WMMA_MMA_GABDC { + def NAME : NVVM_WMMA_MMA_GABDCS; + def _satfinite: NVVM_WMMA_MMA_GABDCS; } -multiclass NVVM_WMMA_MMA_ABD { - defm _f16: NVVM_WMMA_MMA_ABDC; - defm _f32: NVVM_WMMA_MMA_ABDC; } -multiclass NVVM_WMMA_MMA_AB { - defm _f16: NVVM_WMMA_MMA_ABD; - defm _f32: NVVM_WMMA_MMA_ABD; +multiclass NVVM_WMMA_MMA_GAB { + defm _f16: NVVM_WMMA_MMA_GABD; + defm _f32: NVVM_WMMA_MMA_GABD; } -multiclass NVVM_WMMA_MMA_A { - defm _col: NVVM_WMMA_MMA_AB; - defm _row: NVVM_WMMA_MMA_AB; +multiclass NVVM_WMMA_MMA_GA { + defm _col: NVVM_WMMA_MMA_GAB; + defm _row: NVVM_WMMA_MMA_GAB; } -defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">; -defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">; +multiclass NVVM_WMMA_MMA_G { + defm _col: NVVM_WMMA_MMA_GA; + defm _row: NVVM_WMMA_MMA_GA; +} + +multiclass NVVM_WMMA_MMA { + defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">; +} + +defm int_nvvm_wmma : NVVM_WMMA_MMA; } // let TargetPrefix = "nvvm" Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3323,14 +3323,14 @@ // Our result depends on both our and other thread's arguments. Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore; return true; - case Intrinsic::nvvm_wmma_load_a_f16_col: - case Intrinsic::nvvm_wmma_load_a_f16_row: - case Intrinsic::nvvm_wmma_load_a_f16_col_stride: - case Intrinsic::nvvm_wmma_load_a_f16_row_stride: - case Intrinsic::nvvm_wmma_load_b_f16_col: - case Intrinsic::nvvm_wmma_load_b_f16_row: - case Intrinsic::nvvm_wmma_load_b_f16_col_stride: - case Intrinsic::nvvm_wmma_load_b_f16_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v8f16; Info.ptrVal = I.getArgOperand(0); @@ -3340,10 +3340,10 @@ return true; } - case Intrinsic::nvvm_wmma_load_c_f16_col: - case Intrinsic::nvvm_wmma_load_c_f16_row: - case Intrinsic::nvvm_wmma_load_c_f16_col_stride: - case Intrinsic::nvvm_wmma_load_c_f16_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v4f16; Info.ptrVal = I.getArgOperand(0); @@ -3353,10 +3353,10 @@ return true; } - case Intrinsic::nvvm_wmma_load_c_f32_col: - case Intrinsic::nvvm_wmma_load_c_f32_row: - case Intrinsic::nvvm_wmma_load_c_f32_col_stride: - case Intrinsic::nvvm_wmma_load_c_f32_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); @@ -3366,10 +3366,10 @@ return true; } - case Intrinsic::nvvm_wmma_store_d_f16_col: - case Intrinsic::nvvm_wmma_store_d_f16_row: - case Intrinsic::nvvm_wmma_store_d_f16_col_stride: - case Intrinsic::nvvm_wmma_store_d_f16_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v4f16; Info.ptrVal = I.getArgOperand(0); @@ -3379,10 +3379,10 @@ return true; } - case Intrinsic::nvvm_wmma_store_d_f32_col: - case Intrinsic::nvvm_wmma_store_d_f32_row: - case Intrinsic::nvvm_wmma_store_d_f32_col_stride: - case Intrinsic::nvvm_wmma_store_d_f32_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7375,16 +7375,15 @@ class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>; -class WMMA_LOAD_ALSTOS +class WMMA_LOAD_GALSTOS : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic // for this function. - PatFrag IntrMatcher = !cast("INT_WMMA_LOAD_" - # !subst("a", "A", - !subst("b", "B", - !subst("c", "C_" # Type, Abc))) + PatFrag IntrMatcher = !cast("INT_WMMA_" + # Geometry # "_load_" + # !subst("c", "c_" # Type, Abc) # "_" # Layout # !subst(".", "_", Space) # !if(WithStride,"_stride", "") @@ -7419,23 +7418,28 @@ let Pattern = [!con(PatOuts, (set PatArgs))]; let OutOperandList = Outs; let InOperandList = Ins; - let AsmString = "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t" - #!if(!eq(Abc#Type,"cf16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") - #", [$src]" - #!if(WithStride, ", $ldm", "") - #";"; + let AsmString = "wmma.load." + # Abc + # ".sync." + # Layout + # ".m16n16k16" + # Space + # "." # Type # " \t" + # !if(!eq(Abc#Type, "cf16"), + "{{$r0, $r1, $r2, $r3}}", + "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + # ", [$src]" + # !if(WithStride, ", $ldm", "") + # ";"; } -class WMMA_LOAD_INTR_HELPER +class WMMA_LOAD_INTR_HELPER : PatFrag <(ops),(ops)> { // Intrinsic that matches this instruction. - Intrinsic Intr = !cast("int_nvvm_wmma_load_" - # Abc - # "_" # Type - # "_" # Layout + Intrinsic Intr = !cast("int_nvvm_wmma" + # "_" # Geometry # "_load_" + # Abc # "_" # Type # "_" # Layout # !if(WithStride,"_stride", "")); code match_generic = [{ return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); @@ -7453,62 +7457,81 @@ !if(!eq(Space, ".global"), match_global, match_generic)); } -multiclass WMMA_LOAD_ALSTS { - def _avar: WMMA_LOAD_ALSTOS; - def _areg: WMMA_LOAD_ALSTOS; - def _areg64: WMMA_LOAD_ALSTOS; - def _ari: WMMA_LOAD_ALSTOS; - def _ari64: WMMA_LOAD_ALSTOS; +multiclass WMMA_LOAD_GALSTS { + def _avar: WMMA_LOAD_GALSTOS; + def _areg: WMMA_LOAD_GALSTOS; + def _areg64: WMMA_LOAD_GALSTOS; + def _ari: WMMA_LOAD_GALSTOS; + def _ari64: WMMA_LOAD_GALSTOS; } -multiclass WMMA_LOAD_ALSTSh { +multiclass WMMA_LOAD_GALSTSh { // Define a PatFrag that matches appropriate intrinsic that loads from the // given address space. - def _Intr : WMMA_LOAD_INTR_HELPER; - defm NAME: WMMA_LOAD_ALSTS; + def _Intr: WMMA_LOAD_INTR_HELPER; + defm NAME: WMMA_LOAD_GALSTS; } -multiclass WMMA_LOAD_ALST { - defm _stride: WMMA_LOAD_ALSTSh; - defm NAME: WMMA_LOAD_ALSTSh; +multiclass WMMA_LOAD_GALST { + defm _stride: WMMA_LOAD_GALSTSh; + defm NAME: WMMA_LOAD_GALSTSh; } -multiclass WMMA_LOAD_ALT { - defm _global: WMMA_LOAD_ALST; - defm _shared: WMMA_LOAD_ALST; - defm NAME: WMMA_LOAD_ALST; +multiclass WMMA_LOAD_GALT { + defm _global: WMMA_LOAD_GALST; + defm _shared: WMMA_LOAD_GALST; + defm NAME: WMMA_LOAD_GALST; } -multiclass WMMA_LOAD_AT { - defm _row: WMMA_LOAD_ALT; - defm _col: WMMA_LOAD_ALT; +multiclass WMMA_LOAD_GAT { + defm _row: WMMA_LOAD_GALT; + defm _col: WMMA_LOAD_GALT; } -defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>; -defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>; -defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>; -defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>; +multiclass WMMA_LOAD_G { + defm _load_a: WMMA_LOAD_GAT; + defm _load_b: WMMA_LOAD_GAT; + defm _load_c_f16: WMMA_LOAD_GAT; + defm _load_c_f32: WMMA_LOAD_GAT; +} + +defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">; // // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // -class WMMA_STORE_D_LSTSO +class WMMA_STORE_D_GLSTSO : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { - PatFrag IntrMatcher = !cast("INT_WMMA_STORE_D" + PatFrag IntrMatcher = !cast("INT_WMMA" + # "_" # Geometry # "_store_d" # "_" # Type # "_" # Layout # !subst(".", "_", Space) # !if(WithStride,"_stride", "") # "_Intr"); - - dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); - dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); + dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, + regclass:$r2, regclass:$r3); + dag InsR47 = (ins regclass:$r4, regclass:$r5, + regclass:$r6, regclass:$r7); dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47)); dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); dag Ins = !con(InsR, StrideArg); @@ -7525,7 +7548,7 @@ let InOperandList = Ins; let AsmString = "wmma.store.d.sync." # Layout - # ".m16n16k16" + # "." # Geometry # Space # "." # Type # " \t[$src]," @@ -7537,11 +7560,13 @@ } -class WMMA_STORE_INTR_HELPER : PatFrag <(ops),(ops)> { // Intrinsic that matches this instruction. - Intrinsic Intr = !cast("int_nvvm_wmma_store_d" + Intrinsic Intr = !cast("int_nvvm_wmma_" + # Geometry + # "_store_d" # "_" # Type # "_" # Layout # !if(WithStride, "_stride", "")); @@ -7566,57 +7591,77 @@ !if(!eq(Space, ".global"), match_global, match_generic)); } -multiclass WMMA_STORE_D_LSTS { - def _avar: WMMA_STORE_D_LSTSO; - def _areg: WMMA_STORE_D_LSTSO; - def _areg64: WMMA_STORE_D_LSTSO; - def _ari: WMMA_STORE_D_LSTSO; - def _ari64: WMMA_STORE_D_LSTSO; +multiclass WMMA_STORE_D_GLSTS { + def _avar: WMMA_STORE_D_GLSTSO; + def _areg: WMMA_STORE_D_GLSTSO; + def _areg64: WMMA_STORE_D_GLSTSO; + def _ari: WMMA_STORE_D_GLSTSO; + def _ari64: WMMA_STORE_D_GLSTSO; } -multiclass WMMA_STORE_D_LSTSh { +multiclass WMMA_STORE_D_GLSTSh { // Define a PatFrag that matches appropriate intrinsic that loads from the // given address space. - def _Intr: WMMA_STORE_INTR_HELPER; - defm NAME: WMMA_STORE_D_LSTS; + def _Intr: WMMA_STORE_INTR_HELPER; + defm NAME: WMMA_STORE_D_GLSTS; } -multiclass WMMA_STORE_D_LST { - defm _stride: WMMA_STORE_D_LSTSh; - defm NAME: WMMA_STORE_D_LSTSh; + defm _stride: WMMA_STORE_D_GLSTSh; + defm NAME: WMMA_STORE_D_GLSTSh; } -multiclass WMMA_STORE_D_LT { - defm _global: WMMA_STORE_D_LST; - defm _shared: WMMA_STORE_D_LST; - defm NAME: WMMA_STORE_D_LST; + defm _global: WMMA_STORE_D_GLST; + defm _shared: WMMA_STORE_D_GLST; + defm NAME: WMMA_STORE_D_GLST; } -multiclass WMMA_STORE_D_T { - defm _row: WMMA_STORE_D_LT<"row", Type, regclass>; - defm _col: WMMA_STORE_D_LT<"col", Type, regclass>; +multiclass WMMA_STORE_D_GT { + defm _row: WMMA_STORE_D_GLT; + defm _col: WMMA_STORE_D_GLT; } -defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>; -defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>; +multiclass WMMA_STORE_D_G { + defm _store_d_f16: WMMA_STORE_D_GT; + defm _store_d_f32: WMMA_STORE_D_GT; +} + +// multiclass WMMA_STORE_D { +// defm _m16n16k16: WMMA_STORE_D_G<"m16n16k16">; +// } + +defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">; // WMMA.MMA -class WMMA_MMA_ABDCS : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { - Intrinsic Intr = !cast("int_nvvm_wmma_mma_sync_" - # ALayout + Intrinsic Intr = !cast("int_nvvm_wmma_" + # Geometry + # "_mma" + # "_" # ALayout # "_" # BLayout # "_" # DType # "_" # CType - # !subst(".","_",Satfinite)); + # !subst(".", "_", Satfinite)); dag Outs = !if(!eq(DType,"f16"), (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3), (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3, @@ -7655,33 +7700,38 @@ "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"); } -multiclass WMMA_MMA_ABDC { - def _satfinite: WMMA_MMA_ABDCS; - def NAME: WMMA_MMA_ABDCS; } -multiclass WMMA_MMA_ABD { - defm _f16: WMMA_MMA_ABDC; - defm _f32: WMMA_MMA_ABDC; + defm _f16: WMMA_MMA_GABDC; + defm _f32: WMMA_MMA_GABDC; } -multiclass WMMA_MMA_AB { - defm _f16: WMMA_MMA_ABD; - defm _f32: WMMA_MMA_ABD; +multiclass WMMA_MMA_GAB { + defm _f16: WMMA_MMA_GABD; + defm _f32: WMMA_MMA_GABD; } -multiclass WMMA_MMA_A { - defm _col: WMMA_MMA_AB; - defm _row: WMMA_MMA_AB; +multiclass WMMA_MMA_GA { + defm _col: WMMA_MMA_GAB; + defm _row: WMMA_MMA_GAB; } -defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">; -defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">; +multiclass WMMA_MMA_G { + defm _col: WMMA_MMA_GA; + defm _row: WMMA_MMA_GA; +} +defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">; Index: llvm/test/CodeGen/NVPTX/wmma.py =================================================================== --- llvm/test/CodeGen/NVPTX/wmma.py +++ llvm/test/CodeGen/NVPTX/wmma.py @@ -38,29 +38,29 @@ def gen_wmma_load_tests(): load_template = """ -declare ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args}); +declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); -; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}( -define ${ret_ty} @test_wmma_load_${function_suffix}(i8 ${as}* %src ${extra_args}) { -; CHECK wmma.load.${intrinsic_suffix} +; CHECK-LABEL: .func {{.*}}test_${function}( +define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) { +; CHECK ${instruction} ; CHECK: {${check_result}} ; CHECK: [%rd{{[0-9]+}}]${stride_pattern} - %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args}); + %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); ret ${ret_ty} %v0; } -; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o( -define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_args}) { -; CHECK wmma.load.${intrinsic_suffix} +; CHECK-LABEL: .func{{.*}}test_${function}_o( +define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) { +; CHECK ${instruction} ; CHECK: {${check_result}} ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern} %src1 = getelementptr i8, i8 ${as}* %src, i32 128; - %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src1 ${extra_args}); + %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args}); ret ${ret_ty} %v0; } """ - suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}" - instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}" + intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}" + instruction_template = "wmma.load.${abc}.sync.${geom}.${layout}${space}.${itype}" for abc, layout, space, stride, itype in product( "abc", @@ -76,16 +76,17 @@ "stride" : stride, "itype" : itype, "pspace" : get_pspace(space), - "as" : "addrspace(%d)" % get_aspace(space) + "as" : "addrspace(%d)" % get_aspace(space), + "geom" : "m16n16k16", } if itype == "f32" and abc != "c": continue test_params = params - test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params) - test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_") - test_params["instruction_suffix"] = Template(instruction_template).substitute(params) + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".","_") + test_params["instruction"] = Template(instruction_template).substitute(params) test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) if abc == "c" : test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8 @@ -107,29 +108,29 @@ def gen_wmma_store_tests(): store_template = """ -declare void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args}${extra_args}); +declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args}); -; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}( -define void @test_wmma_store_${function_suffix}(i8 ${as}* %src, ${args}${extra_args}) { -; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}} +; CHECK-LABEL: .func {{.*}}test_${function}( +define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) { +; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}} ; CHECK: {${check_args}} ; CHECK: ${stride_pattern} - call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args} ${extra_args}); + call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args}); ret void } -; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o( -define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra_args}) { -; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128] +; CHECK-LABEL: .func{{.*}}test_${function}_o( +define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) { +; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}+128] ; CHECK: ${check_args} ; CHECK: ${stride_pattern} %src1 = getelementptr i8, i8 ${as}* %src, i32 128; - call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src1, ${args}${extra_args}); + call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args}); ret void } """ - suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}" - instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}" + intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}" + instruction_template = "wmma.store.${abc}.sync.${geom}.${layout}${space}.${itype}" for abc, layout, space, stride, itype in product( "d", @@ -145,13 +146,14 @@ "stride" : stride, "itype" : itype, "pspace" : get_pspace(space), - "as" : "addrspace(%d)" % get_aspace(space) + "as" : "addrspace(%d)" % get_aspace(space), + "geom" : "m16n16k16", } test_params = params - test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params) - test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_") - test_params["instruction_suffix"] = Template(instruction_template).substitute(params) + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".","_") + test_params["instruction"] = Template(instruction_template).substitute(params) test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8 if stride: @@ -166,23 +168,24 @@ def gen_wmma_mma_tests(): mma_template = """ -declare ${ret_ty} @llvm.nvvm.wmma.mma.sync.$intrinsic_suffix( +declare ${ret_ty} @${intrinsic}( ${args}); -; CHECK-LABEL: .func {{.*}}test_wmma_mma_${function_suffix}( -define ${ret_ty} @test_wmma_mma_${function_suffix}( +; CHECK-LABEL: .func {{.*}}test_${function}( +define ${ret_ty} @test_${function}( ${args}) { -; CHECK wmma.mma.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}} +; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}} ; CHECK ${check_d} ; CHECK ${check_ab} ; CHECK ${check_ab} ; CHECK ${check_c} - %r = call ${ret_ty} @llvm.nvvm.wmma.mma.sync.${intrinsic_suffix}( + %r = call ${ret_ty} @${intrinsic}( ${args}); ret ${ret_ty} %r; } """ - suffix_template = "${alayout}.${blayout}.m16n16k16.${dtype}.${ctype}${satf}" + intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}" + instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}" for alayout, blayout, ctype, dtype, satf in product( ["row","col"], @@ -196,12 +199,14 @@ "blayout" : blayout, "ctype" : ctype, "dtype" : dtype, - "satf" : satf + "satf" : satf, + "geom" : "m16n16k16", } test_params = params - test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params) - test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".", "_") + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".", "_") + test_params["instruction"] = Template(instruction_template).substitute(params) test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype) test_params["check_ab"] = check_f16_8 test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8