diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4069,7 +4069,7 @@ !interleave(widths, "/") # "-bit signless/unsigned integer">; def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">; -def SPIRV_IsCooperativeMatrixType : +def SPIRV_IsCooperativeMatrixNVType : CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">; def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">; def SPIRV_IsJointMatrixType : @@ -4100,9 +4100,9 @@ "any SPIR-V pointer type">; def SPIRV_AnyArray : DialectType; -def SPIRV_AnyCooperativeMatrix : DialectType; +def SPIRV_AnyCooperativeMatrixNV : DialectType; def SPIRV_AnyImage : DialectType; def SPIRV_AnyJointMatrix : DialectType; def SPIRV_Composite : AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, - SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; + SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; def SPIRV_Type : AnyTypeOf<[ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector, SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, - SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, + SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage ]>; def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; def SPIRV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; -class SPIRV_CoopMatrixOfType allowedTypes> : - ContainerType, SPIRV_IsCooperativeMatrixType, +class SPIRV_CoopMatrixNVOfType allowedTypes> : + ContainerType, SPIRV_IsCooperativeMatrixNVType, "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()", - "Cooperative Matrix">; + "Cooperative Matrix NV">; class SPIRV_JointMatrixOfType allowedTypes> : ContainerType, SPIRV_IsJointMatrixType, @@ -4147,10 +4147,10 @@ class SPIRV_ScalarOrVectorOrCoopMatrixOf : AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, - SPIRV_CoopMatrixOfType<[type]>]>; + SPIRV_CoopMatrixNVOfType<[type]>]>; class SPIRV_MatrixOrCoopMatrixOf : - AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>; + AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixNVOfType<[type]>]>; def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>; def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -13,6 +13,10 @@ #ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS #define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS +//===----------------------------------------------------------------------===// +// SPV_NV_cooperative_matrix extension ops. +//===----------------------------------------------------------------------===// + // ----- def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength", @@ -35,7 +39,7 @@ For example: ``` - %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix + %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix ``` }]; @@ -111,7 +115,7 @@ ``` %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor - : !spirv.ptr as !spirv.coopmatrix + : !spirv.ptr as !spirv.NV.coopmatrix ``` }]; @@ -130,7 +134,7 @@ ); let results = (outs - SPIRV_AnyCooperativeMatrix:$result + SPIRV_AnyCooperativeMatrixNV:$result ); } @@ -182,7 +186,7 @@ ``` %0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2, : - !spirv.coopmatrix + !spirv.NV.coopmatrix ``` }]; @@ -198,13 +202,13 @@ ]; let arguments = (ins - SPIRV_AnyCooperativeMatrix:$a, - SPIRV_AnyCooperativeMatrix:$b, - SPIRV_AnyCooperativeMatrix:$c + SPIRV_AnyCooperativeMatrixNV:$a, + SPIRV_AnyCooperativeMatrixNV:$b, + SPIRV_AnyCooperativeMatrixNV:$c ); let results = (outs - SPIRV_AnyCooperativeMatrix:$result + SPIRV_AnyCooperativeMatrixNV:$result ); } @@ -247,7 +251,7 @@ ``` spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 : - !spirv.ptr, !spirv.coopmatrix + !spirv.ptr, !spirv.NV.coopmatrix ``` }]; @@ -260,7 +264,7 @@ let arguments = (ins SPIRV_AnyPtr:$pointer, - SPIRV_AnyCooperativeMatrix:$object, + SPIRV_AnyCooperativeMatrixNV:$object, SPIRV_Integer:$stride, SPIRV_Bool:$columnmajor, OptionalAttr:$memory_access diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -27,7 +27,7 @@ namespace detail { struct ArrayTypeStorage; -struct CooperativeMatrixTypeStorage; +struct CooperativeMatrixNVTypeStorage; struct ImageTypeStorage; struct JointMatrixTypeStorage; struct MatrixTypeStorage; @@ -398,10 +398,10 @@ llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); -// SPIR-V cooperative matrix type +// SPIR-V NV cooperative matrix type class CooperativeMatrixNVType : public Type::TypeBase { + detail::CooperativeMatrixNVTypeStorage> { public: using Base::Base; @@ -409,11 +409,11 @@ unsigned rows, unsigned columns); Type getElementType() const; - /// Return the scope of the cooperative matrix. + /// Returns the scope of the matrix. Scope getScope() const; - /// return the number of rows of the matrix. + /// Returns the number of rows of the matrix. unsigned getRows() const; - /// return the number of columns of the matrix. + /// Returns the number of columns of the matrix. unsigned getColumns() const; void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -318,9 +318,8 @@ return ArrayType::get(elementType, count, stride); } -// cooperative-matrix-type ::= `!spirv.coopmatrix` `<` element-type ',' scope -// ',' -// rows ',' columns>` +// cooperative-matrix-type ::= +// `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type ',' scope> static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser) { if (parser.parseLess()) @@ -786,7 +785,7 @@ if (keyword == "array") return parseArrayType(*this, parser); - if (keyword == "coopmatrix") + if (keyword == "NV.coopmatrix") return parseCooperativeMatrixType(*this, parser); if (keyword == "jointmatrix") return parseJointMatrixType(*this, parser); @@ -891,7 +890,7 @@ } static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { - os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; + os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; os << type.getElementType() << ", " << stringifyScope(type.getScope()); os << ">"; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -203,23 +203,23 @@ } //===----------------------------------------------------------------------===// -// CooperativeMatrixType +// CooperativeMatrixNVType //===----------------------------------------------------------------------===// -struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage { +struct spirv::detail::CooperativeMatrixNVTypeStorage : public TypeStorage { using KeyTy = std::tuple; - static CooperativeMatrixTypeStorage * + static CooperativeMatrixNVTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key) { - return new (allocator.allocate()) - CooperativeMatrixTypeStorage(key); + return new (allocator.allocate()) + CooperativeMatrixNVTypeStorage(key); } bool operator==(const KeyTy &key) const { return key == KeyTy(elementType, scope, rows, columns); } - CooperativeMatrixTypeStorage(const KeyTy &key) + CooperativeMatrixNVTypeStorage(const KeyTy &key) : elementType(std::get<0>(key)), rows(std::get<2>(key)), columns(std::get<3>(key)), scope(std::get<1>(key)) {} diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir @@ -11,7 +11,7 @@ %i = arith.constant 16 : index %j = arith.constant 16 : index // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr as !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr as !spirv.NV.coopmatrix<16x16xf16, Subgroup> %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: spirv.Return gpu.return @@ -33,7 +33,7 @@ %i = arith.constant 16 : index %j = arith.constant 16 : index // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr as !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr as !spirv.NV.coopmatrix<16x16xf16, Subgroup> %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: spirv.Return gpu.return @@ -49,13 +49,13 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @gpu_wmma_store_op // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> - // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %i = arith.constant 16 : index %j = arith.constant 16 : index // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr, !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr, !spirv.NV.coopmatrix<16x16xf16, Subgroup> gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> // CHECK: spirv.Return gpu.return @@ -71,14 +71,14 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) + // CHECK-SAME: {{%.*}}: !spirv.NV.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi gpu.func @gpu_wmma_store_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %i = arith.constant 16 : index %j = arith.constant 16 : index // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr, !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr, !spirv.NV.coopmatrix<16x16xf16, Subgroup> gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> // CHECK: spirv.Return gpu.return @@ -93,12 +93,12 @@ spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { // CHECK-LABEL: spirv.func @gpu_wmma_mma_op - // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> - // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> - // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, !spirv.NV.coopmatrix<16x16xf16, Subgroup> -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: spirv.Return gpu.return @@ -117,7 +117,7 @@ attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // CHECK: {{%.*}} = spirv.Constant %cst = arith.constant 1.0 : f16 - // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: spirv.Return gpu.return @@ -132,15 +132,15 @@ spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default - // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> - // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: spirv.Return gpu.return @@ -155,14 +155,14 @@ spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar - // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup> // CHECK-SAME: %[[S:.+]]: f16 gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: spirv.Return gpu.return @@ -177,13 +177,13 @@ spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { gpu.module @kernels { // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar - // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK-SAME: %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup> // CHECK-SAME: %[[S:.+]]: f16 gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup> + // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup> %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> gpu.return } diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -138,9 +138,9 @@ // ----- -func.func @convert_f_to_u_coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) { - // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.ConvertFToU %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xi32, Subgroup> +func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { + // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } @@ -222,9 +222,9 @@ // ----- -func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) { - // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xf64, Subgroup> - %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xf64, Subgroup> +func.func @f_convert_coop_matrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> + %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -29,10 +29,10 @@ // ----- -func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> { - // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> - %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.coopmatrix<8x16xf32, Subgroup> +func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { + // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> + return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> } // ----- @@ -53,18 +53,18 @@ // ----- -func.func @composite_construct_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> { +func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} - %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.coopmatrix<8x16xf32, Subgroup> + %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> + return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> } // ----- -func.func @composite_construct_coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.coopmatrix<8x16xf32, Subgroup> { +func.func @composite_construct_NV.coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}} - %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.coopmatrix<8x16xf32, Subgroup> + %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> + return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> } // ----- @@ -121,9 +121,9 @@ // ----- -func.func @composite_extract_coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) -> f32 { - // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.coopmatrix<8x16xf32, Subgroup> - %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.coopmatrix<8x16xf32, Subgroup> +func.func @composite_extract_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) -> f32 { + // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup> return %0 : f32 } @@ -249,10 +249,10 @@ // ----- -func.func @composite_insert_coopmatrix(%arg0: !spirv.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.coopmatrix<8x16xi32, Subgroup> { - // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.coopmatrix<8x16xi32, Subgroup> - return %0: !spirv.coopmatrix<8x16xi32, Subgroup> +func.func @composite_insert_NV.coopmatrix(%arg0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> { + // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup> + return %0: !spirv.NV.coopmatrix<8x16xi32, Subgroup> } // ----- diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -2,150 +2,150 @@ // CHECK-LABEL: @cooperative_matrix_load spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> spirv.Return } // ----- // CHECK-LABEL: @cooperative_matrix_load_memaccess spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_store -spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup> +spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> + spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_store_memaccess -spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> + spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_length spirv.func @cooperative_matrix_length() -> i32 "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.ReturnValue %0 : i32 } // CHECK-LABEL: @cooperative_matrix_muladd -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x32xi8, Subgroup>, !spirv.coopmatrix<32x8xi8, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x32xi8, Subgroup>, !spirv.coopmatrix<32x8xi8, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_add -spirv.func @cooperative_matrix_add(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.IAdd %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_sub -spirv.func @cooperative_matrix_sub(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.ISub %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_sdiv -spirv.func @cooperative_matrix_sdiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.SDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_udiv -spirv.func @cooperative_matrix_udiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.UDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fadd -spirv.func @cooperative_matrix_fadd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FAdd %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> +spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fsub -spirv.func @cooperative_matrix_fsub(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FSub %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> +spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fdiv -spirv.func @cooperative_matrix_fdiv(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FDiv %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> +spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // ----- // CHECK-LABEL: @cooperative_matrix_access_chain -spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { +spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { %0 = spirv.Constant 0: i32 - // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 - %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 + // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 + %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 spirv.ReturnValue %1 : !spirv.ptr } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<16x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Workgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{matrix A and B non-integer element types must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xf32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{matrix A and B integer element types must be the same bit width}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xui8, Subgroup>, !spirv.coopmatrix<16x8xsi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } @@ -153,7 +153,7 @@ spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { // expected-error @+1 {{Pointer must point to a scalar or vector type}} - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } @@ -161,6 +161,6 @@ spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}} - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir @@ -9,10 +9,10 @@ } // CHECK-LABEL: @matrix_times_scalar_2 - spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" { - // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 - spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup> + spirv.func @matrix_times_scalar_2(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 + spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup> } // CHECK-LABEL: @matrix_transpose_1 diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -797,7 +797,7 @@ } //===----------------------------------------------------------------------===// -// spirv.SpecConstantComposite (spirv.coopmatrix) +// spirv.SpecConstantComposite (spirv.NV.coopmatrix) //===----------------------------------------------------------------------===// // ----- @@ -805,7 +805,7 @@ spirv.module Logical GLSL450 { spirv.SpecConstant @sc1 = 1.5 : f32 // expected-error @+1 {{unsupported composite type}} - spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device> + spirv.SpecConstantComposite @scc (@sc1) : !spirv.NV.coopmatrix<8x16xf32, Device> } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -439,18 +439,18 @@ // CooperativeMatrix //===----------------------------------------------------------------------===// -// CHECK: func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xf32, Workgroup>) -func.func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xf32, Workgroup>) -> () +// CHECK: func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) +func.func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> () // ----- // expected-error @+1 {{expected ','}} -func.func private @missing_scope(!spirv.coopmatrix<8x16xi32>) -> () +func.func private @missing_scope(!spirv.NV.coopmatrix<8x16xi32>) -> () // ----- // expected-error @+1 {{expected rows and columns size}} -func.func private @missing_count(!spirv.coopmatrix<8xi32, Subgroup>) -> () +func.func private @missing_count(!spirv.NV.coopmatrix<8xi32, Subgroup>) -> () // ----- diff --git a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir --- a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir +++ b/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir @@ -3,100 +3,100 @@ spirv.module Logical GLSL450 requires #spirv.vce { // CHECK-LABEL: @cooperative_matrix_load spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_load_memaccess spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_store - spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup> + spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" { + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<16x8xi32, Workgroup> + spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<16x8xi32, Workgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_store_memaccess - spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup> + spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> + spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_length spirv.func @cooperative_matrix_length() -> i32 "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.ReturnValue %0 : i32 } // CHECK-LABEL: @cooperative_matrix_muladd - spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_add - spirv.func @cooperative_matrix_add(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.IAdd %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> + spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_sub - spirv.func @cooperative_matrix_sub(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.ISub %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> + spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_sdiv - spirv.func @cooperative_matrix_sdiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.SDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> + spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_udiv - spirv.func @cooperative_matrix_udiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.UDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> + spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fadd - spirv.func @cooperative_matrix_fadd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FAdd %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> + spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fsub - spirv.func @cooperative_matrix_fsub(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FSub %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> + spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fdiv - spirv.func @cooperative_matrix_fdiv(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FDiv %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> + spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_access_chain - spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { + spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { %0 = spirv.Constant 0: i32 - // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 - %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 + // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 + %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 spirv.ReturnValue %1 : !spirv.ptr } } diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir --- a/mlir/test/Target/SPIRV/matrix.mlir +++ b/mlir/test/Target/SPIRV/matrix.mlir @@ -23,10 +23,10 @@ } // CHECK-LABEL: @matrix_times_scalar_3 - spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" { - // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 - spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup> + spirv.func @matrix_times_scalar_3(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 + spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup> } // CHECK-LABEL: @matrix_transpose_1