diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -24,13 +24,14 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: operands.append(_get_op_results_or_values(variadic1)) // CHECK: operands.append(_get_op_result_or_value(non_variadic)) // CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2)) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def variadic1(self): @@ -66,13 +67,14 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: if variadic1 is not None: results.append(variadic1) // CHECK: results.append(non_variadic) // CHECK: if variadic2 is not None: results.append(variadic2) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def variadic1(self): @@ -109,6 +111,7 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: attributes["i32attr"] = i32attr // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get( @@ -117,7 +120,7 @@ // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def i32attr(self): @@ -150,6 +153,7 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) // CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get( @@ -158,7 +162,7 @@ // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def in_(self): @@ -181,10 +185,11 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): @@ -194,6 +199,7 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: results.append(i32) // CHECK: results.append(_gen_res_1) // CHECK: results.append(i64) @@ -203,7 +209,7 @@ // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def f32(self): @@ -230,12 +236,13 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(non_variadic)) // CHECK: operands.extend(_get_op_results_or_values(variadic)) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def non_variadic(self): @@ -258,12 +265,13 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: results.extend(variadic) // CHECK: results.append(non_variadic) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def variadic(self): @@ -285,11 +293,12 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(in_)) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def in_(self): @@ -351,6 +360,7 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: regions = None // CHECK: results.append(i64) // CHECK: results.append(f64) // CHECK: operands.append(_get_op_result_or_value(i32)) @@ -358,7 +368,7 @@ // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: successors=_ods_successors, loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def i32(self): @@ -379,6 +389,50 @@ let results = (outs I64:$i64, F64:$f64); } +// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region" +def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { + // CHECK: def __init__(self, num_variadic, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: regions = None + // CHECK: _ods_successors = None + // CHECK: regions = 2 + num_variadic + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + let regions = (region AnyRegion:$region, AnyRegion, VariadicRegion:$variadic); + + // CHECK: @builtins.property + // CHECK: def region(): + // CHECK: return self.regions[0] + + // CHECK: @builtins.property + // CHECK: def variadic(): + // CHECK: return self.regions[2:] +} + +// CHECK: class VariadicRegionOp(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.variadic_region" +def VariadicRegionOp : TestOp<"variadic_region"> { + // CHECK: def __init__(self, num_variadic, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: regions = None + // CHECK: _ods_successors = None + // CHECK: regions = 0 + num_variadic + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + let regions = (region VariadicRegion:$Variadic); + + // CHECK: @builtins.property + // CHECK: def Variadic(): + // CHECK: return self.regions[0:] +} + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class WithSuccessorsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.with_successors" @@ -390,3 +444,4 @@ let successors = (successor AnySuccessor:$successor, VariadicSuccessor:$successors); } + diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -252,6 +252,12 @@ del self.operation.attributes["{1}"] )Py"; +constexpr const char *regionAccessorTemplate = R"PY( + @builtins.property + def {0}(): + return self.regions[{1}] +)PY"; + static llvm::cl::OptionCategory clOpPythonBindingCat("Options for -gen-python-op-bindings"); @@ -482,10 +488,11 @@ operands = [] results = [] attributes = {{} + regions = None {1} super().__init__(self.build_generic( attributes=attributes, results=results, operands=operands, - successors=_ods_successors, loc=loc, ip=ip)) + successors=_ods_successors, regions=regions, loc=loc, ip=ip)) )Py"; /// Template for appending a single element to the operand/result list. @@ -697,6 +704,30 @@ } } +/// If the operation has variadic regions, adds a builder argument to specify +/// the number of those regions and builder lines to forward it to the generic +/// constructor. +static void +populateBuilderRegions(const Operator &op, + llvm::SmallVectorImpl &builderArgs, + llvm::SmallVectorImpl &builderLines) { + if (op.hasNoVariadicRegions()) + return; + + // This is currently enforced when Operator is constructed. + assert(op.getNumVariadicRegions() == 1 && + op.getRegion(op.getNumRegions() - 1).isVariadic() && + "expected the last region to be varidic"); + + const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); + std::string name = + ("num_" + region.name.take_front().lower() + region.name.drop_front()) + .str(); + builderArgs.push_back(name); + builderLines.push_back( + llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); +} + /// Emits a default builder constructing an operation from the list of its /// result types, followed by a list of its operands. static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { @@ -720,6 +751,7 @@ op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), builderLines); populateBuilderLinesSuccessors(op, successorArgNames, builderLines); + populateBuilderRegions(op, builderArgs, builderLines); builderArgs.push_back("*"); builderArgs.push_back("loc=None"); @@ -767,6 +799,21 @@ op.hasNoVariadicRegions() ? "True" : "False"); } +/// Emits named accessors to regions. +static void emitRegionAccessors(const Operator &op, raw_ostream &os) { + for (auto en : llvm::enumerate(op.getRegions())) { + const NamedRegion ®ion = en.value(); + if (region.name.empty()) + continue; + + assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && + "expected only the last region to be variadic"); + os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), + std::to_string(en.index()) + + (region.isVariadic() ? ":" : "")); + } +} + /// Emits bindings for a specific Op to the given output stream. static void emitOpBindings(const Operator &op, const AttributeClasses &attributeClasses, @@ -787,6 +834,7 @@ emitOperandAccessors(op, os); emitAttributeAccessors(op, attributeClasses, os); emitResultAccessors(op, os); + emitRegionAccessors(op, os); } /// Emits bindings for the dialect specified in the command line, including file