diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def --- a/clang/include/clang/Basic/BuiltinsNVPTX.def +++ b/clang/include/clang/Basic/BuiltinsNVPTX.def @@ -724,6 +724,7 @@ TARGET_BUILTIN(__bmma_m8n8k128_ld_a_b1, "vi*iC*UiIi", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__bmma_m8n8k128_ld_b_b1, "vi*iC*UiIi", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__bmma_m8n8k128_ld_c, "vi*iC*UiIi", "", AND(SM_75,PTX63)) +TARGET_BUILTIN(__bmma_m8n8k128_mma_and_popc_b1, "vi*iC*iC*iC*Ii", "", AND(SM_75,PTX71)) TARGET_BUILTIN(__bmma_m8n8k128_mma_xor_popc_b1, "vi*iC*iC*iC*Ii", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__bmma_m8n8k128_st_c_i32, "vi*iC*UiIi", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__imma_m16n16k16_ld_a_s8, "vi*iC*UiIi", "", AND(SM_72,PTX63)) diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -16630,9 +16630,18 @@ 0, \ 0 // b1 MMA does not support .satfinite. -#define MMA_VARIANTS_B1(geom, type) \ +#define MMA_VARIANTS_B1_XOR(geom, type) \ 0, \ - Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_xor_popc_row_col_##type, \ + 0, \ + 0, \ + 0, \ + 0, \ + 0, \ + 0 +#define MMA_VARIANTS_B1_AND(geom, type) \ + 0, \ + Intrinsic::nvvm_wmma_##geom##_mma_and_popc_row_col_##type, \ 0, \ 0, \ 0, \ @@ -16689,7 +16698,9 @@ case NVPTX::BI__imma_m8n8k32_mma_u4: return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, u4)}}}; case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: - return {1, 1, 2, 2, {{MMA_VARIANTS_B1(m8n8k128, b1)}}}; + return {1, 1, 2, 2, {{MMA_VARIANTS_B1_XOR(m8n8k128, b1)}}}; + case NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1: + return {1, 1, 2, 2, {{MMA_VARIANTS_B1_AND(m8n8k128, b1)}}}; // Double MMA case NVPTX::BI__dmma_m8n8k4_mma_f64: @@ -16710,7 +16721,8 @@ #undef MMA_VARIANTS #undef MMA_SATF_VARIANTS #undef MMA_VARIANTS_I4 -#undef MMA_VARIANTS_B1 +#undef MMA_VARIANTS_B1_AND +#undef MMA_VARIANTS_B1_XOR } } // namespace @@ -17119,6 +17131,7 @@ case NVPTX::BI__imma_m8n8k32_mma_s4: case NVPTX::BI__imma_m8n8k32_mma_u4: case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: + case NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1: case NVPTX::BI__dmma_m8n8k4_mma_f64: case NVPTX::BI__mma_bf16_m16n16k16_mma_f32: case NVPTX::BI__mma_bf16_m8n32k16_mma_f32: @@ -17136,7 +17149,8 @@ if (Layout < 0 || Layout > 3) return nullptr; llvm::APSInt SatfArg; - if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1) + if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1 || + BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1) SatfArg = 0; // .b1 does not have satf argument. else if (Optional OptSatfArg = E->getArg(5)->getIntegerConstantExpr(getContext())) diff --git a/clang/test/CodeGen/builtins-nvptx-mma.cu b/clang/test/CodeGen/builtins-nvptx-mma.cu --- a/clang/test/CodeGen/builtins-nvptx-mma.cu +++ b/clang/test/CodeGen/builtins-nvptx-mma.cu @@ -3,20 +3,21 @@ // *** DO NOT EDIT *** // // This test has been automatically generated by -// builtins-nvtx-mma.py --ptx=70 --gpu-arch=80 +// builtins-nvtx-mma.py --ptx=71 --gpu-arch=80 // -// Make sure we can handle all builtins available on sm_80 with PTX70 +// Make sure we can handle all builtins available on sm_80 with PTX71 // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_80 \ -// RUN: -fcuda-is-device -target-feature +ptx70 \ -// RUN: -DPTX=70 -DSM=80 \ +// RUN: -fcuda-is-device -target-feature +ptx71 \ +// RUN: -DPTX=71 -DSM=80 \ // RUN: -S -emit-llvm -o - -x cuda %s \ -// RUN: | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75 %s +// RUN: | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75,CHECK_PTX71_SM75 %s // Verify that all builtins have correct constraints. // RUN: %clang_cc1 -triple nvptx-unknown-unknown \ // RUN: -target-cpu sm_60 -target-feature +ptx42 \ -// RUN: -DPTX=70 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \ +// RUN: -DPTX=71 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \ // RUN: -verify %s + #if !defined(CUDA_VERSION) #define __device__ __attribute__((device)) #define __global__ __attribute__((global)) @@ -31,6 +32,7 @@ float *fsrc, float *fdst, double *dsrc, double *ddst, int ldm) { + #if (PTX >= 60) && (SM >= 70) // CHECK_PTX60_SM70: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16 @@ -735,7 +737,7 @@ // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.store.d.row.stride.s32 // expected-error-re@+1 {{'__imma_m8n8k32_st_c_i32' needs target feature (sm_75{{.*}},(ptx63{{.*}}}} __imma_m8n8k32_st_c_i32(dst, src, ldm, 0); - // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.row.col.b1 + // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1 // expected-error-re@+1 {{'__bmma_m8n8k128_mma_xor_popc_b1' needs target feature (sm_75{{.*}},(ptx63{{.*}}}} __bmma_m8n8k128_mma_xor_popc_b1(dst, src, src, src, 1); // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.s4 @@ -750,7 +752,7 @@ // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.u4.satfinite // expected-error-re@+1 {{'__imma_m8n8k32_mma_u4' needs target feature (sm_75{{.*}},(ptx63{{.*}}}} __imma_m8n8k32_mma_u4(dst, src, src, src, 1, 1); -#endif // (PTX >= 63) && (SM >= 75) +#endif // (PTX >= 63) && (SM >= 75) #if (PTX >= 70) && (SM >= 80) @@ -898,5 +900,12 @@ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64 // expected-error-re@+1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 0, 0); -#endif // (PTX >= 70) && (SM >= 80) +#endif // (PTX >= 70) && (SM >= 80) + +#if (PTX >= 71) && (SM >= 75) + + // CHECK_PTX71_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1 + // expected-error-re@+1 {{'__bmma_m8n8k128_mma_and_popc_b1' needs target feature (sm_75{{.*}},(ptx71{{.*}}}} + __bmma_m8n8k128_mma_and_popc_b1(dst, src, src, src, 1); +#endif // (PTX >= 71) && (SM >= 75) } diff --git a/clang/test/CodeGen/builtins-nvptx-mma.py b/clang/test/CodeGen/builtins-nvptx-mma.py --- a/clang/test/CodeGen/builtins-nvptx-mma.py +++ b/clang/test/CodeGen/builtins-nvptx-mma.py @@ -22,24 +22,29 @@ return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type) class MMAOp: - def __init__(self, a, b, c, d): + def __init__(self, a, b, c, d, b1op=""): self.a = a self.b = b self.c = c self.d = d + self.b1op = b1op def __repr__(self): return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d )) -def make_mma_ops(geoms, types_a, types_b, types_c, types_d): +def make_mma_ops(geoms, types_a, types_b, types_c, types_d, b1ops=None): ops = [] + if b1ops is None: + b1ops = [""] for geom, type_a, type_c in product( geoms, types_a, types_c): for type_b, type_d in product(types_b if types_b else [type_a], types_d if types_d else [type_c]): - ops.append(MMAOp(MMAFrag(geom, "a", type_a), - MMAFrag(geom, "b", type_b), - MMAFrag(geom, "c", type_c), - MMAFrag(geom, "d", type_d))) + ops += [ + MMAOp(MMAFrag(geom, "a", type_a), + MMAFrag(geom, "b", type_b), + MMAFrag(geom, "c", type_c), + MMAFrag(geom, "d", type_d), b1op) + for b1op in b1ops] return ops def make_ldst_ops(geoms, frags, types): @@ -60,9 +65,12 @@ make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], []) + make_mma_ops(["m8n8k128"], - ["b1"], [], ["s32"], [])) + ["b1"], [], ["s32"], [], + [".xor.popc", ".and.popc"])) def get_ldst_ops(): + # NOTE: fragemts are from the point of view of PTX. + # fragment `d` is only for store ops, others for both loads and stores. return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["a", "b"], ["f16", "u8", "s8", "bf16"]) + make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], @@ -71,8 +79,11 @@ make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) + - make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) + - make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])) + # TF32 m16n16k8 is odd. + # For fragment 'C' it uses __mma_*tf32*_m16n16k8_ld_c + # but 'D' calls __mma_m16n16k8_st_c_*f32*. + make_ldst_ops(["m16n16k8"], ["a", "b", "c"], ["tf32"]) + + make_ldst_ops(["m16n16k8"], ["d"], ["f32"])) def is_geom_supported(geom): # geometries for FP and ints. @@ -180,15 +191,19 @@ else: suffix = op.a.ptx_type - name = "%s_%s_mma%s_%s" % (prefix, op.a.geom, - "_xor_popc" if op.a.ptx_type == "b1" else "", - suffix) + name = "{prefix}_{geom}_mma{b1op}_{suffix}".format( + prefix = prefix, + geom = op.a.geom, + b1op = op.b1op.replace(".","_"), + suffix = suffix) return name -def get_required_sm(frag): +def get_required_sm(frag, b1op=""): if frag.ptx_type in ["f64", "bf16", "tf32"]: return 80 if frag.ptx_type in ["u4", "s4", "b1"]: + if b1op == "_and_popc": + return 80 return 75 if frag.ptx_type in ["s8", "u8"]: return 72 @@ -204,7 +219,9 @@ return 70 assert(False) -def get_required_ptx(frag): +def get_required_ptx(frag, b1op=""): + if frag.ptx_type == "b1" and b1op == ".and.popc": + return 71 if frag.ptx_type in ["f64", "bf16", "tf32"]: return 70 if frag.ptx_type in ["f16", "f32"]: @@ -215,11 +232,13 @@ return 61 return 63 -def get_src_dst_prefix(ptx_type): - if ptx_type == "f32": +def get_src_dst_prefix(frag): + if frag.ptx_type == "f32": return "f" - if ptx_type == "f64": + if frag.ptx_type == "f64": return "d" + if frag.ptx_type == "tf32" and frag.frag in ["c", "d"]: + return "f" return "" def gen_wmma_ldst_tests(results): @@ -235,9 +254,17 @@ if not is_ldst_variant_supported(frag, layout): continue - src_dst_prefix = get_src_dst_prefix(frag.ptx_type) + src_dst_prefix = get_src_dst_prefix(frag) + min_sm = get_required_sm(frag) min_ptx = get_required_ptx(frag) + # TF32 uses f32 for accumulator loads. + if frag.geom == "m16n16k8" and frag.frag =="c": + assert frag.ptx_type == "tf32" + itype = "f32" + else: + itype = frag.ptx_type + params = { "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm), "builtin" : get_ldst_builtin_name(frag), @@ -250,7 +277,7 @@ "frag" : frag.frag, "geom" : frag.geom, "ilayout" : layout, - "itype" : frag.ptx_type, + "itype" : itype, "op" : "store" if frag.frag == "d" else "load", }) } @@ -283,7 +310,7 @@ // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}} ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf}); """.rstrip() - intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" + intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}.${intrinsic_signature}${satf}" for op, alayout, blayout, satf in sorted(product( get_mma_ops(), ["row","col"], @@ -294,15 +321,15 @@ if not is_mma_variant_supported(op, alayout, blayout, satf): continue - asrc_prefix = get_src_dst_prefix(op.a.ptx_type) - csrc_prefix = get_src_dst_prefix(op.c.ptx_type) - ddst_prefix = get_src_dst_prefix(op.d.ptx_type) - min_sm = get_required_sm(op.a) - min_ptx = get_required_ptx(op.a) + asrc_prefix = get_src_dst_prefix(op.a) + csrc_prefix = get_src_dst_prefix(op.c) + ddst_prefix = get_src_dst_prefix(op.d) if op.a.ptx_type == "b1": # .b1 MMA has no satf argument. isatf_arg = "" else: isatf_arg = ", 1" if satf else ", 0" + min_sm = get_required_sm(op.a, op.b1op) + min_ptx = get_required_ptx(op.a, op.b1op) params = { "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm), "builtin" : get_mma_builtin_name(op), @@ -319,6 +346,7 @@ "blayout" : blayout, "intrinsic_signature" : mma_signature(op), "satf" : satf, + "b1op" : op.b1op }) } results[(min_ptx, min_sm)] += Template(mma_template).substitute(params) diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -225,12 +225,13 @@ string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type)); } -class WMMA_NAME { string signature = MMA_SIGNATURE.ret; string llvm = "llvm.nvvm.wmma." # A.geom # ".mma" + # b1op # "." # ALayout # "." # BLayout # !if(!ne(Rnd, ""), !strconcat(".", Rnd), "") @@ -241,11 +242,12 @@ !subst("llvm.", "int_", llvm)); } -class MMA_NAME { string signature = MMA_SIGNATURE.ret; - string llvm = "llvm.nvvm.mma." - # A.geom + string llvm = "llvm.nvvm.mma" + # b1op + # "." # A.geom # "." # ALayout # "." # BLayout # !if(Satfinite, ".satfinite", "") @@ -430,6 +432,13 @@ ); } +class NVVM_MMA_B1OPS frags> { + list ret = !cond( + !eq(frags[0].ptx_elt_type, "b1") : [".xor.popc", ".and.popc"], + true: [""] + ); +} + // Returns true if this combination of layout/satf for MMA ops is supported; // false otherwise. // E.g. @@ -4460,25 +4469,27 @@ } // WMMA.MMA -class NVVM_WMMA_MMA : Intrinsic.llvm>; + WMMA_NAME.llvm>; foreach layout_a = ["row", "col"] in { foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach rnd = ["", "rn", "rz", "rm", "rp"] in { foreach op = NVVM_MMA_OPS.all_wmma_ops in { - if NVVM_WMMA_SUPPORTED.ret then { - def WMMA_NAME.record - : NVVM_WMMA_MMA; - } + foreach b1op = NVVM_MMA_B1OPS.ret in { + if NVVM_WMMA_SUPPORTED.ret then { + def WMMA_NAME.record + : NVVM_WMMA_MMA; + } + } // b1op } // op } // rnd } // satf @@ -4486,21 +4497,23 @@ } // layout_a // MMA -class NVVM_MMA : Intrinsic.llvm>; + MMA_NAME.llvm>; foreach layout_a = ["row", "col"] in { foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { - if NVVM_MMA_SUPPORTED.ret then { - def MMA_NAME.record - : NVVM_MMA; - } + foreach b1op = NVVM_MMA_B1OPS.ret in { + if NVVM_MMA_SUPPORTED.ret then { + def MMA_NAME.record + : NVVM_MMA; + } + } // b1op } // op } // satf } // layout_b diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -146,6 +146,7 @@ def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">; def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">; def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">; +def hasPTX71 : Predicate<"Subtarget->getPTXVersion() >= 71">; def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">; def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7796,15 +7796,24 @@ } // layout } // defset +// B1 instruction variants need extra constraints. +class MMA_OP_PREDICATES { + string Op = b1op; + WMMA_REGINFO Frag = FragA; + list ret = !listconcat( + FragA.Predicates, + !if(!eq(b1op, ".and.popc"), [hasSM80,hasPTX71],[]) + ); +} // WMMA.MMA class WMMA_MMA - : WMMA_INSTR.record, + string ALayout, string BLayout, int Satfinite, string rnd, string b1op> + : WMMA_INSTR.record, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. - Requires { + Requires.ret> { let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); string TypeList = !cond( @@ -7816,7 +7825,7 @@ # "." # FragC.ptx_elt_type, ); let AsmString = "wmma.mma" - # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") + # b1op # ".sync" # "${ptx:aligned}" # "." # ALayout @@ -7837,13 +7846,15 @@ foreach satf = [0, 1] in { foreach rnd = ["", "rn", "rz", "rm", "rp"] in { foreach op = NVVM_MMA_OPS.all_wmma_ops in { - if NVVM_WMMA_SUPPORTED.ret then { - def : WMMA_MMA, - WMMA_REGINFO, - WMMA_REGINFO, - WMMA_REGINFO, - layout_a, layout_b, satf, rnd>; - } + foreach b1op = NVVM_MMA_B1OPS.ret in { + if NVVM_WMMA_SUPPORTED.ret then { + def : WMMA_MMA, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + layout_a, layout_b, satf, rnd, b1op>; + } + } // b1op } // op } // rnd } // satf @@ -7854,12 +7865,12 @@ // MMA class MMA - : WMMA_INSTR.record, + string ALayout, string BLayout, int Satfinite, string b1op> + : WMMA_INSTR.record, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. - Requires { + Requires.ret> { let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); string TypeList = "." # FragD.ptx_elt_type @@ -7872,7 +7883,7 @@ # "." # BLayout # !if(Satfinite, ".satfinite", "") # TypeList - # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t" + # b1op # "\n\t\t" # FragD.regstring # ",\n\t\t" # FragA.regstring # ",\n\t\t" # FragB.regstring # ",\n\t\t" @@ -7884,13 +7895,15 @@ foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { - if NVVM_MMA_SUPPORTED.ret then { - def : MMA, - WMMA_REGINFO, - WMMA_REGINFO, - WMMA_REGINFO, - layout_a, layout_b, satf>; - } + foreach b1op = NVVM_MMA_B1OPS.ret in { + if NVVM_MMA_SUPPORTED.ret then { + def : MMA, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + layout_a, layout_b, satf, b1op>; + } + } // b1op } // op } // satf } // layout_b diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -55,14 +55,14 @@ # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \ # RUN: | FileCheck %t-ptx65-sm_75.ll -# Check all variants of instructions supported by PTX70 on SM80+ -# RUN: %python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll -# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \ -# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA -# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \ +# Check all variants of instructions supported by PTX71 on SM80+ +# RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll +# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \ +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX71MMA +# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \ # RUN: --check-prefixes=INTRINSICS -# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \ -# RUN: | FileCheck %t-ptx70-sm_80.ll +# RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \ +# RUN: | FileCheck %t-ptx71-sm_80.ll from __future__ import print_function @@ -649,9 +649,16 @@ print(Template(mma_template).substitute(test_params)) return (test_params["intrinsic"], test_params["instruction"]) +def get_b1_ops(ptx_type): + if ptx_type != "b1": + return [""] + if ptx_version >= 71: + return [".xor.popc", ".and.popc"] + return [".xor.popc"] + def gen_wmma_mma_tests(): - wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}" - wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}" + wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}" + wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}" generated_items=[] @@ -665,29 +672,30 @@ if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf): continue - params = { - "aligned" : ".aligned" if ptx_version >= 63 else "", - "alayout" : alayout, - "blayout" : blayout, - "intrinsic_signature" : wmma_signature(op), - "ptx_signature" : wmma_ptx_signature(op), - "satf" : satf, - "rnd" : rnd, - "geom" : op.a.geom, - "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "", - } - - intrinsic_template = wmma_intrinsic_template - instruction_template = wmma_instruction_template - - generated_items.append(common_mma_test_gen(params, op, - intrinsic_template, instruction_template)) + for b1op in get_b1_ops(op.a.mma_type.ptx_type): + params = { + "aligned" : ".aligned" if ptx_version >= 63 else "", + "alayout" : alayout, + "blayout" : blayout, + "intrinsic_signature" : wmma_signature(op), + "ptx_signature" : wmma_ptx_signature(op), + "satf" : satf, + "rnd" : rnd, + "geom" : op.a.geom, + "b1op" : b1op + } + + intrinsic_template = wmma_intrinsic_template + instruction_template = wmma_instruction_template + + generated_items.append(common_mma_test_gen(params, op, + intrinsic_template, instruction_template)) return generated_items def gen_mma_tests(): - mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}" - mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}" + mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}" + mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}" generated_items=[] @@ -700,22 +708,23 @@ if not is_mma_variant_supported(op, alayout, blayout, satf): continue - params = { - "aligned" : ".aligned" if ptx_version >= 63 else "", - "alayout" : alayout, - "blayout" : blayout, - "intrinsic_signature" : mma_signature(op), - "ptx_signature" : mma_ptx_signature(op), - "satf" : satf, - "geom" : op.a.geom, - "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "", - } + for b1op in get_b1_ops(op.a.mma_type.ptx_type): + params = { + "aligned" : ".aligned" if ptx_version >= 63 else "", + "alayout" : alayout, + "blayout" : blayout, + "intrinsic_signature" : mma_signature(op), + "ptx_signature" : mma_ptx_signature(op), + "satf" : satf, + "geom" : op.a.geom, + "b1op" : b1op + } - intrinsic_template = mma_intrinsic_template - instruction_template = mma_instruction_template + intrinsic_template = mma_intrinsic_template + instruction_template = mma_instruction_template - generated_items.append(common_mma_test_gen(params, op, - intrinsic_template, instruction_template)) + generated_items.append(common_mma_test_gen(params, op, + intrinsic_template, instruction_template)) return generated_items @@ -810,32 +819,35 @@ ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4 -; PTX70MMA-DAG: mma.m8n8k4.row.col.f64 -; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32 -; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32 -; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16 -; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16 -; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16 -; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32 -; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8 -; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8 -; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8 -; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4 -; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4 -; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4 -; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4 -; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4 -; PTX70MMA-DAG: mma.m8n8k128.row.col.b1 -; PTX70MMA-DAG: mma.m16n8k128.row.col.b1 -; PTX70MMA-DAG: mma.m16n8k256.row.col.b1 +; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 +; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 +; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 +; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16 +; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16 +; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16 +; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32 +; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8 +; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8 +; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8 +; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4 +; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4 +; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4 +; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4 +; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4 +; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1 +; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1 +; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1 +; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1 +; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1 +; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1 ; """)