Index: mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2993,6 +2993,8 @@ def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; +def SPV_IsCooperativeMatrixType : + CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">; def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; @@ -3012,6 +3014,9 @@ "any SPIR-V pointer type">; def SPV_AnyArray : DialectType; +def SPV_AnyCooperativeMatrix : DialectType; def SPV_AnyRTArray : DialectType; def SPV_AnyStruct : DialectType; def SPV_OC_OpGroupNonUniformFMax : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>; def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; +def SPV_OC_OpTypeCooperativeMatrix : I32EnumAttrCase<"OpTypeCooperativeMatrix", 5358>; +def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>; def SPV_OpcodeAttr : SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ @@ -3271,7 +3278,8 @@ SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax, - SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR + SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, + SPV_OC_OpTypeCooperativeMatrix, SPV_OC_OpCooperativeMatrixLoadNV ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! Index: mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td @@ -0,0 +1,72 @@ +//===- SPIRVCooperativeMatrixOps.td - cooperative matmul ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of cooperative matrix multiply extension ops. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_COOPERATIVE_MATRIX_OPS +#define SPIRV_COOPERATIVE_MATRIX_OPS + +// ----- + +def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> { + let summary = "See extension SPV_NV_cooperative_matrix"; + + let description = [{ + Load a cooperative matrix through a pointer. + + Stride is the number of elements in the array in memory between the first + component of consecutive rows (or columns) in the result. It must be a + scalar integer type. + + columnMajor must be a Boolean type. + + Result Type is the type of the loaded object. It must be a cooperative + matrix type. + + ### Custom assembly form + + ``` {.ebnf} + cooperative-matrix-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV` + storage-class ssa-use (`[` memory-access `]`)? ` + : ` cooperative-matrix-type + ``` + + For example: + + ``` + %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor + : !spv.coopmatrix + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_KHR_vulkan_memory_model, SPV_EXT_physical_storage_buffer]>, + Capability<[SPV_C_CooperativeMatrixNV]> + ]; + + let arguments = (ins + SPV_AnyPtr:$pointer, + SPV_Integer:$stride, + SPV_Bool:$columnMajor, + OptionalAttr:$memory_access + ); + + let results = (outs + SPV_AnyCooperativeMatrix:$result + ); + + let verifier = [{ return success(); }]; +} + +// ----- + +#endif // SPIRV_COOPERATIVE_MATRIX_OPS Index: mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -28,6 +28,7 @@ include "mlir/Dialect/SPIRV/SPIRVCastOps.td" include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td" include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td" +include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td" include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" Index: mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -54,6 +54,7 @@ namespace detail { struct ArrayTypeStorage; +struct CooperativeMatrixTypeStorage; struct ImageTypeStorage; struct PointerTypeStorage; struct RuntimeArrayTypeStorage; @@ -63,6 +64,7 @@ namespace TypeKind { enum Kind { Array = Type::FIRST_SPIRV_TYPE, + CooperativeMatrix, Image, Pointer, RuntimeArray, @@ -330,6 +332,34 @@ Optional storage = llvm::None); }; +// SPIR-V cooperative matrix type +class CooperativeMatrixNVType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == TypeKind::CooperativeMatrix; + } + + static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope, + unsigned rows, unsigned columns); + Type getElementType() const; + + /// Return the scope of the cooperative matrix. + spirv::Scope getScope() const; + /// return the number of rows of the matrix. + unsigned getRows() const; + /// return the number of columns of the matrix. + + unsigned getColumns() const; + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); +}; + } // end namespace spirv } // end namespace mlir Index: mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -115,7 +115,8 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addTypes(); + addTypes(); addAttributes(); @@ -264,6 +265,41 @@ return ArrayType::get(elementType, count, stride); } +// cooperative-matrix-type ::= `!spv.coopmat` `<` element-type ',' scope ',' +// rows ',' coloumns>` +static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return Type(); + + auto elementTy = parseAndVerifyType(dialect, parser); + if (!elementTy) + return Type(); + + StringRef scopeClass; + llvm::SMLoc scopeLoc = parser.getCurrentLocation(); + if (parser.parseComma() || parser.parseKeyword(&scopeClass)) + return Type(); + + auto scope = symbolizeScope(scopeClass); + if (!scope) { + parser.emitError(scopeLoc, "unknown scope class ") << scopeClass; + return Type(); + } + unsigned rows, columns; + if (parser.parseComma() || parser.parseInteger(rows) || parser.parseComma() || + parser.parseInteger(columns)) { + parser.emitError(parser.getCurrentLocation(), + "missing number of rows or columns"); + return Type(); + } + + if (parser.parseGreater()) + return Type(); + return CooperativeMatrixNVType::get(elementTy, scope.getValue(), rows, + columns); +} + // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type // methods in alphabetical order // @@ -525,6 +561,8 @@ if (keyword == "array") return parseArrayType(*this, parser); + if (keyword == "coopmatrix") + return parseCooperativeMatrixType(*this, parser); if (keyword == "image") return parseImageType(*this, parser); if (keyword == "ptr") @@ -595,11 +633,20 @@ os << ">"; } +static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { + os << "coopmatrix<" << type.getElementType() << ", "; + os << stringifyScope(type.getScope()) << ", " << type.getRows() << ", "; + os << type.getColumns() << ">"; +} + void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { switch (type.getKind()) { case TypeKind::Array: print(type.cast(), os); return; + case TypeKind::CooperativeMatrix: + print(type.cast(), os); + return; case TypeKind::Pointer: print(type.cast(), os); return; Index: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2637,6 +2637,49 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.CooperativeMatrix +//===----------------------------------------------------------------------===// + +static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser, + OperationState &state) { + spirv::StorageClass storageClass; + SmallVector operandInfo; + Type strideType = parser.getBuilder().getIntegerType(32); + Type columnMajorType = parser.getBuilder().getIntegerType(1); + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || + parser.parseOperandList(operandInfo, 3) || + parseMemoryAccessAttributes(parser, state) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get( + elementType.cast().getElementType(), + storageClass); + SmallVector OperandType = {ptrType, strideType, columnMajorType}; + if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(), + state.operands)) { + return failure(); + } + + state.addTypes(elementType); + return success(); +} + +static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) { + StringRef sc = stringifyStorageClass( + M.pointer().getType().cast().getStorageClass()); + printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc + << "\" " << M.pointer() << ", " << M.stride() << ", " + << M.columnMajor(); + // Print optional memory access attribute. + if (auto memAccess = M.memory_access()) + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; + printer << " : " << M.getType(); +} + namespace mlir { namespace spirv { Index: mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -158,6 +158,7 @@ bool CompositeType::classof(Type type) { switch (type.getKind()) { case TypeKind::Array: + case TypeKind::CooperativeMatrix: case TypeKind::RuntimeArray: case TypeKind::Struct: return true; @@ -177,6 +178,8 @@ switch (getKind()) { case spirv::TypeKind::Array: return cast().getElementType(); + case spirv::TypeKind::CooperativeMatrix: + return cast().getElementType(); case spirv::TypeKind::RuntimeArray: return cast().getElementType(); case spirv::TypeKind::Struct: @@ -192,6 +195,9 @@ switch (getKind()) { case spirv::TypeKind::Array: return cast().getNumElements(); + case spirv::TypeKind::CooperativeMatrix: + return cast().getRows() * + cast().getColumns(); case spirv::TypeKind::RuntimeArray: llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -211,6 +217,9 @@ case spirv::TypeKind::Array: cast().getExtensions(extensions, storage); break; + case spirv::TypeKind::CooperativeMatrix: + cast().getExtensions(extensions, storage); + break; case spirv::TypeKind::RuntimeArray: cast().getExtensions(extensions, storage); break; @@ -233,6 +242,9 @@ case spirv::TypeKind::Array: cast().getCapabilities(capabilities, storage); break; + case spirv::TypeKind::CooperativeMatrix: + cast().getCapabilities(capabilities, storage); + break; case spirv::TypeKind::RuntimeArray: cast().getCapabilities(capabilities, storage); break; @@ -248,6 +260,67 @@ } } +//===----------------------------------------------------------------------===// +// CooperativeMatrixType +//===----------------------------------------------------------------------===// + +struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage { + using KeyTy = std::tuple; + + static CooperativeMatrixTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + CooperativeMatrixTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(elementType, getScope(), rows, columns); + } + + CooperativeMatrixTypeStorage(const KeyTy &key) + : TypeStorage(static_cast(std::get<1>(key))), + elementType(std::get<0>(key)), rows(std::get<2>(key)), + columns(std::get<3>(key)) {} + + Scope getScope() const { return static_cast(getSubclassData()); } + Type elementType; + unsigned rows; + unsigned columns; +}; + +CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType, + Scope scope, unsigned rows, + unsigned columns) { + return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix, + elementType, scope, rows, columns); +} + +Type CooperativeMatrixNVType::getElementType() const { + return getImpl()->elementType; +} + +Scope CooperativeMatrixNVType::getScope() const { + return getImpl()->getScope(); +} + +unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; } + +unsigned CooperativeMatrixNVType::getColumns() const { + return getImpl()->columns; +} + +void CooperativeMatrixNVType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + getElementType().cast().getExtensions(extensions, storage); +} + +void CooperativeMatrixNVType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + getElementType().cast().getCapabilities(capabilities, storage); +} + //===----------------------------------------------------------------------===// // ImageType //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -215,6 +215,8 @@ LogicalResult processArrayType(ArrayRef operands); + LogicalResult processCooperativeMatrixType(ArrayRef operands); + LogicalResult processFunctionType(ArrayRef operands); LogicalResult processRuntimeArrayType(ArrayRef operands); @@ -1158,6 +1160,8 @@ } break; case spirv::Opcode::OpTypeArray: return processArrayType(operands); + case spirv::Opcode::OpTypeCooperativeMatrix: + return processCooperativeMatrixType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); case spirv::Opcode::OpTypeRuntimeArray: @@ -1227,6 +1231,35 @@ return success(); } +LogicalResult +Deserializer::processCooperativeMatrixType(ArrayRef operands) { + if (operands.size() != 5) { + return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " + "type and row x column parameters"); + } + + Type elementTy = getType(operands[1]); + if (!elementTy) { + return emitError(unknownLoc, + "OpTypeCooperativeMatrix references undefined ") + << operands[1]; + } + + auto scope = spirv::symbolizeScope(operands[2]); + if (!scope) { + return emitError(unknownLoc, + "OpTypeCooperativeMatrix references undefined scope ") + << operands[2]; + } + + unsigned rows = operands[3]; + unsigned columns = operands[4]; + + typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( + elementTy, scope.getValue(), rows, columns); + return success(); +} + LogicalResult Deserializer::processRuntimeArrayType(ArrayRef operands) { if (operands.size() != 2) { @@ -2198,6 +2231,7 @@ case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: + case spirv::Opcode::OpTypeCooperativeMatrix: return processType(opcode, operands); case spirv::Opcode::OpConstant: return processConstant(operands, /*isSpec=*/false); Index: mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1090,6 +1090,21 @@ return success(); } + if (auto cooperativeMatrixType = + type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processType(loc, cooperativeMatrixType.getElementType(), + elementTypeID))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeCooperativeMatrix; + operands.push_back(elementTypeID); + operands.push_back(static_cast(cooperativeMatrixType.getScope())); + operands.push_back(cooperativeMatrixType.getRows()); + operands.push_back(cooperativeMatrixType.getColumns()); + return success(); + } + // TODO(ravishankarm) : Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } Index: mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK-LABEL: @cooperative_matrix_load + spv.func @cooperative_matrix_load(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix + %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_load_memaccess + spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix + %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix + spv.Return + } +}