Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -27,12 +27,12 @@ // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins - SPV_ScalarOrVectorOrCoopMatrixOf:$operand1, - SPV_ScalarOrVectorOrCoopMatrixOf:$operand2 + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand1, + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand2 ); let results = (outs - SPV_ScalarOrVectorOrCoopMatrixOf:$result + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$result ); let assemblyFormat = "operands attr-dict `:` type($result)"; } Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td @@ -64,6 +64,27 @@ TypedArrayAttrBase; +// Description of the supported joint matrix operations. See +// https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc +def SPV_JointMatrixPropertiesINTELAttr : + SPV_Attr<"JointMatrixPropertiesINTEL", "joint_matrix_props"> { + let parameters = (ins + "int":$m_size, + "int":$n_size, + "int":$k_size, + "mlir::Type":$a_type, + "mlir::Type":$b_type, + "mlir::Type":$c_type, + "mlir::Type":$result_type, + "mlir::spirv::ScopeAttr":$scope + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def SPV_JointMatrixPropertiesINTELArrayAttr : + TypedArrayAttrBase; + // This attribute specifies the limits for various resources on the target // architecture. // Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -387,6 +387,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp_fast_math_mode", 4027>; def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>; def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; +def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -443,7 +444,7 @@ SPV_INTEL_usm_storage_classes, SPV_INTEL_io_pipes, SPV_INTEL_blocking_pipes, SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone, SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode, - SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, + SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough, SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage, @@ -1390,6 +1391,12 @@ ]; } +def SPV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMatrixINTEL", 6118> { + list availability = [ + Extension<[SPV_INTEL_joint_matrix]> + ]; +} + def SPV_CapabilityAttr : SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [ SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16, @@ -1481,7 +1488,7 @@ SPV_C_UniformTexelBufferArrayNonUniformIndexing, SPV_C_StorageTexelBufferArrayNonUniformIndexing, SPV_C_ShaderViewportIndexLayerEXT, SPV_C_ShaderViewportMaskNV, - SPV_C_ShaderStereoViewNV + SPV_C_ShaderStereoViewNV, SPV_C_JointMatrixINTEL ]>; def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -3823,6 +3830,17 @@ SPV_S_Invocation, SPV_S_QueueFamily, SPV_S_ShaderCallKHR ]>; + +def SPV_S_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>; +def SPV_S_RowMajor : I32EnumAttrCase<"RowMajor", 1>; +def SPV_S_PackedA : I32EnumAttrCase<"PackedA", 2>; +def SPV_S_PackedB : I32EnumAttrCase<"PackedB", 3>; + +def SPV_MatrixLayoutAttr : + SPV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [ + SPV_S_ColumnMajor, SPV_S_RowMajor, SPV_S_PackedA, SPV_S_PackedB + ]>; + def SPV_SC_None : I32BitEnumAttrCaseNone<"None">; def SPV_SC_Flatten : I32BitEnumAttrCaseBit<"Flatten", 0>; def SPV_SC_DontFlatten : I32BitEnumAttrCaseBit<"DontFlatten", 1>; @@ -4013,6 +4031,8 @@ def SPV_IsCooperativeMatrixType : CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">; def SPV_IsImageType : CPred<"$_self.isa<::mlir::spirv::ImageType>()">; +def SPV_IsJointMatrixType : + CPred<"$_self.isa<::mlir::spirv::JointMatrixINTELType>()">; def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">; def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; @@ -4043,6 +4063,8 @@ "any SPIR-V cooperative matrix type">; def SPV_AnyImage : DialectType; +def SPV_AnyJointMatrix : DialectType; def SPV_AnyMatrix : DialectType; def SPV_AnyRTArray : DialectType; def SPV_Composite : AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, - SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>; + SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix]>; def SPV_Type : AnyTypeOf<[ SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, - SPV_AnyCooperativeMatrix, SPV_AnyMatrix, SPV_AnySampledImage + SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix, + SPV_AnySampledImage ]>; def SPV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; @@ -4072,6 +4095,11 @@ "$_self.cast<::mlir::spirv::CooperativeMatrixNVType>().getElementType()", "Cooperative Matrix">; +class SPV_JointMatrixOfType allowedTypes> : + ContainerType, SPV_IsJointMatrixType, + "$_self.cast<::mlir::spirv::JointMatrixINTELType>().getElementType()", + "Joint Matrix">; + class SPV_ScalarOrVectorOf : AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>; @@ -4079,6 +4107,14 @@ AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, SPV_CoopMatrixOfType<[type]>]>; +class SPV_ScalarOrVectorOrJointMatrixOf : + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, + SPV_JointMatrixOfType<[type]>]>; + +class SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf : + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, + SPV_CoopMatrixOfType<[type]>, SPV_JointMatrixOfType<[type]> ]>; + def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; @@ -4311,6 +4347,11 @@ def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>; def SPV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>; def SPV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>; +def SPV_OC_OpTypeJointMatrixINTEL : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>; +def SPV_OC_OpJointMatrixLoadINTEL : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>; +def SPV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreINTEL", 6121>; +def SPV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>; +def SPV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>; def SPV_OpcodeAttr : SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ @@ -4376,7 +4417,10 @@ SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL, - SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT + SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT, + SPV_OC_OpTypeJointMatrixINTEL, SPV_OC_OpJointMatrixLoadINTEL, + SPV_OC_OpJointMatrixStoreINTEL, SPV_OC_OpJointMatrixMadINTEL, + SPV_OC_OpTypejointMatrixWorkItemLengthINTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -23,11 +23,11 @@ !listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> { let arguments = (ins - SPV_ScalarOrVectorOrCoopMatrixOf:$operand + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand ); let results = (outs - SPV_ScalarOrVectorOrCoopMatrixOf:$result + SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$result ); let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td @@ -0,0 +1,248 @@ +//===- SPIRVJointMatrixOps.td - joint matmul ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of joint matrix multiply extension ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS +#define MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS + +// ----- + +def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL", + [NoSideEffect]> { + let summary = "See extension SPV_INTEL_joint_matrix"; + + let description = [{ + Return number of components owned by the current work-item in + a joint matrix. + + Result Type must be an 32-bit unsigned integer type scalar. + + Type is a joint matrix type. + + ``` {.ebnf} + joint-matrix-length-op ::= ssa-id `=` `spv.JointMatrixWorkItemLengthINTEL + ` : ` joint-matrix-type + ``` + + For example: + + ``` + %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix + ``` + }]; + + let assemblyFormat = "attr-dict `:` $type"; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_joint_matrix]>, + Capability<[SPV_C_JointMatrixINTEL]> + ]; + + let arguments = (ins + TypeAttr:$type + ); + + let results = (outs + SPV_Int32:$result + ); + let hasVerifier = 0; +} + +// ----- + +def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> { + let summary = "See extension SPV_INTEL_joint_matrix"; + + let description = [{ + Load a matrix through a pointer. + + Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL. + + Pointer is the pointer to load through. It specifies start of memory region where + elements of the matrix are stored and arranged according to Layout. + + Stride is the number of elements in memory between beginnings of successive rows, + columns (or words) in the result. It must be a scalar integer type. + + Layout indicates how the values loaded from memory are arranged. It must be the + result of a constant instruction. + + Scope is syncronization scope for operation on the matrix. It must be the result + of a constant instruction with scalar integer type. + + If present, any Memory Operands must begin with a memory operand literal. If not + present, it is the same as specifying the memory operand None. + + #### Example: + ```mlir + %0 = spv.JointMatrixLoadINTEL %ptr, %stride + {memory_access = #spv.memory_access} : + (!spv.ptr, i32) -> + !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup> + ``` + }]; + + let assemblyFormat = [{ + $scope $layout operands attr-dict `:` `(` type(operands) `)` `->` type($result) + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_joint_matrix]>, + Capability<[SPV_C_JointMatrixINTEL]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$scope, + SPV_MatrixLayoutAttr:$layout, + SPV_AnyPtr:$pointer, + SPV_Integer:$stride, + OptionalAttr:$memory_access, + OptionalAttr:$alignment + ); + + let results = (outs + SPV_AnyJointMatrix:$result + ); +} + +// ----- + +def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL", + [NoSideEffect, AllTypesMatch<["c", "result"]>]> { + let summary = "See extension SPV_INTEL_joint_matrix"; + + let description = [{ + Multiply matrix A by matrix B and add matrix C to the result + of the multiplication: A*B+C. Here A is a M x K matrix, B is + a K x N matrix and C is a M x N matrix. + + Behavior is undefined if sizes of operands do not meet the + conditions above. All operands and the Result Type must be + OpTypeJointMatrixINTEL. + + A must be a OpTypeJointMatrixINTEL whose Component Type is a + signed numerical type, Row Count equals to M and Column Count + equals to K + + B must be a OpTypeJointMatrixINTEL whose Component Type is a + signed numerical type, Row Count equals to K and Column Count + equals to N + + C and Result Type must be a OpTypeJointMatrixINTEL with Row + Count equals to M and Column Count equals to N + + Scope is syncronization scope for operation on the matrix. + It must be the result of a constant instruction with scalar + integer type. + + #### Example: + ```mlir + %r = spv.JointMatrixMadINTEL %a, %b, %c : + !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, + !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> + -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + ``` + + }]; + + let assemblyFormat = [{ + $scope operands attr-dict`:` type($a) `,` type($b) `->` type($c) + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_joint_matrix]>, + Capability<[SPV_C_JointMatrixINTEL]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$scope, + SPV_AnyJointMatrix:$a, + SPV_AnyJointMatrix:$b, + SPV_AnyJointMatrix:$c + ); + + let results = (outs + SPV_AnyJointMatrix:$result + ); +} + +// ----- + +def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> { + let summary = "See extension SPV_INTEL_joint_matrix"; + + let description = [{ + Store a matrix through a pointer. + + Pointer is the pointer to store through. It specifies + start of memory region where elements of the matrix must + be stored and arranged according to Layout. + + Object is the matrix to store. It must be + OpTypeJointMatrixINTEL. + + Stride is the number of elements in memory between beginnings + of successive rows, columns (or words) of the Object. It must + be a scalar integer type. + + Layout indicates how the values stored to memory are arranged. + It must be the result of a constant instruction. + + Scope is syncronization scope for operation on the matrix. + It must be the result of a constant instruction with scalar + integer type. + + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory + operand None. + + #### Example: + ```mlir + spv.JointMatrixStoreINTEL %ptr, %m, %stride + {memory_access = #spv.memory_access} : (!spv.ptr, + !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32) + ``` + + }]; + + let assemblyFormat = [{ + $scope $layout operands attr-dict `:` `(` type(operands) `)` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_joint_matrix]>, + Capability<[SPV_C_JointMatrixINTEL]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$scope, + SPV_MatrixLayoutAttr:$layout, + SPV_AnyPtr:$pointer, + SPV_AnyJointMatrix:$object, + SPV_Integer:$stride, + OptionalAttr:$memory_access, + OptionalAttr:$alignment + ); + + let results = (outs); +} + +// ----- + +#endif // MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td @@ -30,6 +30,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td" +include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td" Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -29,6 +29,7 @@ struct ArrayTypeStorage; struct CooperativeMatrixTypeStorage; struct ImageTypeStorage; +struct JointMatrixTypeStorage; struct MatrixTypeStorage; struct PointerTypeStorage; struct RuntimeArrayTypeStorage; @@ -420,6 +421,33 @@ Optional storage = llvm::None); }; +// SPIR-V joint matrix type +class JointMatrixINTELType + : public Type::TypeBase { +public: + using Base::Base; + + static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, + unsigned columns, MatrixLayout matrixLayout); + Type getElementType() const; + + /// Return the scope of the joint matrix. + Scope getScope() const; + /// return the number of rows of the matrix. + unsigned getRows() const; + /// return the number of columns of the matrix. + unsigned getColumns() const; + + /// return the layout of the matrix + MatrixLayout getMatrixLayout() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); +}; + // SPIR-V matrix type class MatrixType : public Type::TypeBase { Index: mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -348,6 +348,39 @@ return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]); } +// joint-matrix-type ::= `!spv.jointmatrix` `<`rows `x` columns `x` element-type +// `,` layout `,` scope`>` +static Type parseJointMatrixType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dims; + SMLoc countLoc = parser.getCurrentLocation(); + if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) + return Type(); + + if (dims.size() != 2) { + parser.emitError(countLoc, "expected rows and columns size"); + return Type(); + } + + auto elementTy = parseAndVerifyType(dialect, parser); + if (!elementTy) + return Type(); + MatrixLayout matrixLayout; + if (parser.parseComma() || + parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout ")) + return Type(); + Scope scope; + if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope ")) + return Type(); + if (parser.parseGreater()) + return Type(); + return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1], + matrixLayout); +} + // TODO: Reorder methods to be utilities first and parse*Type // methods in alphabetical order // @@ -753,6 +786,8 @@ return parseArrayType(*this, parser); if (keyword == "coopmatrix") return parseCooperativeMatrixType(*this, parser); + if (keyword == "jointmatrix") + return parseJointMatrixType(*this, parser); if (keyword == "image") return parseImageType(*this, parser); if (keyword == "ptr") @@ -859,6 +894,13 @@ os << ">"; } +static void print(JointMatrixINTELType type, DialectAsmPrinter &os) { + os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; + os << type.getElementType() << ", " + << stringifyMatrixLayout(type.getMatrixLayout()); + os << ", " << stringifyScope(type.getScope()) << ">"; +} + static void print(MatrixType type, DialectAsmPrinter &os) { os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType(); os << ">"; @@ -866,9 +908,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) - .Case( - [&](auto type) { print(type, os); }) + .Case([&](auto type) { print(type, os); }) .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); } Index: mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -436,6 +436,13 @@ resultType.cast().getElementType(); } + if (auto jointMatrixType = + operandType.dyn_cast()) { + operandType = jointMatrixType.getElementType(); + resultType = + resultType.cast().getElementType(); + } + auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth(); auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; @@ -1637,6 +1644,17 @@ return success(); } + if (auto jointType = cType.dyn_cast()) { + if (constituents.size() != 1) + return emitOpError("has incorrect number of operands: expected ") + << "1, but provided " << constituents.size(); + if (jointType.getElementType() != constituents.front().getType()) + return emitOpError("operand type mismatch: expected operand type ") + << jointType.getElementType() << ", but provided " + << constituents.front().getType(); + return success(); + } + if (constituents.size() == cType.getNumElements()) { for (auto index : llvm::seq(0, constituents.size())) { if (constituents[index].getType() != cType.getElementType(index)) { @@ -3893,6 +3911,70 @@ return verifyCoopMatrixMulAdd(*this); } +static LogicalResult +verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { + Type pointeeType = pointer.cast().getPointeeType(); + if (!pointeeType.isa() && !pointeeType.isa()) + return op->emitError( + "Pointer must point to a scalar or vector type but provided ") + << pointeeType; + spirv::StorageClass storage = + pointer.cast().getStorageClass(); + if (storage != spirv::StorageClass::Workgroup && + storage != spirv::StorageClass::CrossWorkgroup) + return op->emitError("Pointer storage class must be Workgroup or " + "CrossWorkgroup but provided ") + << stringifyStorageClass(storage); + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.JointMatrixLoadINTEL +//===----------------------------------------------------------------------===// + +LogicalResult spirv::JointMatrixLoadINTELOp::verify() { + return verifyPointerAndJointMatrixType(*this, pointer().getType(), + result().getType()); +} + +//===----------------------------------------------------------------------===// +// spv.JointMatrixStoreINTEL +//===----------------------------------------------------------------------===// + +LogicalResult spirv::JointMatrixStoreINTELOp::verify() { + return verifyPointerAndJointMatrixType(*this, pointer().getType(), + object().getType()); +} + +//===----------------------------------------------------------------------===// +// spv.JointMatrixMadINTEL +//===----------------------------------------------------------------------===// + +static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) { + if (op.c().getType() != op.result().getType()) + return op.emitOpError("result and third operand must have the same type"); + auto typeA = op.a().getType().cast(); + auto typeB = op.b().getType().cast(); + auto typeC = op.c().getType().cast(); + auto typeR = op.result().getType().cast(); + if (typeA.getRows() != typeR.getRows() || + typeA.getColumns() != typeB.getRows() || + typeB.getColumns() != typeR.getColumns()) + return op.emitOpError("matrix size must match"); + if (typeR.getScope() != typeA.getScope() || + typeR.getScope() != typeB.getScope() || + typeR.getScope() != typeC.getScope()) + return op.emitOpError("matrix scope must match"); + if (typeA.getElementType() != typeB.getElementType() || + typeR.getElementType() != typeC.getElementType()) + return op.emitOpError("matrix element type must match"); + return success(); +} + +LogicalResult spirv::JointMatrixMadINTELOp::verify() { + return verifyJointMatrixMad(*this); +} + //===----------------------------------------------------------------------===// // spv.MatrixTimesScalar //===----------------------------------------------------------------------===// @@ -4150,6 +4232,8 @@ if (cType.isa()) return emitError("unsupported composite type ") << cType; + if (cType.isa()) + return emitError("unsupported composite type ") << cType; if (constituents.size() != cType.getNumElements()) return emitError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " Index: mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -89,9 +89,9 @@ bool CompositeType::classof(Type type) { if (auto vectorType = type.dyn_cast()) return isValid(vectorType); - return type - .isa(); + return type.isa(); } bool CompositeType::isValid(VectorType type) { @@ -110,7 +110,8 @@ Type CompositeType::getElementType(unsigned index) const { return TypeSwitch(*this) - .Case( + .Case( [](auto type) { return type.getElementType(); }) .Case([](MatrixType type) { return type.getColumnType(); }) .Case( @@ -132,6 +133,10 @@ llvm_unreachable( "invalid to query number of elements of spirv::CooperativeMatrix type"); } + if (isa()) { + llvm_unreachable( + "invalid to query number of elements of spirv::JointMatrix type"); + } if (isa()) { llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -140,15 +145,16 @@ } bool CompositeType::hasCompileTimeKnownNumElements() const { - return !isa(); + return !isa(); } void CompositeType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { TypeSwitch(*this) - .Case( + .Case( [&](auto type) { type.getExtensions(extensions, storage); }) .Case([&](VectorType type) { return type.getElementType().cast().getExtensions( @@ -161,8 +167,8 @@ SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage) { TypeSwitch(*this) - .Case( + .Case( [&](auto type) { type.getCapabilities(capabilities, storage); }) .Case([&](VectorType type) { auto vecSize = getNumElements(); @@ -255,6 +261,74 @@ capabilities.push_back(ref); } +//===----------------------------------------------------------------------===// +// JointMatrixType +//===----------------------------------------------------------------------===// + +struct spirv::detail::JointMatrixTypeStorage : public TypeStorage { + using KeyTy = std::tuple; + + static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + JointMatrixTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(elementType, rows, columns, matrixLayout, scope); + } + + JointMatrixTypeStorage(const KeyTy &key) + : elementType(std::get<0>(key)), rows(std::get<1>(key)), + columns(std::get<2>(key)), scope(std::get<4>(key)), + matrixLayout(std::get<3>(key)) {} + + Type elementType; + unsigned rows; + unsigned columns; + Scope scope; + MatrixLayout matrixLayout; +}; + +JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope, + unsigned rows, unsigned columns, + MatrixLayout matrixLayout) { + return Base::get(elementType.getContext(), elementType, rows, columns, + matrixLayout, scope); +} + +Type JointMatrixINTELType::getElementType() const { + return getImpl()->elementType; +} + +Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; } + +unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; } + +unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; } + +MatrixLayout JointMatrixINTELType::getMatrixLayout() const { + return getImpl()->matrixLayout; +} + +void JointMatrixINTELType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + getElementType().cast().getExtensions(extensions, storage); + static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix}; + ArrayRef ref(exts, llvm::array_lengthof(exts)); + extensions.push_back(ref); +} + +void JointMatrixINTELType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + getElementType().cast().getCapabilities(capabilities, storage); + static const Capability caps[] = {Capability::JointMatrixINTEL}; + ArrayRef ref(caps, llvm::array_lengthof(caps)); + capabilities.push_back(ref); +} + //===----------------------------------------------------------------------===// // ImageType //===----------------------------------------------------------------------===// @@ -1172,6 +1246,7 @@ //===----------------------------------------------------------------------===// void SPIRVDialect::registerTypes() { - addTypes(); + addTypes(); } Index: mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp =================================================================== --- mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -168,6 +168,8 @@ return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: return processTypeForwardPointer(operands); + case spirv::Opcode::OpTypeJointMatrixINTEL: + return processType(opcode, operands); case spirv::Opcode::OpConstant: return processConstant(operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstant: Index: mlir/lib/Target/SPIRV/Deserialization/Deserializer.h =================================================================== --- mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -257,6 +257,8 @@ LogicalResult processFunctionType(ArrayRef operands); + LogicalResult processJointMatrixType(ArrayRef operands); + LogicalResult processImageType(ArrayRef operands); LogicalResult processSampledImageType(ArrayRef operands); Index: mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp =================================================================== --- mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -730,6 +730,8 @@ return processCooperativeMatrixType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); + case spirv::Opcode::OpTypeJointMatrixINTEL: + return processJointMatrixType(operands); case spirv::Opcode::OpTypeImage: return processImageType(operands); case spirv::Opcode::OpTypeSampledImage: @@ -888,6 +890,40 @@ return success(); } +LogicalResult +spirv::Deserializer::processJointMatrixType(ArrayRef operands) { + if (operands.size() != 6) { + return emitError(unknownLoc, "OpTypeJointMatrix must have element " + "type and row x column parameters"); + } + + Type elementTy = getType(operands[1]); + if (!elementTy) { + return emitError(unknownLoc, "OpTypeJointMatrix references undefined ") + << operands[1]; + } + + auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt()); + if (!scope) { + return emitError(unknownLoc, + "OpTypeJointMatrix references undefined scope ") + << operands[5]; + } + auto matrixLayout = + spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt()); + if (!matrixLayout) { + return emitError(unknownLoc, + "OpTypeJointMatrix references undefined scope ") + << operands[4]; + } + unsigned rows = getConstantInt(operands[2]).getInt(); + unsigned columns = getConstantInt(operands[3]).getInt(); + + typeMap[operands[0]] = spirv::JointMatrixINTELType::get( + elementTy, scope.value(), rows, columns, matrixLayout.value()); + return success(); +} + LogicalResult spirv::Deserializer::processRuntimeArrayType(ArrayRef operands) { if (operands.size() != 2) { Index: mlir/lib/Target/SPIRV/Serialization/Serializer.cpp =================================================================== --- mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -598,6 +598,27 @@ return success(); } + if (auto jointMatrixType = type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, jointMatrixType.getElementType(), + elementTypeID, serializationCtx))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL; + auto getConstantOp = [&](uint32_t id) { + auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); + return prepareConstantInt(loc, attr); + }; + operands.push_back(elementTypeID); + operands.push_back(getConstantOp(jointMatrixType.getRows())); + operands.push_back(getConstantOp(jointMatrixType.getColumns())); + operands.push_back(getConstantOp( + static_cast(jointMatrixType.getMatrixLayout()))); + operands.push_back( + getConstantOp(static_cast(jointMatrixType.getScope()))); + return success(); + } + if (auto matrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, Index: mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir @@ -0,0 +1,158 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @joint_matrix_load +spv.func @joint_matrix_load(%ptr : !spv.ptr, %stride : i32) "None" { + // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL {{%.*}}, {{%.*}} : (!spv.ptr, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride : (!spv.ptr, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> + spv.Return +} + +// ----- +// CHECK-LABEL: @joint_matrix_load_memaccess +spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32) "None" { + // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL {{%.*}}, {{%.*}} {memory_access = #spv.memory_access} : (!spv.ptr, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride {memory_access = #spv.memory_access} : (!spv.ptr, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_load_diff_ptr_type +spv.func @joint_matrix_load_diff_ptr_type(%ptr : !spv.ptr, Workgroup>, %stride : i32) "None" { + // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL {{%.*}}, {{%.*}} {memory_access = #spv.memory_access} : (!spv.ptr, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride {memory_access = #spv.memory_access} : (!spv.ptr, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_store +spv.func @joint_matrix_store(%ptr : !spv.ptr, %stride : i32, %m : !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>) "None" { + // CHECK: spv.JointMatrixStoreINTEL {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32) + spv.JointMatrixStoreINTEL %ptr, %m, %stride : (!spv.ptr, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32) + spv.Return +} + +// CHECK-LABEL: @joint_matrix_store_memaccess +spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" { + // CHECK: spv.JointMatrixStoreINTEL {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32) + spv.JointMatrixStoreINTEL %ptr, %m, %stride {Volatile} : (!spv.ptr, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32) + spv.Return +} + +// CHECK-LABEL: @joint_matrix_length +spv.func @joint_matrix_length() -> i32 "None" { + // CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup> + %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup> + spv.ReturnValue %0 : i32 +} + +// CHECK-LABEL: @joint_matrix_muladd +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, %b : !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.JointMatrixMadINTEL {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_add +spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_sub +spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_sdiv +spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_udiv +spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_fadd +spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_fsub +spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + spv.Return +} + +// CHECK-LABEL: @joint_matrix_fdiv +spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + spv.Return +} + +// ----- + +// CHECK-LABEL: @joint_matrix_access_chain +spv.func @joint_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { + %0 = spv.Constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 + spv.ReturnValue %1 : !spv.ptr +} + +// ----- + +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" { + // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}} + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" { + // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}} + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" { + // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix scope must match}} + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" { + // expected-error @+1 {{matrix element type must match}} + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr, Workgroup>, %stride : i32) "None" { + // expected-error @+1 {{Pointer must point to a scalar or vector type}} + %0 = spv.JointMatrixLoadINTEL %ptr, %stride : (!spv.ptr, Workgroup>, i32)-> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return +} + +// ----- + +spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32) "None" { + // expected-error @+1 {{Pointer storage class must be Workgroup or CrossWorkgroup}} + %0 = spv.JointMatrixLoadINTEL %ptr, %stride : (!spv.ptr, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return +} Index: mlir/test/Target/SPIRV/joint-matrix-ops.mlir =================================================================== --- /dev/null +++ mlir/test/Target/SPIRV/joint-matrix-ops.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK-LABEL: @joint_matrix_load + spv.func @joint_matrix_load(%ptr : !spv.ptr, %stride : i32) "None" { + // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL {{%.*}}, {{%.*}} : (!spv.ptr, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride : (!spv.ptr, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_load_memaccess + spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32) "None" { + // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL {{%.*}}, {{%.*}} {memory_access = #spv.memory_access} : (!spv.ptr, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride {memory_access = #spv.memory_access} : (!spv.ptr, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_store + spv.func @joint_matrix_store(%ptr : !spv.ptr, %stride : i32, %m : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>) "None" { + // CHECK: spv.JointMatrixStoreINTEL {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32) + spv.JointMatrixStoreINTEL %ptr, %m, %stride : (!spv.ptr, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32) + spv.Return + } + + // CHECK-LABEL: @joint_matrix_store_memaccess + spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" { + // CHECK: spv.JointMatrixStoreINTEL {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spv.memory_access} : (!spv.ptr, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32) + spv.JointMatrixStoreINTEL %ptr, %m, %stride {memory_access = #spv.memory_access} : (!spv.ptr, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32) + spv.Return + } + + // CHECK-LABEL: @joint_matrix_length + spv.func @joint_matrix_length() -> i32 "None" { + // CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.ReturnValue %0 : i32 + } + + // CHECK-LABEL: @joint_matrix_muladd + spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.JointMatrixMadINTEL {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_add + spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_sub + spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_sdiv + spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_udiv + spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_fadd + spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_fsub + spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_fdiv + spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> + spv.Return + } + + // CHECK-LABEL: @joint_matrix_access_chain + spv.func @joint_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { + %0 = spv.Constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 + spv.ReturnValue %1 : !spv.ptr + } +} Index: mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -518,7 +518,8 @@ os << tabs << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName); if (attr.getAttrDefName() == "SPV_ScopeAttr" || - attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { + attr.getAttrDefName() == "SPV_MemorySemanticsAttr" || + attr.getAttrDefName() == "SPV_MatrixLayoutAttr") { // These two enums are encoded as to constant values in SPIR-V blob, // but we directly use the constant value as attribute in SPIR-V dialect. So // need to handle them separately from normal enum attributes. @@ -810,7 +811,8 @@ StringRef words, StringRef wordIndex, raw_ostream &os) { if (attr.getAttrDefName() == "SPV_ScopeAttr" || - attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { + attr.getAttrDefName() == "SPV_MemorySemanticsAttr" || + attr.getAttrDefName() == "SPV_MatrixLayoutAttr") { // These two enums are encoded as to constant values in SPIR-V blob, // but we directly use the constant value as attribute in SPIR-V dialect. So // need to handle them separately from normal enum attributes.