diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -27,12 +27,12 @@ // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins - SPV_ScalarOrVectorOrCoopMatrixOf:$operand1, - SPV_ScalarOrVectorOrCoopMatrixOf:$operand2 + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand1, + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand2 ); let results = (outs - SPV_ScalarOrVectorOrCoopMatrixOf:$result + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$result ); let assemblyFormat = "operands attr-dict `:` type($result)"; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td @@ -64,6 +64,25 @@ TypedArrayAttrBase; +def SPV_JointMatrixPropertiesINTELAttr : + SPV_Attr<"JointMatrixPropertiesINTEL", "joint_matrix_props"> { + let parameters = (ins + "int":$m_size, + "int":$n_size, + "int":$k_size, + "mlir::Type":$a_type, + "mlir::Type":$b_type, + "mlir::Type":$c_type, + "mlir::Type":$result_type, + "mlir::spirv::ScopeAttr":$scope + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def SPV_JointMatrixPropertiesINTELArrayAttr : + TypedArrayAttrBase; + // This attribute specifies the limits for various resources on the target // architecture. // diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -387,6 +387,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp_fast_math_mode", 4027>; def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>; def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; +def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -442,7 +443,7 @@ SPV_INTEL_fpga_buffer_location, SPV_INTEL_arbitrary_precision_fixed_point, SPV_INTEL_usm_storage_classes, SPV_INTEL_io_pipes, SPV_INTEL_blocking_pipes, SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone, - SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode, + SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode, SPV_INTEL_joint_matrix, SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough, @@ -1390,6 +1391,13 @@ ]; } +def SPV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMatrixINTEL", 6118> { + list implies = [SPV_C_Shader]; + list availability = [ + Extension<[SPV_INTEL_joint_matrix]> + ]; +} + def SPV_CapabilityAttr : SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [ SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16, @@ -1463,6 +1471,7 @@ SPV_C_StorageTexelBufferArrayDynamicIndexing, SPV_C_RayTracingNV, SPV_C_RayTracingMotionBlurNV, SPV_C_PhysicalStorageBufferAddresses, SPV_C_RayTracingProvisionalKHR, SPV_C_CooperativeMatrixNV, + SPV_C_JointMatrixINTEL, SPV_C_FragmentShaderSampleInterlockEXT, SPV_C_FragmentShaderShadingRateInterlockEXT, SPV_C_ShaderSMBuiltinsNV, SPV_C_FragmentShaderPixelInterlockEXT, SPV_C_DemoteToHelperInvocation, @@ -4013,6 +4022,8 @@ def SPV_IsCooperativeMatrixType : CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">; def SPV_IsImageType : CPred<"$_self.isa<::mlir::spirv::ImageType>()">; +def SPV_IsJointMatrixType : + CPred<"$_self.isa<::mlir::spirv::JointMatrixINTELType>()">; def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">; def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; @@ -4043,6 +4054,9 @@ "any SPIR-V cooperative matrix type">; def SPV_AnyImage : DialectType; +def SPV_AnyJointMatrix : DialectType; def SPV_AnyMatrix : DialectType; def SPV_AnyRTArray : DialectType; def SPV_Composite : AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, - SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>; + SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix]>; def SPV_Type : AnyTypeOf<[ SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, - SPV_AnyCooperativeMatrix, SPV_AnyMatrix, SPV_AnySampledImage + SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix, + SPV_AnySampledImage ]>; def SPV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; @@ -4072,6 +4087,11 @@ "$_self.cast<::mlir::spirv::CooperativeMatrixNVType>().getElementType()", "Cooperative Matrix">; +class SPV_JointMatrixOfType allowedTypes> : + ContainerType, SPV_IsJointMatrixType, + "$_self.cast<::mlir::spirv::JointMatrixINTELType>().getElementType()", + "Joint Matrix">; + class SPV_ScalarOrVectorOf : AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>; @@ -4079,6 +4099,14 @@ AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, SPV_CoopMatrixOfType<[type]>]>; +class SPV_ScalarOrVectorOrJointMatrixOf : + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, + SPV_JointMatrixOfType<[type]>]>; + +class SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf : + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, + SPV_CoopMatrixOfType<[type]>, SPV_JointMatrixOfType<[type]> ]>; + def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; @@ -4311,6 +4339,11 @@ def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>; def SPV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>; def SPV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>; +def SPV_OC_OpTypeJointMatrixINTEL : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>; +def SPV_OC_OpJointMatrixLoadINTEL : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>; +def SPV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreINTEL", 6121>; +def SPV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>; +def SPV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>; def SPV_OpcodeAttr : SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ @@ -4376,7 +4409,10 @@ SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL, - SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT + SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT, + SPV_OC_OpTypeJointMatrixINTEL, SPV_OC_OpJointMatrixLoadINTEL, + SPV_OC_OpJointMatrixStoreINTEL, SPV_OC_OpJointMatrixMadINTEL, + SPV_OC_OpTypejointMatrixWorkItemLengthINTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -23,11 +23,11 @@ !listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> { let arguments = (ins - SPV_ScalarOrVectorOrCoopMatrixOf:$operand + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand ); let results = (outs - SPV_ScalarOrVectorOrCoopMatrixOf:$result + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$result ); let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td @@ -0,0 +1,269 @@ +//===- SPIRVJointMatrixOps.td - joint matmul ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of joint matrix multiply extension ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS +#define MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS + +// ----- + +def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL", + [NoSideEffect]> { + let summary = "See extension SPV_INTEL_joint_matrix"; + + let description = [{ + Return number of components owned by the current work-item in + a joint matrix. + + Result Type must be an 32-bit unsigned integer type scalar. + + Type is a joint matrix type. + + ``` {.ebnf} + joint-matrix-length-op ::= ssa-id `=` `spv.JointMatrixWorkItemLengthINTEL + ` : ` joint-matrix-type + ``` + + For example: + + ``` + %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix + ``` + }]; + + let assemblyFormat = "attr-dict `:` $type"; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_joint_matrix]>, + Capability<[SPV_C_JointMatrixINTEL]> + ]; + + let arguments = (ins + TypeAttr:$type + ); + + let results = (outs + SPV_Int32:$result + ); + let hasVerifier = 0; +} + +// ----- + +def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> { + let summary = "See extension SPV_INTEL_joint_matrix"; + + let description = [{ + Declare a matrix type. + + Component Type is the type of each component in the resulting type. + It must be a scalar numerical type. + + Row Count is the number of rows in the matrix type. It must be a + constant unsigned 32-bit integer. Behavior is undefined when Row + Count is 0 or OpConstantNull. + + Column Count is the number of columns in the matrix type. It must + be a constant unsigned 32-bit integer. Behavior is undefined when + Column Count is 0 or OpConstantNull. + + Layout indicates how the values are arranged internally in the + matrix type. It must be the result of a constant instruction. + + Scope is memory scope for operations on the matrix. It must be the + result of a constant instruction with scalar integer type. + + ### Custom assembly form + + ``` {.ebnf} + joint-matrixload-op ::= ssa-id `=` `spv.JointMatrixLoadINTEL` + ssa-use `,` ssa-use `,` ssa-use `, + ` ssa-use + (`[` memory-access `]`)? ` : ` + pointer-type `as` + joint-matrix-type + ``` + + For example: + + ``` + %0 = spv.JointMatrixLoadINTEL %ptr, %stride, %Layout, %Scope + : !spv.ptr as !spv.jointmatrix + ``` + }]; + + let assemblyFormat = [{ + $scope operands attr-dict `:` `(` type(operands) `)` `->` type($result) + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_joint_matrix]>, + Capability<[SPV_C_JointMatrixINTEL]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$scope, + SPV_AnyPtr:$pointer, + SPV_Integer:$stride, + SPV_Integer:$layout, + OptionalAttr:$memory_access + ); + + let results = (outs + SPV_AnyJointMatrix:$result + ); +} + +// ----- + +def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL", + [NoSideEffect, AllTypesMatch<["c", "result"]>]> { + let summary = "See extension SPV_INTEL_joint_matrix"; + + let description = [{ + Multiply matrix A by matrix B and add matrix C to the result + of the multiplication: A*B+C. Here A is a M x K matrix, B is + a K x N matrix and C is a M x N matrix. + + Behavior is undefined if sizes of operands do not meet the + conditions above. All operands and the Result Type must be + OpTypeJointMatrixINTEL. + + A must be a OpTypeJointMatrixINTEL whose Component Type is a + signed numerical type, Row Count equals to M and Column Count + equals to K + + B must be a OpTypeJointMatrixINTEL whose Component Type is a + signed numerical type, Row Count equals to K and Column Count + equals to N + + C and Result Type must be a OpTypeJointMatrixINTEL with Row + Count equals to M and Column Count equals to N + + Scope is syncronization scope for operation on the matrix. + It must be the result of a constant instruction with scalar + integer type. + + ``` {.ebnf} + joint-matrixmuladd-op ::= ssa-id `=` `spv.JointMatrixMulAddINTEL` + ssa-use `,` ssa-use `,` ssa-use ` + ,` ssa_use` : ` + a-joint-matrix-type, + b-joint-matrix-type -> + result-joint-matrix-type + ``` + For example: + + ``` + %0 = spv.JointMatrixMulAddINTEL %arg0, %arg1, %arg2, %scope : + !spv.jointmatrix + ``` + }]; + + let assemblyFormat = [{ + $scope operands attr-dict`:` type($a) `,` type($b) `->` type($c) + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_joint_matrix]>, + Capability<[SPV_C_JointMatrixINTEL]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$scope, + SPV_AnyJointMatrix:$a, + SPV_AnyJointMatrix:$b, + SPV_AnyJointMatrix:$c + ); + + let results = (outs + SPV_AnyJointMatrix:$result + ); +} + +// ----- + +def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> { + let summary = "See extension SPV_INTEL_joint_matrix"; + + let description = [{ + Store a matrix through a pointer. + + Pointer is the pointer to store through. It specifies + start of memory region where elements of the matrix must + be stored and arranged according to Layout. + + Object is the matrix to store. It must be + OpTypeJointMatrixINTEL. + + Stride is the number of elements in memory between beginnings + of successive rows, columns (or words) of the Object. It must + be a scalar integer type. + + Layout indicates how the values stored to memory are arranged. + It must be the result of a constant instruction. + + Scope is syncronization scope for operation on the matrix. + It must be the result of a constant instruction with scalar + integer type. + + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory + operand None. + + ``` {.ebnf} + joint-matrix-store-op ::= `spv.JointMatrixStoreINTEL ` + ssa-use `, ` ssa-use `, ` + ssa-use `, ` ssa-use `, ` + ssa-use `, `(`[` memory-access `]`)? `:` + pointer-type `,` spirv-element-type + ``` + + For example: + + ``` + spv.JointMatrixStoreINTEL %arg0, %arg2, %arg1, %arg3, %arg4 : + !spv.ptr, !spv.jointmatrix + ``` + }]; + + let assemblyFormat = [{ + $scope operands attr-dict `:` `(` type(operands) `)` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_joint_matrix]>, + Capability<[SPV_C_JointMatrixINTEL]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$scope, + SPV_AnyPtr:$pointer, + SPV_AnyJointMatrix:$object, + SPV_Integer:$stride, + SPV_Integer:$layout, + OptionalAttr:$memory_access + ); + + let results = (outs); +} + +// ----- + +#endif // MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td @@ -30,6 +30,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td" +include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td" diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -29,6 +29,7 @@ struct ArrayTypeStorage; struct CooperativeMatrixTypeStorage; struct ImageTypeStorage; +struct JointMatrixTypeStorage; struct MatrixTypeStorage; struct PointerTypeStorage; struct RuntimeArrayTypeStorage; @@ -420,6 +421,30 @@ Optional storage = llvm::None); }; +// SPIR-V joint matrix type +class JointMatrixINTELType + : public Type::TypeBase { +public: + using Base::Base; + + static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, + unsigned columns); + Type getElementType() const; + + /// Return the scope of the joint matrix. + Scope getScope() const; + /// return the number of rows of the matrix. + unsigned getRows() const; + /// return the number of columns of the matrix. + unsigned getColumns() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); +}; + // SPIR-V matrix type class MatrixType : public Type::TypeBase { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -348,6 +348,36 @@ return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]); } +// joint-matrix-type ::= `!spv.jointmatrix` `<` element-type ',' scope ',' +// rows ',' columns>` +static Type parseJointMatrixType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dims; + SMLoc countLoc = parser.getCurrentLocation(); + if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) + return Type(); + + if (dims.size() != 2) { + parser.emitError(countLoc, "expected rows and columns size"); + return Type(); + } + + auto elementTy = parseAndVerifyType(dialect, parser); + if (!elementTy) + return Type(); + + Scope scope; + if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope ")) + return Type(); + + if (parser.parseGreater()) + return Type(); + return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1]); +} + // TODO: Reorder methods to be utilities first and parse*Type // methods in alphabetical order // @@ -753,6 +783,8 @@ return parseArrayType(*this, parser); if (keyword == "coopmatrix") return parseCooperativeMatrixType(*this, parser); + if (keyword == "jointmatrix") + return parseJointMatrixType(*this, parser); if (keyword == "image") return parseImageType(*this, parser); if (keyword == "ptr") @@ -859,6 +891,12 @@ os << ">"; } +static void print(JointMatrixINTELType type, DialectAsmPrinter &os) { + os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; + os << type.getElementType() << ", " << stringifyScope(type.getScope()); + os << ">"; +} + static void print(MatrixType type, DialectAsmPrinter &os) { os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType(); os << ">"; @@ -866,9 +904,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) - .Case( - [&](auto type) { print(type, os); }) + .Case([&](auto type) { print(type, os); }) .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -436,6 +436,13 @@ resultType.cast().getElementType(); } + if (auto jointMatrixType = + operandType.dyn_cast()) { + operandType = jointMatrixType.getElementType(); + resultType = + resultType.cast().getElementType(); + } + auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth(); auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; @@ -1637,6 +1644,17 @@ return success(); } + if (auto jointType = cType.dyn_cast()) { + if (constituents.size() != 1) + return emitOpError("has incorrect number of operands: expected ") + << "1, but provided " << constituents.size(); + if (jointType.getElementType() != constituents.front().getType()) + return emitOpError("operand type mismatch: expected operand type ") + << jointType.getElementType() << ", but provided " + << constituents.front().getType(); + return success(); + } + if (constituents.size() == cType.getNumElements()) { for (auto index : llvm::seq(0, constituents.size())) { if (constituents[index].getType() != cType.getElementType(index)) { @@ -3893,6 +3911,72 @@ return verifyCoopMatrixMulAdd(*this); } +static LogicalResult +verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { + Type pointeeType = pointer.cast().getPointeeType(); + if (!pointeeType.isa() && !pointeeType.isa()) + return op->emitError( + "Pointer must point to a scalar or vector type but provided ") + << pointeeType; + spirv::StorageClass storage = + pointer.cast().getStorageClass(); + if (storage != spirv::StorageClass::Workgroup && + storage != spirv::StorageClass::StorageBuffer && + storage != spirv::StorageClass::PhysicalStorageBuffer) + return op->emitError( + "Pointer storage class must be Workgroup, StorageBuffer or " + "PhysicalStorageBufferEXT but provided ") + << stringifyStorageClass(storage); + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.JointMatrixLoadINTEL +//===----------------------------------------------------------------------===// + +LogicalResult spirv::JointMatrixLoadINTELOp::verify() { + return verifyPointerAndJointMatrixType(*this, pointer().getType(), + result().getType()); +} + +//===----------------------------------------------------------------------===// +// spv.JointMatrixStoreINTEL +//===----------------------------------------------------------------------===// + +LogicalResult spirv::JointMatrixStoreINTELOp::verify() { + return verifyPointerAndJointMatrixType(*this, pointer().getType(), + object().getType()); +} + +//===----------------------------------------------------------------------===// +// spv.JointMatrixMadINTEL +//===----------------------------------------------------------------------===// + +static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) { + if (op.c().getType() != op.result().getType()) + return op.emitOpError("result and third operand must have the same type"); + auto typeA = op.a().getType().cast(); + auto typeB = op.b().getType().cast(); + auto typeC = op.c().getType().cast(); + auto typeR = op.result().getType().cast(); + if (typeA.getRows() != typeR.getRows() || + typeA.getColumns() != typeB.getRows() || + typeB.getColumns() != typeR.getColumns()) + return op.emitOpError("matrix size must match"); + if (typeR.getScope() != typeA.getScope() || + typeR.getScope() != typeB.getScope() || + typeR.getScope() != typeC.getScope()) + return op.emitOpError("matrix scope must match"); + if (typeA.getElementType() != typeB.getElementType() || + typeR.getElementType() != typeC.getElementType()) + return op.emitOpError("matrix element type must match"); + return success(); +} + +LogicalResult spirv::JointMatrixMadINTELOp::verify() { + return verifyJointMatrixMad(*this); +} + //===----------------------------------------------------------------------===// // spv.MatrixTimesScalar //===----------------------------------------------------------------------===// @@ -4150,6 +4234,8 @@ if (cType.isa()) return emitError("unsupported composite type ") << cType; + if (cType.isa()) + return emitError("unsupported composite type ") << cType; if (constituents.size() != cType.getNumElements()) return emitError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -89,9 +89,9 @@ bool CompositeType::classof(Type type) { if (auto vectorType = type.dyn_cast()) return isValid(vectorType); - return type - .isa(); + return type.isa(); } bool CompositeType::isValid(VectorType type) { @@ -110,7 +110,8 @@ Type CompositeType::getElementType(unsigned index) const { return TypeSwitch(*this) - .Case( + .Case( [](auto type) { return type.getElementType(); }) .Case([](MatrixType type) { return type.getColumnType(); }) .Case( @@ -132,6 +133,10 @@ llvm_unreachable( "invalid to query number of elements of spirv::CooperativeMatrix type"); } + if (isa()) { + llvm_unreachable( + "invalid to query number of elements of spirv::JointMatrix type"); + } if (isa()) { llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -140,15 +145,16 @@ } bool CompositeType::hasCompileTimeKnownNumElements() const { - return !isa(); + return !isa(); } void CompositeType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { TypeSwitch(*this) - .Case( + .Case( [&](auto type) { type.getExtensions(extensions, storage); }) .Case([&](VectorType type) { return type.getElementType().cast().getExtensions( @@ -161,8 +167,8 @@ SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage) { TypeSwitch(*this) - .Case( + .Case( [&](auto type) { type.getCapabilities(capabilities, storage); }) .Case([&](VectorType type) { auto vecSize = getNumElements(); @@ -255,6 +261,67 @@ capabilities.push_back(ref); } +//===----------------------------------------------------------------------===// +// JointMatrixType +//===----------------------------------------------------------------------===// + +struct spirv::detail::JointMatrixTypeStorage : public TypeStorage { + using KeyTy = std::tuple; + + static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + JointMatrixTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(elementType, scope, rows, columns); + } + + JointMatrixTypeStorage(const KeyTy &key) + : elementType(std::get<0>(key)), rows(std::get<2>(key)), + columns(std::get<3>(key)), scope(std::get<1>(key)) {} + + Type elementType; + unsigned rows; + unsigned columns; + Scope scope; +}; + +JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope, + unsigned rows, + unsigned columns) { + return Base::get(elementType.getContext(), elementType, scope, rows, columns); +} + +Type JointMatrixINTELType::getElementType() const { + return getImpl()->elementType; +} + +Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; } + +unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; } + +unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; } + +void JointMatrixINTELType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + getElementType().cast().getExtensions(extensions, storage); + static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix}; + ArrayRef ref(exts, llvm::array_lengthof(exts)); + extensions.push_back(ref); +} + +void JointMatrixINTELType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + getElementType().cast().getCapabilities(capabilities, storage); + static const Capability caps[] = {Capability::JointMatrixINTEL}; + ArrayRef ref(caps, llvm::array_lengthof(caps)); + capabilities.push_back(ref); +} + //===----------------------------------------------------------------------===// // ImageType //===----------------------------------------------------------------------===// @@ -1172,6 +1239,7 @@ //===----------------------------------------------------------------------===// void SPIRVDialect::registerTypes() { - addTypes(); + addTypes(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -168,6 +168,8 @@ return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: return processTypeForwardPointer(operands); + case spirv::Opcode::OpTypeJointMatrixINTEL: + return processType(opcode, operands); case spirv::Opcode::OpConstant: return processConstant(operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstant: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -257,6 +257,8 @@ LogicalResult processFunctionType(ArrayRef operands); + LogicalResult processJointMatrixType(ArrayRef operands); + LogicalResult processImageType(ArrayRef operands); LogicalResult processSampledImageType(ArrayRef operands); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -730,6 +730,8 @@ return processCooperativeMatrixType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); + case spirv::Opcode::OpTypeJointMatrixINTEL: + return processJointMatrixType(operands); case spirv::Opcode::OpTypeImage: return processImageType(operands); case spirv::Opcode::OpTypeSampledImage: @@ -888,6 +890,34 @@ return success(); } +LogicalResult +spirv::Deserializer::processJointMatrixType(ArrayRef operands) { + if (operands.size() != 5) { + return emitError(unknownLoc, "OpTypeJointMatrix must have element " + "type and row x column parameters"); + } + + Type elementTy = getType(operands[1]); + if (!elementTy) { + return emitError(unknownLoc, "OpTypeJointMatrix references undefined ") + << operands[1]; + } + + auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); + if (!scope) { + return emitError(unknownLoc, + "OpTypeCooperativeMatrix references undefined scope ") + << operands[2]; + } + + unsigned rows = getConstantInt(operands[3]).getInt(); + unsigned columns = getConstantInt(operands[4]).getInt(); + + typeMap[operands[0]] = + spirv::JointMatrixINTELType::get(elementTy, scope.value(), rows, columns); + return success(); +} + LogicalResult spirv::Deserializer::processRuntimeArrayType(ArrayRef operands) { if (operands.size() != 2) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -598,6 +598,25 @@ return success(); } + if (auto jointMatrixType = type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, jointMatrixType.getElementType(), + elementTypeID, serializationCtx))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL; + auto getConstantOp = [&](uint32_t id) { + auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); + return prepareConstantInt(loc, attr); + }; + operands.push_back(elementTypeID); + operands.push_back( + getConstantOp(static_cast(jointMatrixType.getScope()))); + operands.push_back(getConstantOp(jointMatrixType.getRows())); + operands.push_back(getConstantOp(jointMatrixType.getColumns())); + return success(); + } + if (auto matrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, diff --git a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir @@ -0,0 +1,158 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @joint_matrix_load +spv.func @joint_matrix_load(%ptr : !spv.ptr, %stride : i32, %layout : i32) "None" { + // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr, i32, i32) -> !spv.jointmatrix<16x8xi32, Workgroup> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride, %layout : (!spv.ptr, i32, i32) -> !spv.jointmatrix<16x8xi32, Workgroup> + spv.Return +} + +// ----- +// CHECK-LABEL: @joint_matrix_load_memaccess +spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32, %layout : i32) "None" { + // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr, i32, i32) -> !spv.jointmatrix<8x16xi32, Subgroup> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride, %layout {Volatile} : (!spv.ptr, i32, i32) -> !spv.jointmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_load_diff_ptr_type +spv.func @joint_matrix_load_diff_ptr_type(%ptr : !spv.ptr, StorageBuffer>, %stride : i32, %layout : i32) "None" { + // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr, StorageBuffer>, i32, i32) -> !spv.jointmatrix<8x16xi32, Subgroup> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride, %layout {Volatile} : (!spv.ptr, StorageBuffer>, i32, i32) -> !spv.jointmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_store +spv.func @joint_matrix_store(%ptr : !spv.ptr, %stride : i32, %m : !spv.jointmatrix<8x16xi32, Workgroup>, %layout : i32) "None" { + // CHECK: spv.JointMatrixStoreINTEL {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr, !spv.jointmatrix<8x16xi32, Workgroup>, i32, i32) + spv.JointMatrixStoreINTEL %ptr, %m, %stride, %layout : (!spv.ptr, !spv.jointmatrix<8x16xi32, Workgroup> , i32, i32) + spv.Return +} + +// CHECK-LABEL: @joint_matrix_store_memaccess +spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr, %m : !spv.jointmatrix<8x16xi32, Subgroup>, %stride : i32, %layout : i32) "None" { + // CHECK: spv.JointMatrixStoreINTEL {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr, !spv.jointmatrix<8x16xi32, Subgroup>, i32, i32) + spv.JointMatrixStoreINTEL %ptr, %m, %stride, %layout {Volatile} : (!spv.ptr, !spv.jointmatrix<8x16xi32, Subgroup>, i32, i32) + spv.Return +} + +// CHECK-LABEL: @joint_matrix_length +spv.func @joint_matrix_length() -> i32 "None" { + // CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, Subgroup> + %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, Subgroup> + spv.ReturnValue %0 : i32 +} + +// CHECK-LABEL: @joint_matrix_muladd +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x32xi8, Subgroup>, %b : !spv.jointmatrix<32x8xi8, Subgroup>, %c : !spv.jointmatrix<8x8xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.JointMatrixMadINTEL {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x32xi8, Subgroup>, !spv.jointmatrix<32x8xi8, Subgroup> -> !spv.jointmatrix<8x8xi32, Subgroup> + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x32xi8, Subgroup>, !spv.jointmatrix<32x8xi8, Subgroup> -> !spv.jointmatrix<8x8xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_add +spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, Subgroup>, %b : !spv.jointmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, Subgroup> + %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_sub +spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, Subgroup>, %b : !spv.jointmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, Subgroup> + %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_sdiv +spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, Subgroup>, %b : !spv.jointmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, Subgroup> + %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_udiv +spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, Subgroup>, %b : !spv.jointmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, Subgroup> + %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_fadd +spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, Subgroup>, %b : !spv.jointmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, Subgroup> + %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_fsub +spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, Subgroup>, %b : !spv.jointmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, Subgroup> + %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_fdiv +spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, Subgroup>, %b : !spv.jointmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, Subgroup> + %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, Subgroup> + spv.Return +} + +// ----- + +// CHECK-LABEL: @joint_matrix_access_chain +spv.func @joint_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { + %0 = spv.Constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 + spv.ReturnValue %1 : !spv.ptr +} + +// ----- + +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, Subgroup>, %b : !spv.jointmatrix<16x8xi32, Subgroup>, %c : !spv.jointmatrix<8x8xi32, Subgroup>) "None" { + // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}} + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<16x16xi32, Subgroup>, !spv.jointmatrix<16x8xi32, Subgroup> -> !spv.jointmatrix<8x8xi32, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, Subgroup>, %b : !spv.jointmatrix<8x8xi32, Subgroup>, %c : !spv.jointmatrix<8x8xi32, Subgroup>) "None" { + // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}} + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x16xi32, Subgroup>, !spv.jointmatrix<8x8xi32, Subgroup> -> !spv.jointmatrix<8x8xi32, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, Subgroup>, %b : !spv.jointmatrix<16x8xi32, Workgroup>, %c : !spv.jointmatrix<8x8xi32, Subgroup>) "None" { + // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix scope must match}} + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x16xi32, Subgroup>, !spv.jointmatrix<16x8xi32, Workgroup> -> !spv.jointmatrix<8x8xi32, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xf32, Subgroup>, %b : !spv.jointmatrix<16x8xi32, Subgroup>, %c : !spv.jointmatrix<8x8xi32, Subgroup>) "None" { + // expected-error @+1 {{matrix element type must match}} + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x16xf32, Subgroup>, !spv.jointmatrix<16x8xi32, Subgroup> -> !spv.jointmatrix<8x8xi32, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr, StorageBuffer>, %stride : i32, %layout : i32) "None" { + // expected-error @+1 {{Pointer must point to a scalar or vector type}} + %0 = spv.JointMatrixLoadINTEL %ptr, %stride, %layout : (!spv.ptr, StorageBuffer>, i32, i32 )-> !spv.jointmatrix<8x16xi32, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32, %layout : i32) "None" { + // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}} + %0 = spv.JointMatrixLoadINTEL %ptr, %stride, %layout : (!spv.ptr, i32, i32) -> !spv.jointmatrix<8x16xi32, Subgroup> + spv.Return +}