diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4061,6 +4061,17 @@ SPIRV_KHR_CMU_MatrixA, SPIRV_KHR_CMU_MatrixB, SPIRV_KHR_CMU_MatrixAcc ]>; +// Cooperative Matrix Layout for the SPV_KHR_cooperative_matrix extension. +def SPIRV_KHR_CML_RowMajor : I32EnumAttrCase<"RowMajor", 0>; +def SPIRV_KHR_CML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>; + +def SPIRV_KHR_CooperativeMatrixLayoutAttr : + SPIRV_I32EnumAttr<"CooperativeMatrixLayoutKHR", + "valid SPIR-V Cooperative Matrix Layout (KHR)", + "coop_matrix_layout_khr", [ + SPIRV_KHR_CML_RowMajor, SPIRV_KHR_CML_ColumnMajor + ]>; + //===----------------------------------------------------------------------===// // SPIR-V attribute definitions //===----------------------------------------------------------------------===// @@ -4435,6 +4446,7 @@ def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454>; def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>; def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>; +def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>; def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>; def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>; def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>; @@ -4534,7 +4546,8 @@ SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, - SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpCooperativeMatrixLengthKHR, + SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR, + SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -63,6 +63,77 @@ ); } +// ----- + +def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", []> { + let summary = "Loads a cooperative matrix through a pointer"; + + let description = [{ + Load a cooperative matrix through a pointer. + + Result Type is the type of the loaded object. It must be a cooperative + matrix type. + + Pointer is a pointer. Its type must be an OpTypePointer whose Type operand is + a scalar or vector type. If the Shader capability was declared, Pointer must + point into an array and any ArrayStride decoration on Pointer is ignored. + + MemoryLayout specifies how matrix elements are laid out in memory. It must + come from a 32-bit integer constant instruction whose value corresponds to a + Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a + description of the layouts and detailed layout-specific rules. + + Stride further qualifies how matrix elements are laid out in memory. It must + be a scalar integer type and its exact semantics depend on MemoryLayout. + + Memory Operand must be a Memory Operand literal. If not present, it is the + same as specifying None. + + NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known + as 'Memory Access'. + + For a given dynamic instance of this instruction, all operands of this + instruction must be the same for all invocations in a given scope instance + (where the scope is the scope the cooperative matrix type was created with). + All invocations in a given scope instance must be active or all must be + inactive. + + ``` {.ebnf} + cooperative-matrix-load-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLoad` + ssa-use `,` ssa-use `,` + cooperative-matrix-layout `,` + (`[` memory-operand `]`)? ` : ` + pointer-type `as` cooperative-matrix-type + ``` + + #### Example: + + ``` + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor + : !spirv.ptr + as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_KHR_cooperative_matrix]>, + Capability<[SPIRV_C_CooperativeMatrixKHR]> + ]; + + let arguments = (ins + SPIRV_AnyPtr:$pointer, + SPIRV_Integer:$stride, + SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout, + OptionalAttr:$memory_operand + ); + + let results = (outs + SPIRV_AnyCooperativeMatrix:$result + ); +} + //===----------------------------------------------------------------------===// // SPV_NV_cooperative_matrix extension ops. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1445,13 +1445,13 @@ std::enable_if_t::value, ParseResult> resolveOperands(Operands &&operands, Types &&types, SMLoc loc, SmallVectorImpl &result) { - size_t operandSize = std::distance(operands.begin(), operands.end()); - size_t typeSize = std::distance(types.begin(), types.end()); + size_t operandSize = llvm::range_size(operands); + size_t typeSize = llvm::range_size(types); if (operandSize != typeSize) return emitError(loc) << operandSize << " operands present, but expected " << typeSize; - for (auto [operand, type] : llvm::zip(operands, types)) + for (auto [operand, type] : llvm::zip_equal(operands, types)) if (resolveOperand(operand, type, result)) return failure(); return success(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -36,6 +36,7 @@ #include "llvm/Support/FormatVariadic.h" #include #include +#include using namespace mlir; @@ -53,7 +54,9 @@ constexpr char kIndicesAttrName[] = "indices"; constexpr char kInitializerAttrName[] = "initializer"; constexpr char kInterfaceAttrName[] = "interface"; +constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout"; constexpr char kMemoryAccessAttrName[] = "memory_access"; +constexpr char kMemoryOperandAttrName[] = "memory_operand"; constexpr char kMemoryScopeAttrName[] = "memory_scope"; constexpr char kPackedVectorFormatAttrName[] = "format"; constexpr char kSemanticsAttrName[] = "semantics"; @@ -176,6 +179,7 @@ static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName = spirv::attributeName()) { + static_assert(std::is_enum_v); Attribute attrVal; NamedAttrList attr; auto loc = parser.getCurrentLocation(); @@ -202,6 +206,7 @@ static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, StringRef attrName = spirv::attributeName()) { + static_assert(std::is_enum_v); if (parseEnumStrAttr(value, parser)) return failure(); state.addAttribute(attrName, @@ -218,6 +223,7 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, StringRef attrName = spirv::attributeName()) { + static_assert(std::is_enum_v); if (parseEnumKeywordAttr(value, parser)) return failure(); state.addAttribute(attrName, @@ -246,14 +252,15 @@ return success(); } -/// Parses optional memory access attributes attached to a memory access -/// operand/pointer. Specifically, parses the following syntax: +/// Parses optional memory access (a.k.a. memory operand) attributes attached to +/// a memory access operand/pointer. Specifically, parses the following syntax: /// (`[` memory-access `]`)? /// where: /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` /// integer-literal | `"NonTemporal"` -static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, - OperationState &state) { +static ParseResult +parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state, + StringRef attrName = kMemoryAccessAttrName) { // Parse an optional list of attributes staring with '[' if (parser.parseOptionalLSquare()) { // Nothing to do @@ -262,7 +269,7 @@ spirv::MemoryAccess memoryAccessAttr; if (parseEnumStrAttr(memoryAccessAttr, parser, state, - kMemoryAccessAttrName)) + attrName)) return failure(); if (spirv::bitEnumContainsAll(memoryAccessAttr, @@ -4035,6 +4042,75 @@ return success(); } +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixLoad +//===----------------------------------------------------------------------===// + +ParseResult spirv::KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser, + OperationState &result) { + std::array operandInfo = {}; + if (parser.parseOperand(operandInfo[0]) || parser.parseComma()) + return failure(); + if (parser.parseOperand(operandInfo[1]) || parser.parseComma()) + return failure(); + + spirv::CooperativeMatrixLayoutKHR layout; + if (::parseEnumKeywordAttr( + layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) { + return failure(); + } + + if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName)) + return failure(); + + Type ptrType; + Type elementType; + if (parser.parseColon() || parser.parseType(ptrType) || + parser.parseKeywordType("as", elementType)) { + return failure(); + } + result.addTypes(elementType); + + Type strideType = parser.getBuilder().getIntegerType(32); + if (parser.resolveOperands(operandInfo, {ptrType, strideType}, + parser.getNameLoc(), result.operands)) { + return failure(); + } + + return success(); +} + +void spirv::KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { + printer << " " << getPointer() << ", " << getStride() << ", " + << getMatrixLayout(); + // Print optional memory operand attribute. + if (auto memOperand = getMemoryOperand()) + printer << " [\"" << memOperand << "\"]"; + printer << " : " << getPointer().getType() << " as " << getType(); +} + +static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, + Type coopMatrix) { + auto pointerType = cast(pointer); + Type pointeeType = pointerType.getPointeeType(); + if (!isa(pointeeType)) { + return op->emitError( + "Pointer must point to a scalar or vector type but provided ") + << pointeeType; + } + + // TODO: Verify the memory object behind the pointer: + // > If the Shader capability was declared, Pointer must point into an array + // > and any ArrayStride decoration on Pointer is ignored. + + return success(); +} + +LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() { + return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), + getResult().getType()); +} + //===----------------------------------------------------------------------===// // spirv.NV.CooperativeMatrixLength //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// // CooperativeMatrix (KHR) @@ -21,6 +21,80 @@ // ----- +// CHECK-LABEL: @cooperative_matrix_load +spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32) "None" { + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor : + // CHECK-SAME: !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor : + !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + spirv.Return +} + +// CHECK-LABEL: @cooperative_matrix_load_memoperand +spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr, %stride : i32) "None" { + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] : + // CHECK-SAME: !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor ["Volatile"] : + !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + spirv.Return +} + +// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type +spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32) "None" { + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor ["Volatile"] : + // CHECK-SAME: !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor ["Volatile"] : + !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> + spirv.Return +} + +// CHECK-LABEL: @cooperative_matrix_load_function +spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr, %stride : i32) "None" { + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor : + // CHECK-SAME: !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor : + !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc> + spirv.Return +} + +// ----- + +spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32) "None" { + // expected-error @+1 {{Pointer must point to a scalar or vector type}} + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor : + !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr, %stride : i32) "None" { + // expected-error @+1 {{expected ','}} + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride : + !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr, %stride : i32) "None" { + // expected-error @+1 {{expected valid keyword}} + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : + !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" { + // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}} + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor : + !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> + spirv.Return +} + +// ----- + //===----------------------------------------------------------------------===// // NV.CooperativeMatrix //===----------------------------------------------------------------------===//