diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -54,6 +54,7 @@ DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); +DEFINE_C_API_STRUCT(MlirSymbolTable, void); DEFINE_C_API_STRUCT(MlirAttribute, const void); DEFINE_C_API_STRUCT(MlirIdentifier, const void); @@ -738,6 +739,46 @@ /// Returns the hash value of the type id. MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); +//===----------------------------------------------------------------------===// +// Symbol and SymbolTable API. +//===----------------------------------------------------------------------===// + +/// Returns the name of the attribute used to store symbol names compatible with +/// symbol tables. +MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(); + +/// Creates a symbol table for the given operation. If the operation does not +/// have the SymbolTable trait, returns a null symbol table. +MLIR_CAPI_EXPORTED MlirSymbolTable +mlirSymbolTableCreate(MlirOperation operation); + +/// Returns true if the symbol table is null. +static inline bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable) { + return !symbolTable.ptr; +} + +/// Destroys the symbol table created with mlirSymbolTableCreate. This does not +/// affect the operations in the table. +MLIR_CAPI_EXPORTED void mlirSymbolTableDestroy(MlirSymbolTable symbolTable); + +/// Looks up a symbol with the given name in the given symbol table and returns +/// the operation that corresponds to the symbol. If the symbol cannot be found, +/// returns a null operation. +MLIR_CAPI_EXPORTED MlirOperation +mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name); + +/// Inserts the given operation into the given symbol table. The operation must +/// have the symbol trait. If the symbol table already has a symbol with the +/// same name, renames the symbol being inserted to ensure name uniqueness. Note +/// that this does not move the operation itself into the block of the symbol +/// table operation, this should be done separately. +MLIR_CAPI_EXPORTED void mlirSymbolTableInsert(MlirSymbolTable symbolTable, + MlirOperation operation); + +/// Removes the given operation from the symbol table and erases it. +MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, + MlirOperation operation); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -27,6 +27,7 @@ DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) +DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable); DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1530,6 +1530,55 @@ return PyValue(ownerRef, value); } +//------------------------------------------------------------------------------ +// PySymbolTable. +//------------------------------------------------------------------------------ + +PySymbolTable::PySymbolTable(PyOperationBase &operation) + : operation(operation.getOperation().getRef()) { + symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); + if (mlirSymbolTableIsNull(symbolTable)) { + throw py::cast_error("Operation is not a Symbol Table."); + } +} + +py::object PySymbolTable::dunderGetItem(const std::string &name) { + operation->checkValid(); + MlirOperation symbol = mlirSymbolTableLookup( + symbolTable, mlirStringRefCreate(name.data(), name.length())); + if (mlirOperationIsNull(symbol)) + throw py::key_error("Symbol '" + name + "' not in the symbol table."); + + return PyOperation::forOperation(operation->getContext(), symbol, + operation.getObject()) + ->createOpView(); +} + +void PySymbolTable::erase(PyOperationBase &symbol) { + operation->checkValid(); + symbol.getOperation().checkValid(); + mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); + // The operation is also erased, so we must invalidate it. There may be Python + // references to this operation so we don't want to delete it from the list of + // live operations here. + symbol.getOperation().valid = false; +} + +void PySymbolTable::dunderDel(const std::string &name) { + py::object operation = dunderGetItem(name); + erase(py::cast(operation)); +} + +void PySymbolTable::insert(PyOperationBase &symbol) { + operation->checkValid(); + symbol.getOperation().checkValid(); + MlirAttribute symbolAttr = mlirOperationGetAttributeByName( + symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); + if (mlirAttributeIsNull(symbolAttr)) + throw py::value_error("Expected operation to have a symbol name."); + mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); +} + namespace { /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed @@ -2669,6 +2718,20 @@ PyBlockArgument::bind(m); PyOpResult::bind(m); + //---------------------------------------------------------------------------- + // Mapping of SymbolTable. + //---------------------------------------------------------------------------- + py::class_(m, "SymbolTable", py::module_local()) + .def(py::init()) + .def("__getitem__", &PySymbolTable::dunderGetItem) + .def("insert", &PySymbolTable::insert) + .def("erase", &PySymbolTable::erase) + .def("__delitem__", &PySymbolTable::dunderDel) + .def("__contains__", [](PySymbolTable &table, const std::string &name) { + return !mlirOperationIsNull(mlirSymbolTableLookup( + table, mlirStringRefCreate(name.data(), name.length()))); + }); + // Container bindings. PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -32,6 +32,7 @@ class PyModule; class PyOperation; class PyType; +class PySymbolTable; class PyValue; /// Template for a reference to a concrete type which captures a python @@ -513,6 +514,7 @@ bool valid = true; friend class PyOperationBase; + friend class PySymbolTable; }; /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for @@ -876,6 +878,38 @@ MlirIntegerSet integerSet; }; +/// Bindings for MLIR symbol tables. +class PySymbolTable { +public: + /// Constructs a symbol table for the given operation. + explicit PySymbolTable(PyOperationBase &operation); + + /// Destroys the symbol table. + ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } + + /// Returns the symbol (opview) with the given name, throws if there is no + /// such symbol in the table. + pybind11::object dunderGetItem(const std::string &name); + + /// Removes the given operation from the symbol table and erases it. + void erase(PyOperationBase &symbol); + + /// Removes the operation with the given name from the symbol table and erases + /// it, throws if there is no such symbol in the table. + void dunderDel(const std::string &name); + + /// Inserts the given operation into the symbol table. The operation must have + /// the symbol trait. + void insert(PyOperationBase &symbol); + + /// Casts the bindings class into the C API structure. + operator MlirSymbolTable() { return symbolTable; } + +private: + PyOperationRef operation; + MlirSymbolTable symbolTable; +}; + void populateIRAffine(pybind11::module &m); void populateIRAttributes(pybind11::module &m); void populateIRCore(pybind11::module &m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -763,3 +763,36 @@ size_t mlirTypeIDHashValue(MlirTypeID typeID) { return hash_value(unwrap(typeID)); } + +//===----------------------------------------------------------------------===// +// Symbol and SymbolTable API. +//===----------------------------------------------------------------------===// + +MlirStringRef mlirSymbolTableGetSymbolAttributeName() { + return wrap(SymbolTable::getSymbolAttrName()); +} + +MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { + if (!unwrap(operation)->hasTrait()) + return wrap(static_cast(nullptr)); + return wrap(new SymbolTable(unwrap(operation))); +} + +void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { + delete unwrap(symbolTable); +} + +MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, + MlirStringRef name) { + return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length))); +} + +void mlirSymbolTableInsert(MlirSymbolTable symbolTable, + MlirOperation operation) { + unwrap(symbolTable)->insert(unwrap(operation)); +} + +void mlirSymbolTableErase(MlirSymbolTable symbolTable, + MlirOperation operation) { + unwrap(symbolTable)->erase(unwrap(operation)); +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1692,57 +1692,6 @@ (intptr_t)userData); } -void testDiagnostics() { - MlirContext ctx = mlirContextCreate(); - MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( - ctx, errorHandler, (void *)42, deleteUserData); - fprintf(stderr, "@test_diagnostics\n"); - MlirLocation unknownLoc = mlirLocationUnknownGet(ctx); - mlirEmitError(unknownLoc, "test diagnostics"); - MlirLocation fileLineColLoc = mlirLocationFileLineColGet( - ctx, mlirStringRefCreateFromCString("file.c"), 1, 2); - mlirEmitError(fileLineColLoc, "test diagnostics"); - MlirLocation callSiteLoc = mlirLocationCallSiteGet( - mlirLocationFileLineColGet( - ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3), - fileLineColLoc); - mlirEmitError(callSiteLoc, "test diagnostics"); - MlirLocation null = {0}; - MlirLocation nameLoc = - mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null); - mlirEmitError(nameLoc, "test diagnostics"); - MlirLocation locs[2] = {nameLoc, callSiteLoc}; - MlirAttribute nullAttr = {0}; - MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr); - mlirEmitError(fusedLoc, "test diagnostics"); - mlirContextDetachDiagnosticHandler(ctx, id); - mlirEmitError(unknownLoc, "more test diagnostics"); - // CHECK-LABEL: @test_diagnostics - // CHECK: processing diagnostic (userData: 42) << - // CHECK: test diagnostics - // CHECK: loc(unknown) - // CHECK: >> end of diagnostic (userData: 42) - // CHECK: processing diagnostic (userData: 42) << - // CHECK: test diagnostics - // CHECK: loc("file.c":1:2) - // CHECK: >> end of diagnostic (userData: 42) - // CHECK: processing diagnostic (userData: 42) << - // CHECK: test diagnostics - // CHECK: loc(callsite("other-file.c":2:3 at "file.c":1:2)) - // CHECK: >> end of diagnostic (userData: 42) - // CHECK: processing diagnostic (userData: 42) << - // CHECK: test diagnostics - // CHECK: loc("named") - // CHECK: >> end of diagnostic (userData: 42) - // CHECK: processing diagnostic (userData: 42) << - // CHECK: test diagnostics - // CHECK: loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)]) - // CHECK: deleting user data (userData: 42) - // CHECK-NOT: processing diagnostic - // CHECK: more test diagnostics - mlirContextDestroy(ctx); -} - int testTypeID(MlirContext ctx) { fprintf(stderr, "@testTypeID\n"); @@ -1841,6 +1790,132 @@ return 0; } +int testSymbolTable(MlirContext ctx) { + fprintf(stderr, "@testSymbolTable\n"); + + const char *moduleString = "func private @foo()" + "func private @bar()"; + const char *otherModuleString = "func private @qux()"; + + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirModule otherModule = mlirModuleCreateParse( + ctx, mlirStringRefCreateFromCString(otherModuleString)); + + MlirSymbolTable symbolTable = + mlirSymbolTableCreate(mlirModuleGetOperation(module)); + + MlirOperation funcFoo = + mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("foo")); + if (mlirOperationIsNull(funcFoo)) + return 1; + + MlirOperation funcBar = + mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar")); + if (mlirOperationEqual(funcFoo, funcBar)) + return 2; + + MlirOperation missing = + mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux")); + if (!mlirOperationIsNull(missing)) + return 3; + + MlirBlock moduleBody = mlirModuleGetBody(module); + MlirBlock otherModuleBody = mlirModuleGetBody(otherModule); + MlirOperation operation = mlirBlockGetFirstOperation(otherModuleBody); + mlirOperationRemoveFromParent(operation); + mlirBlockAppendOwnedOperation(moduleBody, operation); + + // At this moment, the operation is still missing from the symbol table. + MlirOperation stillMissing = + mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux")); + if (!mlirOperationIsNull(stillMissing)) + return 4; + + // After it is added to the symbol table, and not only the operation with + // which the table is associated, it can be looked up. + mlirSymbolTableInsert(symbolTable, operation); + MlirOperation funcQux = + mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux")); + if (!mlirOperationEqual(operation, funcQux)) + return 5; + + // Erasing from the symbol table also removes the operation. + mlirSymbolTableErase(symbolTable, funcBar); + MlirOperation nowMissing = + mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar")); + if (!mlirOperationIsNull(nowMissing)) + return 6; + + mlirOperationDump(mlirModuleGetOperation(module)); + mlirOperationDump(mlirModuleGetOperation(otherModule)); + // clang-format off + // CHECK-LABEL: @testSymbolTable + // CHECK: module + // CHECK: func private @foo + // CHECK: func private @qux + // CHECK: module + // CHECK-NOT: @qux + // clang-format on + + mlirSymbolTableDestroy(symbolTable); + mlirModuleDestroy(module); + mlirModuleDestroy(otherModule); + + return 0; +} + +void testDiagnostics() { + MlirContext ctx = mlirContextCreate(); + MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( + ctx, errorHandler, (void *)42, deleteUserData); + fprintf(stderr, "@test_diagnostics\n"); + MlirLocation unknownLoc = mlirLocationUnknownGet(ctx); + mlirEmitError(unknownLoc, "test diagnostics"); + MlirLocation fileLineColLoc = mlirLocationFileLineColGet( + ctx, mlirStringRefCreateFromCString("file.c"), 1, 2); + mlirEmitError(fileLineColLoc, "test diagnostics"); + MlirLocation callSiteLoc = mlirLocationCallSiteGet( + mlirLocationFileLineColGet( + ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3), + fileLineColLoc); + mlirEmitError(callSiteLoc, "test diagnostics"); + MlirLocation null = {0}; + MlirLocation nameLoc = + mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null); + mlirEmitError(nameLoc, "test diagnostics"); + MlirLocation locs[2] = {nameLoc, callSiteLoc}; + MlirAttribute nullAttr = {0}; + MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr); + mlirEmitError(fusedLoc, "test diagnostics"); + mlirContextDetachDiagnosticHandler(ctx, id); + mlirEmitError(unknownLoc, "more test diagnostics"); + // CHECK-LABEL: @test_diagnostics + // CHECK: processing diagnostic (userData: 42) << + // CHECK: test diagnostics + // CHECK: loc(unknown) + // CHECK: >> end of diagnostic (userData: 42) + // CHECK: processing diagnostic (userData: 42) << + // CHECK: test diagnostics + // CHECK: loc("file.c":1:2) + // CHECK: >> end of diagnostic (userData: 42) + // CHECK: processing diagnostic (userData: 42) << + // CHECK: test diagnostics + // CHECK: loc(callsite("other-file.c":2:3 at "file.c":1:2)) + // CHECK: >> end of diagnostic (userData: 42) + // CHECK: processing diagnostic (userData: 42) << + // CHECK: test diagnostics + // CHECK: loc("named") + // CHECK: >> end of diagnostic (userData: 42) + // CHECK: processing diagnostic (userData: 42) << + // CHECK: test diagnostics + // CHECK: loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)]) + // CHECK: deleting user data (userData: 42) + // CHECK-NOT: processing diagnostic + // CHECK: more test diagnostics + mlirContextDestroy(ctx); +} + int main() { MlirContext ctx = mlirContextCreate(); mlirRegisterAllDialects(ctx); @@ -1870,9 +1945,10 @@ return 11; if (testClone()) return 12; - if (testTypeID(ctx)) { + if (testTypeID(ctx)) return 13; - } + if (testSymbolTable(ctx)) + return 14; mlirContextDestroy(ctx); diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -787,3 +787,75 @@ # CHECK: module { # CHECK-NEXT: } print(m1) + + +# CHECK-LABEL: TEST: testSymbolTable +@run +def testSymbolTable(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + m1 = Module.parse(""" + func private @foo() + func private @bar()""") + m2 = Module.parse(""" + func private @qux() + func private @foo() + "foo.bar"() : () -> ()""") + + symbol_table = SymbolTable(m1.operation) + + # CHECK: func private @foo + # CHECK: func private @bar + assert "foo" in symbol_table + print(symbol_table["foo"]) + assert "bar" in symbol_table + bar = symbol_table["bar"] + print(symbol_table["bar"]) + + assert "qux" not in symbol_table + + del symbol_table["bar"] + try: + symbol_table.erase(symbol_table["bar"]) + except KeyError: + pass + else: + assert False, "expected KeyError" + + # CHECK: module + # CHECK: func private @foo() + print(m1) + assert "bar" not in symbol_table + + try: + print(bar) + except RuntimeError as e: + if "the operation has been invalidated" not in str(e): + raise + else: + assert False, "expected RuntimeError due to invalidated operation" + + qux = m2.body.operations[0].remove_from_parent() + m1.body.append(qux) + symbol_table.insert(qux) + assert "qux" in symbol_table + + # Check that insertion actually renames this symbol in the symbol table. + foo2 = m2.body.operations[0].remove_from_parent() + m1.body.append(foo2) + symbol_table.insert(foo2) + assert foo2.name != "foo" + + # CHECK: module + # CHECK: func private @foo() + # CHECK: func private @qux() + # CHECK: func private @foo{{.*}} + print(m1) + + try: + symbol_table.insert(m2.body.operations[0]) + except ValueError as e: + if "Expected operation to have a symbol name" not in str(e): + raise + else: + assert False, "exepcted ValueError when adding a non-symbol"