Index: google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td =================================================================== --- google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td +++ google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td @@ -22,7 +22,18 @@ // Operands type same as result type. SPV_BinaryOp; + [NoSideEffect, SameOperandsAndResultType])> { + // In addition to normal types arithmetic instructions can support cooperative + // matrix. + let arguments = (ins + SPV_ScalarOrVectorOrCoopMatrixOf:$operand1, + SPV_ScalarOrVectorOrCoopMatrixOf:$operand2 + ); + + let results = (outs + SPV_ScalarOrVectorOrCoopMatrixOf:$result + ); +} class SPV_ArithmeticUnaryOp traits = []> : Index: google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td =================================================================== --- google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3003,6 +3003,7 @@ def SPV_Void : TypeAlias; def SPV_Bool : TypeAlias; +def SPV_Int32 : TypeAlias; def SPV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>; def SPV_Float : FloatOfWidths<[16, 32, 64]>; def SPV_Float16or32 : FloatOfWidths<[16, 32]>; @@ -3034,9 +3035,18 @@ def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; +class SPV_CoopMatrixOfType allowedTypes> : + ContainerType, SPV_IsCooperativeMatrixType, + "$_self.cast<::mlir::spirv::CooperativeMatrixNVType>().getElementType()", + "Cooperative Matrix">; + class SPV_ScalarOrVectorOf : AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>; +class SPV_ScalarOrVectorOrCoopMatrixOf : + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>, + SPV_CoopMatrixOfType<[type]>]>; + def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; @@ -3227,6 +3237,9 @@ def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>; def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>; +def SPV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>; +def SPV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>; +def SPV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>; def SPV_OpcodeAttr : SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ @@ -3279,7 +3292,9 @@ SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, - SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV + SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV, + SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, + SPV_OC_OpCooperativeMatrixLengthNV ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! Index: google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td =================================================================== --- google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td +++ google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td @@ -23,11 +23,11 @@ !listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> { let arguments = (ins - SPV_ScalarOrVectorOf:$operand + SPV_ScalarOrVectorOrCoopMatrixOf:$operand ); let results = (outs - SPV_ScalarOrVectorOf:$result + SPV_ScalarOrVectorOrCoopMatrixOf:$result ); let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; Index: google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td =================================================================== --- google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td +++ google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td @@ -15,6 +15,49 @@ // ----- +def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV", + [NoSideEffect]> { + let summary = "See extension SPV_NV_cooperative_matrix"; + + let description = [{ + Number of components of a cooperative matrix type accessible to each + invocation when treated as a composite. + + Result Type must be an OpTypeInt with 32-bit Width and 0 Signedness. + + Type is a cooperative matrix type. + + ``` {.ebnf} + cooperative-matrix-length-op ::= ssa-id `=` `spv.CooperativeMatrixLengthNV + ` : ` cooperative-matrix-type + ``` + + For example: + + ``` + %0 = spv.CooperativeMatrixLengthNV : !spv.coopmatrix + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_NV_cooperative_matrix]>, + Capability<[SPV_C_CooperativeMatrixNV]> + ]; + + let arguments = (ins + TypeAttr:$type + ); + + let results = (outs + SPV_Int32:$result + ); + let verifier = [{ return success(); }]; +} + +// ----- + def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> { let summary = "See extension SPV_NV_cooperative_matrix"; @@ -55,9 +98,10 @@ ### Custom assembly form ``` {.ebnf} - cooperative-matrix-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV` - storage-class ssa-use (`[` memory-access `]`)? ` - : ` cooperative-matrix-type + cooperative-matrixload-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV` + storage-class ssa-use `,` ssa-use `,` ssa-use + (`[` memory-access `]`)? ` : ` + cooperative-matrix-type ``` For example: @@ -91,4 +135,136 @@ // ----- +def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV", + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "See extension SPV_NV_cooperative_matrix"; + + let description = [{ + Linear-algebraic matrix multiply of A by B and then component-wise add C. + The order of the operations is implementation-dependent. The internal + precision of floating-point operations is defined by the client API. + Integer operations are performed at the precision of the Result Type and are + exact unless there is overflow or underflow, in which case the result is + undefined. + + Result Type must be a cooperative matrix type with M rows and N columns. + + A is a cooperative matrix with M rows and K columns. + + B is a cooperative matrix with K rows and N columns. + + C is a cooperative matrix with M rows and N columns. + + The values of M, N, and K must be consistent across the result and operands. + This is referred to as an MxNxK matrix multiply. + + A, B, C, and Result Type must have the same scope, and this defines the + scope of the operation. A, B, C, and Result Type need not necessarily have + the same component type, this is defined by the client API. + + If the Component Type of any matrix operand is an integer type, then its + components are treated as signed if its Component Type has Signedness of 1 + and are treated as unsigned otherwise. + + For a given dynamic instance of this instruction, all invocations in a given + scope instance must be active or all must be inactive (where the scope is + the scope of the operation). + + ``` {.ebnf} + cooperative-matrixmuladd-op ::= ssa-id `=` `spv.CooperativeMatrixMulAddNV` + sssa-use `,` ssa-use `,` ssa-use ` : ` + cooperative-matrix-type + ``` + For example: + + ``` + %0 = spv.CooperativeMatrixMulAddNV %arg0, %arg1, %arg2, : + !spv.coopmatrix + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_NV_cooperative_matrix]>, + Capability<[SPV_C_CooperativeMatrixNV]> + ]; + + let arguments = (ins + SPV_AnyCooperativeMatrix:$a, + SPV_AnyCooperativeMatrix:$b, + SPV_AnyCooperativeMatrix:$c + ); + + let results = (outs + SPV_AnyCooperativeMatrix:$result + ); + + let verifier = [{ return success(); }]; +} + +// ----- + +def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> { + let summary = "See extension SPV_NV_cooperative_matrix"; + + let description = [{ + Store a cooperative matrix through a pointer. + + Pointer is a pointer into an array. Its type must be an OpTypePointer whose + Type operand is a scalar or vector type. The storage class of Pointer must + be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is + supported) PhysicalStorageBufferEXT. + + Object is the object to store. Its type must be an + OpTypeCooperativeMatrixNV. + + Stride is the number of elements in the array in memory between the first + component of consecutive rows (or columns) in the result. It must be a + scalar integer type. + + ColumnMajor indicates whether the values stored to memory are arranged in + column-major or row-major order. It must be a boolean constant instruction, + with false indicating row major and true indicating column major. + + Memory Access must be a Memory Access literal. If not present, it is the + same as specifying None. + + ``` {.ebnf} + coop-matrix-store-op ::= `spv.CooperativeMatrixStoreNV ` + storage-class ssa-use `, ` ssa-use `, ` + ssa-use `, ` ssa-use `, ` + (`[` memory-access `]`)? `:` spirv-element-type + ``` + + For example: + + ``` + spv.CooperativeMatrixStoreNV "StorageBuffer" %arg0, %arg2, %arg1, %arg3 : + !spv.coopmatrix + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_NV_cooperative_matrix]>, + Capability<[SPV_C_CooperativeMatrixNV]> + ]; + + let arguments = (ins + SPV_AnyPtr:$pointer, + SPV_AnyCooperativeMatrix:$object, + SPV_Integer:$stride, + SPV_Bool:$columnmajor, + OptionalAttr:$memory_access + ); + + let results = (outs); + + let verifier = [{ return success(); }]; +} + +// ----- + #endif // SPIRV_COOPERATIVE_MATRIX_OPS Index: google3/third_party/llvm/llvm-project/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp =================================================================== --- google3/third_party/llvm/llvm-project/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ google3/third_party/llvm/llvm-project/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -213,6 +213,13 @@ resultType = resultType.cast().getElementType(); } + if (auto coopMatrixType = + operandType.dyn_cast()) { + operandType = coopMatrixType.getElementType(); + resultType = + resultType.cast().getElementType(); + } + auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth(); auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; @@ -2662,6 +2669,92 @@ printer << " : " << M.getType(); } +//===----------------------------------------------------------------------===// +// spv.CooperativeMatrixStoreNV +//===----------------------------------------------------------------------===// + +static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser, + OperationState &state) { + spirv::StorageClass storageClass; + SmallVector operandInfo; + Type strideType = parser.getBuilder().getIntegerType(32); + Type columnMajorType = parser.getBuilder().getIntegerType(1); + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || + parser.parseOperandList(operandInfo, 4) || + parseMemoryAccessAttributes(parser, state) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get( + elementType.cast().getElementType(), + storageClass); + SmallVector OperandType = {ptrType, elementType, strideType, + columnMajorType}; + if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(), + state.operands)) { + return failure(); + } + + return success(); +} + +static void print(spirv::CooperativeMatrixStoreNVOp M, OpAsmPrinter &printer) { + StringRef sc = stringifyStorageClass( + M.pointer().getType().cast().getStorageClass()); + printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " \"" + << sc << "\" " << M.pointer() << ", " << M.object() << ", " + << M.stride() << ", " << M.columnmajor(); + // Print optional memory access attribute. + if (auto memAccess = M.memory_access()) + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; + printer << " : " << M.getOperand(1).getType(); +} + +//===----------------------------------------------------------------------===// +// spv.CooperativeMatrixLengthNV +//===----------------------------------------------------------------------===// + +static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType operandInfo; + Type dstType = parser.getBuilder().getIntegerType(32); + Type type; + if (parser.parseColonType(type)) { + return failure(); + } + state.addAttribute(kTypeAttrName, TypeAttr::get(type)); + state.addTypes(dstType); + return success(); +} + +static void print(spirv::CooperativeMatrixLengthNVOp M, OpAsmPrinter &printer) { + printer << M.getOperationName() << " : " << M.type(); +} + +//===----------------------------------------------------------------------===// +// spv.CooperativeMatrixMulAddNV +//===----------------------------------------------------------------------===// + +static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser, + OperationState &state) { + SmallVector ops; + Type type; + if (parser.parseOperandList(ops, 3) || parser.parseColonType(type) || + parser.resolveOperands(ops, type, state.operands)) { + return failure(); + } + state.addTypes(type); + return success(); +} + +static void print(spirv::CooperativeMatrixMulAddNVOp M, OpAsmPrinter &printer) { + printer << M.getOperationName() << ' ' << M.getOperand(0) << ", " + << M.getOperand(1) << ", " << M.getOperand(2) << ", " + << " : " << M.getOperand(0).getType(); +} + namespace mlir { namespace spirv { Index: google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir =================================================================== --- google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir +++ google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -14,4 +14,81 @@ %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> spv.Return } + + // CHECK-LABEL: @cooperative_matrix_store + spv.func @cooperative_matrix_store(%ptr : !spv.ptr, %stride : i32, %m : !spv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" { + // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup> + spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_store_memaccess + spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { + // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_length + spv.func @cooperative_matrix_length() -> i32 "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup> + spv.ReturnValue %0 : i32 + } + + // CHECK-LABEL: @cooperative_matrix_muladd + spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>, %c : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_add + spv.func @cooperative_matrix_add(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.IAdd %a, %b : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_sub + spv.func @cooperative_matrix_sub(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.ISub %a, %b : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_sdiv + spv.func @cooperative_matrix_sdiv(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.SDiv %a, %b : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_udiv + spv.func @cooperative_matrix_udiv(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.UDiv %a, %b : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_fadd + spv.func @cooperative_matrix_fadd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %r = spv.FAdd %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_fsub + spv.func @cooperative_matrix_fsub(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %r = spv.FSub %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_fdiv + spv.func @cooperative_matrix_fdiv(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> + spv.Return + } } Index: google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir =================================================================== --- google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -14,3 +14,80 @@ %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> spv.Return } + +// CHECK-LABEL: @cooperative_matrix_store +spv.func @cooperative_matrix_store(%ptr : !spv.ptr, %stride : i32, %m : !spv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { + // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Workgroup> + spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<8x16xi32, Workgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_store_memaccess +spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { + // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_length +spv.func @cooperative_matrix_length() -> i32 "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup> + spv.ReturnValue %0 : i32 +} + +// CHECK-LABEL: @cooperative_matrix_muladd +spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>, %c : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_add +spv.func @cooperative_matrix_add(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.IAdd %a, %b : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_sub +spv.func @cooperative_matrix_sub(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.ISub %a, %b : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_sdiv +spv.func @cooperative_matrix_sdiv(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.SDiv %a, %b : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_udiv +spv.func @cooperative_matrix_udiv(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup> + %r = spv.UDiv %a, %b : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_fadd +spv.func @cooperative_matrix_fadd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %r = spv.FAdd %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_fsub +spv.func @cooperative_matrix_fsub(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %r = spv.FSub %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_fdiv +spv.func @cooperative_matrix_fdiv(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" { + // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> + spv.Return +} Index: google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/ops.mlir =================================================================== --- google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/ops.mlir +++ google3/third_party/llvm/llvm-project/mlir/test/Dialect/SPIRV/ops.mlir @@ -328,6 +328,14 @@ // ----- +func @convert_f_to_u_coopmatrix(%arg0 : !spv.coopmatrix<8x16xf32, Subgroup>) { + // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> to !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.ConvertFToU %arg0 : !spv.coopmatrix<8x16xf32, Subgroup> to !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// ----- + func @convert_f_to_u_scalar_invalid(%arg0 : f16) -> i32 { // expected-error @+1 {{expected the same bit widths for operand type and result type, but provided 'f16' and 'i32'}} %0 = spv.ConvertFToU %arg0 : f16 to i32 @@ -380,6 +388,14 @@ // ----- +func @f_convert_coop_matrix(%arg0 : !spv.coopmatrix<8x16xf32, Subgroup>) { + // CHECK: {{%.*}} = spv.FConvert {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> to !spv.coopmatrix<8x16xf64, Subgroup> + %0 = spv.FConvert %arg0 : !spv.coopmatrix<8x16xf32, Subgroup> to !spv.coopmatrix<8x16xf64, Subgroup> + spv.Return +} + +// ----- + func @f_convert_vector(%arg0 : f32) -> f32 { // expected-error @+1 {{expected the different bit widths for operand type and result type, but provided 'f32' and 'f32'}} %0 = spv.FConvert %arg0 : f32 to f32 Index: google3/third_party/llvm/llvm-project/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp =================================================================== --- google3/third_party/llvm/llvm-project/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ google3/third_party/llvm/llvm-project/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -508,6 +508,11 @@ << formatv(" {0}.push_back(static_cast(" "attr.cast().getValue().getZExtValue()));\n", operandList); + } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") { + os << tabs + << formatv(" {0}.push_back(static_cast(" + "getTypeID(attr.cast().getValue())));\n", + operandList); } else { PrintFatalError( loc, @@ -769,6 +774,11 @@ << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "opBuilder.getI32IntegerAttr({2}[{3}++])));\n", attrList, attrName, words, wordIndex); + } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") { + os << tabs + << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " + "TypeAttr::get(getType({2}[{3}++]))));\n", + attrList, attrName, words, wordIndex); } else { PrintFatalError( loc, llvm::Twine(