diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -179,6 +179,25 @@ let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr:$is); } +// CHECK: @_ods_cext.register_operation(_Dialect) +// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs" +def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> { + // CHECK: def __init__(self, *, arr=None, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: regions = None + // CHECK: attributes["arr"] = arr if arr != None else _ods_ir.ArrayAttr.get([]) + // CHECK: _ods_successors = None + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + + let arguments = (ins DefaultValuedAttr:$arr); + let results = (outs); +} + // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { // CHECK: def __init__(self, type, *, loc=None, ip=None): @@ -544,4 +563,3 @@ let successors = (successor AnySuccessor:$successor, VariadicSuccessor:$successors); } - diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -542,6 +542,14 @@ constexpr const char *initOptionalAttributeTemplate = R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; +/// Template for setting an attribute with a default value in the operation +/// builder. +/// {0} is the attribute name; +/// {1} is the builder argument name; +/// {2} is the default value. +constexpr const char *initDefaultValuedAttributeTemplate = + R"Py(attributes["{0}"] = {1} if {1} != None else {2})Py"; + constexpr const char *initUnitAttributeTemplate = R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( _ods_get_default_loc_context(loc)))Py"; @@ -669,6 +677,25 @@ continue; } + // Attributes with default value are handled specially. + if (attribute->attr.hasDefaultValue()) { + // TODO: In case we cannot generate Python code for the default value, + // the attribute must be specified by the user. `None` will fail during + // runtime. + std::string defaultValPy = "None"; + llvm::StringRef storageType = attribute->attr.getStorageType().trim(); + llvm::StringRef defaultValCpp = attribute->attr.getDefaultValue().trim(); + + // A list of commonly used attribute types and default values for which + // we can generate Python code. Extend as needed. + if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}")) + defaultValPy = "_ods_ir.ArrayAttr.get([])"; + builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate, + attribute->name, argNames[i], + defaultValPy)); + continue; + } + builderLines.push_back(llvm::formatv(attribute->attr.isOptional() ? initOptionalAttributeTemplate : initAttributeTemplate,