diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -53,6 +53,10 @@ string gft = Geom#":"#Frag#":"#ptx_elt_type; string ft = frag#":"#ptx_elt_type; list regs = !cond( + // mma.sync.m8n8k4 uses smaller a/b fragments than wmma fp ops + !eq(gft,"m8n8k4:a:f16") : RepLLVMType<2, llvm_v2f16_ty>.ret, + !eq(gft,"m8n8k4:b:f16") : RepLLVMType<2, llvm_v2f16_ty>.ret, + // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 // All currently supported geometries use the same fragment format, // so we only need to consider {fragment, type}. @@ -137,13 +141,19 @@ class WMMA_NAME_MMA { string signature = MMA_SIGNATURE.ret; - string llvm = "llvm.nvvm.wmma." - # A.geom - # ".mma" - # "." # ALayout - # "." # BLayout - # signature - # !if(Satfinite, ".satfinite", ""); + string llvm = !if( + !eq(A.geom, "m8n8k4"), + "llvm.nvvm.mma.m8n8k4" + # "." # ALayout + # "." # BLayout + # signature, + "llvm.nvvm.wmma." + # A.geom + # ".mma" + # "." # ALayout + # "." # BLayout + # signature + # !if(Satfinite, ".satfinite", "")); string record = !subst(".", "_", !subst("llvm.", "int_", llvm)); @@ -160,7 +170,7 @@ !foldl([]>, TypeA, t2, type_a, !listconcat(t2, !foldl([]>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3, !foldl([]>, TypeC, t4, type_c, !listconcat(t4, - !foldl([]>, !if(!size(TypeC), TypeC, [type_c]), t5, type_d, !listconcat(t5, + !foldl([]>, !if(!size(TypeD), TypeD, [type_c]), t5, type_d, !listconcat(t5, [[WMMA_REGS, WMMA_REGS, WMMA_REGS, @@ -185,19 +195,23 @@ // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { list> fp_mma_ops = MMA_OPS< + ["m8n8k4"], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; + list> fp_wmma_ops = MMA_OPS< ["m16n16k16", "m32n8k16", "m8n32k16"], ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; - list> int_mma_ops = MMA_OPS< + list> int_wmma_ops = MMA_OPS< ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []>.ret; - list> subint_mma_ops = MMA_OPS< + list> subint_wmma_ops = MMA_OPS< ["m8n8k32"], ["s4", "u4"], [], ["s32"], []>.ret; - list> bit_mma_ops = MMA_OPS< + list> bit_wmma_ops = MMA_OPS< ["m8n8k128"], ["b1"], [], ["s32"], []>.ret; - list> all_mma_ops = !listconcat(fp_mma_ops, int_mma_ops, - subint_mma_ops, bit_mma_ops); + list> all_mma_ops = !listconcat( + fp_mma_ops, fp_wmma_ops, int_wmma_ops, + subint_wmma_ops, bit_wmma_ops); list ldst_ab_ops = MMA_LDST_OPS< ["m16n16k16", "m32n8k16", "m8n32k16"], @@ -245,10 +259,25 @@ # ":" # frags[0].frag ; string t = frags[0].ptx_elt_type; + + // gcd is a shortcut used to identify instructions that depend on + // geom+frag_c+frag_d. Not all instances of this class have all fragments + // specified. If there are not enough fragments, the tail evaluates to '?'. + string gcd = frags[0].geom + # ":" + # !if(!eq(!size(frags), 4), + frags[2].ptx_elt_type # frags[3].ptx_elt_type, + "?"); list ret = !cond( // Sub-int MMA only supports fixed A/B layout. // b1 does not support .satf. !eq(mma#":"#satf, "b1:row:col:0") : [1], + // mma.m8n8k4 has no .satf modifier. + !and(!eq(frags[0].geom, "m8n8k4"), + !ne(satf, 0)): [], + + // mma.m8n8k4 has no C=f32 D=f16 variant. + !eq(gcd, "m8n8k4:f32f16"): [], !eq(mma, "s4:row:col") : [1], !eq(mma, "u4:row:col") : [1], !eq(mma, "s4:row:col") : [1], @@ -4094,7 +4123,7 @@ [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>], WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>; -// Create all load/store variants +// Create all load/store variants foreach layout = ["row", "col"] in { foreach stride = [0, 1] in { foreach frag = NVVM_MMA_OPS.all_ld_ops in diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7400,7 +7400,9 @@ // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) !or(!eq(geom,"m8n8k128"), - !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63]); + !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63], + + !eq(geom, "m8n8k4") : [hasSM70, hasPTX64]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -7546,25 +7548,37 @@ let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); string TypeList = !cond( + !eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type + # ".f16.f16." + # FragC.ptx_elt_type, !eq(FragD.ptx_elt_type, "s32") : ".s32" # "." # FragA.ptx_elt_type # "." # FragB.ptx_elt_type # ".s32", 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type, ); - let AsmString = "wmma.mma" - # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") - # ".sync" - # "${ptx:aligned}" - # "." # ALayout - # "." # BLayout - # "." # FragA.geom - # TypeList - # !if(Satfinite, ".satfinite", "") # "\n\t\t" - # FragD.regstring # ",\n\t\t" - # FragA.regstring # ",\n\t\t" - # FragB.regstring # ",\n\t\t" - # FragC.regstring # ";"; + let AsmString = !if(!eq(FragA.geom, "m8n8k4"), + "mma.sync.aligned.m8n8k4" + # "." # ALayout + # "." # BLayout + # TypeList # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";", + "wmma.mma" + # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") + # ".sync" + # "${ptx:aligned}" + # "." # ALayout + # "." # BLayout + # "." # FragA.geom + # TypeList + # !if(Satfinite, ".satfinite", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"); } defset list MMAs = { diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -4,39 +4,47 @@ # Check all variants of instructions supported by PTX60 on SM70 # RUN: python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,PTX60,SM70 +# RUN: --check-prefixes=INTRINSICS,M16N16 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,PTX60U,SM70U +# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \ # RUN: | FileCheck %t-ptx60-sm_70.ll # Check all variants of instructions supported by PTX61 on SM70 # RUN: python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,SM70 +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,PTX61U,SM70U +# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \ # RUN: | FileCheck %t-ptx61-sm_70.ll # Check all variants of instructions supported by PTX63 on SM72 # RUN: python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \ -# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72 +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \ -# RUN: --check-prefixes=INTRINSICS,PTX63U,SM72U +# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \ # RUN: | FileCheck %t-ptx63-sm_72.ll # Check all variants of instructions supported by PTX63 on SM75 # RUN: python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \ -# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72,SM75 +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \ -# RUN: --check-prefixes=INTRINSICS,PTX63U,SM75U +# RUN: --check-prefixes=INTRINSICS,NOMMA # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \ # RUN: | FileCheck %t-ptx63-sm_75.ll +# Check all variants of instructions supported by PTX64 on SM70+ +# RUN: python %s --ptx=64 --gpu-arch=70 > %t-ptx64-sm_70.ll +# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \ +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA +# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \ +# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT +# RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \ +# RUN: | FileCheck %t-ptx64-sm_70.ll from __future__ import print_function @@ -70,10 +78,11 @@ def __init__(self, geom, frag, ptx_elt_type): self.geom = geom self.frag = frag + self.is_mma = True if geom == "m8n8k4" else False; self.mma_type = MMAType(ptx_elt_type); self.nregs = { - "a:f16" : 8, - "b:f16" : 8, + "a:f16" : 2 if self.is_mma else 8, + "b:f16" : 2 if self.is_mma else 8, "c:f16" : 4, "d:f16" : 4, "c:f32" : 8, @@ -145,7 +154,9 @@ in product(geoms, frags, types)] def get_mma_ops(): - return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], + return (make_mma_ops(["m8n8k4"], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []) + @@ -165,6 +176,8 @@ def is_geom_supported(geom): # geometries for FP and ints. + if geom == "m8n8k4": + return ptx_version >= 64 if geom in ["m8n32k16", "m32n8k16"]: return ptx_version >= 61 # geometries for sub-ints. @@ -186,6 +199,13 @@ if not (is_type_supported(op.a.mma_type.ptx_type) and is_geom_supported(op.a.geom)): return False + if op.a.geom == "m8n8k4": + if satf: + return False + if op.c.mma_type.ptx_type == "f32": + # If C is f32, D must be, too. + return op.d.mma_type.ptx_type == "f32" + # sub-integer require row/col layout, and no satf. if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]: if op.a.mma_type.ptx_type == "b1" and satf: @@ -232,8 +252,6 @@ def check_pattern(frag): return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs) -known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"] - def gen_wmma_load_tests(): load_template = """ declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); @@ -389,6 +407,8 @@ if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]: # int and sub-int instructions encode all four types as D.A.B.C return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c)) + if op.a.geom == "m8n8k4": + return "%s.f16.f16.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) else: # the rest are FP instructions use D.C return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) @@ -411,8 +431,10 @@ ret ${ret_ty} %r; } """ - intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" - instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}" + wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" + wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}" + mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}.${intrinsic_signature}" + mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}.${ptx_signature}" generated_items=[] @@ -436,6 +458,13 @@ "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "", } + if op.a.geom == "m8n8k4": + intrinsic_template = mma_intrinsic_template + instruction_template = mma_instruction_template + else: + intrinsic_template = wmma_intrinsic_template + instruction_template = wmma_instruction_template + test_params = params test_params["intrinsic"] = Template(intrinsic_template).substitute(params) test_params["function"] = test_params["intrinsic"].replace(".", "_") @@ -458,55 +487,68 @@ # Generate set of checks to verify that that we did generate sensible set of # tests for the given combination of PTX and SM variants. # -# PTX: verifies that we did generate tests for correct classes of intrinsics. -# PTXU: verifies that we did not generate intrinsics unsupported by -# the PTX version. -# SM: verifies that we did generate correct classes of instructions for the SM. -# SMU: verifies that we did not generate instructions unsupported by the SM -# -# Note that SM/PTX constraints overlap, but DAG checks do not allow overlapping -# matches. We implicitly rely that we generate multiple variants of most of the -# instructions and usually have enough input data to find more than one match of -# the same kind, if necessary. When it's not possible (e.g. there's only one -# m8n8k128.mma.row.col.b1), we may need to match PTX instruction instead. def gen_check_unsupported_ops(items): print("; Complete list of intrinsics supported by PTX%d on sm_%d" % (ptx_version, gpu_arch)) print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}") print(""" -; PTX60-DAG: m16n16k16.load.{{[ab].*}}.f16.p -; PTX60-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p -; PTX60U-NOT: m32n8k16 -; PTX60U-NOT: m8n32k16 -; PTX60U-NOT: .{{s32|s[48]|u[48]|b1}} - -; All features of PTX60, plus m32n8k16/m8n32k16 geometries. -; PTX61-DAG: m32n8k16.load.{{[ab].*}}.f16.p -; PTX61-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p -; PTX61-DAG: m8n32k16.load.{{[ab].*}}.f16.p -; PTX61-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p -; PTX61U-NOT: .{{s32|s[48]|u[48]|b1}} - -; SM70U-NOT: .{{s32|s[48]|u[48]|b1}} - -; PTX63 supports all features of PTX60+PTX61, plus support for integers. -; Alas we can"t just use PTX checks for that as available instructions -; depend on SM integers need sm72+ and subinteger ops need sm75, so we -; transition to SM checks -; SM72-DAG: m16n16k16.load.{{[ab].*}}.s8.p -; SM72-DAG: m8n32k16.load.{{[ab].*}}.s8.p -; SM72-DAG: m32n8k16.load.{{[ab].*}}.s8.p -; SM72-DAG: m16n16k16.load.{{[ab].*}}.u8.p -; SM72-DAG: m8n32k16.load.{{[ab].*}}.u8.p -; SM72-DAG: m32n8k16.load.{{[ab].*}}.u8.p -; SM72-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p -; SM72U-NOT: .{{s4|u4|b1}} - -; SM75-DAG: m8n8k128.load.{{[ab].*}}.b1.p -; SM75-DAG: m8n8k32.load.{{[ab].*}}.s4.p -; SM75-DAG: m8n8k32.load.{{[ab].*}}.u4.p -; SM75-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p -; SM75-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p + +; NOEXTGEOM-NOT: {{m8n32|m32n8}} +; NOINT-NOT: .{{s32|s8}} +; NOSUBINT-NOT: {{s4|u4|b1}} +; NOMMA-NOT: .m8n8k4. + +; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p +; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p +; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32 +; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16 +; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16 +; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32 + +; PTX60 adds support for m32n8k16/m8n32k16 geometries. +; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p +; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p +; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32 +; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16 +; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16 +; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32 + +; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p +; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p +; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32 +; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16 +; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16 +; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32 + +; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p +; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p +; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p +; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p +; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p +; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p +; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p +; INT-DAG: m16n16k16.mma.{{.*}}.u8 +; INT-DAG: m16n16k16.mma.{{.*}}.s8 +; INT-DAG: m8n32k16.mma.{{.*}}.u8 +; INT-DAG: m8n32k16.mma.{{.*}}.s8 +; INT-DAG: m32n8k16.mma.{{.*}}.u8 +; INT-DAG: m32n8k16.mma.{{.*}}.s8 + +; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p +; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p +; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p +; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p +; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p +; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4 +; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4 +; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1 + +; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32 +; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16 +; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16 +; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32 +; + """) print("; INTRINSICS_LIST_BEGIN")