diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1356,37 +1356,25 @@ /// Resolve a list of operands to SSA values, emitting an error on failure, or /// appending the results to the list on success. This method should be used /// when all operands have the same type. - ParseResult resolveOperands(ArrayRef operands, Type type, + template > + ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl &result) { - for (auto elt : operands) - if (resolveOperand(elt, type, result)) + for (const UnresolvedOperand &operand : operands) + if (resolveOperand(operand, type, result)) return failure(); return success(); } + template > + ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, + SmallVectorImpl &result) { + return resolveOperands(std::forward(operands), type, result); + } /// Resolve a list of operands and a list of operand types to SSA values, /// emitting an error and returning failure, or appending the results /// to the list on success. - ParseResult resolveOperands(ArrayRef operands, - ArrayRef types, SMLoc loc, - SmallVectorImpl &result) { - if (operands.size() != types.size()) - return emitError(loc) - << operands.size() << " operands present, but expected " - << types.size(); - - for (unsigned i = 0, e = operands.size(); i != e; ++i) - if (resolveOperand(operands[i], types[i], result)) - return failure(); - return success(); - } - template - ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, - SmallVectorImpl &result) { - return resolveOperands(std::forward(operands), - ArrayRef(type), loc, result); - } - template + template , + typename Types = ArrayRef> std::enable_if_t::value, ParseResult> resolveOperands(Operands &&operands, Types &&types, SMLoc loc, SmallVectorImpl &result) { @@ -1396,8 +1384,8 @@ return emitError(loc) << operandSize << " operands present, but expected " << typeSize; - for (auto it : llvm::zip(operands, types)) - if (resolveOperand(std::get<0>(it), std::get<1>(it), result)) + for (auto [operand, type] : llvm::zip(operands, types)) + if (resolveOperand(operand, type, result)) return failure(); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2128,7 +2128,7 @@ [SameOperandsAndResultType]> { let arguments = (ins Variadic:$args); let results = (outs AnyType:$result); - let assemblyFormat = "$args attr-dict `:` type($result)"; + let assemblyFormat = "operands attr-dict `:` type($result)"; } def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> { diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -165,12 +165,12 @@ // Check that we can infer type equalities from certain traits. def ZCoverageValidD : TestFormat_Op<[{ operands type($result) attr-dict -}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>, +}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef)>, Results<(outs AnyMemRef:$result)>; def ZCoverageValidE : TestFormat_Op<[{ $operand type($operand) attr-dict }], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>, - Results<(outs AnyMemRef:$result)>; + Results<(outs AnyMemRef)>; def ZCoverageValidF : TestFormat_Op<[{ operands type($other) attr-dict }], [SameTypeOperands]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1389,6 +1389,8 @@ body << tgfmt(*tform, &fmtContext); } else { body << var->name << "Types"; + if (!var->isVariadic()) + body << "[0]"; } } else if (const NamedAttribute *attr = resolver.getAttribute()) { if (Optional tform = resolver.getVarTransformer()) @@ -1477,8 +1479,8 @@ emitTypeResolver(operandTypes.front(), op.getOperand(0).name); } - body << ", allOperandLoc, result.operands))\n" - << " return ::mlir::failure();\n"; + body << ", allOperandLoc, result.operands))\n return " + "::mlir::failure();\n"; return; } @@ -1492,25 +1494,8 @@ TypeResolution &operandType = operandTypes[i]; emitTypeResolver(operandType, operand.name); - // If the type is resolved by a non-variadic variable, index into the - // resolved type list. This allows for resolving the types of a variadic - // operand list from a non-variadic variable. - bool verifyOperandAndTypeSize = true; - if (auto *resolverVar = operandType.getVariable()) { - if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) { - body << "[0]"; - verifyOperandAndTypeSize = false; - } - } else { - verifyOperandAndTypeSize = !operandType.getBuilderIdx(); - } - - // Check to see if the sizes between the types and operands must match. If - // they do, provide the operand location to select the proper resolution - // overload. - if (verifyOperandAndTypeSize) - body << ", " << operand.name << "OperandsLoc"; - body << ", result.operands))\n return ::mlir::failure();\n"; + body << ", " << operand.name + << "OperandsLoc, result.operands))\n return ::mlir::failure();\n"; } } @@ -2671,11 +2656,11 @@ // Set the resolvers for each operand and result. for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) - if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty()) + if (!seenOperandTypes.test(i)) variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None}; if (includeResults) { for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) - if (!seenResultTypes.test(i) && !op.getResultName(i).empty()) + if (!seenResultTypes.test(i)) variableTyResolver[op.getResultName(i)] = {resolver, llvm::None}; } }