diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -343,6 +343,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup_rotate", 28>; def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>; def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>; +def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>; def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>; def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>; @@ -435,6 +436,7 @@ SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask, SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate, SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation, + SPV_KHR_cooperative_matrix, SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing, SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density, SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer, @@ -835,6 +837,12 @@ Extension<[SPV_KHR_ray_cull_mask]> ]; } +def SPIRV_C_CooperativeMatrixKHR : I32EnumAttrCase<"CooperativeMatrixKHR", 6022> { + list availability = [ + Extension<[SPV_KHR_cooperative_matrix]>, + MinVersion + ]; +} def SPIRV_C_BitInstructions : I32EnumAttrCase<"BitInstructions", 6025> { list availability = [ Extension<[SPV_KHR_bit_instructions]> @@ -1457,6 +1465,7 @@ SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL, SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll, SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR, + SPIRV_C_CooperativeMatrixKHR, SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT, SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL, SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL, @@ -4069,6 +4078,8 @@ !interleave(widths, "/") # "-bit signless/unsigned integer">; def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">; +def SPIRV_IsCooperativeMatrixType : + CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">; def SPIRV_IsCooperativeMatrixNVType : CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">; def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">; @@ -4100,6 +4111,9 @@ "any SPIR-V pointer type">; def SPIRV_AnyArray : DialectType; +def SPIRV_AnyCooperativeMatrix : DialectType; def SPIRV_AnyCooperativeMatrixNV : DialectType; @@ -4121,17 +4135,23 @@ def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>; def SPIRV_Composite : AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, - SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; + SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV, + SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; def SPIRV_Type : AnyTypeOf<[ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector, SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, - SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, - SPIRV_AnySampledImage + SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV, + SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage ]>; def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; def SPIRV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; +class SPIRV_CoopMatrixOfType allowedTypes> : + ContainerType, SPIRV_IsCooperativeMatrixType, + "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()", + "Cooperative Matrix">; + class SPIRV_CoopMatrixNVOfType allowedTypes> : ContainerType, SPIRV_IsCooperativeMatrixNVType, "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()", @@ -4147,10 +4167,12 @@ class SPIRV_ScalarOrVectorOrCoopMatrixOf : AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, - SPIRV_CoopMatrixNVOfType<[type]>]>; + SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>; class SPIRV_MatrixOrCoopMatrixOf : - AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixNVOfType<[type]>]>; + AnyTypeOf<[SPIRV_AnyMatrix, + SPIRV_CoopMatrixOfType<[type]>, + SPIRV_CoopMatrixNVOfType<[type]>]>; def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>; def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>; @@ -4400,6 +4422,8 @@ def SPIRV_OC_OpSDotAccSat : I32EnumAttrCase<"OpSDotAccSat", 4453>; def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454>; def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>; +def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>; +def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>; def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>; def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>; def SPIRV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>; @@ -4498,7 +4522,8 @@ SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, - SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixNV, + SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpCooperativeMatrixLengthKHR, + SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV, SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -7,12 +7,62 @@ //===----------------------------------------------------------------------===// // // This is the op definition spec of cooperative matrix multiply extension ops. +// We support both cooperative matrix extensions: +// - SPV_NV_cooperative_matrix +// - SPV_KHR_cooperative_matrix // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS #define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS +//===----------------------------------------------------------------------===// +// SPV_KHR_cooperative_matrix extension ops. +//===----------------------------------------------------------------------===// + +// ----- + +def SPIRV_KHRCooperativeMatrixLengthOp : + SPIRV_KhrVendorOp<"CooperativeMatrixLength", [Pure]> { + let summary = "Queries the number of cooperative matrix components"; + + let description = [{ + Number of components of a cooperative matrix type accessible to each + invocation when treated as a composite. + + The type attrubute must be a cooperative matrix type. + + ``` {.ebnf} + cooperative-matrix-length-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLength + ` : ` cooperative-matrix-type + ``` + + #### Example: + + ``` + %0 = spirv.KHR.CooperativeMatrixLength : + !spirv.coopmatrix<8x16xi32, 0, Subgroup, i32> + ``` + }]; + + let assemblyFormat = "attr-dict `:` $cooperative_matrix_type"; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_KHR_cooperative_matrix]>, + Capability<[SPIRV_C_CooperativeMatrixKHR]> + ]; + + let arguments = (ins + TypeAttr:$cooperative_matrix_type + ); + + let results = (outs + SPIRV_Int32:$result + ); +} + //===----------------------------------------------------------------------===// // SPV_NV_cooperative_matrix extension ops. //===----------------------------------------------------------------------===// @@ -59,7 +109,6 @@ let results = (outs SPIRV_Int32:$result ); - let hasVerifier = 0; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -20,6 +20,7 @@ #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" +#include #include namespace mlir { @@ -27,6 +28,7 @@ namespace detail { struct ArrayTypeStorage; +struct CooperativeMatrixTypeStorage; struct CooperativeMatrixNVTypeStorage; struct ImageTypeStorage; struct JointMatrixTypeStorage; @@ -398,6 +400,32 @@ llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); +// SPIR-V KHR cooperative matrix type +class CooperativeMatrixType + : public Type::TypeBase { +public: + using Base::Base; + + static CooperativeMatrixType get(Type elementType, Scope scope, uint32_t rows, + uint32_t columns, uint32_t use); + Type getElementType() const; + + /// Returns the scope of the matrix. + Scope getScope() const; + /// Returns the number of rows of the matrix. + uint32_t getRows() const; + /// Returns the number of columns of the matrix. + uint32_t getColumns() const; + /// Returns the use parameter of the cooperative matrix. + uint32_t getUse() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage = std::nullopt); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + std::optional storage = std::nullopt); +}; + // SPIR-V NV cooperative matrix type class CooperativeMatrixNVType : public Type::TypeBase +// `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,` +// use `,` scope `>` static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser) { + if (parser.parseLess()) + return {}; + + SmallVector dims; + SMLoc countLoc = parser.getCurrentLocation(); + if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) + return {}; + + if (dims.size() != 2) { + parser.emitError(countLoc, "expected row and column count"); + return {}; + } + + auto elementTy = parseAndVerifyType(dialect, parser); + if (!elementTy) + return {}; + + if (parser.parseComma()) + return {}; + + countLoc = parser.getCurrentLocation(); + uint32_t use = 0; + if (parser.parseInteger(use)) + return {}; + + Scope scope; + if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope ")) + return {}; + + if (parser.parseGreater()) + return {}; + + return CooperativeMatrixType::get(elementTy, scope, dims[0], dims[1], use); +} + +// nv-cooperative-matrix-type ::= +// `!spirv.NV.coopmatrix` `<` element-type ',' scope ',' rows ',' columns>` +static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { if (parser.parseLess()) return Type(); @@ -785,8 +825,10 @@ if (keyword == "array") return parseArrayType(*this, parser); - if (keyword == "NV.coopmatrix") + if (keyword == "coopmatrix") return parseCooperativeMatrixType(*this, parser); + if (keyword == "NV.coopmatrix") + return parseCooperativeMatrixNVType(*this, parser); if (keyword == "jointmatrix") return parseJointMatrixType(*this, parser); if (keyword == "image") @@ -889,6 +931,13 @@ structContext.remove(type.getIdentifier()); } +static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { + os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; + os << type.getElementType() << ", " << type.getUse() << ", " + << stringifyScope(type.getScope()); + os << ">"; +} + static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; os << type.getElementType() << ", " << stringifyScope(type.getScope()); @@ -909,9 +958,10 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) - .Case([&](auto type) { print(type, os); }) + .Case( + [&](auto type) { print(type, os); }) .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -32,6 +32,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include #include @@ -429,28 +430,30 @@ Type operandType = op->getOperand(0).getType(); Type resultType = op->getResult(0).getType(); - // ODS checks that result type and operand type have the same shape. - if (auto vectorType = llvm::dyn_cast(operandType)) { - operandType = vectorType.getElementType(); - resultType = llvm::cast(resultType).getElementType(); - } - - if (auto coopMatrixType = - llvm::dyn_cast(operandType)) { - operandType = coopMatrixType.getElementType(); - resultType = - llvm::cast(resultType).getElementType(); - } - - if (auto jointMatrixType = - llvm::dyn_cast(operandType)) { - operandType = jointMatrixType.getElementType(); - resultType = - llvm::cast(resultType).getElementType(); - } - - auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); - auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth(); + // ODS checks that result type and operand type have the same shape. Check + // that composite types match and extract the element types, if any. + using TypePair = std::pair; + auto [operandElemTy, resultElemTy] = + TypeSwitch(operandType) + .Case( + [resultType](auto concreteOperandTy) -> TypePair { + if (auto concreteResultTy = + dyn_cast(resultType)) { + return {concreteOperandTy.getElementType(), + concreteResultTy.getElementType()}; + } + return {}; + }) + .Default([resultType](Type operandType) -> TypePair { + return {operandType, resultType}; + }); + + if (!operandElemTy || !resultElemTy) + return op->emitOpError("incompatible operand and result types"); + + auto operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth(); + auto resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth(); auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; if (requireSameBitWidth) { @@ -458,7 +461,7 @@ return op->emitOpError( "expected the same bit widths for operand type and result " "type, but provided ") - << operandType << " and " << resultType; + << operandElemTy << " and " << resultElemTy; } return success(); } @@ -467,7 +470,7 @@ return op->emitOpError( "expected the different bit widths for operand type and result " "type, but provided ") - << operandType << " and " << resultType; + << operandElemTy << " and " << resultElemTy; } return success(); } @@ -4018,6 +4021,34 @@ return success(); } +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixLength +//===----------------------------------------------------------------------===// + +LogicalResult spirv::KHRCooperativeMatrixLengthOp::verify() { + if (!isa(getCooperativeMatrixType())) { + return emitOpError( + "type attribute must be a '!spirv.coopmatrix' type, found ") + << getCooperativeMatrixType() << " instead"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.NV.CooperativeMatrixLength +//===----------------------------------------------------------------------===// + +LogicalResult spirv::NVCooperativeMatrixLengthOp::verify() { + if (!isa(getCooperativeMatrixType())) { + return emitOpError( + "type attribute must be a '!spirv.NV.coopmatrix' type, found ") + << getCooperativeMatrixType() << " instead"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // spirv.NV.CooperativeMatrixLoad //===----------------------------------------------------------------------===// @@ -4053,8 +4084,8 @@ printer << " : " << getPointer().getType() << " as " << getType(); } -static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, - Type coopMatrix) { +static LogicalResult +verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) { Type pointeeType = llvm::cast(pointer).getPointeeType(); if (!llvm::isa(pointeeType) && !llvm::isa(pointeeType)) @@ -4074,8 +4105,8 @@ } LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() { - return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), - getResult().getType()); + return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(), + getResult().getType()); } //===----------------------------------------------------------------------===// @@ -4114,8 +4145,8 @@ } LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() { - return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), - getObject().getType()); + return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(), + getObject().getType()); } //===----------------------------------------------------------------------===// @@ -4123,7 +4154,7 @@ //===----------------------------------------------------------------------===// static LogicalResult -verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) { +verifyCoopMatrixMulAddNV(spirv::NVCooperativeMatrixMulAddOp op) { if (op.getC().getType() != op.getResult().getType()) return op.emitOpError("result and third operand must have the same type"); auto typeA = llvm::cast(op.getA().getType()); @@ -4156,9 +4187,13 @@ } LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() { - return verifyCoopMatrixMulAdd(*this); + return verifyCoopMatrixMulAddNV(*this); } +//===----------------------------------------------------------------------===// +// spirv.INTEL.JointMatrixLoad +//===----------------------------------------------------------------------===// + static LogicalResult verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { Type pointeeType = llvm::cast(pointer).getPointeeType(); @@ -4179,10 +4214,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// spirv.INTEL.JointMatrixLoad -//===----------------------------------------------------------------------===// - LogicalResult spirv::INTELJointMatrixLoadOp::verify() { return verifyPointerAndJointMatrixType(*this, getPointer().getType(), getResult().getType()); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -18,7 +18,9 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include #include +#include using namespace mlir; using namespace mlir::spirv; @@ -93,9 +95,10 @@ bool CompositeType::classof(Type type) { if (auto vectorType = llvm::dyn_cast(type)) return isValid(vectorType); - return llvm::isa(type); + return llvm::isa(type); } bool CompositeType::isValid(VectorType type) { @@ -114,8 +117,8 @@ Type CompositeType::getElementType(unsigned index) const { return TypeSwitch(*this) - .Case( + .Case( [](auto type) { return type.getElementType(); }) .Case([](MatrixType type) { return type.getColumnType(); }) .Case( @@ -133,9 +136,9 @@ return structType.getNumElements(); if (auto vectorType = llvm::dyn_cast(*this)) return vectorType.getNumElements(); - if (llvm::isa(*this)) { + if (llvm::isa(*this)) { llvm_unreachable( - "invalid to query number of elements of spirv::CooperativeMatrix type"); + "invalid to query number of elements of spirv Cooperative Matrix type"); } if (llvm::isa(*this)) { llvm_unreachable( @@ -149,16 +152,16 @@ } bool CompositeType::hasCompileTimeKnownNumElements() const { - return !llvm::isa(*this); + return !llvm::isa(*this); } void CompositeType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { TypeSwitch(*this) - .Case( + .Case( [&](auto type) { type.getExtensions(extensions, storage); }) .Case([&](VectorType type) { return llvm::cast(type.getElementType()) @@ -171,8 +174,8 @@ SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { TypeSwitch(*this) - .Case( + .Case( [&](auto type) { type.getCapabilities(capabilities, storage); }) .Case([&](VectorType type) { auto vecSize = getNumElements(); @@ -202,6 +205,74 @@ return std::nullopt; } +//===----------------------------------------------------------------------===// +// CooperativeMatrixType +//===----------------------------------------------------------------------===// + +struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage { + using KeyTy = std::tuple; + + static CooperativeMatrixTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + CooperativeMatrixTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(elementType, scope, rows, columns, use); + } + + CooperativeMatrixTypeStorage(const KeyTy &key) + : elementType(std::get<0>(key)), scope(std::get<1>(key)), + rows(std::get<2>(key)), columns(std::get<3>(key)), + use(std::get<4>(key)) {} + + Type elementType; + Scope scope; + uint32_t rows; + uint32_t columns; + uint32_t use; +}; + +CooperativeMatrixType CooperativeMatrixType::get(Type elementType, Scope scope, + uint32_t rows, + uint32_t columns, + uint32_t use) { + return Base::get(elementType.getContext(), elementType, scope, rows, columns, + use); +} + +Type CooperativeMatrixType::getElementType() const { + return getImpl()->elementType; +} + +Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; } + +uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; } + +uint32_t CooperativeMatrixType::getColumns() const { + return getImpl()->columns; +} + +uint32_t CooperativeMatrixType::getUse() const { return getImpl()->use; } + +void CooperativeMatrixType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage) { + llvm::cast(getElementType()).getExtensions(extensions, storage); + static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix}; + extensions.push_back(exts); +} + +void CooperativeMatrixType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + std::optional storage) { + llvm::cast(getElementType()) + .getCapabilities(capabilities, storage); + static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR}; + capabilities.push_back(caps); +} + //===----------------------------------------------------------------------===// // CooperativeMatrixNVType //===----------------------------------------------------------------------===// @@ -1247,7 +1318,7 @@ //===----------------------------------------------------------------------===// void SPIRVDialect::registerTypes() { - addTypes(); + addTypes(); } diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -138,6 +138,14 @@ // ----- +func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, 0, Subgroup>) { + // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.coopmatrix<8x16xf32, 0, Subgroup> to !spirv.coopmatrix<8x16xi32, 0, Subgroup> + %0 = spirv.ConvertFToU %arg0 : !spirv.coopmatrix<8x16xf32, 0, Subgroup> to !spirv.coopmatrix<8x16xi32, 0, Subgroup> + spirv.Return +} + +// ----- + func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup> %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup> @@ -222,7 +230,15 @@ // ----- -func.func @f_convert_coop_matrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { +func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, 1, Subgroup>) { + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, 1, Subgroup> to !spirv.coopmatrix<8x16xf64, 1, Subgroup> + %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, 1, Subgroup> to !spirv.coopmatrix<8x16xf64, 1, Subgroup> + spirv.Return +} + +// ----- + +func.func @f_convert_coop_matrix_nv(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> spirv.Return @@ -238,6 +254,14 @@ // ----- +func.func @f_convert_coop_matrix_to_nv_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, 1, Subgroup>) { + // expected-error @+1 {{incompatible operand and result types}} + %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, 1, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> + spirv.Return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.SConvert //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -1,4 +1,29 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect --split-input-file --verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// CooperativeMatrix (KHR) +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @cooperative_matrix_length +spirv.func @cooperative_matrix_length() -> i32 "None" { + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, 0, Subgroup> + %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, 0, Subgroup> + spirv.ReturnValue %0 : i32 +} + +// ----- + +spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" { + // expected-error @+1 {{'spirv.KHR.CooperativeMatrixLength' op type attribute must be a '!spirv.coopmatrix'}} + %0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> + spirv.ReturnValue %0 : i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// NV.CooperativeMatrix +//===----------------------------------------------------------------------===// // CHECK-LABEL: @cooperative_matrix_load spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { @@ -7,7 +32,6 @@ spirv.Return } -// ----- // CHECK-LABEL: @cooperative_matrix_load_memaccess spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> @@ -164,3 +188,11 @@ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } + +// ----- + +spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" { + // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}} + %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, 0, Subgroup> + spirv.ReturnValue %0 : i32 +} diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -436,11 +436,50 @@ // ----- //===----------------------------------------------------------------------===// -// CooperativeMatrix +// CooperativeMatrix (KHR) //===----------------------------------------------------------------------===// -// CHECK: func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -func.func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> () +// CHECK: func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, 0, Subgroup>, !spirv.coopmatrix<8x8xf32, 1, Workgroup>) +func.func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, 0, Subgroup>, !spirv.coopmatrix<8x8xf32, 1, Workgroup>) -> () + +// ----- + +// expected-error @+1 {{expected ','}} +func.func private @missing_scope(!spirv.coopmatrix<8x16xi32, 2>) -> () + +// ----- + +// expected-error @+1 {{expected valid keyword}} +func.func private @missing_scope2(!spirv.coopmatrix<8x8xi32, 3, >) -> () + +// ----- + +// expected-error @+1 {{expected row and column count}} +func.func private @missing_count(!spirv.coopmatrix<8xi32, Subgroup>) -> () + +// ----- + +// expected-error @+1 {{expected row and column count}} +func.func private @too_many_dims(!spirv.coopmatrix<8x16x32xi32, 2, Subgroup>) -> () + +// ----- + +// expected-error @+1 {{expected integer value}} +func.func private @missing_use(!spirv.coopmatrix<8x8xi32, >) -> () + +// ----- + +// expected-error @+1 {{expected integer value}} +func.func private @use_not_integer(!spirv.coopmatrix<8x8xi32, Subgroup, Subgroup>) -> () + +// ----- + +//===----------------------------------------------------------------------===// +// NV.CooperativeMatrix +//===----------------------------------------------------------------------===// + +// CHECK: func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) +func.func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> () // ----- diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -857,7 +857,7 @@ << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "opBuilder.getI32IntegerAttr({2}[{3}++])));\n", attrList, attrName, words, wordIndex); - } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") { + } else if (attr.isEnumAttr() || attr.isTypeAttr()) { os << tabs << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "TypeAttr::get(getType({2}[{3}++]))));\n", @@ -866,7 +866,7 @@ PrintFatalError( loc, llvm::Twine( "unhandled attribute type in deserialization generation : '") + - attr.getAttrDefName() + llvm::Twine("'")); + attrName + llvm::Twine("'")); } }