diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -42,8 +42,7 @@ Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { string llvmBuilder = "$res = createIntrinsicCall(builder," # "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");"; - let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }]; - let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; + let assemblyFormat = "attr-dict `:` type($res)"; } //===----------------------------------------------------------------------===// @@ -77,8 +76,7 @@ string llvmBuilder = [{ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0); }]; - let parser = [{ return success(); }]; - let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; + let assemblyFormat = "attr-dict"; } def NVVM_ShflBflyOp : @@ -129,8 +127,7 @@ $res = createIntrinsicCall( builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args); }]; - let parser = [{ return parseNVVMMmaOp(parser, result); }]; - let printer = [{ printNVVMMmaOp(p, *this); }]; + let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -42,8 +42,7 @@ Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { string llvmBuilder = "$res = createIntrinsicCall(builder," # "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");"; - let parser = [{ return parseROCDLOp(parser, result); }]; - let printer = [{ printROCDLOp(p, this->getOperation()); }]; + let assemblyFormat = "attr-dict `:` type($res)"; } class ROCDL_DeviceFunctionOp, Arguments<(ins)> { string llvmBuilder = "$res = createDeviceFunctionCall(builder, \"" # device_function # "\", " # parameter # ");"; - let parser = [{ return parseROCDLOp(parser, result); }]; - let printer = [{ printROCDLOp(p, this->getOperation()); }]; + let assemblyFormat = "attr-dict `:` type($res)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -304,6 +304,10 @@ return getAttrOfType("callee"); } }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; } def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> { @@ -651,6 +655,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let assemblyFormat = "$memref attr-dict `:` type($memref)"; } def DimOp : Std_Op<"dim", [NoSideEffect]> { @@ -987,6 +992,7 @@ }]>]; let hasFolder = 1; + let assemblyFormat = "operands attr-dict `:` type(operands)"; } def RemFOp : FloatArithmeticOp<"remf"> { diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -203,6 +203,7 @@ return vector().getType().cast(); } }]; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; } def Vector_ShuffleOp : @@ -364,6 +365,10 @@ static StringRef getSizesAttrName() { return "sizes"; } static StringRef getStridesAttrName() { return "strides"; } }]; + let assemblyFormat = [{ + $vector `,` $sizes `,` $strides attr-dict `:` type($vector) `into` + type(results) + }]; } def Vector_InsertElementOp : @@ -482,6 +487,10 @@ static StringRef getSizesAttrName() { return "sizes"; } static StringRef getStridesAttrName() { return "strides"; } }]; + let assemblyFormat = [{ + $vectors `,` $sizes `,` $strides attr-dict `:` type($vectors) `into` + type(results) + }]; } def Vector_InsertStridedSliceOp : @@ -727,6 +736,7 @@ void getOffsets(SmallVectorImpl &results); }]; let hasCanonicalizer = 1; + let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; } def Vector_TransferReadOp : @@ -939,6 +949,10 @@ return memref().getType().cast(); } }]; + let assemblyFormat = [{ + $vector `,` $memref `[` $indices `]` attr-dict `:` type($vector) `,` + type($memref) + }]; } def Vector_TypeCastOp : @@ -1017,6 +1031,7 @@ let extraClassDeclaration = [{ static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; } }]; + let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)"; } def Vector_CreateMaskOp : @@ -1048,6 +1063,7 @@ }]; let hasCanonicalizer = 1; + let assemblyFormat = "$operands attr-dict `:` type(results)"; } def Vector_TupleOp : @@ -1148,6 +1164,7 @@ return source().getType(); } }]; + let assemblyFormat = "$source attr-dict `:` type($source)"; } #endif // VECTOR_OPS diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -312,7 +312,8 @@ def AnyInteger : Type()">, "integer">; // Index type. -def Index : Type()">, "index">; +def Index : Type()">, "index">, + BuildableType<"getIndexType()">; // Integer type of a specific width. class I diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -41,18 +41,6 @@ p << " : " << op->getResultTypes(); } -// ::= `llvm.nvvm.XYZ` : type -static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser &parser, - OperationState &result) { - Type type; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return failure(); - - result.addTypes(type); - return success(); -} - static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) { return parser.getBuilder() .getContext() @@ -103,41 +91,6 @@ parser.getNameLoc(), result.operands)); } -// ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...` -// : signature_type -static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) { - SmallVector ops; - Type type; - llvm::SMLoc typeLoc; - if (parser.parseOperandList(ops) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.getCurrentLocation(&typeLoc) || parser.parseType(type)) { - return failure(); - } - - auto signature = type.dyn_cast(); - if (!signature) { - return parser.emitError( - typeLoc, "expected the type to be the full list of input and output"); - } - - if (signature.getNumResults() != 1) { - return parser.emitError(typeLoc, "expected single result"); - } - - return failure(parser.addTypeToList(signature.getResult(0), result.types) || - parser.resolveOperands(ops, signature.getInputs(), - parser.getNameLoc(), result.operands)); -} - -static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) { - p << op.getOperationName() << " " << op.getOperands(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " - << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()), - op.getType(), op.getContext()); -} - static LogicalResult verify(MmaOp op) { auto dialect = op.getContext()->getRegisteredDialect(); auto f16Ty = LLVM::LLVMType::getHalfTy(dialect); diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -31,24 +31,6 @@ using namespace ROCDL; //===----------------------------------------------------------------------===// -// Printing/parsing for ROCDL ops -//===----------------------------------------------------------------------===// - -static void printROCDLOp(OpAsmPrinter &p, Operation *op) { - p << op->getName() << " " << op->getOperands(); - if (op->getNumResults() > 0) - p << " : " << op->getResultTypes(); -} - -// ::= `rocdl.XYZ` : type -static ParseResult parseROCDLOp(OpAsmParser &parser, OperationState &result) { - Type type; - return failure(parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.addTypeToList(type, result.types)); -} - -//===----------------------------------------------------------------------===// // ROCDLDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -446,29 +446,6 @@ // CallOp //===----------------------------------------------------------------------===// -static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { - FlatSymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector operands; - auto calleeLoc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, "callee", result.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(calleeType) || - parser.addTypesToList(calleeType.getResults(), result.types) || - parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - result.operands)) - return failure(); - - return success(); -} - -static void print(OpAsmPrinter &p, CallOp op) { - p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : " << op.getCalleeType(); -} - static LogicalResult verify(CallOp op) { // Check that the callee attribute was specified. auto fnAttr = op.getAttrOfType("callee"); @@ -1184,19 +1161,6 @@ }; } // end anonymous namespace. -static void print(OpAsmPrinter &p, DeallocOp op) { - p << "dealloc " << op.memref() << " : " << op.memref().getType(); -} - -static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType memrefInfo; - MemRefType type; - - return failure(parser.parseOperand(memrefInfo) || - parser.parseColonType(type) || - parser.resolveOperand(memrefInfo, type, result.operands)); -} - static LogicalResult verify(DeallocOp op) { if (!op.memref().getType().isa()) return op.emitOpError("operand must be a memref"); @@ -1844,20 +1808,6 @@ // RankOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, RankOp op) { - p << "rank " << op.getOperand() << " : " << op.getOperand().getType(); -} - -static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType operandInfo; - Type type; - Type indexType = parser.getBuilder().getIndexType(); - return failure(parser.parseOperand(operandInfo) || - parser.parseColonType(type) || - parser.resolveOperand(operandInfo, type, result.operands) || - parser.addTypeToList(indexType, result.types)); -} - OpFoldResult RankOp::fold(ArrayRef operands) { // Constant fold rank when the rank of the tensor is known. auto type = getOperand().getType(); diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -474,38 +474,6 @@ result.addAttribute(getStridesAttrName(), stridesAttr); } -static ParseResult parseExtractSlicesOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType operandInfo; - ArrayAttr sizesAttr; - StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName(); - ArrayAttr stridesAttr; - StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName(); - VectorType vectorType; - TupleType resultTupleType; - return failure( - parser.parseOperand(operandInfo) || parser.parseComma() || - parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) || - parser.parseComma() || - parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(vectorType) || - parser.parseKeywordType("into", resultTupleType) || - parser.resolveOperand(operandInfo, vectorType, result.operands) || - parser.addTypeToList(resultTupleType, result.types)); -} - -static void print(OpAsmPrinter &p, ExtractSlicesOp op) { - p << op.getOperationName() << ' ' << op.vector() << ", "; - p << op.sizes() << ", " << op.strides(); - p.printOptionalAttrDict( - op.getAttrs(), - /*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(), - ExtractSlicesOp::getStridesAttrName()}); - p << " : " << op.vector().getType(); - p << " into " << op.getResultTupleType(); -} - static LogicalResult isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType, TupleType tupleType, ArrayRef sizes, @@ -572,11 +540,6 @@ // BroadcastOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, BroadcastOp op) { - p << op.getOperationName() << " " << op.source() << " : " - << op.getSourceType() << " to " << op.getVectorType(); -} - static LogicalResult verify(BroadcastOp op) { VectorType srcVectorType = op.getSourceType().dyn_cast(); VectorType dstVectorType = op.getVectorType(); @@ -601,18 +564,6 @@ return success(); } -static ParseResult parseBroadcastOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType source; - Type sourceType; - VectorType vectorType; - return failure(parser.parseOperand(source) || - parser.parseColonType(sourceType) || - parser.parseKeywordType("to", vectorType) || - parser.resolveOperand(source, sourceType, result.operands) || - parser.addTypeToList(vectorType, result.types)); -} - //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// @@ -808,38 +759,6 @@ // InsertSlicesOp //===----------------------------------------------------------------------===// -static ParseResult parseInsertSlicesOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType operandInfo; - ArrayAttr sizesAttr; - StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName(); - ArrayAttr stridesAttr; - StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName(); - TupleType tupleType; - VectorType resultVectorType; - return failure( - parser.parseOperand(operandInfo) || parser.parseComma() || - parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) || - parser.parseComma() || - parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(tupleType) || - parser.parseKeywordType("into", resultVectorType) || - parser.resolveOperand(operandInfo, tupleType, result.operands) || - parser.addTypeToList(resultVectorType, result.types)); -} - -static void print(OpAsmPrinter &p, InsertSlicesOp op) { - p << op.getOperationName() << ' ' << op.vectors() << ", "; - p << op.sizes() << ", " << op.strides(); - p.printOptionalAttrDict( - op.getAttrs(), - /*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(), - InsertSlicesOp::getStridesAttrName()}); - p << " : " << op.vectors().getType(); - p << " into " << op.getResultVectorType(); -} - static LogicalResult verify(InsertSlicesOp op) { SmallVector sizes; op.getSizes(sizes); @@ -1231,27 +1150,6 @@ result.addAttribute(getStridesAttrName(), stridesAttr); } -static void print(OpAsmPrinter &p, StridedSliceOp op) { - p << op.getOperationName() << " " << op.vector(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.vector().getType() << " to " << op.getResult().getType(); -} - -static ParseResult parseStridedSliceOp(OpAsmParser &parser, - OperationState &result) { - llvm::SMLoc attributeLoc, typeLoc; - OpAsmParser::OperandType vector; - VectorType vectorType, resultVectorType; - return failure(parser.parseOperand(vector) || - parser.getCurrentLocation(&attributeLoc) || - parser.parseOptionalAttrDict(result.attributes) || - parser.getCurrentLocation(&typeLoc) || - parser.parseColonType(vectorType) || - parser.parseKeywordType("to", resultVectorType) || - parser.resolveOperand(vector, vectorType, result.operands) || - parser.addTypeToList(resultVectorType, result.types)); -} - static LogicalResult verify(StridedSliceOp op) { auto type = op.getVectorType(); auto offsets = op.offsets(); @@ -1519,35 +1417,6 @@ //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, TransferWriteOp op) { - p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "[" - << op.indices() << "]"; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getVectorType() << ", " << op.getMemRefType(); -} - -static ParseResult parseTransferWriteOp(OpAsmParser &parser, - OperationState &result) { - llvm::SMLoc typesLoc; - OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType memRefInfo; - SmallVector indexInfo; - SmallVector types; - if (parser.parseOperand(storeValueInfo) || parser.parseComma() || - parser.parseOperand(memRefInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(result.attributes) || - parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) - return failure(); - if (types.size() != 2) - return parser.emitError(typesLoc, "two types required"); - auto indexType = parser.getBuilder().getIndexType(); - Type vectorType = types[0], memRefType = types[1]; - return failure( - parser.resolveOperand(storeValueInfo, vectorType, result.operands) || - parser.resolveOperand(memRefInfo, memRefType, result.operands) || - parser.resolveOperands(indexInfo, indexType, result.operands)); -} static LogicalResult verify(TransferWriteOp op) { // Consistency of elemental types in memref and vector. @@ -1676,23 +1545,6 @@ // ConstantMaskOp //===----------------------------------------------------------------------===// -static ParseResult parseConstantMaskOp(OpAsmParser &parser, - OperationState &result) { - Type resultType; - ArrayAttr maskDimSizesAttr; - StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName(); - return failure( - parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) || - parser.parseColonType(resultType) || - parser.addTypeToList(resultType, result.types)); -} - -static void print(OpAsmPrinter &p, ConstantMaskOp op) { - p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : " - << op.getResult().getType(); -} - static LogicalResult verify(ConstantMaskOp &op) { // Verify that array attr size matches the rank of the vector result. auto resultType = op.getResult().getType().cast(); @@ -1724,23 +1576,6 @@ // CreateMaskOp //===----------------------------------------------------------------------===// -static ParseResult parseCreateMaskOp(OpAsmParser &parser, - OperationState &result) { - auto indexType = parser.getBuilder().getIndexType(); - Type resultType; - SmallVector operandInfo; - return failure( - parser.parseOperandList(operandInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType) || - parser.resolveOperands(operandInfo, indexType, result.operands) || - parser.addTypeToList(resultType, result.types)); -} - -static void print(OpAsmPrinter &p, CreateMaskOp op) { - p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType(); -} - static LogicalResult verify(CreateMaskOp op) { // Verify that an operand was specified for each result vector each dimension. if (op.getNumOperands() != @@ -1750,23 +1585,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// PrintOp -//===----------------------------------------------------------------------===// - -static ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType source; - Type sourceType; - return failure(parser.parseOperand(source) || - parser.parseColonType(sourceType) || - parser.resolveOperand(source, sourceType, result.operands)); -} - -static void print(OpAsmPrinter &p, PrintOp op) { - p << op.getOperationName() << ' ' << op.source() << " : " - << op.getPrintType(); -} - namespace { // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -366,7 +366,7 @@ %b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">, %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float, %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) { - // expected-error@+1 {{expected the type to be the full list of input and output}} + // expected-error@+1 {{invalid kind of type specified}} %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm<"{ float, float, float, float, float, float, float, float }"> llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }"> } @@ -378,9 +378,9 @@ %b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">, %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float, %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) { - // expected-error@+1 {{expected single result}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) - llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) + // expected-error@+1 {{op requires one result}} + %0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) + llvm.return %0#0 : !llvm<"{ float, float, float, float, float, float, float, float }"> } // -----