diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -71,6 +71,9 @@ pybind11::object pyClass, pybind11::object rawOpViewClass); + /// Returns the custom Attribute builder dictionary. + pybind11::object attributeBuilderMap(); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. llvm::Optional @@ -92,6 +95,8 @@ /// Map of operation name to custom subclass that directly initializes /// the OpView base class (bypassing the user class constructor). llvm::StringMap rawOpViewClassMap; + /// Map of attribute ODS name to custom builder. + pybind11::object attributeBuilder; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -194,6 +194,18 @@ } }; +struct PyAttrBuilder { + static py::function get(const py::object &) { + return PyGlobals::get().attributeBuilderMap(); + } + + static void bind(py::module &m) { + py::class_(m, "AttrBuilder", py::module_local()) + .def_property_readonly_static("__get__", &PyAttrBuilder::get, + "User-friendly attribute builders"); + } +}; + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -3223,4 +3235,7 @@ // Debug bindings. PyGlobalDebugFlag::bind(m); + + // Attribute builder getter. + PyAttrBuilder::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -30,6 +30,9 @@ // The default search path include {mlir.}dialects, where {mlir.} is the // package prefix configured at compile time. dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); + py::object SimpleNamespace = + py::module_::import("types").attr("SimpleNamespace"); + attributeBuilder = SimpleNamespace(); } PyGlobals::~PyGlobals() { instance = nullptr; } @@ -84,6 +87,8 @@ rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } +py::object PyGlobals::attributeBuilderMap() { return attributeBuilder; } + llvm::Optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { loadDialectModule(dialectNamespace); diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -202,7 +202,7 @@ // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { - // CHECK: def __init__(self, type, *, loc=None, ip=None): + // CHECK: def __init__(self, type_, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: _ods_result_type_source_attr = attributes["type"] @@ -217,7 +217,7 @@ // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { - // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None): + // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None): let arguments = (ins TypeAttr:$type); let results = (outs AnyType:$res, Variadic); } diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py --- a/mlir/test/python/dialects/shape.py +++ b/mlir/test/python/dialects/shape.py @@ -6,6 +6,17 @@ import mlir.dialects.shape as shape +# FIXME: These should be in common spot, here for convenience. +AttrBuilder.BoolAttr = lambda x: BoolAttr.get(x) +AttrBuilder.IndexAttr = lambda x: IntegerAttr.get(IndexType.get(), x) +AttrBuilder.I32Attr = lambda x: IntegerAttr.get(IntegerType.get_signless(32), x) +AttrBuilder.I64Attr = lambda x: IntegerAttr.get(IntegerType.get_signless(64), x) +AttrBuilder.SymbolNameAttr = lambda x: StringAttr.get(x) + +AttrBuilder.IndexElementsAttr = lambda x: DenseElementsAttr.get( + np.array(x, dtype=np.int64), type=IndexType.get()) + + def run(f): print("\nTEST:", f.__name__) f() @@ -22,9 +33,19 @@ @func.FuncOp.from_py_func( RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)) def const_shape_tensor(arg): + shape.ConstWitnessOp(False) + shape.ConstSizeOp(30) + shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40)) + shape.ConstShapeOp([1, 2]) return shape.ConstShapeOp( - DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get())) + DenseElementsAttr.get( + np.array([3, 4], dtype=np.int64), type=IndexType.get())) # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) - # CHECK: shape.const_shape [10, 20] : tensor<2xindex> + # CHECK-DAG: shape.const_witness false + # CHECK-DAG: shape.const_size 30 + # CHECK-DAG: shape.const_size 40 + # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex> + # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex> print(module) + print("bye") diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -280,15 +280,16 @@ using AttributeClasses = DenseMap; -/// Checks whether `str` is a Python keyword. -static bool isPythonKeyword(StringRef str) { - static llvm::StringSet<> keywords( - {"and", "as", "assert", "break", "class", "continue", - "def", "del", "elif", "else", "except", "finally", - "for", "from", "global", "if", "import", "in", - "is", "lambda", "nonlocal", "not", "or", "pass", - "raise", "return", "try", "while", "with", "yield"}); - return keywords.contains(str); +/// Checks whether `str` is a Python keyword or would shadow builtin function. +static bool isPythonReserved(StringRef str) { + static llvm::StringSet<> reserved( + {"and", "as", "assert", "break", "callable", "class", + "continue", "def", "del", "elif", "else", "except", + "finally", "for", "from", "global", "if", "import", + "in", "is", "lambda", "nonlocal", "not", "or", + "pass", "raise", "return", "issubclass", "try", "type", + "while", "with", "yield"}); + return reserved.contains(str); } /// Checks whether `str` would shadow a generated variable or attribute @@ -306,7 +307,7 @@ /// (does not change the `name` if it already is suitable) and returns the /// modified version. static std::string sanitizeName(StringRef name) { - if (isPythonKeyword(name) || isODSReserved(name)) + if (isPythonReserved(name) || isODSReserved(name)) return (name + "_").str(); return name.str(); } @@ -536,11 +537,15 @@ /// {1} is the builder argument name. constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; -/// Template for setting an optional attribute in the operation builder. -/// {0} is the attribute name; -/// {1} is the builder argument name. -constexpr const char *initOptionalAttributeTemplate = - R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; +/// Template for builder attribute from raw input in the operation builder. +/// {0} is the builder argument name; +/// {1} is the attribute builder from raw. +/// Use the value the user passed in if either it is already an Attribute or +/// there is not a method registered to make it an Attribute. +// TODO: Should we just blindly try AttrBuilder post checking if Attribute? +// Different error message if neither condition holds? +constexpr const char *initAttributeWithBuilderTemplate = + R"Py({0} if (issubclass(type({0}), _ods_ir.Attribute) or not callable(getattr(_ods_ir.AttrBuilder, '{1}', None))) else _ods_ir.AttrBuilder.{1}({0}))Py"; constexpr const char *initUnitAttributeTemplate = R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( @@ -669,11 +674,16 @@ continue; } - builderLines.push_back(llvm::formatv( - (attribute->attr.isOptional() || attribute->attr.hasDefaultValue()) - ? initOptionalAttributeTemplate - : initAttributeTemplate, - attribute->name, argNames[i])); + std::string builderLine; + if (attribute->attr.isOptional() || attribute->attr.hasDefaultValue()) { + builderLine = llvm::formatv(R"Py(if {0} is not None: )Py", argNames[i]); + } + std::string val = + llvm::formatv(initAttributeWithBuilderTemplate, argNames[i], + attribute->attr.getAttrDefName()); + builderLine.append( + llvm::formatv(initAttributeTemplate, attribute->name, val)); + builderLines.push_back(builderLine); } }