diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4447,6 +4447,7 @@ def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>; def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>; def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>; +def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>; def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>; def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>; def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>; @@ -4546,11 +4547,12 @@ SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, - SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR, - SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR, - SPIRV_OC_OpTypeCooperativeMatrixNV, - SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV, - SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV, + SPIRV_OC_OpSUDotAccSat, + SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR, + SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR, + SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV, + SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV, + SPIRV_OC_OpCooperativeMatrixLengthNV, SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -134,6 +134,75 @@ ); } +// ----- + +def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> { + let summary = "Stores a cooperative matrix through a pointer"; + + let description = [{ + Store a cooperative matrix through a pointer. + Pointer is a pointer. Its type must be an OpTypePointer whose Type operand + is a scalar or vector type. If the Shader capability was declared, Pointer + must point into an array and any ArrayStride decoration on Pointer is + ignored. + + Object is the object to store. Its type must be an + OpTypeCooperativeMatrixKHR. + + MemoryLayout specifies how matrix elements are laid out in memory. It must + come from a 32-bit integer constant instruction whose value corresponds to a + Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a + description of the layouts and detailed layout-specific rules. + + Stride further qualifies how matrix elements are laid out in memory. It must + be a scalar integer type and its exact semantics depend on MemoryLayout. + + Memory Operand must be a Memory Operand literal. If not present, it is the + same as specifying None. + + NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known + as 'Memory Access'. + + For a given dynamic instance of this instruction, all operands of this + instruction must be the same for all invocations in a given scope instance + (where the scope is the scope the cooperative matrix type was created with). + All invocations in a given scope instance must be active or all must be + inactive. + + ``` {.ebnf} + coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore ` + ssa-use `, ` ssa-use `, ` + ssa-use `, ` cooperative-matrix-layout `, ` + (`[` memory-operand `]`)? `:` + pointer-type `,` coop-matrix-type + ``` + + #### Example: + + ``` + spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride : + !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_KHR_cooperative_matrix]>, + Capability<[SPIRV_C_CooperativeMatrixKHR]> + ]; + + let arguments = (ins + SPIRV_AnyPtr:$pointer, + SPIRV_AnyCooperativeMatrix:$object, + SPIRV_Integer:$stride, + SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout, + OptionalAttr:$memory_operand + ); + + let results = (outs); +} + //===----------------------------------------------------------------------===// // SPV_NV_cooperative_matrix extension ops. //===----------------------------------------------------------------------===// @@ -364,7 +433,7 @@ ssa-use `, ` ssa-use `, ` ssa-use `, ` ssa-use `, ` (`[` memory-access `]`)? `:` - pointer-type `,` spirv-element-type + pointer-type `,` coop-matrix-type ``` For example: diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -4111,6 +4111,58 @@ getResult().getType()); } +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixStore +//===----------------------------------------------------------------------===// + +ParseResult spirv::KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser, + OperationState &result) { + std::array operandInfo = {}; + for (auto &op : operandInfo) { + if (parser.parseOperand(op) || parser.parseComma()) + return failure(); + } + + spirv::CooperativeMatrixLayoutKHR layout; + if (::parseEnumKeywordAttr( + layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) { + return failure(); + } + + if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName)) + return failure(); + + Type ptrType; + Type objectType; + if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() || + parser.parseType(objectType)) { + return failure(); + } + + Type strideType = parser.getBuilder().getIntegerType(32); + if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType}, + parser.getNameLoc(), result.operands)) { + return failure(); + } + + return success(); +} + +void spirv::KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { + printer << " " << getPointer() << ", " << getObject() << ", " << getStride() + << ", " << getMatrixLayout(); + + // Print optional memory operand attribute. + if (auto memOperand = getMemoryOperand()) + printer << " [\"" << *memOperand << "\"]"; + printer << " : " << getPointer().getType() << ", " << getObject().getType(); +} + +LogicalResult spirv::KHRCooperativeMatrixStoreOp::verify() { + return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), + getObject().getType()); +} + //===----------------------------------------------------------------------===// // spirv.NV.CooperativeMatrixLength //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -57,6 +57,27 @@ spirv.Return } +// CHECK-LABEL: @cooperative_matrix_store +spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, + %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor : + // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA> + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA> + spirv.Return +} + +// CHECK-LABEL: @cooperative_matrix_store_memoperand +spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr, + %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, + %stride : i32) "None" { + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] : + // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> + spirv.Return +} + // ----- spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32) "None" { @@ -95,6 +116,36 @@ // ----- +spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr, %stride : i32, + %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { + // expected-error @+1 {{expected ','}} + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr, %stride : i32, + %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { + // expected-error @+1 {{expected valid keyword}} + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, + %stride : i32) "None" { + // expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}} + spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor : + !spirv.ptr, i32 + spirv.Return +} + +// ----- + //===----------------------------------------------------------------------===// // NV.CooperativeMatrix //===----------------------------------------------------------------------===//