diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def --- a/clang/include/clang/Basic/BuiltinsNVPTX.def +++ b/clang/include/clang/Basic/BuiltinsNVPTX.def @@ -759,6 +759,29 @@ TARGET_BUILTIN(__imma_m8n8k32_mma_u4, "vi*iC*iC*iC*IiIi", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__imma_m8n8k32_st_c_i32, "vi*iC*UiIi", "", AND(SM_75,PTX63)) +// Builtins to support double and alternate float WMMA instructions on sm_80 +TARGET_BUILTIN(__dmma_m8n8k4_ld_a, "vd*dC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__dmma_m8n8k4_ld_b, "vd*dC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__dmma_m8n8k4_ld_c, "vd*dC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__dmma_m8n8k4_st_c_f64, "vd*dC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__dmma_m8n8k4_mma_f64, "vd*dC*dC*dC*IiIi", "", AND(SM_80,PTX70)) + +TARGET_BUILTIN(__mma_bf16_m16n16k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_bf16_m16n16k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_bf16_m16n16k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_bf16_m8n32k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_bf16_m8n32k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_bf16_m8n32k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_bf16_m32n8k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_bf16_m32n8k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_bf16_m32n8k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70)) + +TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_c, "vf*fC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_m16n16k8_st_c_f32, "vf*fC*UiIi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__mma_tf32_m16n16k8_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70)) + // Async Copy TARGET_BUILTIN(__nvvm_cp_async_mbarrier_arrive, "vWi*", "", AND(SM_80,PTX70)) TARGET_BUILTIN(__nvvm_cp_async_mbarrier_arrive_shared, "vWi*3", "", AND(SM_80,PTX70)) diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -16402,6 +16402,34 @@ case NVPTX::BI__bmma_m8n8k128_ld_c: return MMA_LDST(2, m8n8k128_load_c_s32); + // Double MMA loads + case NVPTX::BI__dmma_m8n8k4_ld_a: + return MMA_LDST(1, m8n8k4_load_a_f64); + case NVPTX::BI__dmma_m8n8k4_ld_b: + return MMA_LDST(1, m8n8k4_load_b_f64); + case NVPTX::BI__dmma_m8n8k4_ld_c: + return MMA_LDST(2, m8n8k4_load_c_f64); + + // Alternate float MMA loads + case NVPTX::BI__mma_bf16_m16n16k16_ld_a: + return MMA_LDST(4, m16n16k16_load_a_bf16); + case NVPTX::BI__mma_bf16_m16n16k16_ld_b: + return MMA_LDST(4, m16n16k16_load_b_bf16); + case NVPTX::BI__mma_bf16_m8n32k16_ld_a: + return MMA_LDST(2, m8n32k16_load_a_bf16); + case NVPTX::BI__mma_bf16_m8n32k16_ld_b: + return MMA_LDST(8, m8n32k16_load_b_bf16); + case NVPTX::BI__mma_bf16_m32n8k16_ld_a: + return MMA_LDST(8, m32n8k16_load_a_bf16); + case NVPTX::BI__mma_bf16_m32n8k16_ld_b: + return MMA_LDST(2, m32n8k16_load_b_bf16); + case NVPTX::BI__mma_tf32_m16n16k8_ld_a: + return MMA_LDST(4, m16n16k8_load_a_tf32); + case NVPTX::BI__mma_tf32_m16n16k8_ld_b: + return MMA_LDST(2, m16n16k8_load_b_tf32); + case NVPTX::BI__mma_tf32_m16n16k8_ld_c: + return MMA_LDST(8, m16n16k8_load_c_f32); + // NOTE: We need to follow inconsitent naming scheme used by NVCC. Unlike // PTX and LLVM IR where stores always use fragment D, NVCC builtins always // use fragment C for both loads and stores. @@ -16433,6 +16461,14 @@ case NVPTX::BI__bmma_m8n8k128_st_c_i32: return MMA_LDST(2, m8n8k128_store_d_s32); + // Double MMA store + case NVPTX::BI__dmma_m8n8k4_st_c_f64: + return MMA_LDST(2, m8n8k4_store_d_f64); + + // Alternate float MMA store + case NVPTX::BI__mma_m16n16k8_st_c_f32: + return MMA_LDST(8, m16n16k8_store_d_f32); + default: llvm_unreachable("Unknown MMA builtin"); } @@ -16446,10 +16482,14 @@ unsigned NumEltsB; unsigned NumEltsC; unsigned NumEltsD; + + // Variants are ordered by layout-A/layout-B/satf, where 'row' has priority + // over 'col' for layout. The index of non-satf variants is expected to match + // the undocumented layout constants used by CUDA's mma.hpp. std::array Variants; unsigned getMMAIntrinsic(int Layout, bool Satf) { - unsigned Index = Layout * 2 + Satf; + unsigned Index = Layout + 4 * Satf; if (Index >= Variants.size()) return 0; return Variants[Index]; @@ -16460,93 +16500,107 @@ // Layout and Satf, 0 otherwise. static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) { // clang-format off -#define MMA_VARIANTS(geom, type) {{ \ +#define MMA_VARIANTS(geom, type) \ Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \ - Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ - Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type +#define MMA_SATF_VARIANTS(geom, type) \ + MMA_VARIANTS(geom, type), \ + Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \ - Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \ - }} + Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite // Sub-integer MMA only supports row.col layout. -#define MMA_VARIANTS_I4(geom, type) {{ \ - 0, \ +#define MMA_VARIANTS_I4(geom, type) \ 0, \ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ - Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ 0, \ 0, \ 0, \ - 0 \ - }} -// b1 MMA does not support .satfinite. -#define MMA_VARIANTS_B1(geom, type) {{ \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ 0, \ + 0 +// b1 MMA does not support .satfinite. +#define MMA_VARIANTS_B1(geom, type) \ 0, \ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ 0, \ 0, \ 0, \ 0, \ - 0 \ - }} - // clang-format on - switch (BuiltinID) { - // FP MMA - // Note that 'type' argument of MMA_VARIANT uses D_C notation, while - // NumEltsN of return value are ordered as A,B,C,D. - case NVPTX::BI__hmma_m16n16k16_mma_f16f16: - return {8, 8, 4, 4, MMA_VARIANTS(m16n16k16, f16_f16)}; - case NVPTX::BI__hmma_m16n16k16_mma_f32f16: - return {8, 8, 4, 8, MMA_VARIANTS(m16n16k16, f32_f16)}; - case NVPTX::BI__hmma_m16n16k16_mma_f16f32: - return {8, 8, 8, 4, MMA_VARIANTS(m16n16k16, f16_f32)}; - case NVPTX::BI__hmma_m16n16k16_mma_f32f32: - return {8, 8, 8, 8, MMA_VARIANTS(m16n16k16, f32_f32)}; - case NVPTX::BI__hmma_m32n8k16_mma_f16f16: - return {8, 8, 4, 4, MMA_VARIANTS(m32n8k16, f16_f16)}; - case NVPTX::BI__hmma_m32n8k16_mma_f32f16: - return {8, 8, 4, 8, MMA_VARIANTS(m32n8k16, f32_f16)}; - case NVPTX::BI__hmma_m32n8k16_mma_f16f32: - return {8, 8, 8, 4, MMA_VARIANTS(m32n8k16, f16_f32)}; - case NVPTX::BI__hmma_m32n8k16_mma_f32f32: - return {8, 8, 8, 8, MMA_VARIANTS(m32n8k16, f32_f32)}; - case NVPTX::BI__hmma_m8n32k16_mma_f16f16: - return {8, 8, 4, 4, MMA_VARIANTS(m8n32k16, f16_f16)}; - case NVPTX::BI__hmma_m8n32k16_mma_f32f16: - return {8, 8, 4, 8, MMA_VARIANTS(m8n32k16, f32_f16)}; - case NVPTX::BI__hmma_m8n32k16_mma_f16f32: - return {8, 8, 8, 4, MMA_VARIANTS(m8n32k16, f16_f32)}; - case NVPTX::BI__hmma_m8n32k16_mma_f32f32: - return {8, 8, 8, 8, MMA_VARIANTS(m8n32k16, f32_f32)}; - - // Integer MMA - case NVPTX::BI__imma_m16n16k16_mma_s8: - return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, s8)}; - case NVPTX::BI__imma_m16n16k16_mma_u8: - return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, u8)}; - case NVPTX::BI__imma_m32n8k16_mma_s8: - return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, s8)}; - case NVPTX::BI__imma_m32n8k16_mma_u8: - return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, u8)}; - case NVPTX::BI__imma_m8n32k16_mma_s8: - return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, s8)}; - case NVPTX::BI__imma_m8n32k16_mma_u8: - return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, u8)}; - - // Sub-integer MMA - case NVPTX::BI__imma_m8n8k32_mma_s4: - return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, s4)}; - case NVPTX::BI__imma_m8n8k32_mma_u4: - return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, u4)}; - case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: - return {1, 1, 2, 2, MMA_VARIANTS_B1(m8n8k128, b1)}; - default: - llvm_unreachable("Unexpected builtin ID."); - } + 0, \ + 0 + // clang-format on + switch (BuiltinID) { + // FP MMA + // Note that 'type' argument of MMA_SATF_VARIANTS uses D_C notation, while + // NumEltsN of return value are ordered as A,B,C,D. + case NVPTX::BI__hmma_m16n16k16_mma_f16f16: + return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m16n16k16, f16_f16)}}}; + case NVPTX::BI__hmma_m16n16k16_mma_f32f16: + return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m16n16k16, f32_f16)}}}; + case NVPTX::BI__hmma_m16n16k16_mma_f16f32: + return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m16n16k16, f16_f32)}}}; + case NVPTX::BI__hmma_m16n16k16_mma_f32f32: + return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, f32_f32)}}}; + case NVPTX::BI__hmma_m32n8k16_mma_f16f16: + return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m32n8k16, f16_f16)}}}; + case NVPTX::BI__hmma_m32n8k16_mma_f32f16: + return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m32n8k16, f32_f16)}}}; + case NVPTX::BI__hmma_m32n8k16_mma_f16f32: + return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m32n8k16, f16_f32)}}}; + case NVPTX::BI__hmma_m32n8k16_mma_f32f32: + return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, f32_f32)}}}; + case NVPTX::BI__hmma_m8n32k16_mma_f16f16: + return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m8n32k16, f16_f16)}}}; + case NVPTX::BI__hmma_m8n32k16_mma_f32f16: + return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m8n32k16, f32_f16)}}}; + case NVPTX::BI__hmma_m8n32k16_mma_f16f32: + return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m8n32k16, f16_f32)}}}; + case NVPTX::BI__hmma_m8n32k16_mma_f32f32: + return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, f32_f32)}}}; + + // Integer MMA + case NVPTX::BI__imma_m16n16k16_mma_s8: + return {2, 2, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, s8)}}}; + case NVPTX::BI__imma_m16n16k16_mma_u8: + return {2, 2, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, u8)}}}; + case NVPTX::BI__imma_m32n8k16_mma_s8: + return {4, 1, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, s8)}}}; + case NVPTX::BI__imma_m32n8k16_mma_u8: + return {4, 1, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, u8)}}}; + case NVPTX::BI__imma_m8n32k16_mma_s8: + return {1, 4, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, s8)}}}; + case NVPTX::BI__imma_m8n32k16_mma_u8: + return {1, 4, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, u8)}}}; + + // Sub-integer MMA + case NVPTX::BI__imma_m8n8k32_mma_s4: + return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, s4)}}}; + case NVPTX::BI__imma_m8n8k32_mma_u4: + return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, u4)}}}; + case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: + return {1, 1, 2, 2, {{MMA_VARIANTS_B1(m8n8k128, b1)}}}; + + // Double MMA + case NVPTX::BI__dmma_m8n8k4_mma_f64: + return {1, 1, 2, 2, {{MMA_VARIANTS(m8n8k4, f64)}}}; + + // Alternate FP MMA + case NVPTX::BI__mma_bf16_m16n16k16_mma_f32: + return {4, 4, 8, 8, {{MMA_VARIANTS(m16n16k16, bf16)}}}; + case NVPTX::BI__mma_bf16_m8n32k16_mma_f32: + return {2, 8, 8, 8, {{MMA_VARIANTS(m8n32k16, bf16)}}}; + case NVPTX::BI__mma_bf16_m32n8k16_mma_f32: + return {8, 2, 8, 8, {{MMA_VARIANTS(m32n8k16, bf16)}}}; + case NVPTX::BI__mma_tf32_m16n16k8_mma_f32: + return {4, 4, 8, 8, {{MMA_VARIANTS(m16n16k8, tf32)}}}; + default: + llvm_unreachable("Unexpected builtin ID."); + } #undef MMA_VARIANTS +#undef MMA_SATF_VARIANTS #undef MMA_VARIANTS_I4 #undef MMA_VARIANTS_B1 } @@ -16844,7 +16898,20 @@ case NVPTX::BI__bmma_m8n8k128_ld_a_b1: case NVPTX::BI__bmma_m8n8k128_ld_b_b1: case NVPTX::BI__bmma_m8n8k128_ld_c: - { + // Double MMA loads. + case NVPTX::BI__dmma_m8n8k4_ld_a: + case NVPTX::BI__dmma_m8n8k4_ld_b: + case NVPTX::BI__dmma_m8n8k4_ld_c: + // Alternate float MMA loads. + case NVPTX::BI__mma_bf16_m16n16k16_ld_a: + case NVPTX::BI__mma_bf16_m16n16k16_ld_b: + case NVPTX::BI__mma_bf16_m8n32k16_ld_a: + case NVPTX::BI__mma_bf16_m8n32k16_ld_b: + case NVPTX::BI__mma_bf16_m32n8k16_ld_a: + case NVPTX::BI__mma_bf16_m32n8k16_ld_b: + case NVPTX::BI__mma_tf32_m16n16k8_ld_a: + case NVPTX::BI__mma_tf32_m16n16k8_ld_b: + case NVPTX::BI__mma_tf32_m16n16k8_ld_c: { Address Dst = EmitPointerWithAlignment(E->getArg(0)); Value *Src = EmitScalarExpr(E->getArg(1)); Value *Ldm = EmitScalarExpr(E->getArg(2)); @@ -16889,7 +16956,9 @@ case NVPTX::BI__imma_m32n8k16_st_c_i32: case NVPTX::BI__imma_m8n32k16_st_c_i32: case NVPTX::BI__imma_m8n8k32_st_c_i32: - case NVPTX::BI__bmma_m8n8k128_st_c_i32: { + case NVPTX::BI__bmma_m8n8k128_st_c_i32: + case NVPTX::BI__dmma_m8n8k4_st_c_f64: + case NVPTX::BI__mma_m16n16k8_st_c_f32: { Value *Dst = EmitScalarExpr(E->getArg(0)); Address Src = EmitPointerWithAlignment(E->getArg(1)); Value *Ldm = EmitScalarExpr(E->getArg(2)); @@ -16941,7 +17010,12 @@ case NVPTX::BI__imma_m8n32k16_mma_u8: case NVPTX::BI__imma_m8n8k32_mma_s4: case NVPTX::BI__imma_m8n8k32_mma_u4: - case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: { + case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: + case NVPTX::BI__dmma_m8n8k4_mma_f64: + case NVPTX::BI__mma_bf16_m16n16k16_mma_f32: + case NVPTX::BI__mma_bf16_m8n32k16_mma_f32: + case NVPTX::BI__mma_bf16_m32n8k16_mma_f32: + case NVPTX::BI__mma_tf32_m16n16k8_mma_f32: { Address Dst = EmitPointerWithAlignment(E->getArg(0)); Address SrcA = EmitPointerWithAlignment(E->getArg(1)); Address SrcB = EmitPointerWithAlignment(E->getArg(2)); diff --git a/clang/test/CodeGen/builtins-nvptx-mma.cu b/clang/test/CodeGen/builtins-nvptx-mma.cu --- a/clang/test/CodeGen/builtins-nvptx-mma.cu +++ b/clang/test/CodeGen/builtins-nvptx-mma.cu @@ -3,21 +3,20 @@ // *** DO NOT EDIT *** // // This test has been automatically generated by -// builtins-nvtx-mma.py --ptx=63 --gpu-arch=75 +// builtins-nvtx-mma.py --ptx=70 --gpu-arch=80 // -// Make sure we can handle all builtins available on sm_75 with PTX63 -// RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_75 \ -// RUN: -fcuda-is-device -target-feature +ptx63 \ -// RUN: -DPTX=63 -DSM=75 \ +// Make sure we can handle all builtins available on sm_80 with PTX70 +// RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_80 \ +// RUN: -fcuda-is-device -target-feature +ptx70 \ +// RUN: -DPTX=70 -DSM=80 \ // RUN: -S -emit-llvm -o - -x cuda %s \ -// RUN: | FileCheck -check-prefixes=CHECK_PTX61_SM70,CHECK_PTX63_SM75,CHECK_PTX63_SM72,CHECK_PTX60_SM70 %s +// RUN: | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75 %s // Verify that all builtins have correct constraints. // RUN: %clang_cc1 -triple nvptx-unknown-unknown \ // RUN: -target-cpu sm_60 -target-feature +ptx42 \ -// RUN: -DPTX=63 -DSM=75 -fcuda-is-device -S -o /dev/null -x cuda \ +// RUN: -DPTX=70 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \ // RUN: -verify %s - #if !defined(CUDA_VERSION) #define __device__ __attribute__((device)) #define __global__ __attribute__((global)) @@ -29,8 +28,8 @@ // CHECK-LABEL: test_wmma_buitins __device__ void test_wmma_buitins(int *src, int *dst, - float *fsrc, float *fdst, int ldm) { - + float *fsrc, float *fdst, + double *dsrc, double *ddst, int ldm) { #if (PTX >= 60) && (SM >= 70) @@ -751,5 +750,153 @@ // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.u4.satfinite // expected-error-re@+1 {{'__imma_m8n8k32_mma_u4' needs target feature (sm_75{{.*}},(ptx63{{.*}}}} __imma_m8n8k32_mma_u4(dst, src, src, src, 1, 1); -#endif // (PTX >= 63) && (SM >= 75) +#endif // (PTX >= 63) && (SM >= 75) + +#if (PTX >= 70) && (SM >= 80) + + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m16n16k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m16n16k16_ld_a(dst, src, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m16n16k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m16n16k16_ld_a(dst, src, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m16n16k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m16n16k16_ld_b(dst, src, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m16n16k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m16n16k16_ld_b(dst, src, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_ld_a(dst, src, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_ld_a(dst, src, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_ld_b(dst, src, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_ld_b(dst, src, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.c.col.stride.f32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_ld_c(fdst, fsrc, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.c.row.stride.f32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_ld_c(fdst, fsrc, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.store.d.col.stride.f32 + // expected-error-re@+1 {{'__mma_m16n16k8_st_c_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_m16n16k8_st_c_f32(fdst, fsrc, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.store.d.row.stride.f32 + // expected-error-re@+1 {{'__mma_m16n16k8_st_c_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_m16n16k8_st_c_f32(fdst, fsrc, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m32n8k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m32n8k16_ld_a(dst, src, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m32n8k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m32n8k16_ld_a(dst, src, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m32n8k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m32n8k16_ld_b(dst, src, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m32n8k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m32n8k16_ld_b(dst, src, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m8n32k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m8n32k16_ld_a(dst, src, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m8n32k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m8n32k16_ld_a(dst, src, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m8n32k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m8n32k16_ld_b(dst, src, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16 + // expected-error-re@+1 {{'__mma_bf16_m8n32k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m8n32k16_ld_b(dst, src, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_ld_a(ddst, dsrc, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_ld_a(ddst, dsrc, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_ld_b(ddst, dsrc, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_ld_b(ddst, dsrc, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_ld_c(ddst, dsrc, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_ld_c(ddst, dsrc, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_st_c_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_st_c_f64(ddst, dsrc, ldm, 1); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_st_c_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_st_c_f64(ddst, dsrc, ldm, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16 + // expected-error-re@+1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 3, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.bf16 + // expected-error-re@+1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 2, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.bf16 + // expected-error-re@+1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 1, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16 + // expected-error-re@+1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 0, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 3, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.col.row.tf32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 2, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.row.col.tf32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 1, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32 + // expected-error-re@+1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 0, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16 + // expected-error-re@+1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 3, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.col.row.bf16 + // expected-error-re@+1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 2, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.row.col.bf16 + // expected-error-re@+1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 1, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16 + // expected-error-re@+1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 0, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16 + // expected-error-re@+1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 3, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.col.row.bf16 + // expected-error-re@+1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 2, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.row.col.bf16 + // expected-error-re@+1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 1, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16 + // expected-error-re@+1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 0, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 3, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.col.row.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 2, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.col.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 1, 0); + // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64 + // expected-error-re@+1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} + __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 0, 0); +#endif // (PTX >= 70) && (SM >= 80) } diff --git a/clang/test/CodeGen/builtins-nvptx-mma.py b/clang/test/CodeGen/builtins-nvptx-mma.py --- a/clang/test/CodeGen/builtins-nvptx-mma.py +++ b/clang/test/CodeGen/builtins-nvptx-mma.py @@ -47,7 +47,13 @@ in product(geoms, frags, types)] def get_mma_ops(): - return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], + return (make_mma_ops(["m16n16k8"], + ["tf32"], [], ["f32"], []) + + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], + ["bf16"], [], ["f32"], []) + + make_mma_ops(["m8n8k4"], + ["f64"], [], ["f64"], []) + + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []) + @@ -55,14 +61,18 @@ ["s4", "u4"], [], ["s32"], []) + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], [])) + def get_ldst_ops(): return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], - ["a", "b"], ["f16", "u8", "s8"]) + + ["a", "b"], ["f16", "u8", "s8", "bf16"]) + make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]) + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) + - make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])) + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) + + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) + + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) + + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])) def is_geom_supported(geom): # geometries for FP and ints. @@ -73,6 +83,8 @@ return ptx_version >= 63 and gpu_arch >= 75 if geom == "m16n16k16": return ptx_version >= 60 + if geom in ["m16n16k8", "m8n8k4"]: + return ptx_version >= 70 and gpu_arch >= 80 assert(False) # Unexpected geometry. def is_type_supported(ptx_type): @@ -80,16 +92,24 @@ return ptx_version >= 63 and gpu_arch >= 72 if ptx_type in ["s4", "u4", "b1"]: return ptx_version >= 63 and gpu_arch >= 75 + if ptx_type in ["bf16", "tf32", "f64"]: + return ptx_version >= 70 and gpu_arch >= 80 return ptx_version >= 60 and gpu_arch >= 70 +def is_rnd_supported(op): + # rnd is only supported for FP64 WMMA + return op.a.ptx_type == "f64" + def is_mma_variant_supported(op, layout_a, layout_b, satf): if not (is_type_supported(op.a.ptx_type) and is_geom_supported(op.a.geom)): return False - # sub-integer require row/col layout, and no satf. + + if satf and not op.a.ptx_type in ["f16", "s8", "u8", "s4", "u4"]: + return False + + # sub-integer types require row/col layout. if op.a.ptx_type in ["s4", "u4", "b1"]: - if op.a.ptx_type == "b1" and satf: - return False return layout_a == "row" and layout_b == "col" return True @@ -98,7 +118,7 @@ and is_geom_supported(frag.geom)): return False if frag.ptx_type in ["s4", "u4", "b1"]: - # sub-integer require sm_75 and ptx63, row/col layout for a/b. + # sub-integer types require sm_75 and ptx63, row/col layout for a/b. return ((frag.frag == "a" and layout == "row") or (frag.frag == "b" and layout == "col") or frag.frag in ["c", "d"]) @@ -109,12 +129,21 @@ if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]: if frag.ptx_type in ["f16", "f32"]: prefix = "__hmma" + elif frag.ptx_type == "bf16": + prefix = "__mma_bf16" else: prefix = "__imma" elif frag.geom == "m8n8k32": prefix = "__imma" # sub-integers elif frag.geom == "m8n8k128": prefix = "__bmma" + elif frag.geom == "m8n8k4": + prefix = "__dmma" + elif frag.geom == "m16n16k8": + if frag.ptx_type == "f32": + prefix = "__mma" + else: + prefix = "__mma_tf32" assert prefix return prefix @@ -123,10 +152,13 @@ if prefix == "__hmma": suffix = "" if frag.frag in ["a","b"] else frag.ptx_type - elif prefix in ["__imma", "__bmma"]: - suffix = "" if frag.frag in ["c"] else frag.ptx_type + elif prefix in ["__dmma", "__mma_bf16", "__mma_tf32"]: + suffix = "" if frag.frag in ["a","b","c"] else frag.ptx_type + else: + suffix = "" if frag.frag == "c" else frag.ptx_type if suffix == "s32": suffix = "i32" + if frag.frag == "d": ifrag = "c" op = "st" @@ -143,6 +175,8 @@ if prefix == "__hmma": suffix = op.d.ptx_type + op.c.ptx_type + elif prefix in ["__mma_bf16", "__mma_tf32"]: + suffix = op.d.ptx_type else: suffix = op.a.ptx_type @@ -151,8 +185,9 @@ suffix) return name - def get_required_sm(frag): + if frag.ptx_type in ["f64", "bf16", "tf32"]: + return 80 if frag.ptx_type in ["u4", "s4", "b1"]: return 75 if frag.ptx_type in ["s8", "u8"]: @@ -163,18 +198,34 @@ else: # s8/u8 return 72 if frag.ptx_type in ["f16", "f32"]: - return 70 + if frag.geom == "m16n16k8": + return 80 + else: + return 70 assert(False) def get_required_ptx(frag): + if frag.ptx_type in ["f64", "bf16", "tf32"]: + return 70 if frag.ptx_type in ["f16", "f32"]: - return 60 if frag.geom == "m16n16k16" else 61 + if frag.geom == "m16n16k16": + return 60 + if frag.geom == "m16n16k8": + return 70 + return 61 return 63 +def get_src_dst_prefix(ptx_type): + if ptx_type == "f32": + return "f" + if ptx_type == "f64": + return "d" + return "" + def gen_wmma_ldst_tests(results): load_template = """ // CHECK${check_suffix}: call {{.*}} @${intrinsic} - // expected-error-re@+1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}} + // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}} ${builtin}(${dst}, ${src}, ldm, ${blayout}); """.rstrip() intrinsic_template = "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}" @@ -184,7 +235,7 @@ if not is_ldst_variant_supported(frag, layout): continue - is_fp = frag.ptx_type == "f32" + src_dst_prefix = get_src_dst_prefix(frag.ptx_type) min_sm = get_required_sm(frag) min_ptx = get_required_ptx(frag) params = { @@ -192,8 +243,8 @@ "builtin" : get_ldst_builtin_name(frag), "min_ptx" : min_ptx, "min_sm" : min_sm, - "dst": "fdst" if is_fp else "dst", - "src": "fsrc" if is_fp else "src", + "dst": src_dst_prefix + "dst", + "src": src_dst_prefix + "src", "blayout" : 0 if layout == "row" else 1, "intrinsic" : Template(intrinsic_template).substitute({ "frag" : frag.frag, @@ -208,12 +259,12 @@ return results def mma_signature(op): - if op.a.ptx_type in ["s8", "u8", "s4", "u4", "b1"]: - # int and sub-int ops are identified by input type. - return op.a.ptx_type - else: - # the rest are FP ops identified by accumulator & result type. + if op.a.ptx_type == "f16": + # FP16 ops identified by accumulator & result type. return "%s.%s" % (op.d.ptx_type, op.c.ptx_type) + else: + # other ops are identified by input type. + return op.a.ptx_type # Get numeric value for rowcol parameter of the builtin # AFAICT it uses the encoding accepted by NVVM intrinsics: @@ -229,8 +280,8 @@ def gen_wmma_mma_tests(results): mma_template = """ // CHECK${check_suffix}: call {{.*}} @${intrinsic} - // expected-error-re@+1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}} - ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_isatf}); + // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}} + ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf}); """.rstrip() intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" @@ -243,9 +294,9 @@ if not is_mma_variant_supported(op, alayout, blayout, satf): continue - a_is_fp = op.a.ptx_type == "f32" - c_is_fp = op.c.ptx_type == "f32" - d_is_fp = op.d.ptx_type == "f32" + asrc_prefix = get_src_dst_prefix(op.a.ptx_type) + csrc_prefix = get_src_dst_prefix(op.c.ptx_type) + ddst_prefix = get_src_dst_prefix(op.d.ptx_type) min_sm = get_required_sm(op.a) min_ptx = get_required_ptx(op.a) if op.a.ptx_type == "b1": # .b1 MMA has no satf argument. @@ -257,11 +308,11 @@ "builtin" : get_mma_builtin_name(op), "min_ptx" : min_ptx, "min_sm" : min_sm, - "dst": "fdst" if d_is_fp else "dst", - "asrc": "fsrc" if a_is_fp else "src", - "csrc": "fsrc" if c_is_fp else "src", + "dst": ddst_prefix + "dst", + "asrc": asrc_prefix + "src", + "csrc": csrc_prefix + "src", "ilayout" : get_ilayout(alayout, blayout), - "maybe_isatf" : isatf_arg, + "maybe_satf" : isatf_arg, "intrinsic" : Template(intrinsic_template).substitute({ "geom" : op.a.geom, "alayout" : alayout, @@ -322,7 +373,8 @@ // CHECK-LABEL: test_wmma_buitins __device__ void test_wmma_buitins(int *src, int *dst, - float *fsrc, float *fdst, int ldm) { + float *fsrc, float *fdst, + double *dsrc, double *ddst, int ldm) { """); for (ptx, sm), tests in sorted(results.items()): diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -52,13 +52,27 @@ string gft = Geom#":"#Frag#":"#ptx_elt_type; string ft = frag#":"#ptx_elt_type; list regs = !cond( - // mma.sync.m8n8k4 uses smaller a/b fragments than wmma fp ops + // mma fp ops use smaller fragments than wmma fp ops !eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2), !eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2), - - // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 - // All currently supported geometries use the same fragment format, - // so we only need to consider {fragment, type}. + !eq(gft,"m16n8k8:a:f16") : !listsplat(llvm_v2f16_ty, 2), + !eq(gft,"m16n8k8:b:f16") : [llvm_v2f16_ty], + !eq(gft,"m16n8k8:c:f16") : !listsplat(llvm_v2f16_ty, 2), + !eq(gft,"m16n8k8:d:f16") : !listsplat(llvm_v2f16_ty, 2), + !eq(gft,"m16n8k8:c:f32") : !listsplat(llvm_float_ty, 4), + !eq(gft,"m16n8k8:d:f32") : !listsplat(llvm_float_ty, 4), + !eq(gft,"m16n8k16:a:f16") : !listsplat(llvm_v2f16_ty, 4), + !eq(gft,"m16n8k16:b:f16") : !listsplat(llvm_v2f16_ty, 2), + !eq(gft,"m16n8k16:c:f16") : !listsplat(llvm_v2f16_ty, 2), + !eq(gft,"m16n8k16:d:f16") : !listsplat(llvm_v2f16_ty, 2), + !eq(gft,"m16n8k16:c:f32") : !listsplat(llvm_float_ty, 4), + !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4), + !eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4), + !eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4), + + // wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 + // All other supported geometries use the same fragment format for f32 and + // f16, so we only need to consider {fragment, type}. !eq(ft,"a:f16") : !listsplat(llvm_v2f16_ty, 8), !eq(ft,"b:f16") : !listsplat(llvm_v2f16_ty, 8), !eq(ft,"c:f16") : !listsplat(llvm_v2f16_ty, 4), @@ -66,7 +80,36 @@ !eq(ft,"c:f32") : !listsplat(llvm_float_ty, 8), !eq(ft,"d:f32") : !listsplat(llvm_float_ty, 8), - // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16 + // wmma tf32 -> s32 @ m16n16k8 + !eq(gft,"m16n16k8:a:tf32") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n16k8:b:tf32") : !listsplat(llvm_i32_ty, 4), + + // mma tf32 -> s32 @ m16n16k8/m16n8k8 + !eq(gft,"m16n8k4:a:tf32") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k4:b:tf32") : [llvm_i32_ty], + !eq(gft,"m16n8k8:a:tf32") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k8:b:tf32") : !listsplat(llvm_i32_ty, 2), + + !eq(gft,"m8n8k4:a:f64") : [llvm_double_ty], + !eq(gft,"m8n8k4:b:f64") : [llvm_double_ty], + !eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2), + !eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2), + + // wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16 + !eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m8n32k16:a:bf16") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m8n32k16:b:bf16") : !listsplat(llvm_i32_ty, 8), + !eq(gft,"m32n8k16:a:bf16") : !listsplat(llvm_i32_ty, 8), + !eq(gft,"m32n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2), + + // mma bf16 -> s32 @ m16n8k16/m16n8k8 + !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty], + + // wmma u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16 !eq(gft,"m16n16k16:a:u8") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m16n16k16:a:s8") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m16n16k16:b:u8") : !listsplat(llvm_i32_ty, 2), @@ -88,17 +131,65 @@ !eq(gft,"m32n8k16:c:s32") : !listsplat(llvm_i32_ty, 8), !eq(gft,"m32n8k16:d:s32") : !listsplat(llvm_i32_ty, 8), - // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) - !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty], + // mma u8/s8 -> s32 @ m8n8k16/m16n8k16/m16n8k32 + !eq(gft,"m8n8k16:a:u8") : [llvm_i32_ty], + !eq(gft,"m8n8k16:a:s8") : [llvm_i32_ty], + !eq(gft,"m8n8k16:b:u8") : [llvm_i32_ty], + !eq(gft,"m8n8k16:b:s8") : [llvm_i32_ty], + !eq(gft,"m8n8k16:c:s32") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m8n8k16:d:s32") : !listsplat(llvm_i32_ty, 2), + + !eq(gft,"m16n8k16:a:u8") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k16:a:s8") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k16:b:u8") : [llvm_i32_ty], + !eq(gft,"m16n8k16:b:s8") : [llvm_i32_ty], + !eq(gft,"m16n8k16:c:s32") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k16:d:s32") : !listsplat(llvm_i32_ty, 4), + + !eq(gft,"m16n8k32:a:u8") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k32:a:s8") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k32:b:u8") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k32:b:s8") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4), + + // wmma/mma u4/s4 -> s32 @ m8n8k32 (u4/s4) !eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty], !eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty], - !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty], !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty], !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty], - !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2), - !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m8n8k32:c:s32") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m8n8k32:d:s32") : !listsplat(llvm_i32_ty, 2), + + !eq(gft,"m16n8k32:a:u4") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k32:a:s4") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k32:b:u4") : [llvm_i32_ty], + !eq(gft,"m16n8k32:b:s4") : [llvm_i32_ty], + !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4), + + !eq(gft,"m16n8k64:a:u4") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k64:a:s4") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k64:b:u4") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k64:b:s4") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4), + + // wmma/mma b1 -> s32 @ m8n8k128(b1) + !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty], + !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty], + !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2), + + !eq(gft,"m16n8k128:a:b1") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k128:b:b1") : [llvm_i32_ty], + !eq(gft,"m16n8k128:c:s32") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k128:d:s32") : !listsplat(llvm_i32_ty, 4), + + !eq(gft,"m16n8k256:a:b1") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4), ); } @@ -125,35 +216,40 @@ class MMA_SIGNATURE { list id_frags = !cond( - // int and sub-int ops are identified by input type. - !eq(A.ptx_elt_type, "s8") : [A], - !eq(A.ptx_elt_type, "u8") : [A], - !eq(A.ptx_elt_type, "s4") : [A], - !eq(A.ptx_elt_type, "u4") : [A], - !eq(A.ptx_elt_type, "b1") : [A], - // the rest are FP ops identified by accumulator & result type. - true: [D, C] + // FP16 ops are identified by accumulator & result type. + !eq(A.ptx_elt_type, "f16") : [D, C], + // other ops are identified by input types. + !ne(A.ptx_elt_type, B.ptx_elt_type): [A, B], + true: [A] ); string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type)); } -class WMMA_NAME_MMA { +class WMMA_NAME { string signature = MMA_SIGNATURE.ret; - string llvm = !if( - !eq(A.geom, "m8n8k4"), - "llvm.nvvm.mma.m8n8k4" - # "." # ALayout - # "." # BLayout - # signature, - "llvm.nvvm.wmma." - # A.geom - # ".mma" - # "." # ALayout - # "." # BLayout - # signature - # !if(Satfinite, ".satfinite", "")); + string llvm = "llvm.nvvm.wmma." + # A.geom + # ".mma" + # "." # ALayout + # "." # BLayout + # !if(!ne(Rnd, ""), !strconcat(".", Rnd), "") + # signature + # !if(Satfinite, ".satfinite", ""); + + string record = !subst(".", "_", + !subst("llvm.", "int_", llvm)); +} +class MMA_NAME { + string signature = MMA_SIGNATURE.ret; + string llvm = "llvm.nvvm.mma." + # A.geom + # "." # ALayout + # "." # BLayout + # !if(Satfinite, ".satfinite", "") + # signature; string record = !subst(".", "_", !subst("llvm.", "int_", llvm)); } @@ -188,14 +284,18 @@ list ops = !foreach(x, ret, x.gft); } - - // Creates list of valid combinations of fragments. This is the master list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { - list> fp_mma_ops = MMA_OPS< + list> tf32_wmma_ops = MMA_OPS< + ["m16n16k8"], + ["tf32"], [], ["f32"], []>.ret; + list> bf16_wmma_ops = MMA_OPS< + ["m16n16k16", "m32n8k16", "m8n32k16"], + ["bf16"], [], ["f32"], []>.ret; + list> f64_wmma_ops = MMA_OPS< ["m8n8k4"], - ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; + ["f64"], [], ["f64"], []>.ret; list> fp_wmma_ops = MMA_OPS< ["m16n16k16", "m32n8k16", "m8n32k16"], ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; @@ -208,16 +308,50 @@ list> bit_wmma_ops = MMA_OPS< ["m8n8k128"], ["b1"], [], ["s32"], []>.ret; + list> all_wmma_ops = !listconcat( + tf32_wmma_ops, bf16_wmma_ops, f64_wmma_ops, + fp_wmma_ops, int_wmma_ops, subint_wmma_ops, bit_wmma_ops); + + list> tf32_mma_ops = MMA_OPS< + ["m16n8k4", "m16n8k8"], + ["tf32"], [], ["f32"], []>.ret; + list> bf16_mma_ops = MMA_OPS< + ["m16n8k16", "m16n8k8"], + ["bf16"], [], ["f32"], []>.ret; + list> f64_mma_ops = MMA_OPS< + ["m8n8k4"], + ["f64"], [], ["f64"], []>.ret; + list> fp_mma_ops = MMA_OPS< + ["m8n8k4", "m16n8k8", "m16n8k16"], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; + list> int_mma_ops = MMA_OPS< + ["m8n8k16", "m16n8k16", "m16n8k32"], + ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret; + list> subint_mma_ops = MMA_OPS< + ["m8n8k32", "m16n8k32", "m16n8k64"], + ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret; + list> bit_mma_ops = MMA_OPS< + ["m8n8k128", "m16n8k128", "m16n8k256"], + ["b1"], [], ["s32"], []>.ret; list> all_mma_ops = !listconcat( - fp_mma_ops, fp_wmma_ops, int_wmma_ops, - subint_wmma_ops, bit_wmma_ops); + tf32_mma_ops, bf16_mma_ops, f64_mma_ops, + fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops); list ldst_ab_ops = MMA_LDST_OPS< ["m16n16k16", "m32n8k16", "m8n32k16"], - ["a", "b"], ["f16", "u8", "s8"]>.ret; + ["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret; list ldst_cd_ops = MMA_LDST_OPS< ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]>.ret; + list ldst_tf32_ab_ops = MMA_LDST_OPS< + ["m16n16k8"], + ["a", "b"], ["tf32"]>.ret; + list ldst_tf32_cd_ops = MMA_LDST_OPS< + ["m16n16k8"], + ["c", "d"], ["f32"]>.ret; + list ldst_f64_abcd_ops = MMA_LDST_OPS< + ["m8n8k4"], + ["a", "b", "c", "d"], ["f64"]>.ret; list ldst_subint_ab_ops = MMA_LDST_OPS< ["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret; list ldst_bit_ab_ops = MMA_LDST_OPS< @@ -225,6 +359,9 @@ list ldst_subint_cd_ops = MMA_LDST_OPS< ["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]>.ret; list all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops, + ldst_tf32_ab_ops, + ldst_tf32_cd_ops, + ldst_f64_abcd_ops, ldst_subint_ab_ops, ldst_bit_ab_ops, ldst_subint_cd_ops); @@ -235,69 +372,110 @@ def NVVM_MMA_OPS : NVVM_MMA_OPS; -// Returns true if this combination of layout/satf is supported; false otherwise. -// MMA ops must provide all parameters. Loads and stores -- only frags and layout_a. -// The class is used to prevent generation of records for the unsupported variants. + +// Returns true if this combination of fragment and layout for WMMA load/store +// ops is supported; false otherwise. +// E.g. +// if NVVM_WMMA_LDST_SUPPORTED<...>.ret then +// def : FOO<>; // The record will only be defined for supported ops. +// +class NVVM_WMMA_LDST_SUPPORTED { + string f = frag.frag; + string t = frag.ptx_elt_type; + + bit ret = !cond( + // Sub-int load and store requires A fragment to be of row layout and B + // fragments to be of column layout. + !and(!or(!eq(t, "b1"), + !eq(t, "u4"), + !eq(t, "s4")), + !or(!and(!eq(f, "a"), + !ne(layout, "row")), + !and(!eq(f, "b"), + !ne(layout, "col")))) : false, + true: true + ); +} + +// Returns true if this combination of layout/satf/rnd for WMMA ops is +// supported; false otherwise. +// E.g. +// if NVVM_WMMA_SUPPORTED<...>.ret then +// def : FOO<>; // The record will only be defined for supported ops. +// +class NVVM_WMMA_SUPPORTED frags, string layout_a, string layout_b, int satf, string rnd> { + // WMMA ops check both layouts. + string layout = layout_a # ":" # layout_b; + string t = frags[0].ptx_elt_type; + + bit ret = !cond( + // only f64 wmma functions support rnd options + // any non f64 type that uses a rnd value is invalid + !and(!ne(t, "f64"), !ne(rnd, "")) : false, + + // satf is only valid for select types + !and(!eq(satf, 1), + !ne(t, "s8"), + !ne(t, "u8"), + !ne(t, "s4"), + !ne(t, "u4"), + !ne(t, "f16")): false, + + // Sub-int wmma requires row/column layout + !and(!or(!eq(t, "s4"), + !eq(t, "u4"), + !eq(t, "b1")), + !ne(layout, "row:col")) : false, + true: true + ); +} + +// Returns true if this combination of layout/satf for MMA ops is supported; +// false otherwise. // E.g. // if NVVM_MMA_SUPPORTED<...>.ret then // def : FOO<>; // The record will only be defined for supported ops. // -class NVVM_MMA_SUPPORTED frags, string layout_a, string layout_b="-", int satf=-1> { +class NVVM_MMA_SUPPORTED frags, string layout_a, string layout_b, int satf> { // MMA ops check both layouts. - string mma = frags[0].ptx_elt_type - # ":" # layout_a - # ":" # layout_b; - // Load ops only need type/fragment/layout. - string ld = frags[0].ptx_elt_type - # ":" # frags[0].frag - # ":" # layout_a - ; - string ldf = frags[0].ptx_elt_type - # ":" # frags[0].frag - ; - string t = frags[0].ptx_elt_type; + string layout = layout_a # ":" # layout_b; + string a_type = frags[0].ptx_elt_type; + string b_type = frags[1].ptx_elt_type; + string c_type = frags[2].ptx_elt_type; + string d_type = frags[3].ptx_elt_type; + string geom = frags[0].geom; // gcd is a shortcut used to identify instructions that depend on - // geom+frag_c+frag_d. Not all instances of this class have all fragments - // specified. If there are not enough fragments, the tail evaluates to '?'. - string gcd = frags[0].geom - # ":" - # !if(!eq(!size(frags), 4), - frags[2].ptx_elt_type # frags[3].ptx_elt_type, - "?"); + // geom+frag_c+frag_d. + string gcd = geom # ":" # c_type # d_type; bit ret = !cond( - // Sub-int MMA only supports fixed A/B layout. - // b1 does not support .satf. - !eq(mma#":"#satf, "b1:row:col:0") : true, - // mma.m8n8k4 has no .satf modifier. - !and(!eq(frags[0].geom, "m8n8k4"), - !ne(satf, 0)): false, - - // mma.m8n8k4 has no C=f32 D=f16 variant. + + // Limit satf to valid types + !and(!eq(satf, 1), + !ne(a_type, "s8"), + !ne(a_type, "u8"), + !ne(a_type, "s4"), + !ne(a_type, "u4")): false, + + // m8n8k4 has no C=f32 D=f16 variant. !eq(gcd, "m8n8k4:f32f16"): false, - !eq(mma, "s4:row:col") : true, - !eq(mma, "u4:row:col") : true, - !eq(mma, "s4:row:col") : true, - !eq(mma, "u4:row:col") : true, - // Sub-int load/stores have fixed layout for A and B. - !and(!eq(layout_b, "-"), // It's a Load or Store op - !or(!eq(ld, "b1:a:row"), - !eq(ld, "b1:b:col"), - !eq(ldf, "b1:c"), - !eq(ldf, "b1:d"), - !eq(ld, "s4:a:row"), - !eq(ld, "s4:b:col"), - !eq(ldf, "s4:c"), - !eq(ldf, "s4:d"), - !eq(ld, "u4:a:row"), - !eq(ld, "u4:b:col"), - !eq(ldf, "u4:c"), - !eq(ldf, "u4:d"))) : true, - // All other sub-int ops are not supported. - !eq(t, "b1") : false, - !eq(t, "s4") : false, - !eq(t, "u4") : false, - // All other (non sub-int) are OK. + + // only m8n8k4 for f16 does not require row:col layout + !and(!ne(layout, "row:col"), + !or(!ne(geom, "m8n8k4"), + !ne(a_type, "f16"))) : false, + + // m16n8k8 requires A and B to be the same type and C and D to be the same + // type. + !and(!eq(geom, "m16n8k8"), + !or(!ne(a_type, b_type), + !ne(c_type, d_type))): false, + + // m16n8k8 requires C and D to be the same type. + !and(!eq(geom, "m16n8k8"), + !ne(c_type, d_type)): false, + + // All other are OK. true: true ); } @@ -4271,36 +4449,59 @@ foreach layout = ["row", "col"] in { foreach stride = [0, 1] in { foreach frag = NVVM_MMA_OPS.all_ld_ops in - if NVVM_MMA_SUPPORTED<[frag], layout>.ret then + if NVVM_WMMA_LDST_SUPPORTED.ret then def WMMA_NAME_LDST<"load", frag, layout, stride>.record : NVVM_WMMA_LD; foreach frag = NVVM_MMA_OPS.all_st_ops in - if NVVM_MMA_SUPPORTED<[frag], layout>.ret then + if NVVM_WMMA_LDST_SUPPORTED.ret then def WMMA_NAME_LDST<"store", frag, layout, stride>.record : NVVM_WMMA_ST; } } // WMMA.MMA -class NVVM_WMMA_MMA : Intrinsic.llvm>; + WMMA_NAME.llvm>; + +foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach satf = [0, 1] in { + foreach rnd = ["", "rn", "rz", "rm", "rp"] in { + foreach op = NVVM_MMA_OPS.all_wmma_ops in { + if NVVM_WMMA_SUPPORTED.ret then { + def WMMA_NAME.record + : NVVM_WMMA_MMA; + } + } // op + } // rnd + } // satf + } // layout_b +} // layout_a + +// MMA +class NVVM_MMA + : Intrinsic.llvm>; foreach layout_a = ["row", "col"] in { foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { if NVVM_MMA_SUPPORTED.ret then { - def WMMA_NAME_MMA.record - : NVVM_WMMA_MMA; + def MMA_NAME.record + : NVVM_MMA; } - } + } // op } // satf } // layout_b } // layout_a diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3490,6 +3490,10 @@ case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride: case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride: case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride: case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col: case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride: case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride: @@ -3497,7 +3501,11 @@ case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row: case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride: case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: { + case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); @@ -3515,6 +3523,14 @@ case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride: case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride: case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride: + case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col: + case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride: + case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row: + case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride: case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col: case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride: @@ -3523,7 +3539,15 @@ case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row: case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride: case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: { + case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row: + case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride: + case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col: + case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride: + case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row: + case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v4i32; Info.ptrVal = I.getArgOperand(0); @@ -3603,7 +3627,11 @@ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col: case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row: case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: { + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: + case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col: + case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row: + case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride: + case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); @@ -3613,6 +3641,16 @@ return true; } + case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride: + + case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride: + case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col: case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride: case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row: @@ -3651,6 +3689,37 @@ return true; } + case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col: + case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride: + case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row: + case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride: + + case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col: + case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride: + case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row: + case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::f64; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOLoad; + Info.align = Align(8); + return true; + } + + case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col: + case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride: + case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row: + case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v2f64; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOLoad; + Info.align = Align(16); + return true; + } + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col: case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row: case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride: @@ -3683,7 +3752,11 @@ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col: case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row: case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride: - case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: { + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: + case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col: + case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row: + case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride: + case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); @@ -3731,6 +3804,19 @@ return true; } + case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col: + case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride: + case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row: + case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::v2f64; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(16); + return true; + } + case Intrinsic::nvvm_atomic_load_inc_32: case Intrinsic::nvvm_atomic_load_dec_32: diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -144,6 +144,7 @@ def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">; def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">; def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">; +def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">; def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">; def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1943,21 +1943,21 @@ !strconcat("ldu.global.", TyStr), []>; } -multiclass VLDU_G_ELE_V4 { +multiclass VLDU_G_ELE_V4 { def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int32Regs:$src), + regclass:$dst4), (ins Int32Regs:$src), !strconcat("ldu.global.", TyStr), []>; def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int64Regs:$src), + regclass:$dst4), (ins Int64Regs:$src), !strconcat("ldu.global.", TyStr), []>; def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins MEMri:$src), + regclass:$dst4), (ins MEMri:$src), !strconcat("ldu.global.", TyStr), []>; def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins MEMri64:$src), + regclass:$dst4), (ins MEMri64:$src), !strconcat("ldu.global.", TyStr), []>; def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins imemAny:$src), + regclass:$dst4), (ins imemAny:$src), !strconcat("ldu.global.", TyStr), []>; } @@ -1997,7 +1997,7 @@ //----------------------------------- -// Support for ldg on sm_35 or later +// Support for ldg on sm_35 or later //----------------------------------- // Don't annotate ld.global.nc as mayLoad, because these loads go through the @@ -2045,7 +2045,7 @@ // vector -// Elementized vector ldg +// Elementized vector ldg multiclass VLDG_G_ELE_V2 { def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), (ins Int32Regs:$src), @@ -2064,21 +2064,21 @@ !strconcat("ld.global.nc.", TyStr), []>; } -multiclass VLDG_G_ELE_V4 { +multiclass VLDG_G_ELE_V4 { def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int32Regs:$src), + regclass:$dst4), (ins Int32Regs:$src), !strconcat("ld.global.nc.", TyStr), []>; def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int64Regs:$src), + regclass:$dst4), (ins Int64Regs:$src), !strconcat("ld.global.nc.", TyStr), []>; def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins MEMri:$src), + regclass:$dst4), (ins MEMri:$src), !strconcat("ld.global.nc.", TyStr), []>; def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins MEMri64:$src), + regclass:$dst4), (ins MEMri64:$src), !strconcat("ld.global.nc.", TyStr), []>; def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins imemAny:$src), + regclass:$dst4), (ins imemAny:$src), !strconcat("ld.global.nc.", TyStr), []>; } @@ -7568,12 +7568,15 @@ // In addition to target-independent fields provided by WMMA_REGS, it adds // the fields commonly used to implement specific PTX instruction -- register // types and names, constraints, parts of assembly, etc. -class WMMA_REGINFO +class WMMA_REGINFO : WMMA_REGS { // NVPTX register types used to carry fragment data. NVPTXRegClass regclass = !cond( !eq(ptx_elt_type, "f16") : Float16x2Regs, !eq(ptx_elt_type, "f32") : Float32Regs, + !eq(ptx_elt_type, "f64") : Float64Regs, + !eq(ptx_elt_type, "bf16") : Int32Regs, + !eq(ptx_elt_type, "tf32") : Int32Regs, !eq(ptx_elt_type, "s32") : Int32Regs, !eq(ptx_elt_type, "s8") : Int32Regs, !eq(ptx_elt_type, "u8") : Int32Regs, @@ -7602,6 +7605,9 @@ !or(!eq(ptx_elt_type, "f16"), !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60], + !and(!eq(geom,"m8n8k4"), + !eq(ptx_elt_type, "f64")) : [hasSM80, hasPTX70], + // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16 !and(!or(!eq(geom, "m8n32k16"), !eq(geom, "m32n8k16")), @@ -7616,11 +7622,46 @@ !eq(ptx_elt_type, "s8"), !eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63], - // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) - !or(!eq(geom,"m8n8k128"), - !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63], + !and(!or(!eq(geom,"m16n16k16"), + !eq(geom,"m8n32k16"), + !eq(geom,"m32n8k16")), + !eq(ptx_elt_type, "bf16")) : [hasSM80, hasPTX70], + + !and(!eq(geom,"m16n16k8"), + !eq(ptx_elt_type, "tf32")) : [hasSM80, hasPTX70], + + !and(!eq(geom,"m16n16k8"), + !eq(ptx_elt_type, "f32")) : [hasSM80, hasPTX70], + + // b1 -> s32 @ m8n8k128(b1) + !and(!ne(op,"mma"), + !eq(geom,"m8n8k128")) : [hasSM75, hasPTX63], + + // u4/s4 -> s32 @ m8n8k32 (u4/s4) + !and(!ne(op,"mma"), + !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63], + + !or(!eq(geom,"m16n8k8"), + !eq(geom,"m8n8k16")) : [hasSM75, hasPTX65], - !eq(geom, "m8n8k4") : [hasSM70, hasPTX64]); + !and(!ne(ptx_elt_type,"f64"), + !eq(geom, "m8n8k4")) : [hasSM70, hasPTX64], + + // mma m8n8k32 requires higher PTX version + !and(!eq(op,"mma"), + !eq(geom,"m8n8k32")) : [hasSM75, hasPTX65], + + !and(!eq(ptx_elt_type,"f64"), + !eq(geom, "m8n8k4")) : [hasSM80, hasPTX70], + + !and(!eq(op,"mma"), + !or(!eq(geom, "m16n8k16"), + !eq(geom, "m16n8k4"), + !eq(geom, "m16n8k32"), + !eq(geom, "m16n8k64"), + !eq(geom, "m8n8k128"), + !eq(geom, "m16n8k128"), + !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -7744,11 +7785,11 @@ foreach space = [".global", ".shared", ""] in { foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { foreach frag = NVVM_MMA_OPS.all_ld_ops in - if NVVM_MMA_SUPPORTED<[frag], layout>.ret then - def : WMMA_LOAD, layout, space, stride, addr>; + if NVVM_WMMA_LDST_SUPPORTED.ret then + def : WMMA_LOAD, layout, space, stride, addr>; foreach frag = NVVM_MMA_OPS.all_st_ops in - if NVVM_MMA_SUPPORTED<[frag], layout>.ret then - def : WMMA_STORE_D, layout, space, stride, addr>; + if NVVM_WMMA_LDST_SUPPORTED.ret then + def : WMMA_STORE_D, layout, space, stride, addr>; } // addr } // space } // stride @@ -7758,46 +7799,84 @@ // WMMA.MMA class WMMA_MMA - : WMMA_INSTR.record, - [FragA.Ins, FragB.Ins, FragC.Ins]>, + string ALayout, string BLayout, int Satfinite, string rnd> + : WMMA_INSTR.record, + [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. Requires { let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); string TypeList = !cond( - !eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type - # ".f16.f16." - # FragC.ptx_elt_type, - !eq(FragD.ptx_elt_type, "s32") : ".s32" - # "." # FragA.ptx_elt_type - # "." # FragB.ptx_elt_type - # ".s32", - 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type, + !eq(FragA.ptx_elt_type, "f16") : "." # FragD.ptx_elt_type + # "." # FragC.ptx_elt_type, + 1: "." # FragD.ptx_elt_type + # "." # FragA.ptx_elt_type + # "." # FragB.ptx_elt_type + # "." # FragC.ptx_elt_type, ); - let AsmString = !if(!eq(FragA.geom, "m8n8k4"), - "mma.sync.aligned.m8n8k4" - # "." # ALayout - # "." # BLayout - # TypeList # "\n\t\t" - # FragD.regstring # ",\n\t\t" - # FragA.regstring # ",\n\t\t" - # FragB.regstring # ",\n\t\t" - # FragC.regstring # ";", - "wmma.mma" - # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") - # ".sync" - # "${ptx:aligned}" - # "." # ALayout - # "." # BLayout - # "." # FragA.geom - # TypeList - # !if(Satfinite, ".satfinite", "") # "\n\t\t" - # FragD.regstring # ",\n\t\t" - # FragA.regstring # ",\n\t\t" - # FragB.regstring # ",\n\t\t" - # FragC.regstring # ";"); + let AsmString = "wmma.mma" + # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") + # ".sync" + # "${ptx:aligned}" + # "." # ALayout + # "." # BLayout + # "." # FragA.geom + # !if(!ne(rnd, ""), !strconcat(".", rnd), "") + # TypeList + # !if(Satfinite, ".satfinite", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"; +} + +defset list WMMAs = { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach satf = [0, 1] in { + foreach rnd = ["", "rn", "rz", "rm", "rp"] in { + foreach op = NVVM_MMA_OPS.all_wmma_ops in { + if NVVM_WMMA_SUPPORTED.ret then { + def : WMMA_MMA, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + layout_a, layout_b, satf, rnd>; + } + } // op + } // rnd + } // satf + } // layout_b + } // layout_a +} // defset + +// MMA +class MMA + : WMMA_INSTR.record, + [FragA.Ins, FragB.Ins, FragC.Ins]>, + // Requires does not seem to have effect on Instruction w/o Patterns. + // We set it here anyways and propagate to the Pat<> we construct below. + Requires { + let OutOperandList = FragD.Outs; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + string TypeList = "." # FragD.ptx_elt_type + # "." # FragA.ptx_elt_type + # "." # FragB.ptx_elt_type + # "." # FragC.ptx_elt_type; + let AsmString = "mma.sync.aligned." + # FragA.geom + # "." # ALayout + # "." # BLayout + # !if(Satfinite, ".satfinite", "") + # TypeList + # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"; } defset list MMAs = { @@ -7806,11 +7885,11 @@ foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { if NVVM_MMA_SUPPORTED.ret then { - def : WMMA_MMA, - WMMA_REGINFO, - WMMA_REGINFO, - WMMA_REGINFO, - layout_a, layout_b, satf>; + def : MMA, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + layout_a, layout_b, satf>; } } // op } // satf @@ -7822,12 +7901,12 @@ // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. -class WMMA_PAT +class MMA_PAT : Pat, Requires; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, MMA_LDSTs) in - def : WMMA_PAT; +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in + def : MMA_PAT; diff --git a/llvm/test/CodeGen/NVPTX/lit.local.cfg b/llvm/test/CodeGen/NVPTX/lit.local.cfg --- a/llvm/test/CodeGen/NVPTX/lit.local.cfg +++ b/llvm/test/CodeGen/NVPTX/lit.local.cfg @@ -1,2 +1,3 @@ if not 'NVPTX' in config.root.targets: config.unsupported = True +config.suffixes.add('.py') diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -6,7 +6,7 @@ # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA +# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \ # RUN: | FileCheck %t-ptx60-sm_70.ll @@ -15,7 +15,7 @@ # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA +# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \ # RUN: | FileCheck %t-ptx61-sm_70.ll @@ -24,7 +24,7 @@ # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \ -# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA +# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \ # RUN: | FileCheck %t-ptx63-sm_72.ll @@ -33,7 +33,7 @@ # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \ -# RUN: --check-prefixes=INTRINSICS,NOMMA +# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \ # RUN: | FileCheck %t-ptx63-sm_75.ll @@ -42,10 +42,28 @@ # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT +# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \ # RUN: | FileCheck %t-ptx64-sm_70.ll +# Check all variants of instructions supported by PTX65 on SM75+ +# RUN: python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll +# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \ +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA +# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \ +# RUN: --check-prefixes=INTRINSICS +# RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \ +# RUN: | FileCheck %t-ptx65-sm_75.ll + +# Check all variants of instructions supported by PTX70 on SM80+ +# RUN: python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll +# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \ +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA +# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \ +# RUN: --check-prefixes=INTRINSICS +# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \ +# RUN: | FileCheck %t-ptx70-sm_80.ll + from __future__ import print_function import argparse @@ -56,19 +74,23 @@ def __init__(self, ptx_type): self.ptx_type = ptx_type self.llvm_type = { - "f16" : "<2 x half>", - "f32" : "float", - "s32" : "i32", - "s8" : "i32", - "u8" : "i32", - "s4" : "i32", - "u4" : "i32", - "b1" : "i32", + "f16" : "<2 x half>", + "f32" : "float", + "f64" : "double", + "s32" : "i32", + "s8" : "i32", + "u8" : "i32", + "s4" : "i32", + "u4" : "i32", + "b1" : "i32", + "bf16" : "i32", + "tf32" : "i32", }[ptx_type]; self.ptx_reg_pattern = { "f16" : "%hh[0-9]+", "f32" : "%f[0-9]+", + "f64" : "%fd[0-9]+", }.get(ptx_type, "%r[0-9]+") def __repr__(self): @@ -78,16 +100,8 @@ def __init__(self, geom, frag, ptx_elt_type): self.geom = geom self.frag = frag - self.is_mma = True if geom == "m8n8k4" else False; self.mma_type = MMAType(ptx_elt_type); self.nregs = { - "a:f16" : 2 if self.is_mma else 8, - "b:f16" : 2 if self.is_mma else 8, - "c:f16" : 4, - "d:f16" : 4, - "c:f32" : 8, - "d:f32" : 8, - }.get("%s:%s" % (frag, ptx_elt_type), { # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16 "m16n16k16:a:u8" : 2, "m16n16k16:a:s8" : 2, @@ -110,18 +124,123 @@ "m32n8k16:c:s32" : 8, "m32n8k16:d:s32" : 8, - # u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) - "m8n8k128:a:b1" : 1, + "m8n8k16:a:u8": 1, + "m8n8k16:a:s8": 1, + "m8n8k16:b:u8": 1, + "m8n8k16:b:s8": 1, + "m8n8k16:c:s32": 2, + "m8n8k16:d:s32": 2, + + "m16n8k16:a:u8": 2, + "m16n8k16:a:s8": 2, + "m16n8k16:b:u8": 1, + "m16n8k16:b:s8": 1, + "m16n8k16:c:s32": 4, + "m16n8k16:d:s32": 4, + + "m16n8k32:a:u8": 4, + "m16n8k32:a:s8": 4, + "m16n8k32:b:u8": 2, + "m16n8k32:b:s8": 2, + "m16n8k32:c:s32": 4, + "m16n8k32:d:s32": 4, + + # u4/s4 -> s32 @ m8n8k32 (u4/s4) "m8n8k32:a:u4" : 1, "m8n8k32:a:s4" : 1, - "m8n8k128:b:b1" : 1, "m8n8k32:b:u4" : 1, "m8n8k32:b:s4" : 1, - "m8n8k128:c:s32" : 2, - "m8n8k128:d:s32" : 2, "m8n8k32:c:s32" : 2, "m8n8k32:d:s32" : 2, - }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None)); + + "m16n8k32:a:u4" : 2, + "m16n8k32:a:s4" : 2, + "m16n8k32:b:u4" : 1, + "m16n8k32:b:s4" : 1, + "m16n8k32:c:s32" : 4, + "m16n8k32:d:s32" : 4, + + "m16n8k64:a:u4" : 4, + "m16n8k64:a:s4" : 4, + "m16n8k64:b:u4" : 2, + "m16n8k64:b:s4" : 2, + "m16n8k64:c:s32" : 4, + "m16n8k64:d:s32" : 4, + + # b1 -> s32 @ m8n8k128(b1) + "m8n8k128:a:b1" : 1, + "m8n8k128:b:b1" : 1, + "m8n8k128:c:s32" : 2, + "m8n8k128:d:s32" : 2, + + "m16n8k128:a:b1" : 2, + "m16n8k128:b:b1" : 1, + "m16n8k128:c:s32" : 4, + "m16n8k128:d:s32" : 4, + + "m16n8k256:a:b1" : 4, + "m16n8k256:b:b1" : 2, + "m16n8k256:c:s32" : 4, + "m16n8k256:d:s32" : 4, + + # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16 + "m16n16k16:a:bf16" : 4, + "m16n16k16:b:bf16" : 4, + "m8n32k16:a:bf16" : 2, + "m8n32k16:b:bf16" : 8, + "m32n8k16:a:bf16" : 8, + "m32n8k16:b:bf16" : 2, + + "m16n8k16:a:bf16" : 4, + "m16n8k16:b:bf16" : 2, + "m16n8k16:c:f32" : 4, + "m16n8k16:d:f32" : 4, + "m16n8k8:a:bf16" : 2, + "m16n8k8:b:bf16" : 1, + "m16n8k8:c:f32" : 4, + "m16n8k8:d:f32" : 4, + + "m8n8k4:a:f64" : 1, + "m8n8k4:b:f64" : 1, + "m8n8k4:c:f64" : 2, + "m8n8k4:d:f64" : 2, + + # tf32 -> s32 @ m16n16k8 + "m16n16k8:a:tf32" : 4, + "m16n16k8:b:tf32" : 4, + + "m16n8k4:a:tf32" : 2, + "m16n8k4:b:tf32" : 1, + "m16n8k4:c:f32" : 4, + "m16n8k4:d:f32" : 4, + "m16n8k8:a:tf32" : 4, + "m16n8k8:b:tf32" : 2, + "m16n8k8:c:f32" : 4, + "m16n8k8:d:f32" : 4, + + "m8n8k4:a:f16": 2, + "m8n8k4:b:f16": 2, + "m16n8k8:a:f16": 2, + "m16n8k8:b:f16": 1, + "m16n8k8:c:f16": 2, + "m16n8k8:d:f16": 2, + "m16n8k8:c:f32": 4, + "m16n8k8:d:f32": 4, + "m16n8k16:a:f16": 4, + "m16n8k16:b:f16": 2, + "m16n8k16:c:f16": 2, + "m16n8k16:d:f16": 2, + "m16n8k16:c:f32": 4, + "m16n8k16:d:f32": 4, + }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), { + # All other FP shape/fragment/type combinations have the same size + "a:f16" : 8, + "b:f16" : 8, + "c:f16" : 4, + "d:f16" : 4, + "c:f32" : 8, + "d:f32" : 8, + }.get("%s:%s" % (frag, ptx_elt_type), None)) assert(self.nregs); def __repr__(self): @@ -153,9 +272,13 @@ return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type) in product(geoms, frags, types)] -def get_mma_ops(): - return (make_mma_ops(["m8n8k4"], - ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + +def get_wmma_ops(): + return (make_mma_ops(["m16n16k8"], + ["tf32"], [], ["f32"], []) + + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], + ["bf16"], [], ["f32"], []) + + make_mma_ops(["m8n8k4"], + ["f64"], [], ["f64"], []) + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], @@ -164,20 +287,38 @@ ["s4", "u4"], [], ["s32"], []) + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], [])) + +def get_mma_ops(): + return (make_mma_ops(["m8n8k4"], + ["f64"], [], ["f64"], []) + + make_mma_ops(["m16n8k4", "m16n8k8"], + ["tf32"], [], ["f32"], []) + + make_mma_ops(["m16n8k16", "m16n8k8"], + ["bf16"], [], ["f32"], []) + + make_mma_ops(["m8n8k4", "m16n8k8", "m16n8k16"], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + + make_mma_ops(["m8n8k16", "m16n8k16", "m16n8k32"], + ["s8", "u8"], ["s8", "u8"], ["s32"], []) + + make_mma_ops(["m8n8k32", "m16n8k32", "m16n8k64"], + ["s4", "u4"], ["s4", "u4"], ["s32"], []) + + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], + ["b1"], [], ["s32"], [])) + def get_ldst_ops(kind): ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], - ["a", "b"], ["f16", "u8", "s8"]) + + ["a", "b"], ["f16", "u8", "s8", "bf16"]) + make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]) + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) + - make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])) + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) + + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) + + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) + + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])) return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")] -def is_geom_supported(geom): +def is_wmma_geom_supported(geom): # geometries for FP and ints. - if geom == "m8n8k4": - return ptx_version >= 64 if geom in ["m8n32k16", "m32n8k16"]: return ptx_version >= 61 # geometries for sub-ints. @@ -185,6 +326,21 @@ return ptx_version >= 63 and gpu_arch >= 75 if geom == "m16n16k16": return ptx_version >= 60 + if geom == "m16n8k8": + return ptx_version >= 65 + if geom in ["m16n16k8", "m8n8k4"]: + return ptx_version >= 70 + assert(False) # Unexpected geometry. + +def is_mma_geom_supported(geom): + # geometries for FP and ints. + if geom == "m8n8k4": + return ptx_version >= 64 + if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]: + return ptx_version >= 65 + if geom in ["m16n8k16", "m16n8k4", "m16n8k32", "m16n8k64", "m8n8k128", + "m16n8k128", "m16n8k256"]: + return ptx_version >= 70 assert(False) # Unexpected geometry. def is_type_supported(ptx_type): @@ -192,30 +348,63 @@ return ptx_version >= 63 and gpu_arch >= 72 if ptx_type in ["s4", "u4", "b1"]: return ptx_version >= 63 and gpu_arch >= 75 + if ptx_type in ["bf16", "tf32", "f64"]: + return ptx_version >= 70 return ptx_version >= 60 and gpu_arch >= 70 +def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf): + if not (is_type_supported(op.a.mma_type.ptx_type) + and is_wmma_geom_supported(op.a.geom)): + return False + + # rnd is only supported for FP64 WMMA + if rnd and op.a.mma_type.ptx_type != "f64": + return False + + if satf: + # satfinite for floating points was removed in PTX 6.5 + if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65: + return False + if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]: + return False + + # sub-integer require row/col layout. + if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]: + return layout_a == "row" and layout_b == "col" + return True def is_mma_variant_supported(op, layout_a, layout_b, satf): if not (is_type_supported(op.a.mma_type.ptx_type) - and is_geom_supported(op.a.geom)): + and is_mma_geom_supported(op.a.geom)): + return False + + if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]: + return False + + # If the type of C is f32 then so must the type of D + if (op.a.geom == "m8n8k4" and op.c.mma_type.ptx_type == "f32" + and op.d.mma_type.ptx_type != "f32"): return False - if op.a.geom == "m8n8k4": - if satf: + + # A and B type must be the same. C and D type must be the same + if (op.a.geom == "m16n8k8" + and (op.a.mma_type.ptx_type != op.b.mma_type.ptx_type + or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type)): return False - if op.c.mma_type.ptx_type == "f32": - # If C is f32, D must be, too. - return op.d.mma_type.ptx_type == "f32" - # sub-integer require row/col layout, and no satf. - if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]: - if op.a.mma_type.ptx_type == "b1" and satf: + # C and D type must be the same + if (op.a.geom == "m16n8k16" + and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type): return False + + # Require row/col layout for all MMA except m8n8k4 on FP16 + if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"): return layout_a == "row" and layout_b == "col" return True def is_ldst_variant_supported(frag, layout): if not (is_type_supported(frag.mma_type.ptx_type) - and is_geom_supported(frag.geom)): + and is_wmma_geom_supported(frag.geom)): return False if frag.mma_type.ptx_type in ["s4", "u4", "b1"]: # sub-integer require sm_75 and ptx63, row/col layout for a/b. @@ -396,24 +585,37 @@ return generated_items def mma_signature(op): - if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]: - # int and sub-int ops are identified by input type. - return op.a.mma_type.ptx_type - else: - # the rest are FP ops identified by accumulator & result type. + if op.a.mma_type.ptx_type == "f16": + # FP16 ops identified by accumulator & result type. return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) + elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type: + # other ops are identified by input types. + return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type) + else: + # if input types are the same, it only appears once. + return op.a.mma_type.ptx_type def mma_ptx_signature(op): - if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]: - # int and sub-int instructions encode all four types as D.A.B.C - return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c)) - if op.a.geom == "m8n8k4": - return "%s.f16.f16.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) + # Encode all four types as D.A.B.C + return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c)) + +def wmma_signature(op): + if op.a.mma_type.ptx_type == "f16": + # FP16 ops identified by accumulator & result type. + return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) else: - # the rest are FP instructions use D.C + # other ops are identified by input type. + return op.a.mma_type.ptx_type + +def wmma_ptx_signature(op): + if op.a.mma_type.ptx_type == "f16": + # FP16 instructions use D.C return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) + else: + # other instructions encode all four types as D.A.B.C + return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c)) -def gen_wmma_mma_tests(): +def common_mma_test_gen(params, op, intrinsic_template, instruction_template): mma_template = """ declare ${ret_ty} @${intrinsic}( ${args}); @@ -431,10 +633,61 @@ ret ${ret_ty} %r; } """ - wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" - wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}" - mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}.${intrinsic_signature}" - mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}.${ptx_signature}" + + test_params = params + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".", "_") + test_params["instruction"] = Template(instruction_template).substitute(params) + test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d) + test_params["check_a"] = check_pattern(op.a) + test_params["check_b"] = check_pattern(op.b) + test_params["check_c"] = check_pattern(op.c) + test_params["check_d"] = check_pattern(op.d) + args = ",\n ".join(make_wmma_slice_args(frag) + for frag in (op.a, op.b, op.c)) + test_params["args"] = args + print(Template(mma_template).substitute(test_params)) + return (test_params["intrinsic"], test_params["instruction"]) + +def gen_wmma_mma_tests(): + wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}" + wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}" + + generated_items=[] + + for op, alayout, blayout, rnd, satf in product( + get_wmma_ops(), + ["row","col"], + ["row","col"], + [".rn", ".rz", ".rm", ".rp", ""], + [".satfinite", ""]): + + if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf): + continue + + params = { + "aligned" : ".aligned" if ptx_version >= 63 else "", + "alayout" : alayout, + "blayout" : blayout, + "intrinsic_signature" : wmma_signature(op), + "ptx_signature" : wmma_ptx_signature(op), + "satf" : satf, + "rnd" : rnd, + "geom" : op.a.geom, + "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "", + } + + intrinsic_template = wmma_intrinsic_template + instruction_template = wmma_instruction_template + + generated_items.append(common_mma_test_gen(params, op, + intrinsic_template, instruction_template)) + + return generated_items + +def gen_mma_tests(): + mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}" + mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}" generated_items=[] @@ -458,28 +711,11 @@ "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "", } - if op.a.geom == "m8n8k4": - intrinsic_template = mma_intrinsic_template - instruction_template = mma_instruction_template - else: - intrinsic_template = wmma_intrinsic_template - instruction_template = wmma_instruction_template + intrinsic_template = mma_intrinsic_template + instruction_template = mma_instruction_template - test_params = params - test_params["intrinsic"] = Template(intrinsic_template).substitute(params) - test_params["function"] = test_params["intrinsic"].replace(".", "_") - test_params["instruction"] = Template(instruction_template).substitute(params) - test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d) - test_params["check_a"] = check_pattern(op.a) - test_params["check_b"] = check_pattern(op.b) - test_params["check_c"] = check_pattern(op.c) - test_params["check_d"] = check_pattern(op.d) - args = ",\n ".join(make_wmma_slice_args(frag) - for frag in (op.a, op.b, op.c)) - test_params["args"] = args - print(Template(mma_template).substitute(test_params)) - generated_items.append((test_params["intrinsic"], - test_params["instruction"])) + generated_items.append(common_mma_test_gen(params, op, + intrinsic_template, instruction_template)) return generated_items @@ -497,6 +733,8 @@ ; NOINT-NOT: .{{s32|s8}} ; NOSUBINT-NOT: {{s4|u4|b1}} ; NOMMA-NOT: .m8n8k4. +; NOALTFLOAT-NOT: .{{bf16|tf32}} +; NODOUBLE-NOT: .f64 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p @@ -543,10 +781,61 @@ ; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4 ; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1 +; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p +; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p +; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p +; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p +; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16 +; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16 +; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16 +; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32 + +; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p +; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p +; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64 + ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32 + +; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16 +; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32 +; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8 +; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8 +; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8 +; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8 +; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4 +; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4 +; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4 +; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4 + +; PTX70MMA-DAG: mma.m8n8k4.row.col.f64 +; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32 +; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32 +; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16 +; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16 +; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16 +; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32 +; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8 +; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8 +; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8 +; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8 +; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8 +; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8 +; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8 +; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8 +; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4 +; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4 +; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4 +; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4 +; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4 +; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4 +; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4 +; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4 +; PTX70MMA-DAG: mma.m8n8k128.row.col.b1 +; PTX70MMA-DAG: mma.m16n8k128.row.col.b1 +; PTX70MMA-DAG: mma.m16n8k256.row.col.b1 ; """) @@ -561,6 +850,7 @@ items = gen_wmma_load_tests() items += gen_wmma_store_tests() items += gen_wmma_mma_tests() + items += gen_mma_tests() gen_check_unsupported_ops(items) parser = argparse.ArgumentParser()