diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -179,6 +179,27 @@ let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr:$is); } +// CHECK: @_ods_cext.register_operation(_Dialect) +// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs" +def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> { + // CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: regions = None + // CHECK: attributes["arr"] = arr if arr is not None else _ods_ir.ArrayAttr.get([]) + // CHECK: unsupported is not None, "attribute unsupported must be specified" + // CHECK: _ods_successors = None + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + + let arguments = (ins DefaultValuedAttr:$arr, + DefaultValuedAttr:$unsupported); + let results = (outs); +} + // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { // CHECK: def __init__(self, type, *, loc=None, ip=None): @@ -544,4 +565,3 @@ let successors = (successor AnySuccessor:$successor, VariadicSuccessor:$successors); } - diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringSet.h" @@ -542,6 +543,21 @@ constexpr const char *initOptionalAttributeTemplate = R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; +/// Template for setting an attribute with a default value in the operation +/// builder. +/// {0} is the attribute name; +/// {1} is the builder argument name; +/// {2} is the default value. +constexpr const char *initDefaultValuedAttributeTemplate = + R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py"; + +/// Template for asserting that an attribute value was provided when calling a +/// builder. +/// {0} is the attribute name; +/// {1} is the builder argument name. +constexpr const char *assertAttributeValueSpecified = + R"Py(assert {1} is not None, "attribute {0} must be specified")Py"; + constexpr const char *initUnitAttributeTemplate = R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( _ods_get_default_loc_context(loc)))Py"; @@ -647,6 +663,21 @@ } } +/// Generates Python code for the default value of the given attribute. +static FailureOr getAttributeDefaultValue(Attribute attr) { + assert(attr.hasDefaultValue() && "expected attribute with default value"); + StringRef storageType = attr.getStorageType().trim(); + StringRef defaultValCpp = attr.getDefaultValue().trim(); + + // A list of commonly used attribute types and default values for which + // we can generate Python code. Extend as needed. + if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}")) + return std::string("_ods_ir.ArrayAttr.get([])"); + + // No match: Cannot generate Python code. + return failure(); +} + /// Populates `builderLines` with additional lines that are required in the /// builder to set up operation attributes. `argNames` is expected to contain /// the names of builder arguments that correspond to op arguments, i.e. to the @@ -669,6 +700,25 @@ continue; } + // Attributes with default value are handled specially. + if (attribute->attr.hasDefaultValue()) { + // In case we cannot generate Python code for the default value, the + // attribute must be specified by the user. + FailureOr defaultValPy = + getAttributeDefaultValue(attribute->attr); + if (succeeded(defaultValPy)) { + builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate, + attribute->name, argNames[i], + *defaultValPy)); + } else { + builderLines.push_back(llvm::formatv(assertAttributeValueSpecified, + attribute->name, argNames[i])); + builderLines.push_back( + llvm::formatv(initAttributeTemplate, attribute->name, argNames[i])); + } + continue; + } + builderLines.push_back(llvm::formatv(attribute->attr.isOptional() ? initOptionalAttributeTemplate : initAttributeTemplate,