diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td @@ -36,6 +36,25 @@ def SPV_CapabilityArrayAttr : TypedArrayAttrBase< SPV_CapabilityAttr, "SPIR-V capability array attribute">; +// Description of Cooperative matrix operations supported on the +// target. Represents `VkCooperativeMatrixPropertiesNV`. See +// https://renderdoc.org/vkspec_chunked/chap9.html#VkCooperativeMatrixPropertiesNV +def SPV_CooperativeMatrixPropertiesNVAttr : + StructAttr<"CooperativeMatrixPropertiesNV", SPIRV_Dialect, [ + StructFieldAttr<"mSize", I32Attr>, + StructFieldAttr<"nSize", I32Attr>, + StructFieldAttr<"kSize", I32Attr>, + StructFieldAttr<"aType", TypeAttr>, + StructFieldAttr<"bType", TypeAttr>, + StructFieldAttr<"cType", TypeAttr>, + StructFieldAttr<"resultType", TypeAttr>, + StructFieldAttr<"scope", SPV_ScopeAttr> +]>; + +def SPV_CooperativeMatrixPropertiesNVArrayAttr : + TypedArrayAttrBase; + // This attribute specifies the limits for various resources on the target // architecture. // @@ -60,7 +79,13 @@ // The default number of invocations in each subgroup. // 0x7FFFFFFF means unknown. - StructFieldAttr<"subgroup_size", DefaultValuedAttr> + StructFieldAttr<"subgroup_size", DefaultValuedAttr>, + + // The configurations of cooperative matrix operations + // supported. Default is an empty list. + StructFieldAttr< + "cooperative_matrix_properties_nv", + DefaultValuedAttr> ]>; #endif // SPIRV_TARGET_AND_ABI diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -140,7 +140,8 @@ /*max_compute_shared_memory_size=*/nullptr, /*max_compute_workgroup_invocations=*/nullptr, /*max_compute_workgroup_size=*/nullptr, - /*subgroup_size=*/nullptr, context); + /*subgroup_size=*/nullptr, + /*cooperative_matrix_properties_nv=*/nullptr, context); } StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; } diff --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir --- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir @@ -171,6 +171,44 @@ // ----- +func @target_env_cooperative_matrix() attributes{ + // CHECK: spv.target_env = #spv.target_env< + // CHECK-SAME: SPV_NV_cooperative_matrix + // CHECK-SAME: cooperative_matrix_properties_nv = [ + // CHECK-SAME: {aType = i8, bType = i8, cType = i32, + // CHECK-SAME: kSize = 32 : i32, mSize = 8 : i32, nSize = 8 : i32 + // CHECK-SAME: resultType = i32, scope = 3 : i32} + // CHECK-SAME: {aType = f16, bType = f16, cType = f16, + // CHECK-SAME: kSize = 16 : i32, mSize = 8 : i32, nSize = 8 : i32 + // CHECK-SAME: resultType = f16, scope = 3 : i32} + spv.target_env = #spv.target_env< + #spv.vce, + { + cooperative_matrix_properties_nv = [{ + mSize = 8: i32, + nSize = 8: i32, + kSize = 32: i32, + aType = i8, + bType = i8, + cType = i32, + resultType = i32, + scope = 3: i32 + }, { + mSize = 8: i32, + nSize = 8: i32, + kSize = 16: i32, + aType = f16, + bType = f16, + cType = f16, + resultType = f16, + scope = 3: i32 + }] + }> +} { return } + +// ----- + //===----------------------------------------------------------------------===// // spv.vce //===----------------------------------------------------------------------===//