diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4069,7 +4069,7 @@ !interleave(widths, "/") # "-bit signless/unsigned integer">; def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">; -def SPIRV_IsCooperativeMatrixType : +def SPIRV_IsCooperativeMatrixNVType : CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">; def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">; def SPIRV_IsJointMatrixType : @@ -4100,9 +4100,9 @@ "any SPIR-V pointer type">; def SPIRV_AnyArray : DialectType; -def SPIRV_AnyCooperativeMatrix : DialectType; +def SPIRV_AnyCooperativeMatrixNV : DialectType; def SPIRV_AnyImage : DialectType; def SPIRV_AnyJointMatrix : DialectType; def SPIRV_Composite : AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, - SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; + SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; def SPIRV_Type : AnyTypeOf<[ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector, SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, - SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, + SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage ]>; def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; def SPIRV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; -class SPIRV_CoopMatrixOfType allowedTypes> : - ContainerType, SPIRV_IsCooperativeMatrixType, +class SPIRV_CoopMatrixNVOfType allowedTypes> : + ContainerType, SPIRV_IsCooperativeMatrixNVType, "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()", - "Cooperative Matrix">; + "Cooperative Matrix NV">; class SPIRV_JointMatrixOfType allowedTypes> : ContainerType, SPIRV_IsJointMatrixType, @@ -4147,10 +4147,10 @@ class SPIRV_ScalarOrVectorOrCoopMatrixOf : AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, - SPIRV_CoopMatrixOfType<[type]>]>; + SPIRV_CoopMatrixNVOfType<[type]>]>; class SPIRV_MatrixOrCoopMatrixOf : - AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>; + AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixNVOfType<[type]>]>; def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>; def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -13,6 +13,10 @@ #ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS #define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS +//===----------------------------------------------------------------------===// +// SPV_NV_cooperative_matrix extension ops. +//===----------------------------------------------------------------------===// + // ----- def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength", @@ -35,7 +39,7 @@ For example: ``` - %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix + %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix ``` }]; @@ -111,7 +115,7 @@ ``` %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor - : !spirv.ptr as !spirv.coopmatrix + : !spirv.ptr as !spirv.NV.coopmatrix ``` }]; @@ -130,7 +134,7 @@ ); let results = (outs - SPIRV_AnyCooperativeMatrix:$result + SPIRV_AnyCooperativeMatrixNV:$result ); } @@ -182,7 +186,7 @@ ``` %0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2, : - !spirv.coopmatrix + !spirv.NV.coopmatrix ``` }]; @@ -198,13 +202,13 @@ ]; let arguments = (ins - SPIRV_AnyCooperativeMatrix:$a, - SPIRV_AnyCooperativeMatrix:$b, - SPIRV_AnyCooperativeMatrix:$c + SPIRV_AnyCooperativeMatrixNV:$a, + SPIRV_AnyCooperativeMatrixNV:$b, + SPIRV_AnyCooperativeMatrixNV:$c ); let results = (outs - SPIRV_AnyCooperativeMatrix:$result + SPIRV_AnyCooperativeMatrixNV:$result ); } @@ -247,7 +251,7 @@ ``` spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 : - !spirv.ptr, !spirv.coopmatrix + !spirv.ptr, !spirv.NV.coopmatrix ``` }]; @@ -260,7 +264,7 @@ let arguments = (ins SPIRV_AnyPtr:$pointer, - SPIRV_AnyCooperativeMatrix:$object, + SPIRV_AnyCooperativeMatrixNV:$object, SPIRV_Integer:$stride, SPIRV_Bool:$columnmajor, OptionalAttr:$memory_access diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -27,7 +27,7 @@ namespace detail { struct ArrayTypeStorage; -struct CooperativeMatrixTypeStorage; +struct CooperativeMatrixNVTypeStorage; struct ImageTypeStorage; struct JointMatrixTypeStorage; struct MatrixTypeStorage; @@ -398,10 +398,10 @@ llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); -// SPIR-V cooperative matrix type +// SPIR-V NV cooperative matrix type class CooperativeMatrixNVType : public Type::TypeBase { + detail::CooperativeMatrixNVTypeStorage> { public: using Base::Base; @@ -409,11 +409,11 @@ unsigned rows, unsigned columns); Type getElementType() const; - /// Return the scope of the cooperative matrix. + /// Returns the scope of the matrix. Scope getScope() const; - /// return the number of rows of the matrix. + /// Returns the number of rows of the matrix. unsigned getRows() const; - /// return the number of columns of the matrix. + /// Returns the number of columns of the matrix. unsigned getColumns() const; void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -786,7 +786,7 @@ if (keyword == "array") return parseArrayType(*this, parser); - if (keyword == "coopmatrix") + if (keyword == "NV.coopmatrix") return parseCooperativeMatrixType(*this, parser); if (keyword == "jointmatrix") return parseJointMatrixType(*this, parser); @@ -891,7 +891,7 @@ } static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { - os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; + os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; os << type.getElementType() << ", " << stringifyScope(type.getScope()); os << ">"; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -203,23 +203,23 @@ } //===----------------------------------------------------------------------===// -// CooperativeMatrixType +// CooperativeMatrixNVType //===----------------------------------------------------------------------===// -struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage { +struct spirv::detail::CooperativeMatrixNVTypeStorage : public TypeStorage { using KeyTy = std::tuple; - static CooperativeMatrixTypeStorage * + static CooperativeMatrixNVTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key) { - return new (allocator.allocate()) - CooperativeMatrixTypeStorage(key); + return new (allocator.allocate()) + CooperativeMatrixNVTypeStorage(key); } bool operator==(const KeyTy &key) const { return key == KeyTy(elementType, scope, rows, columns); } - CooperativeMatrixTypeStorage(const KeyTy &key) + CooperativeMatrixNVTypeStorage(const KeyTy &key) : elementType(std::get<0>(key)), rows(std::get<2>(key)), columns(std::get<3>(key)), scope(std::get<1>(key)) {} diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -138,9 +138,9 @@ // ----- -func.func @convert_f_to_u_coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) { - // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.ConvertFToU %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xi32, Subgroup> +func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { + // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } @@ -222,9 +222,9 @@ // ----- -func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) { - // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xf64, Subgroup> - %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xf64, Subgroup> +func.func @f_convert_coop_matrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> + %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -29,10 +29,10 @@ // ----- -func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> { - // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> - %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.coopmatrix<8x16xf32, Subgroup> +func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { + // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> + return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> } // ----- @@ -53,18 +53,18 @@ // ----- -func.func @composite_construct_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> { +func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} - %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.coopmatrix<8x16xf32, Subgroup> + %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> + return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> } // ----- -func.func @composite_construct_coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.coopmatrix<8x16xf32, Subgroup> { +func.func @composite_construct_NV.coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}} - %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.coopmatrix<8x16xf32, Subgroup> + %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> + return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> } // ----- @@ -121,9 +121,9 @@ // ----- -func.func @composite_extract_coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) -> f32 { - // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.coopmatrix<8x16xf32, Subgroup> - %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.coopmatrix<8x16xf32, Subgroup> +func.func @composite_extract_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) -> f32 { + // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup> return %0 : f32 } @@ -249,10 +249,10 @@ // ----- -func.func @composite_insert_coopmatrix(%arg0: !spirv.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.coopmatrix<8x16xi32, Subgroup> { - // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.coopmatrix<8x16xi32, Subgroup> - return %0: !spirv.coopmatrix<8x16xi32, Subgroup> +func.func @composite_insert_NV.coopmatrix(%arg0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> { + // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup> + return %0: !spirv.NV.coopmatrix<8x16xi32, Subgroup> } // ----- diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -2,150 +2,150 @@ // CHECK-LABEL: @cooperative_matrix_load spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> spirv.Return } // ----- // CHECK-LABEL: @cooperative_matrix_load_memaccess spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_store -spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup> +spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> + spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_store_memaccess -spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { + // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> + spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_length spirv.func @cooperative_matrix_length() -> i32 "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.ReturnValue %0 : i32 } // CHECK-LABEL: @cooperative_matrix_muladd -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x32xi8, Subgroup>, !spirv.coopmatrix<32x8xi8, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x32xi8, Subgroup>, !spirv.coopmatrix<32x8xi8, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_add -spirv.func @cooperative_matrix_add(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.IAdd %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_sub -spirv.func @cooperative_matrix_sub(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.ISub %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_sdiv -spirv.func @cooperative_matrix_sdiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.SDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_udiv -spirv.func @cooperative_matrix_udiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup> - %r = spirv.UDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup> +spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fadd -spirv.func @cooperative_matrix_fadd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FAdd %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> +spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fsub -spirv.func @cooperative_matrix_fsub(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FSub %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> +spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // CHECK-LABEL: @cooperative_matrix_fdiv -spirv.func @cooperative_matrix_fdiv(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FDiv %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup> +spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> + %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> spirv.Return } // ----- // CHECK-LABEL: @cooperative_matrix_access_chain -spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { +spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { %0 = spirv.Constant 0: i32 - // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 - %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 + // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 + %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 spirv.ReturnValue %1 : !spirv.ptr } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<16x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Workgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{matrix A and B non-integer element types must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xf32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- -spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { +spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{matrix A and B integer element types must be the same bit width}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xui8, Subgroup>, !spirv.coopmatrix<16x8xsi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> spirv.Return } @@ -153,7 +153,7 @@ spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { // expected-error @+1 {{Pointer must point to a scalar or vector type}} - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } @@ -161,6 +161,6 @@ spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}} - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir @@ -9,10 +9,10 @@ } // CHECK-LABEL: @matrix_times_scalar_2 - spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" { - // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 - spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup> + spirv.func @matrix_times_scalar_2(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 + spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup> } // CHECK-LABEL: @matrix_transpose_1 diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -797,7 +797,7 @@ } //===----------------------------------------------------------------------===// -// spirv.SpecConstantComposite (spirv.coopmatrix) +// spirv.SpecConstantComposite (spirv.NV.coopmatrix) //===----------------------------------------------------------------------===// // ----- @@ -805,7 +805,7 @@ spirv.module Logical GLSL450 { spirv.SpecConstant @sc1 = 1.5 : f32 // expected-error @+1 {{unsupported composite type}} - spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device> + spirv.SpecConstantComposite @scc (@sc1) : !spirv.NV.coopmatrix<8x16xf32, Device> } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -439,18 +439,18 @@ // CooperativeMatrix //===----------------------------------------------------------------------===// -// CHECK: func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xf32, Workgroup>) -func.func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xf32, Workgroup>) -> () +// CHECK: func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) +func.func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> () // ----- // expected-error @+1 {{expected ','}} -func.func private @missing_scope(!spirv.coopmatrix<8x16xi32>) -> () +func.func private @missing_scope(!spirv.NV.coopmatrix<8x16xi32>) -> () // ----- // expected-error @+1 {{expected rows and columns size}} -func.func private @missing_count(!spirv.coopmatrix<8xi32, Subgroup>) -> () +func.func private @missing_count(!spirv.NV.coopmatrix<8xi32, Subgroup>) -> () // -----