diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -743,6 +743,34 @@ dictionary = DictAttr.get({"array": array, "unit": UnitAttr.get()}) ``` +Custom builders for Attributes to be used during Operation creation can be +registered by way of the `register_attribute_builder`. In particular the +following is how a custom builder is registered for `I32Attr`: + +```python +@register_attribute_builder("I32Attr") +def _i32Attr(x: int, context: Context): + return IntegerAttr.get( + IntegerType.get_signless(32, context=context), x) +``` + +This allows to invoke op creation of an op with a `I32Attr` with + +```python +foo.Op(30) +``` + +The registration is based on the ODS name but registry is via pure python +method. Only single custom builder is allowed to be registered per ODS attribute +type (e.g., I32Attr can have only one, which can correspond to multiple of the +underlying IntegerAttr type). + +instead of + +```python +foo.Op(IntegerAttr.get(IndexType.get_signless(32, context=context), 30)) +``` + ## Style In general, for the core parts of MLIR, the Python bindings should be largely diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -58,6 +58,12 @@ /// have a DIALECT_NAMESPACE attribute. pybind11::object registerDialectDecorator(pybind11::object pyClass); + /// Adds a user-friendly Attribute builder. + /// Raises an exception if the mapping already exists. + /// This is intended to be called by implementation code. + void registerAttributeBuilder(const std::string &attributeKind, + pybind11::function pyFunc); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -71,6 +77,10 @@ pybind11::object pyClass, pybind11::object rawOpViewClass); + /// Returns the custom Attribute builder for Attribute kind. + std::optional + lookupAttributeBuilder(const std::string &attributeKind); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. llvm::Optional @@ -92,6 +102,8 @@ /// Map of operation name to custom subclass that directly initializes /// the OpView base class (bypassing the user class constructor). llvm::StringMap rawOpViewClassMap; + /// Map of attribute ODS name to custom builder. + llvm::StringMap attributeBuilderMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -194,6 +194,29 @@ } }; +struct PyAttrBuilderMap { + static bool dunderContains(const std::string &attributeKind) { + return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); + } + static py::function dundeGetItemNamed(const std::string &attributeKind) { + auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); + if (!builder) + throw py::key_error(); + return *builder; + } + static void dundeSetItemNamed(const std::string &attributeKind, + py::function func) { + PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func)); + } + + static void bind(py::module &m) { + py::class_(m, "AttrBuilder", py::module_local()) + .def_static("contains", &PyAttrBuilderMap::dunderContains) + .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) + .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed); + } +}; + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -3283,4 +3306,7 @@ // Debug bindings. PyGlobalDebugFlag::bind(m); + + // Attribute builder getter. + PyAttrBuilderMap::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -60,6 +60,17 @@ loadedDialectModulesCache.insert(dialectNamespace); } +void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, + py::function pyFunc) { + py::function &found = attributeBuilderMap[attributeKind]; + if (found) { + throw std::runtime_error((llvm::Twine("Attribute builder for '") + + attributeKind + "' is already registered") + .str()); + } + found = std::move(pyFunc); +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -84,6 +95,22 @@ rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } +std::optional +PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { + // Fast match against the class map first (common case). + const auto foundIt = attributeBuilderMap.find(attributeKind); + if (foundIt != attributeBuilderMap.end()) { + if (foundIt->second.is_none()) + return std::nullopt; + assert(foundIt->second && "py::function is defined"); + return foundIt->second; + } + + // Not found and loading did not yield a registration. Negative cache. + attributeBuilderMap[attributeKind] = py::none(); + return std::nullopt; +} + llvm::Optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { loadDialectModule(dialectNamespace); diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,3 +4,44 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug + + +# Convenience decorator for registering user-friendly Attribute builders. +def register_attribute_builder(kind): + def decorator_builder(func): + AttrBuilder.insert(kind, func) + return func + return decorator_builder + + +@register_attribute_builder("BoolAttr") +def _boolAttr(x: bool, context: Context): + return BoolAttr.get(x, context=context) + +@register_attribute_builder("IndexAttr") +def _indexAttr(x: int, context: Context): + return IntegerAttr.get(IndexType.get(context=context), x) + +@register_attribute_builder("I32Attr") +def _i32Attr(x: int, context: Context): + return IntegerAttr.get( + IntegerType.get_signless(32, context=context), x) + +@register_attribute_builder("I64Attr") +def _i64Attr(x: int, context: Context): + return IntegerAttr.get( + IntegerType.get_signless(64, context=context), x) + +@register_attribute_builder("SymbolNameAttr") +def _symbolNameAttr(x: str, context: Context): + return StringAttr.get(x, context=context) + +try: + import numpy as np + @register_attribute_builder("IndexElementsAttr") + def _indexElementsAttr(x: list[int], context: Context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), type=IndexType.get(context=context), + context=context) +except ImportError: + pass diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -115,11 +115,14 @@ // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: attributes["i32attr"] = i32attr - // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr + // CHECK: attributes["i32attr"] = (i32attr if ( + // CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or + // CHECK-NEXT: not _ods_ir.AttrBuilder.contains('I32Attr') + // CHECK-NEXT: _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context) + // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) - // CHECK: attributes["in"] = in_ + // CHECK: attributes["in"] = (in_ // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -161,7 +164,7 @@ // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) // CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) - // CHECK: if is_ is not None: attributes["is"] = is_ + // CHECK: if is_ is not None: attributes["is"] = (is_ // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -188,8 +191,8 @@ // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: if arr is not None: attributes["arr"] = arr - // CHECK: if unsupported is not None: attributes["unsupported"] = unsupported + // CHECK: if arr is not None: attributes["arr"] = (arr + // CHECK: if unsupported is not None: attributes["unsupported"] = (unsupported // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -202,7 +205,7 @@ // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { - // CHECK: def __init__(self, type, *, loc=None, ip=None): + // CHECK: def __init__(self, type_, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: _ods_result_type_source_attr = attributes["type"] @@ -217,7 +220,7 @@ // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { - // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None): + // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None): let arguments = (ins TypeAttr:$type); let results = (outs AnyType:$res, Variadic); } diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py --- a/mlir/test/python/dialects/shape.py +++ b/mlir/test/python/dialects/shape.py @@ -22,9 +22,18 @@ @func.FuncOp.from_py_func( RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)) def const_shape_tensor(arg): + shape.ConstWitnessOp(False) + shape.ConstSizeOp(30) + shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40)) + shape.ConstShapeOp([1, 2]) return shape.ConstShapeOp( - DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get())) + DenseElementsAttr.get( + np.array([3, 4], dtype=np.int64), type=IndexType.get())) # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) - # CHECK: shape.const_shape [10, 20] : tensor<2xindex> + # CHECK-DAG: shape.const_witness false + # CHECK-DAG: shape.const_size 30 + # CHECK-DAG: shape.const_size 40 + # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex> + # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex> print(module) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -280,15 +280,16 @@ using AttributeClasses = DenseMap; -/// Checks whether `str` is a Python keyword. -static bool isPythonKeyword(StringRef str) { - static llvm::StringSet<> keywords( - {"and", "as", "assert", "break", "class", "continue", - "def", "del", "elif", "else", "except", "finally", - "for", "from", "global", "if", "import", "in", - "is", "lambda", "nonlocal", "not", "or", "pass", - "raise", "return", "try", "while", "with", "yield"}); - return keywords.contains(str); +/// Checks whether `str` is a Python keyword or would shadow builtin function. +static bool isPythonReserved(StringRef str) { + static llvm::StringSet<> reserved( + {"and", "as", "assert", "break", "callable", "class", + "continue", "def", "del", "elif", "else", "except", + "finally", "for", "from", "global", "if", "import", + "in", "is", "lambda", "nonlocal", "not", "or", + "pass", "raise", "return", "issubclass", "try", "type", + "while", "with", "yield"}); + return reserved.contains(str); } /// Checks whether `str` would shadow a generated variable or attribute @@ -306,7 +307,7 @@ /// (does not change the `name` if it already is suitable) and returns the /// modified version. static std::string sanitizeName(StringRef name) { - if (isPythonKeyword(name) || isODSReserved(name)) + if (isPythonReserved(name) || isODSReserved(name)) return (name + "_").str(); return name.str(); } @@ -531,16 +532,30 @@ "operands.append(_get_op_results_or_values({0}))"; constexpr const char *multiResultAppendTemplate = "results.extend({0})"; -/// Template for setting an attribute in the operation builder. -/// {0} is the attribute name; -/// {1} is the builder argument name. -constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; - -/// Template for setting an optional attribute in the operation builder. -/// {0} is the attribute name; -/// {1} is the builder argument name. -constexpr const char *initOptionalAttributeTemplate = - R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; +/// Template for attribute builder from raw input in the operation builder. +/// {0} is the builder argument name; +/// {1} is the attribute builder from raw; +/// {2} is the attribute builder from raw. +/// Use the value the user passed in if either it is already an Attribute or +/// there is no method registered to make it an Attribute. +constexpr const char *initAttributeWithBuilderTemplate = + R"Py(attributes["{1}"] = ({0} if ( + issubclass(type({0}), _ods_ir.Attribute) or + not _ods_ir.AttrBuilder.contains('{2}')) else + _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py"; + +/// Template for attribute builder from raw input for optional attribute in the +/// operation builder. +/// {0} is the builder argument name; +/// {1} is the attribute builder from raw; +/// {2} is the attribute builder from raw. +/// Use the value the user passed in if either it is already an Attribute or +/// there is no method registered to make it an Attribute. +constexpr const char *initOptionalAttributeWithBuilderTemplate = + R"Py(if {0} is not None: attributes["{1}"] = ({0} if ( + issubclass(type({0}), _ods_ir.Attribute) or + not _ods_ir.AttrBuilder.contains('{2}')) else + _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py"; constexpr const char *initUnitAttributeTemplate = R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( @@ -656,6 +671,7 @@ populateBuilderLinesAttr(const Operator &op, llvm::ArrayRef argNames, llvm::SmallVectorImpl &builderLines) { + builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)"); for (int i = 0, e = op.getNumArgs(); i < e; ++i) { Argument arg = op.getArg(i); auto *attribute = arg.dyn_cast(); @@ -670,10 +686,10 @@ } builderLines.push_back(llvm::formatv( - (attribute->attr.isOptional() || attribute->attr.hasDefaultValue()) - ? initOptionalAttributeTemplate - : initAttributeTemplate, - attribute->name, argNames[i])); + attribute->attr.isOptional() || attribute->attr.hasDefaultValue() + ? initOptionalAttributeWithBuilderTemplate + : initAttributeWithBuilderTemplate, + argNames[i], attribute->name, attribute->attr.getAttrDefName())); } } @@ -753,8 +769,7 @@ /// corresponding interface: /// - {0} is the name of the class for which the types are inferred. constexpr const char *inferTypeInterfaceTemplate = - R"PY(_ods_context = _ods_get_default_loc_context(loc) -results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( + R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( operands=operands, attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), context=_ods_context,