Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -50,6 +50,7 @@ string geom = Geom; string frag = Frag; string ptx_elt_type = PtxEltType; + string gft = Geom#":"#Frag#":"#ptx_elt_type; string ft = frag#":"#ptx_elt_type; list regs = !cond( // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 @@ -60,7 +61,42 @@ !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret, - !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret); + !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret, + + // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16 + !eq(gft,"m16n16k16:a:u8") : RepLLVMType<2, llvm_i32_ty>.ret, + !eq(gft,"m16n16k16:a:s8") : RepLLVMType<2, llvm_i32_ty>.ret, + !eq(gft,"m16n16k16:b:u8") : RepLLVMType<2, llvm_i32_ty>.ret, + !eq(gft,"m16n16k16:b:s8") : RepLLVMType<2, llvm_i32_ty>.ret, + !eq(gft,"m16n16k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + !eq(gft,"m16n16k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + + !eq(gft,"m8n32k16:a:u8") : [llvm_i32_ty], + !eq(gft,"m8n32k16:a:s8") : [llvm_i32_ty], + !eq(gft,"m8n32k16:b:u8") : RepLLVMType<4, llvm_i32_ty>.ret, + !eq(gft,"m8n32k16:b:s8") : RepLLVMType<4, llvm_i32_ty>.ret, + !eq(gft,"m8n32k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + !eq(gft,"m8n32k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + + !eq(gft,"m32n8k16:a:u8") : RepLLVMType<4, llvm_i32_ty>.ret, + !eq(gft,"m32n8k16:a:s8") : RepLLVMType<4, llvm_i32_ty>.ret, + !eq(gft,"m32n8k16:b:u8") : [llvm_i32_ty], + !eq(gft,"m32n8k16:b:s8") : [llvm_i32_ty], + !eq(gft,"m32n8k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + !eq(gft,"m32n8k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + + // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) + !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty], + !eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty], + !eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty], + !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty], + !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty], + !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty], + !eq(gft,"m8n8k128:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret, + !eq(gft,"m8n8k128:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret, + !eq(gft,"m8n8k32:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret, + !eq(gft,"m8n8k32:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret, + ); } class WMMA_NAME_LDST { @@ -84,22 +120,163 @@ # !if(WithStride, "_stride", ""); } -class WMMA_NAME_MMA { +class MMA_SIGNATURE { + list id_frags = !cond( + // int and sub-int ops are identified by input type. + !eq(A.ptx_elt_type, "s8") : [A], + !eq(A.ptx_elt_type, "u8") : [A], + !eq(A.ptx_elt_type, "s4") : [A], + !eq(A.ptx_elt_type, "u4") : [A], + !eq(A.ptx_elt_type, "b1") : [A], + // the rest are FP ops identified by accumulator & result type. + 1: [D, C] + ); + string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type)); +} + +class WMMA_NAME_MMA { + string signature = MMA_SIGNATURE.ret; string llvm = "llvm.nvvm.wmma." - # C.geom + # A.geom # ".mma" # "." # ALayout # "." # BLayout - # "." # D.ptx_elt_type // Intrinsic encodes 'd' first. - # "." # C.ptx_elt_type + # signature # !if(Satfinite, ".satfinite", ""); string record = !subst(".", "_", !subst("llvm.", "int_", llvm)); } +// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. +// Geom: list of supported geometries. +// TypeN: PTX type of the corresponding fragment's element. +// TypeB and TypeD may be empty if it must match that of TypeA or TypeC. +class MMA_OPS Geom, list TypeA, list TypeB, + list TypeC, list TypeD> { + list> ret = + !foldl([]>, Geom, t1, geom, !listconcat(t1, + !foldl([]>, TypeA, t2, type_a, !listconcat(t2, + !foldl([]>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3, + !foldl([]>, TypeC, t4, type_c, !listconcat(t4, + !foldl([]>, !if(!size(TypeC), TypeC, [type_c]), t5, type_d, !listconcat(t5, + [[WMMA_REGS, + WMMA_REGS, + WMMA_REGS, + WMMA_REGS]])))))))))); + // Debugging aid for readable representation of the list above. + list> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]); +} + +class MMA_LDST_OPS Geom, list Frags, list Types> { + list ret = + !foldl([], Geom, t1, geom, !listconcat(t1, + !foldl([], Frags, t2, frag, !listconcat(t2, + !foldl([], Types, t3, type, !listconcat(t3, + [WMMA_REGS])))))); + // Debugging aid for readable representation of the list above. + list ops = !foreach(x, ret, x.gft); +} + + + +// Creates list of valid combinations of fragments. This is the master list that +// drives generation of corresponding intrinsics and instructions. +class NVVM_MMA_OPS { + list> fp_mma_ops = MMA_OPS< + ["m16n16k16", "m32n8k16", "m8n32k16"], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; + list> int_mma_ops = MMA_OPS< + ["m16n16k16", "m32n8k16", "m8n32k16"], + ["s8", "u8"], [], ["s32"], []>.ret; + list> subint_mma_ops = MMA_OPS< + ["m8n8k32"], + ["s4", "u4"], [], ["s32"], []>.ret; + list> bit_mma_ops = MMA_OPS< + ["m8n8k128"], + ["b1"], [], ["s32"], []>.ret; + list> all_mma_ops = !listconcat(fp_mma_ops, int_mma_ops, + subint_mma_ops, bit_mma_ops); + + list ldst_ab_ops = MMA_LDST_OPS< + ["m16n16k16", "m32n8k16", "m8n32k16"], + ["a", "b"], ["f16", "u8", "s8"]>.ret; + list ldst_cd_ops = MMA_LDST_OPS< + ["m16n16k16", "m32n8k16", "m8n32k16"], + ["c", "d"], ["f16", "f32", "s32"]>.ret; + list ldst_subint_ab_ops = MMA_LDST_OPS< + ["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret; + list ldst_bit_ab_ops = MMA_LDST_OPS< + ["m8n8k128"], ["a", "b"], ["b1"]>.ret; + list ldst_subint_cd_ops = MMA_LDST_OPS< + ["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]>.ret; + list all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops, + ldst_subint_ab_ops, + ldst_bit_ab_ops, + ldst_subint_cd_ops); + // Separate A/B/C fragments (loads) from D (stores). + list all_ld_ops = !foldl([], all_ldst_ops, a, b, + !listconcat(a, !if(!eq(b.frag,"d"), [],[b]))); + list all_st_ops = !foldl([], all_ldst_ops, a, b, + !listconcat(a, !if(!eq(b.frag,"d"), [b],[]))); +} + +def NVVM_MMA_OPS : NVVM_MMA_OPS; + +// Returns [1] if this combination of layout/satf is supported, [] otherwise. +// MMA ops must provide all parameters. Loads and stores -- only frags and layout_a. +// The class is used to prevent generation of records for the unsupported variants. +// E.g. +// foreach _ = NVVM_MMA_SUPPORTED<...>.ret in = +// def : FOO<>; // The record will only be defined for supported ops. +// +class NVVM_MMA_SUPPORTED frags, string layout_a, string layout_b="-", int satf=-1> { + // MMA ops check both layout and satf. + string mma = frags[0].ptx_elt_type + # ":" # layout_a + # ":" # layout_b + # ":" # satf + ; + // Load ops only need type/fragment/layout. ; + string ld = frags[0].ptx_elt_type + # ":" # frags[0].frag + # ":" # layout_a + ; + string ldf = frags[0].ptx_elt_type + # ":" # frags[0].frag + ; + string t = frags[0].ptx_elt_type; + list ret = !cond( + // Sub-int MMA only supports fixed A/B layout w/o satf. + !eq(mma, "b1:row:col:0") : [1], + !eq(mma, "s4:row:col:0") : [1], + !eq(mma, "u4:row:col:0") : [1], + !eq(mma, "s4:row:col:0") : [1], + !eq(mma, "u4:row:col:0") : [1], + // Sub-int load/stores have fixed layout for A and B. + !and(!eq(layout_b, "-"), // It's a Load or Store op + !or(!eq(ld, "b1:a:row"), + !eq(ld, "b1:b:col"), + !eq(ldf, "b1:c"), + !eq(ldf, "b1:d"), + !eq(ld, "s4:a:row"), + !eq(ld, "s4:b:col"), + !eq(ldf, "s4:c"), + !eq(ldf, "s4:d"), + !eq(ld, "u4:a:row"), + !eq(ld, "u4:b:col"), + !eq(ldf, "u4:c"), + !eq(ldf, "u4:d"))) : [1], + // All other sub-int ops are not supported. + !eq(t, "b1") : [], + !eq(t, "s4") : [], + !eq(t, "u4") : [], + // All other (non sub-int) are OK. + 1: [1] + ); +} + let TargetPrefix = "nvvm" in { def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">, Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], @@ -3970,51 +4147,41 @@ WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>; // Create all load/store variants -foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { - foreach layout = ["row", "col"] in { - foreach stride = [0, 1] in { - foreach frag = [WMMA_REGS, - WMMA_REGS, - WMMA_REGS, - WMMA_REGS] in { - def WMMA_NAME_LDST<"load", frag, layout, stride>.record +foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach frag = NVVM_MMA_OPS.all_ld_ops in + foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in + def WMMA_NAME_LDST<"load", frag, layout, stride>.record : NVVM_WMMA_LD; - } - foreach frag = [WMMA_REGS, - WMMA_REGS] in { - def WMMA_NAME_LDST<"store", frag, layout, stride>.record + foreach frag = NVVM_MMA_OPS.all_st_ops in + foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in + def WMMA_NAME_LDST<"store", frag, layout, stride>.record : NVVM_WMMA_ST; - } - } } } // WMMA.MMA -class NVVM_WMMA_MMA +class NVVM_WMMA_MMA : Intrinsic.regs, - WMMA_REGS.regs, - C.regs), + !listconcat(A.regs, B.regs, C.regs), [IntrNoMem], - WMMA_NAME_MMA.llvm>; + WMMA_NAME_MMA.llvm>; -foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { - foreach layout_a = ["row", "col"] in { - foreach layout_b = ["row", "col"] in { - foreach frag_c = [WMMA_REGS, - WMMA_REGS] in { - foreach frag_d = [WMMA_REGS, - WMMA_REGS] in { - foreach satf = [0, 1] in { - def WMMA_NAME_MMA.record - : NVVM_WMMA_MMA; - } +foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach satf = [0, 1] in { + foreach op = NVVM_MMA_OPS.all_mma_ops in { + foreach _ = NVVM_MMA_SUPPORTED.ret in { + def WMMA_NAME_MMA.record + : NVVM_WMMA_MMA; } } - } - } -} + } // satf + } // layout_b +} // layout_a } // let TargetPrefix = "nvvm" Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3398,6 +3398,94 @@ Info.align = 16; return true; } + case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v2i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOLoad; + Info.align = 8; + return true; + } + + case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row: + + case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v4i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOLoad; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row: + + case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row: + case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row: + case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride: + case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col: + case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride: + case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row: + case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride: + case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride: + case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row: + case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col: + case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride: + case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride: + case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOLoad; + Info.align = 4; + return true; + } case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col: case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row: @@ -3441,6 +3529,44 @@ return true; } + case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v8i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOLoad; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col: + case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride: + case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row: + case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride: + case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col: + case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride: + case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row: + case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v2i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOLoad; + Info.align = 8; + return true; + } + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col: case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row: case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride: @@ -3483,6 +3609,44 @@ return true; } + case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row: + case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::v8i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col: + case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride: + case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row: + case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride: + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::v2i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = 8; + return true; + } + case Intrinsic::nvvm_atomic_load_add_f32: case Intrinsic::nvvm_atomic_load_add_f64: case Intrinsic::nvvm_atomic_load_inc_32: Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -142,9 +142,12 @@ def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">; def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">; def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">; +def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">; def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">; def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">; +def hasSM72 : Predicate<"Subtarget->getSmVersion() >= 72">; +def hasSM75 : Predicate<"Subtarget->getSmVersion() >= 75">; def useShortPtr : Predicate<"useShortPointers()">; def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7406,12 +7406,18 @@ // In addition to target-independent fields provided by WMMA_REGS, it adds // the fields commonly used to implement specific PTX instruction -- register // types and names, constraints, parts of assembly, etc. -class WMMA_REGINFO - : WMMA_REGS { +class WMMA_REGINFO + : WMMA_REGS { // NVPTX register types used to carry fragment data. NVPTXRegClass regclass = !cond( - !eq(PtxEltType, "f16") : Float16x2Regs, - !eq(PtxEltType, "f32") : Float32Regs); + !eq(ptx_elt_type, "f16") : Float16x2Regs, + !eq(ptx_elt_type, "f32") : Float32Regs, + !eq(ptx_elt_type, "s32") : Int32Regs, + !eq(ptx_elt_type, "s8") : Int32Regs, + !eq(ptx_elt_type, "u8") : Int32Regs, + !eq(ptx_elt_type, "s4") : Int32Regs, + !eq(ptx_elt_type, "u4") : Int32Regs, + !eq(ptx_elt_type, "b1") : Int32Regs); // Instruction input/output arguments for the fragment. list ptx_regs = !foreach(tmp, regs, regclass); @@ -7433,15 +7439,27 @@ // all fragments of the instruction are viable. list Predicates = !cond( // fp16 -> fp16/fp32 @ m16n16k16 - !and(!eq(Geom, "m16n16k16"), - !or(!eq(PtxEltType, "f16"), - !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60], + !and(!eq(geom, "m16n16k16"), + !or(!eq(ptx_elt_type, "f16"), + !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60], // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16 - !and(!or(!eq(Geom, "m8n32k16"), - !eq(Geom, "m32n8k16")), - !or(!eq(PtxEltType, "f16"), - !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]); + !and(!or(!eq(geom, "m8n32k16"), + !eq(geom, "m32n8k16")), + !or(!eq(ptx_elt_type, "f16"), + !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX61], + + // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16 + !and(!or(!eq(geom,"m16n16k16"), + !eq(geom,"m8n32k16"), + !eq(geom,"m32n8k16")), + !or(!eq(ptx_elt_type, "u8"), + !eq(ptx_elt_type, "s8"), + !eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63], + + // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) + !or(!eq(geom,"m8n8k128"), + !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -7559,44 +7577,48 @@ // Create all load/store variants defset list MMA_LDSTs = { - foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { - foreach layout = ["row", "col"] in { - foreach stride = [0, 1] in { - foreach space = [".global", ".shared", ""] in { - foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { - foreach frag = [WMMA_REGINFO, - WMMA_REGINFO, - WMMA_REGINFO, - WMMA_REGINFO] in { - def : WMMA_LOAD; - } - foreach frag = [WMMA_REGINFO, - WMMA_REGINFO] in { - def : WMMA_STORE_D; - } - } // addr - } // space - } // stride - } // layout - } // geom + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach space = [".global", ".shared", ""] in { + foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { + foreach frag = NVVM_MMA_OPS.all_ld_ops in + foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in + def : WMMA_LOAD, layout, space, stride, addr>; + foreach frag = NVVM_MMA_OPS.all_st_ops in + foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in + def : WMMA_STORE_D, layout, space, stride, addr>; + } // addr + } // space + } // stride + } // layout } // defset // WMMA.MMA class WMMA_MMA - : WMMA_INSTR.record, + : WMMA_INSTR.record, [FragA.Ins, FragB.Ins, FragC.Ins]>, - Requires { + // Requires does not seem to have effect on Instruction w/o Patterns. + // We set it here anyways and propagate to the Pat<> we construct below. + Requires { let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); - let AsmString = "wmma.mma.sync" + string TypeList = !cond( + !eq(FragD.ptx_elt_type, "s32") : ".s32" + # "." # FragA.ptx_elt_type + # "." # FragB.ptx_elt_type + # ".s32", + 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type, + ); + let AsmString = "wmma.mma" + # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") + # ".sync" # "${ptx:aligned}" # "." # ALayout # "." # BLayout # "." # FragA.geom - # "." # FragD.ptx_elt_type - # "." # FragC.ptx_elt_type + # TypeList # !if(Satfinite, ".satfinite", "") # "\n\t\t" # FragD.regstring # ",\n\t\t" # FragA.regstring # ",\n\t\t" @@ -7605,32 +7627,32 @@ } defset list MMAs = { - foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { - foreach layout_a = ["row", "col"] in { - foreach layout_b = ["row", "col"] in { - foreach frag_c = [WMMA_REGINFO, - WMMA_REGINFO] in { - foreach frag_d = [WMMA_REGINFO, - WMMA_REGINFO] in { - foreach satf = [0, 1] in { - def : WMMA_MMA, - WMMA_REGINFO, - frag_c, frag_d, layout_a, layout_b, satf>; - } // satf - } // frag_d - } // frag_c - } // layout_b - } // layout_a - } // geom + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach satf = [0, 1] in { + foreach op = NVVM_MMA_OPS.all_mma_ops in { + foreach _ = NVVM_MMA_SUPPORTED.ret in { + def : WMMA_MMA, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + layout_a, layout_b, satf>; + } + } // op + } // satf + } // layout_b + } // layout_a } // defset + // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. class WMMA_PAT : Pat; + (wi ptx.version))>, + Requires; // Build intrinsic->instruction patterns for all MMA instructions. foreach mma = !listconcat(MMAs, MMA_LDSTs) in Index: llvm/test/CodeGen/NVPTX/wmma.py =================================================================== --- llvm/test/CodeGen/NVPTX/wmma.py +++ llvm/test/CodeGen/NVPTX/wmma.py @@ -1,10 +1,42 @@ # This test generates all variants of wmma intrinsics and verifies that LLVM # generates correct instructions for them. -# RUN: python %s > %t.ll -# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll -# RUN: python %s --ptx=63 > %t-ptx63.ll -# RUN: llc < %t-ptx63.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx63 | FileCheck %t-ptx63.ll +# Check all variants of instructions supported by PTX60 on SM70 +# RUN: python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll +# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \ +# RUN: --check-prefixes=INTRINSICS,PTX60,SM70 +# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \ +# RUN: --check-prefixes=INTRINSICS,PTX60U,SM70U +# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \ +# RUN: | FileCheck %t-ptx60-sm_70.ll + +# Check all variants of instructions supported by PTX61 on SM70 +# RUN: python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll +# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \ +# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,SM70 +# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \ +# RUN: --check-prefixes=INTRINSICS,PTX61U,SM70U +# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \ +# RUN: | FileCheck %t-ptx61-sm_70.ll + +# Check all variants of instructions supported by PTX63 on SM72 +# RUN: python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll +# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \ +# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72 +# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \ +# RUN: --check-prefixes=INTRINSICS,PTX63U,SM72U +# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \ +# RUN: | FileCheck %t-ptx63-sm_72.ll + +# Check all variants of instructions supported by PTX63 on SM75 +# RUN: python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll +# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \ +# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72,SM75 +# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \ +# RUN: --check-prefixes=INTRINSICS,PTX63U,SM75U +# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \ +# RUN: | FileCheck %t-ptx63-sm_75.ll + from __future__ import print_function @@ -12,13 +44,172 @@ from itertools import product from string import Template -def make_wmma_slice_ty(abcd, itype): - elt_ty = "<2 x half>" if itype == "f16" else "float" - num_elts = 4 if abcd in "cd" and itype == "f16" else 8; - return [elt_ty] * num_elts +class MMAType: + def __init__(self, ptx_type): + self.ptx_type = ptx_type + self.llvm_type = { + "f16" : "<2 x half>", + "f32" : "float", + "s32" : "i32", + "s8" : "i32", + "u8" : "i32", + "s4" : "i32", + "u4" : "i32", + "b1" : "i32", + }[ptx_type]; -def make_wmma_ld_ret_ty(abc, itype): - return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype)) + self.ptx_reg_pattern = { + "f16" : "%hh[0-9]+", + "f32" : "%f[0-9]+", + }.get(ptx_type, "%r[0-9]+") + + def __repr__(self): + return "%s/%s" % (self.ptx_type, self.llvm_type) + +class MMAFrag: + def __init__(self, geom, frag, ptx_elt_type): + self.geom = geom + self.frag = frag + self.mma_type = MMAType(ptx_elt_type); + self.nregs = { + "a:f16" : 8, + "b:f16" : 8, + "c:f16" : 4, + "d:f16" : 4, + "c:f32" : 8, + "d:f32" : 8, + }.get("%s:%s" % (frag, ptx_elt_type), { + # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16 + "m16n16k16:a:u8" : 2, + "m16n16k16:a:s8" : 2, + "m16n16k16:b:u8" : 2, + "m16n16k16:b:s8" : 2, + "m16n16k16:c:s32" : 8, + "m16n16k16:d:s32" : 8, + + "m8n32k16:a:u8" : 1, + "m8n32k16:a:s8" : 1, + "m8n32k16:b:u8" : 4, + "m8n32k16:b:s8" : 4, + "m8n32k16:c:s32" : 8, + "m8n32k16:d:s32" : 8, + + "m32n8k16:a:u8" : 4, + "m32n8k16:a:s8" : 4, + "m32n8k16:b:u8" : 1, + "m32n8k16:b:s8" : 1, + "m32n8k16:c:s32" : 8, + "m32n8k16:d:s32" : 8, + + # u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) + "m8n8k128:a:b1" : 1, + "m8n8k32:a:u4" : 1, + "m8n8k32:a:s4" : 1, + "m8n8k128:b:b1" : 1, + "m8n8k32:b:u4" : 1, + "m8n8k32:b:s4" : 1, + "m8n8k128:c:s32" : 2, + "m8n8k128:d:s32" : 2, + "m8n8k32:c:s32" : 2, + "m8n8k32:d:s32" : 2, + }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None)); + assert(self.nregs); + + def __repr__(self): + return "%s:%s:%s%s" % (self.geom, self.frag, self.mma_type, + "" if self.nregs == 1 else ("*%d" % self.nregs)) + +class MMAOp: + def __init__(self, a, b, c, d): + self.a = a + self.b = b + self.c = c + self.d = d + + def __repr__(self): + return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d )) + +def make_mma_ops(geoms, types_a, types_b, types_c, types_d): + ops = [] + for geom, type_a, type_c in product( geoms, types_a, types_c): + for type_b, type_d in product(types_b if types_b else [type_a], + types_d if types_d else [type_c]): + ops.append(MMAOp(MMAFrag(geom, "a", type_a), + MMAFrag(geom, "b", type_b), + MMAFrag(geom, "c", type_c), + MMAFrag(geom, "d", type_d))) + return ops + +def make_ldst_ops(geoms, frags, types): + return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type) + in product(geoms, frags, types)] + +def get_mma_ops(): + return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], + ["s8", "u8"], [], ["s32"], []) + + make_mma_ops(["m8n8k32"], + ["s4", "u4"], [], ["s32"], []) + + make_mma_ops(["m8n8k128"], + ["b1"], [], ["s32"], [])) +def get_ldst_ops(kind): + ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], + ["a", "b"], ["f16", "u8", "s8"]) + + make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], + ["c", "d"], ["f16", "f32", "s32"]) + + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) + + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) + + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])) + return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")] + +def is_geom_supported(geom): + # geometries for FP and ints. + if geom in ["m8n32k16", "m32n8k16"]: + return ptx_version >= 61 + # geometries for sub-ints. + if geom in ["m8n8k32", "m8n8k128"]: + return ptx_version >= 63 and gpu_arch >= 75 + if geom == "m16n16k16": + return ptx_version >= 60 + assert(False) # Unexpected geometry. + +def is_type_supported(ptx_type): + if ptx_type in ["s8", "u8", "s32"]: + return ptx_version >= 63 and gpu_arch >= 72 + if ptx_type in ["s4", "u4", "b1"]: + return ptx_version >= 63 and gpu_arch >= 75 + return ptx_version >= 60 and gpu_arch >= 70 + + +def is_mma_variant_supported(op, layout_a, layout_b, satf): + if not (is_type_supported(op.a.mma_type.ptx_type) + and is_geom_supported(op.a.geom)): + return False + # sub-integer require row/col layout, and no satf. + if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]: + return layout_a == "row" and layout_b == "col" and satf == "" + return True + +def is_ldst_variant_supported(frag, layout): + if not (is_type_supported(frag.mma_type.ptx_type) + and is_geom_supported(frag.geom)): + return False + if frag.mma_type.ptx_type in ["s4", "u4", "b1"]: + # sub-integer require sm_75 and ptx63, row/col layout for a/b. + return ((frag.frag == "a" and layout == "row") + or (frag.frag == "b" and layout == "col") + or frag.frag in ["c", "d"]) + return True + +def make_wmma_slice_ty(frag): + return [frag.mma_type.llvm_type] * frag.nregs + +def make_wmma_ld_ret_ty(frag): + results = make_wmma_slice_ty(frag) + if len(results) == 1: + return "%s" % results[0] + return "{%s}" % ", ".join(results) # returns address space def get_aspace(space): @@ -36,10 +227,8 @@ def get_pspace(space): return "p%di8" % get_aspace(space); -# Convenient test patterns. -check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8) -check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4) -check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8) +def check_pattern(frag): + return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs) known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"] @@ -69,38 +258,35 @@ intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}" instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}" - for geom, abc, layout, space, stride, itype in product( - known_geoms, - "abc", + generated_items = [] + + for frag, layout, space, stride in product( + get_ldst_ops("load"), ["row","col"], ["",".shared",".global"], ["", ".stride"], - ["f16", "f32"]): + ): + if not is_ldst_variant_supported(frag, layout): + continue params = { - "abc" : abc, + "abc" : frag.frag, "aligned" : ".aligned" if ptx_version >= 63 else "", "layout" : layout, "space" : space, "stride" : stride, - "itype" : itype, + "itype" : frag.mma_type.ptx_type, "pspace" : get_pspace(space), "as" : "addrspace(%d)" % get_aspace(space), - "geom" : geom, + "geom" : frag.geom, } - if itype == "f32" and abc != "c": - continue - test_params = params test_params["intrinsic"] = Template(intrinsic_template).substitute(params) test_params["function"] = test_params["intrinsic"].replace(".","_") test_params["instruction"] = Template(instruction_template).substitute(params) - test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) - if abc == "c" : - test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8 - else: - test_params["check_result"] = check_f16_8 + test_params["ret_ty"] = make_wmma_ld_ret_ty(frag) + test_params["check_result"] = check_pattern(frag) if stride: test_params["extra_args"] = ", i32 %stride"; @@ -111,9 +297,14 @@ print(Template(load_template).substitute(test_params)) -def make_wmma_slice_args(itype, abcd, prefix="v"): - return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t - in enumerate(make_wmma_slice_ty(abcd, itype))]) + generated_items.append((test_params["intrinsic"], + test_params["instruction"])) + + return generated_items + +def make_wmma_slice_args(frag): + return ", ".join(["%s %%%s%d" % (t, frag.frag, i) for i,t + in enumerate(make_wmma_slice_ty(frag))]) def gen_wmma_store_tests(): store_template = """ @@ -141,41 +332,64 @@ intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}" instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}" - for geom, abc, layout, space, stride, itype in product( - known_geoms, - "d", + generated_items = [] + + for frag, layout, space, stride in product( + get_ldst_ops("store"), ["row","col"], ["",".shared",".global"], - ["", ".stride"], - ["f16", "f32"]): + ["", ".stride"]): + + if not is_ldst_variant_supported(frag, layout): + continue params = { - "abc" : abc, + "abc" : frag.frag, "aligned" : ".aligned" if ptx_version >= 63 else "", "layout" : layout, "space" : space, "stride" : stride, - "itype" : itype, + "itype" : frag.mma_type.ptx_type, "pspace" : get_pspace(space), "as" : "addrspace(%d)" % get_aspace(space), - "geom" : geom, + "geom" : frag.geom, } test_params = params test_params["intrinsic"] = Template(intrinsic_template).substitute(params) test_params["function"] = test_params["intrinsic"].replace(".","_") test_params["instruction"] = Template(instruction_template).substitute(params) - test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) - test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8 + test_params["ret_ty"] = make_wmma_ld_ret_ty(frag) + test_params["check_args"] = check_pattern(frag) if stride: test_params["extra_args"] = ", i32 %stride"; test_params["stride_pattern"] = ", %r{{[0-9]+}};" else: test_params["extra_args"] = "" test_params["stride_pattern"] = ";" - test_params["args"] = make_wmma_slice_args(itype, "d"); + test_params["args"] = make_wmma_slice_args(frag); print(Template(store_template).substitute(test_params)) + generated_items.append((test_params["intrinsic"], + test_params["instruction"])) + + return generated_items + +def mma_signature(op): + if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]: + # int and sub-int ops are identified by input type. + return op.a.mma_type.ptx_type + else: + # the rest are FP ops identified by accumulator & result type. + return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) + +def mma_ptx_signature(op): + if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]: + # int and sub-int instructions encode all four types as D.A.B.C + return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c)) + else: + # the rest are FP instructions use D.C + return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) def gen_wmma_mma_tests(): mma_template = """ @@ -187,58 +401,129 @@ ${args}) { ; CHECK: ${instruction} ; CHECK-NEXT: ${check_d} -; CHECK-NEXT: ${check_ab} -; CHECK-NEXT: ${check_ab} +; CHECK-NEXT: ${check_a} +; CHECK-NEXT: ${check_b} ; CHECK-NEXT: ${check_c} %r = call ${ret_ty} @${intrinsic}( ${args}); ret ${ret_ty} %r; } """ - intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}" - instruction_template = "wmma.mma.sync${aligned}.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}" + intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" + instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}" - for geom, alayout, blayout, ctype, dtype, satf in product( - known_geoms, + generated_items=[] + + for op, alayout, blayout, satf in product( + get_mma_ops(), ["row","col"], ["row","col"], - ["f16", "f32"], - ["f16", "f32"], [".satfinite", ""]): + if not is_mma_variant_supported(op, alayout, blayout, satf): + continue + params = { "aligned" : ".aligned" if ptx_version >= 63 else "", "alayout" : alayout, "blayout" : blayout, - "ctype" : ctype, - "dtype" : dtype, + "intrinsic_signature" : mma_signature(op), + "ptx_signature" : mma_ptx_signature(op), "satf" : satf, - "geom" : geom, + "geom" : op.a.geom, + "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "", } test_params = params test_params["intrinsic"] = Template(intrinsic_template).substitute(params) test_params["function"] = test_params["intrinsic"].replace(".", "_") test_params["instruction"] = Template(instruction_template).substitute(params) - test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype) - test_params["check_ab"] = check_f16_8 - test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8 - test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8 - args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd) - for abcd, t in (("a", "f16"), - ("b", "f16"), - ("c", ctype))) + test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d) + test_params["check_a"] = check_pattern(op.a) + test_params["check_b"] = check_pattern(op.b) + test_params["check_c"] = check_pattern(op.c) + test_params["check_d"] = check_pattern(op.d) + args = ",\n ".join(make_wmma_slice_args(frag) + for frag in (op.a, op.b, op.c)) test_params["args"] = args print(Template(mma_template).substitute(test_params)) + generated_items.append((test_params["intrinsic"], + test_params["instruction"])) -def main(): - gen_wmma_load_tests() - gen_wmma_store_tests() - gen_wmma_mma_tests() + return generated_items + +# Append complete list of intrinsics and instructions we've generated tests for. +# Generate set of checks to verify that that we did generate sensible set of +# tests for the given combination of PTX and SM variants. +# +# PTX: verifies that we did generate tests for correct classes of intrinsics. +# PTXU: verifies that we did not generate intrinsics unsupported by +# the PTX version. +# SM: verifies that we did generate correct classes of instructions for the SM. +# SMU: verifies that we did not generate instructions unsupported by the SM +# +# Note that SM/PTX constraints overlap, but DAG checks do not allow overlapping +# matches. We implicitly rely that we generate multiple variants of most of the +# instructions and usually have enough input data to find more than one match of +# the same kind, if necessary. When it's not possible (e.g. there's only one +# m8n8k128.mma.row.col.b1), we may need to match PTX instruction instead. +def gen_check_unsupported_ops(items): + print("; Complete list of intrinsics supported by PTX%d on sm_%d" + % (ptx_version, gpu_arch)) + print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}") + print(""" +; PTX60-DAG: m16n16k16.load.{{[ab].*}}.f16.p +; PTX60-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p +; PTX60U-NOT: m32n8k16 +; PTX60U-NOT: m8n32k16 +; PTX60U-NOT: .{{s32|s[48]|u[48]|b1}} + +; All features of PTX60, plus m32n8k16/m8n32k16 geometries. +; PTX61-DAG: m32n8k16.load.{{[ab].*}}.f16.p +; PTX61-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p +; PTX61-DAG: m8n32k16.load.{{[ab].*}}.f16.p +; PTX61-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p +; PTX61U-NOT: .{{s32|s[48]|u[48]|b1}} + +; SM70U-NOT: .{{s32|s[48]|u[48]|b1}} + +; PTX63 supports all features of PTX60+PTX61, plus support for integers. +; Alas we can"t just use PTX checks for that as available instructions +; depend on SM integers need sm72+ and subinteger ops need sm75, so we +; transition to SM checks +; SM72-DAG: m16n16k16.load.{{[ab].*}}.s8.p +; SM72-DAG: m8n32k16.load.{{[ab].*}}.s8.p +; SM72-DAG: m32n8k16.load.{{[ab].*}}.s8.p +; SM72-DAG: m16n16k16.load.{{[ab].*}}.u8.p +; SM72-DAG: m8n32k16.load.{{[ab].*}}.u8.p +; SM72-DAG: m32n8k16.load.{{[ab].*}}.u8.p +; SM72-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p +; SM72U-NOT: .{{s4|u4|b1}} + +; SM75-DAG: m8n8k128.load.{{[ab].*}}.b1.p +; SM75-DAG: m8n8k32.load.{{[ab].*}}.s4.p +; SM75-DAG: m8n8k32.load.{{[ab].*}}.u4.p +; SM75-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p +; SM75-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p +""") + + print("; INTRINSICS_LIST_BEGIN") + for intrinsic, instruction in sorted(items): + print("; ", intrinsic, " -> ", instruction,"") + print("; INTRINSICS_LIST_END") + print("; INTRINSICS: ; INTRINSICS_LIST_END") + +def gen_tests(): + items = gen_wmma_load_tests() + items += gen_wmma_store_tests() + items += gen_wmma_mma_tests() + gen_check_unsupported_ops(items) parser = argparse.ArgumentParser() -parser.add_argument('--ptx', type=int, default=60) +parser.add_argument("--ptx", type=int, default=60) +parser.add_argument("--gpu-arch", type=int, default=70) args = parser.parse_args() ptx_version = args.ptx +gpu_arch = args.gpu_arch -main() +gen_tests()