diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -7,7 +7,11 @@ //===----------------------------------------------------------------------===// #include +#include +#include +#include #include +#include #include "IRModule.h" @@ -15,7 +19,10 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir-c/Support.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/Support.h" +#include "llvm/ADT/StringRef.h" namespace py = pybind11; using namespace mlir; @@ -442,22 +449,91 @@ } }; +class PySymbolRefAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; + static constexpr const char *pyClassName = "SymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static MlirAttribute fromList(const std::vector &symbols, + PyMlirContext &context) { + if (symbols.empty()) + throw std::runtime_error("SymbolRefAttr must be composed of at least " + "one symbol."); + std::string rootSym = symbols[0]; + if (rootSym[0] == '@') + rootSym = rootSym.substr(1); + MlirStringRef rootSymbol = toMlirStringRef(rootSym); + SmallVector referenceAttrs; + for (size_t i = 1; i < symbols.size(); ++i) { + std::string sym = symbols[i]; + if (sym[0] == '@') + sym = sym.substr(1); + referenceAttrs.push_back( + mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(sym))); + } + return mlirSymbolRefAttrGet(context.get(), rootSymbol, + referenceAttrs.size(), referenceAttrs.data()); + } + + static MlirAttribute fromStr(const std::string &symbols, + DefaultingPyMlirContext context) { + std::vector symbols_; + llvm::StringRef symbolsRef(symbols); + while (!symbolsRef.empty()) { + auto [head, tail] = symbolsRef.split("::"); + symbols_.push_back(head.str()); + symbolsRef = tail; + } + return PySymbolRefAttribute::fromList(symbols_, context.resolve()); + } + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::vector &symbols, + DefaultingPyMlirContext context) { + return PySymbolRefAttribute::fromList(symbols, context.resolve()); + }, + py::arg("symbols"), py::arg("context") = py::none(), + "Gets a uniqued SymbolRef attribute"); + c.def_static( + "get", + [](const std::string &symbols, DefaultingPyMlirContext context) { + return PySymbolRefAttribute::fromStr(symbols, context.resolve()); + }, + py::arg("symbols"), py::arg("context") = py::none(), + "Gets a uniqued SymbolRef attribute"); + c.def_property_readonly( + "value", + [](PySymbolRefAttribute &self) { + std::string symbol = + "@" + unwrap(mlirSymbolRefAttrGetRootReference(self)).str(); + for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); + ++i) + symbol += "::@" + + unwrap(mlirSymbolRefAttrGetRootReference( + mlirSymbolRefAttrGetNestedReference(self, i))) + .str(); + return symbol; + }, + "Returns the value of the SymbolRef attribute as a string"); + } +}; + class PyFlatSymbolRefAttribute - : public PyConcreteAttribute { + : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; static constexpr const char *pyClassName = "FlatSymbolRefAttr"; using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFlatSymbolRefAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); - return PyFlatSymbolRefAttribute(context->getRef(), attr); + [](const std::string &value, DefaultingPyMlirContext context) { + return PySymbolRefAttribute::fromStr(value, context.resolve()); }, py::arg("value"), py::arg("context") = py::none(), "Gets a uniqued FlatSymbolRef attribute"); @@ -1167,6 +1243,16 @@ throw py::cast_error(msg); } +py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { + if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) + return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); + if (PySymbolRefAttribute::isaFunction(pyAttribute)) + return py::cast(PySymbolRefAttribute(pyAttribute)); + std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1201,6 +1287,11 @@ pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); PyDictAttribute::bind(m); + PySymbolRefAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirSymbolRefAttrGetTypeID(), + pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); + PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); PyFloatAttribute::bind(m); diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -73,6 +73,11 @@ @register_attribute_builder("SymbolRefAttr") def _symbolRefAttr(x, context): + return SymbolRefAttr.get(x, context=context) + + +@register_attribute_builder("FlatSymbolRefAttr") +def _flatSymbolRefAttr(x, context): return FlatSymbolRefAttr.get(x, context=context) @@ -105,6 +110,7 @@ def _denseI64ArrayAttr(x, context): return DenseI64ArrayAttr.get(x, context=context) + @register_attribute_builder("DenseBoolArrayAttr") def _denseBoolArrayAttr(x, context): return DenseBoolArrayAttr.get(x, context=context) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -228,7 +228,7 @@ @run def testFlatSymbolRefAttr(): with Context() as ctx: - sattr = FlatSymbolRefAttr(Attribute.parse("@symbol")) + sattr = Attribute.parse("@symbol") # CHECK: symattr value: symbol print("symattr value:", sattr.value) @@ -237,6 +237,28 @@ print("default_get:", FlatSymbolRefAttr.get("foobar")) +# CHECK-LABEL: TEST: testSymbolRefAttr +@run +def testSymbolRefAttr(): + with Context() as ctx: + sattr = Attribute.parse("@symbol1::@symbol2") + # CHECK: symattr value: @symbol1::@symbol2 + print("symattr value:", sattr.value) + + # Test factory methods. + # CHECK: default_get: @symbol1::@symbol2 + print("default_get:", SymbolRefAttr.get("symbol1::symbol2")) + + # CHECK: default_get: @symbol1::@symbol2 + print("default_get:", SymbolRefAttr.get(["symbol1", "symbol2"])) + + # CHECK: default_get: @symbol1::@symbol2 + print("default_get:", SymbolRefAttr.get("@symbol1::@symbol2")) + + # CHECK: default_get: @symbol1::@symbol2 + print("default_get:", SymbolRefAttr.get(["@symbol1", "@symbol2"])) + + # CHECK-LABEL: TEST: testOpaqueAttr @run def testOpaqueAttr():