diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py --- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py @@ -41,3 +41,14 @@ elements_per_group = total_variadic_length // n_variadic start = n_preceding_simple + n_preceding_variadic * elements_per_group return start, elements_per_group + +def _get_default_loc_context(location = None): + """ + Returns a context in which the defaulted location is created. If the location + is None, takes the current location from the stack, raises ValueError if there + is no location on the stack. + """ + if location is None: + # Location.current raises ValueError if there is no current location. + return _cext.ir.Location.current.context + return location.context diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -31,7 +31,7 @@ // CHECK: if variadic2 is not None: operands.append(variadic2) // CHECK: operand_segment_sizes.append(0 if variadic2 is None else 1) // CHECK: attributes["operand_segment_sizes"] = _ir.DenseElementsAttr.get(operand_segment_sizes, - // CHECK: context=Location.current.context if loc is None else loc.context) + // CHECK: context=_get_default_loc_context(loc)) // CHECK: super().__init__(_ir.Operation.create( // CHECK: "test.attr_sized_operands", attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) @@ -77,7 +77,7 @@ // CHECK: if variadic2 is not None: results.append(variadic2) // CHECK: result_segment_sizes.append(0 if variadic2 is None else 1) // CHECK: attributes["result_segment_sizes"] = _ir.DenseElementsAttr.get(result_segment_sizes, - // CHECK: context=Location.current.context if loc is None else loc.context) + // CHECK: context=_get_default_loc_context(loc)) // CHECK: super().__init__(_ir.Operation.create( // CHECK: "test.attr_sized_results", attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) @@ -118,7 +118,7 @@ // CHECK: attributes["i32attr"] = i32attr // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ir.UnitAttr.get( - // CHECK: _ir.Location.current.context if loc is None else loc.context) + // CHECK: _get_default_loc_context(loc)) // CHECK: attributes["in"] = in_ // CHECK: super().__init__(_ir.Operation.create( // CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results, @@ -156,7 +156,7 @@ // CHECK: operands.append(_gen_arg_0) // CHECK: operands.append(_gen_arg_2) // CHECK: if bool(in_): attributes["in"] = _ir.UnitAttr.get( - // CHECK: _ir.Location.current.context if loc is None else loc.context) + // CHECK: _get_default_loc_context(loc)) // CHECK: if is_ is not None: attributes["is"] = is_ // CHECK: super().__init__(_ir.Operation.create( // CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results, diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -28,7 +28,7 @@ import array from . import _cext -from . import _segmented_accessor, _equally_sized_accessor +from . import _segmented_accessor, _equally_sized_accessor, _get_default_loc_context _ir = _cext.ir )Py"; @@ -410,7 +410,7 @@ /// Template for attaching segment sizes to the attribute list. constexpr const char *segmentAttributeTemplate = R"Py(attributes["{0}_segment_sizes"] = _ir.DenseElementsAttr.get({0}_segment_sizes, - context=Location.current.context if loc is None else loc.context))Py"; + context=_get_default_loc_context(loc)))Py"; /// Template for appending the unit size to the segment sizes. /// {0} is either 'operand' or 'result'; @@ -443,7 +443,7 @@ constexpr const char *initUnitAttributeTemplate = R"Py(if bool({1}): attributes["{0}"] = _ir.UnitAttr.get( - _ir.Location.current.context if loc is None else loc.context))Py"; + _get_default_loc_context(loc)))Py"; /// Populates `builderArgs` with the Python-compatible names of builder function /// arguments, first the results, then the intermixed attributes and operands in