diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -754,6 +754,9 @@ /// symbol tables. MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(); +/// Returns the name of the attribute used to store symbol visibility. +MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(); + /// Creates a symbol table for the given operation. If the operation does not /// have the SymbolTable trait, returns a null symbol table. MLIR_CAPI_EXPORTED MlirSymbolTable @@ -787,6 +790,23 @@ MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation); +/// Attempt to replace all uses that are nested within the given operation +/// of the given symbol 'oldSymbol' with the provided 'newSymbol'. This does +/// not traverse into nested symbol tables. Will fail atomically if there are +/// any unknown operations that may be potential symbol tables. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses( + MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from); + +/// Walks all symbol table operations nested within, and including, `op`. For +/// each symbol table operation, the provided callback is invoked with the op +/// and a boolean signifying if the symbols within that symbol table can be +/// treated as if all uses within the IR are visible to the caller. +/// `allSymUsesVisible` identifies whether all of the symbol uses of symbols +/// within `op` are visible. +MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables( + MlirOperation from, bool allSymUsesVisible, + void (*callback)(MlirOperation, bool, void *userData), void *userData); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1596,6 +1596,112 @@ mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); } +PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { + // Op must already be a symbol. + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); + MlirAttribute existingNameAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingNameAttr)) + throw py::value_error("Expected operation to have a symbol name."); + return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); +} + +void PySymbolTable::setSymbolName(PyOperationBase &symbol, + const std::string &name) { + // Op must already be a symbol. + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); + MlirAttribute existingNameAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingNameAttr)) + throw py::value_error("Expected operation to have a symbol name."); + MlirAttribute newNameAttr = + mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); + mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); +} + +PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); + MlirAttribute existingVisAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingVisAttr)) + throw py::value_error("Expected operation to have a symbol visibility."); + return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); +} + +void PySymbolTable::setVisibility(PyOperationBase &symbol, + const std::string &visibility) { + if (visibility != "public" && visibility != "private" && + visibility != "nested") + throw py::value_error( + "Expected visibility to be 'public', 'private' or 'nested'"); + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); + MlirAttribute existingVisAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingVisAttr)) + throw py::value_error("Expected operation to have a symbol visibility."); + MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), + toMlirStringRef(visibility)); + mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); +} + +void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, + const std::string &newSymbol, + PyOperationBase &from) { + PyOperation &fromOperation = from.getOperation(); + fromOperation.checkValid(); + if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( + toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), + from.getOperation()))) + + throw py::value_error("Symbol rename failed"); +} + +void PySymbolTable::walkSymbolTables(PyOperationBase &from, + bool allSymUsesVisible, + py::object callback) { + PyOperation &fromOperation = from.getOperation(); + fromOperation.checkValid(); + struct UserData { + PyMlirContextRef context; + py::object callback; + bool gotException; + std::string exceptionWhat; + py::object exceptionType; + }; + UserData userData{ + fromOperation.getContext(), std::move(callback), false, {}, {}}; + mlirSymbolTableWalkSymbolTables( + fromOperation.get(), allSymUsesVisible, + [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { + UserData *calleeUserData = static_cast(calleeUserDataVoid); + auto pyFoundOp = + PyOperation::forOperation(calleeUserData->context, foundOp); + if (calleeUserData->gotException) + return; + try { + calleeUserData->callback(pyFoundOp.getObject(), isVisible); + } catch (py::error_already_set &e) { + calleeUserData->gotException = true; + calleeUserData->exceptionWhat = e.what(); + calleeUserData->exceptionType = e.type(); + } + }, + static_cast(&userData)); + if (userData.gotException) { + std::string message("Exception raised in callback: "); + message.append(userData.exceptionWhat); + throw std::runtime_error(std::move(message)); + } +} + namespace { /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed @@ -2773,10 +2879,26 @@ .def("insert", &PySymbolTable::insert, py::arg("operation")) .def("erase", &PySymbolTable::erase, py::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) - .def("__contains__", [](PySymbolTable &table, const std::string &name) { - return !mlirOperationIsNull(mlirSymbolTableLookup( - table, mlirStringRefCreate(name.data(), name.length()))); - }); + .def("__contains__", + [](PySymbolTable &table, const std::string &name) { + return !mlirOperationIsNull(mlirSymbolTableLookup( + table, mlirStringRefCreate(name.data(), name.length()))); + }) + // Static helpers. + .def_static("set_symbol_name", &PySymbolTable::setSymbolName, + py::arg("symbol"), py::arg("name")) + .def_static("get_symbol_name", &PySymbolTable::getSymbolName, + py::arg("symbol")) + .def_static("get_visibility", &PySymbolTable::getVisibility, + py::arg("symbol")) + .def_static("set_visibility", &PySymbolTable::setVisibility, + py::arg("symbol"), py::arg("visibility")) + .def_static("replace_all_symbol_uses", + &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), + py::arg("new_symbol"), py::arg("from_op")) + .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, + py::arg("from_op"), py::arg("all_sym_uses_visible"), + py::arg("callback")); // Container bindings. PyBlockArgumentList::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -910,6 +910,25 @@ /// the symbol trait. PyAttribute insert(PyOperationBase &symbol); + /// Gets and sets the name of a symbol op. + static PyAttribute getSymbolName(PyOperationBase &symbol); + static void setSymbolName(PyOperationBase &symbol, const std::string &name); + + /// Gets and sets the visibility of a symbol op. + static PyAttribute getVisibility(PyOperationBase &symbol); + static void setVisibility(PyOperationBase &symbol, + const std::string &visibility); + + /// Replaces all symbol uses within an operation. See the API + /// mlirSymbolTableReplaceAllSymbolUses for all caveats. + static void replaceAllSymbolUses(const std::string &oldSymbol, + const std::string &newSymbol, + PyOperationBase &from); + + /// Walks all symbol tables under and including 'from'. + static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, + pybind11::object callback); + /// Casts the bindings class into the C API structure. operator MlirSymbolTable() { return symbolTable; } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -786,6 +786,10 @@ return wrap(SymbolTable::getSymbolAttrName()); } +MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { + return wrap(SymbolTable::getVisibilityAttrName()); +} + MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { if (!unwrap(operation)->hasTrait()) return wrap(static_cast(nullptr)); @@ -810,3 +814,25 @@ MlirOperation operation) { unwrap(symbolTable)->erase(unwrap(operation)); } + +MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, + MlirStringRef newSymbol, + MlirOperation from) { + auto cppFrom = unwrap(from); + auto *context = cppFrom->getContext(); + auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context); + auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context); + return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, + unwrap(from))); +} + +void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, + void (*callback)(MlirOperation, bool, + void *userData), + void *userData) { + SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible, + [&](Operation *foundOpCpp, bool isVisible) { + callback(wrap(foundOpCpp), isVisible, + userData); + }); +} diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -835,79 +835,6 @@ # CHECK-NOT: func private @foo -# CHECK-LABEL: TEST: testSymbolTable -@run -def testSymbolTable(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - m1 = Module.parse(""" - func private @foo() - func private @bar()""") - m2 = Module.parse(""" - func private @qux() - func private @foo() - "foo.bar"() : () -> ()""") - - symbol_table = SymbolTable(m1.operation) - - # CHECK: func private @foo - # CHECK: func private @bar - assert "foo" in symbol_table - print(symbol_table["foo"]) - assert "bar" in symbol_table - bar = symbol_table["bar"] - print(symbol_table["bar"]) - - assert "qux" not in symbol_table - - del symbol_table["bar"] - try: - symbol_table.erase(symbol_table["bar"]) - except KeyError: - pass - else: - assert False, "expected KeyError" - - # CHECK: module - # CHECK: func private @foo() - print(m1) - assert "bar" not in symbol_table - - try: - print(bar) - except RuntimeError as e: - if "the operation has been invalidated" not in str(e): - raise - else: - assert False, "expected RuntimeError due to invalidated operation" - - qux = m2.body.operations[0] - m1.body.append(qux) - symbol_table.insert(qux) - assert "qux" in symbol_table - - # Check that insertion actually renames this symbol in the symbol table. - foo2 = m2.body.operations[0] - m1.body.append(foo2) - updated_name = symbol_table.insert(foo2) - assert foo2.name.value != "foo" - assert foo2.name == updated_name - - # CHECK: module - # CHECK: func private @foo() - # CHECK: func private @qux() - # CHECK: func private @foo{{.*}} - print(m1) - - try: - symbol_table.insert(m2.body.operations[0]) - except ValueError as e: - if "Expected operation to have a symbol name" not in str(e): - raise - else: - assert False, "exepcted ValueError when adding a non-symbol" - - # CHECK-LABEL: TEST: testOperationHash @run def testOperationHash(): diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/ir/symbol_table.py @@ -0,0 +1,156 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +import io +import itertools +from mlir.ir import * + + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f + + +# CHECK-LABEL: TEST: testSymbolTableInsert +@run +def testSymbolTableInsert(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + m1 = Module.parse(""" + func private @foo() + func private @bar()""") + m2 = Module.parse(""" + func private @qux() + func private @foo() + "foo.bar"() : () -> ()""") + + symbol_table = SymbolTable(m1.operation) + + # CHECK: func private @foo + # CHECK: func private @bar + assert "foo" in symbol_table + print(symbol_table["foo"]) + assert "bar" in symbol_table + bar = symbol_table["bar"] + print(symbol_table["bar"]) + + assert "qux" not in symbol_table + + del symbol_table["bar"] + try: + symbol_table.erase(symbol_table["bar"]) + except KeyError: + pass + else: + assert False, "expected KeyError" + + # CHECK: module + # CHECK: func private @foo() + print(m1) + assert "bar" not in symbol_table + + try: + print(bar) + except RuntimeError as e: + if "the operation has been invalidated" not in str(e): + raise + else: + assert False, "expected RuntimeError due to invalidated operation" + + qux = m2.body.operations[0] + m1.body.append(qux) + symbol_table.insert(qux) + assert "qux" in symbol_table + + # Check that insertion actually renames this symbol in the symbol table. + foo2 = m2.body.operations[0] + m1.body.append(foo2) + updated_name = symbol_table.insert(foo2) + assert foo2.name.value != "foo" + assert foo2.name == updated_name + + # CHECK: module + # CHECK: func private @foo() + # CHECK: func private @qux() + # CHECK: func private @foo{{.*}} + print(m1) + + try: + symbol_table.insert(m2.body.operations[0]) + except ValueError as e: + if "Expected operation to have a symbol name" not in str(e): + raise + else: + assert False, "exepcted ValueError when adding a non-symbol" + + +# CHECK-LABEL: testSymbolTableRAUW +@run +def testSymbolTableRAUW(): + with Context() as ctx: + m = Module.parse(""" + func private @foo() { + call @bar() : () -> () + return + } + func private @bar() + """) + foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2] + SymbolTable.set_symbol_name(bar, "bam") + # Note that module.operation counts as a "nested symbol table" which won't + # be traversed into, so it is necessary to traverse its children. + SymbolTable.replace_all_symbol_uses("bar", "bam", foo) + # CHECK: call @bam() + # CHECK: func private @bam + print(m) + # CHECK: Foo symbol: "foo" + # CHECK: Bar symbol: "bam" + print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}") + print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}") + + +# CHECK-LABEL: testSymbolTableVisibility +@run +def testSymbolTableVisibility(): + with Context() as ctx: + m = Module.parse(""" + func private @foo() { + return + } + """) + foo = m.operation.regions[0].blocks[0].operations[0] + # CHECK: Existing visibility: "private" + print(f"Existing visibility: {SymbolTable.get_visibility(foo)}") + SymbolTable.set_visibility(foo, "public") + # CHECK: func public @foo + print(m) + + +# CHECK: testWalkSymbolTables +@run +def testWalkSymbolTables(): + with Context() as ctx: + m = Module.parse(""" + module @outer { + module @inner{ + } + } + """) + def callback(symbol_table_op, uses_visible): + print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}") + # CHECK: SYMBOL TABLE: True: module @inner + # CHECK: SYMBOL TABLE: True: module @outer + SymbolTable.walk_symbol_tables(m.operation, True, callback) + + # Make sure exceptions in the callback are handled. + def error_callback(symbol_table_op, uses_visible): + assert False, "Raised from python" + try: + SymbolTable.walk_symbol_tables(m.operation, True, error_callback) + except RuntimeError as e: + # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python + print(f"GOT EXCEPTION: {e}") +