diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td @@ -39,6 +39,8 @@ ``` }]; + let assemblyFormat = "attr-dict `:` $type"; + let availability = [ MinVersion, MaxVersion, @@ -139,7 +141,7 @@ // ----- def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV", - [NoSideEffect]> { + [NoSideEffect, AllTypesMatch<["c", "result"]>]> { let summary = "See extension SPV_NV_cooperative_matrix"; let description = [{ @@ -188,6 +190,10 @@ ``` }]; + let assemblyFormat = [{ + operands attr-dict`:` type($a) `,` type($b) `->` type($c) + }]; + let availability = [ MinVersion, MaxVersion, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1134,12 +1134,11 @@ return compositeConstructOp.emitError( "has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); - } else { - if (constituents.size() != cType.getNumElements()) - return compositeConstructOp.emitError( - "has incorrect number of operands: expected ") - << cType.getNumElements() << ", but provided " - << constituents.size(); + } else if (constituents.size() != cType.getNumElements()) { + return compositeConstructOp.emitError( + "has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " + << constituents.size(); } for (auto index : llvm::seq(0, constituents.size())) { @@ -2736,56 +2735,9 @@ } //===----------------------------------------------------------------------===// -// spv.CooperativeMatrixLengthNV -//===----------------------------------------------------------------------===// - -static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operandInfo; - Type dstType = parser.getBuilder().getIntegerType(32); - Type type; - if (parser.parseColonType(type)) { - return failure(); - } - state.addAttribute(kTypeAttrName, TypeAttr::get(type)); - state.addTypes(dstType); - return success(); -} - -static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix, - OpAsmPrinter &printer) { - printer << coopMatrix.getOperationName() << " : " << coopMatrix.type(); -} - -//===----------------------------------------------------------------------===// // spv.CooperativeMatrixMulAddNV //===----------------------------------------------------------------------===// -static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser, - OperationState &state) { - SmallVector ops; - SmallVector types(3); - if (parser.parseOperandList(ops, 3) || parser.parseColon() || - parser.parseType(types[0]) || parser.parseComma() || - parser.parseType(types[1]) || parser.parseArrow() || - parser.parseType(types[2]) || - parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) { - return failure(); - } - state.addTypes(types[2]); - return success(); -} - -static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix, - OpAsmPrinter &printer) { - printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0) - << ", " << coopMatrix.getOperand(1) << ", " - << coopMatrix.getOperand(2) << ", " - << " : " << coopMatrix.getOperand(0).getType() << ", " - << coopMatrix.getOperand(1).getType() << " -> " - << coopMatrix.getOperand(2).getType(); -} - static LogicalResult verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { if (op.c().getType() != op.result().getType()) diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -38,7 +38,7 @@ // CHECK-LABEL: @cooperative_matrix_muladd spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> + // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> spv.Return } diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -38,7 +38,7 @@ // CHECK-LABEL: @cooperative_matrix_muladd spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> + // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> spv.Return }