diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py --- a/mlir/test/Bindings/Python/dialects.py +++ b/mlir/test/Bindings/Python/dialects.py @@ -82,11 +82,11 @@ # Create via dialects context collection. input1 = createInput() input2 = createInput() - op1 = ctx.dialects.std.AddFOp(input1.type, input1, input2) + op1 = ctx.dialects.std.AddFOp(input1, input2, result=input1.type) # Create via an import from mlir.dialects.std import AddFOp - AddFOp(input1.type, input1, op1.result) + AddFOp(input1, op1.result, input1.type) # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -20,7 +20,7 @@ // CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,] def AttrSizedOperandsOp : TestOp<"attr_sized_operands", [AttrSizedOperandSegments]> { - // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): + // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -61,7 +61,7 @@ // CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,] def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { - // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): + // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -103,7 +103,7 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOp : TestOp<"attributed_op"> { - // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None): + // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -143,7 +143,7 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { - // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None): + // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -173,7 +173,7 @@ // CHECK: class EmptyOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.empty" def EmptyOp : TestOp<"empty">; - // CHECK: def __init__(self, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -185,16 +185,16 @@ // CHECK: class MissingNamesOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.missing_names" def MissingNamesOp : TestOp<"missing_names"> { - // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, loc=None, ip=None): + // CHECK: def __init__(self, _gen_arg_0, f32, _gen_arg_2, i32, _gen_res_1, i64, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: results.append(i32) - // CHECK: results.append(_gen_res_1) - // CHECK: results.append(i64) // CHECK: operands.append(_gen_arg_0) // CHECK: operands.append(f32) // CHECK: operands.append(_gen_arg_2) + // CHECK: results.append(i32) + // CHECK: results.append(_gen_res_1) + // CHECK: results.append(i64) // CHECK: super().__init__(self._ods_build_default( // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) @@ -220,7 +220,7 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { - // CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None): + // CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -247,7 +247,7 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicResultOp : TestOp<"one_variadic_result"> { - // CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None): + // CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -273,7 +273,7 @@ // CHECK: class PythonKeywordOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.python_keyword" def PythonKeywordOp : TestOp<"python_keyword"> { - // CHECK: def __init__(self, in_, loc=None, ip=None): + // CHECK: def __init__(self, in_, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -338,14 +338,14 @@ // CHECK: class SimpleOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.simple" def SimpleOp : TestOp<"simple"> { - // CHECK: def __init__(self, i64, f64, i32, f32, loc=None, ip=None): + // CHECK: def __init__(self, i32, f32, i64, f64, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: results.append(i64) - // CHECK: results.append(f64) // CHECK: operands.append(i32) // CHECK: operands.append(f32) + // CHECK: results.append(i64) + // CHECK: results.append(f64) // CHECK: super().__init__(self._ods_build_default( // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) @@ -368,3 +368,11 @@ // CHECK: return self.operation.results[1] let results = (outs I64:$i64, F64:$f64); } + +// CHECK: @_ods_cext.register_operation(_Dialect) +// CHECK: class UnnamedUnaryResultOp(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.unnamed_unary_result" +def UnnamedUnaryResultOp : TestOp<"unnamed_unary_result"> { + // CHECK: def __init__(self, result, *, loc=None, ip=None): + let results = (outs I32); +} diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -277,7 +277,7 @@ static bool isODSReserved(StringRef str) { static llvm::StringSet<> reserved( {"attributes", "create", "context", "ip", "operands", "print", "get_asm", - "loc", "verify", "regions", "result", "results", "self", "operation", + "loc", "verify", "regions", "results", "self", "operation", "DIALECT_NAMESPACE", "OPERATION_NAME"}); return str.startswith("_ods_") || str.endswith("_ods") || reserved.contains(str); @@ -525,14 +525,8 @@ static void populateBuilderArgs(const Operator &op, llvm::SmallVectorImpl &builderArgs, - llvm::SmallVectorImpl &operandNames) { - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - std::string name = op.getResultName(i).str(); - if (name.empty()) - name = llvm::formatv("_gen_res_{0}", i); - name = sanitizeName(name); - builderArgs.push_back(name); - } + llvm::SmallVectorImpl &operandNames, + llvm::SmallVectorImpl &resultNames) { for (int i = 0, e = op.getNumArgs(); i < e; ++i) { std::string name = op.getArgName(i).str(); if (name.empty()) @@ -542,6 +536,21 @@ if (!op.getArg(i).is()) operandNames.push_back(name); } + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + std::string name = op.getResultName(i).str(); + if (name.empty()) { + if (op.getNumResults() == 1) { + // Special case for one result, make the default name be 'result' + // to properly match the built-in result accessor. + name = "result"; + } else { + name = llvm::formatv("_gen_res_{0}", i); + } + } + name = sanitizeName(name); + builderArgs.push_back(name); + resultNames.push_back(name); + } } /// Populates `builderLines` with additional lines that are required in the @@ -624,19 +633,19 @@ llvm::SmallVector builderArgs; llvm::SmallVector builderLines; llvm::SmallVector operandArgNames; + llvm::SmallVector resultNames; builderArgs.reserve(op.getNumOperands() + op.getNumResults() + op.getNumNativeAttributes()); - populateBuilderArgs(op, builderArgs, operandArgNames); - populateBuilderLines( - op, "result", - llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), - builderLines, getNumResults, getResult); + populateBuilderArgs(op, builderArgs, operandArgNames, resultNames); populateBuilderLines(op, "operand", operandArgNames, builderLines, getNumOperands, getOperand); + populateBuilderLines(op, "result", resultNames, builderLines, getNumResults, + getResult); populateBuilderLinesAttr( - op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), + op, llvm::makeArrayRef(builderArgs).drop_back(resultNames.size()), builderLines); + builderArgs.push_back("*"); builderArgs.push_back("loc=None"); builderArgs.push_back("ip=None"); os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),