diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -119,6 +119,13 @@ MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name); +/// Returns whether the given fully-qualified operation (i.e. +/// 'dialect.operation') is registered with the context. This will return true +/// if the dialect is registered and the operation is registered within the +/// dialect. +MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, + MlirStringRef name); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1752,7 +1752,12 @@ }, [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); - }); + }) + .def("is_registered_operation", + [](PyMlirContext &self, std::string &name) { + return mlirContextIsRegisteredOperation( + self.get(), MlirStringRef{name.data(), name.size()}); + }); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor @@ -1828,7 +1833,8 @@ throw SetPyError(PyExc_ValueError, "No current Location"); return loc; }, - "Gets the Location bound to the current thread or raises ValueError") + "Gets the Location bound to the current thread or raises " + "ValueError") .def_static( "unknown", [](DefaultingPyMlirContext context) { @@ -1870,8 +1876,8 @@ [](const std::string moduleAsm, DefaultingPyMlirContext context) { MlirModule module = mlirModuleCreateParse( context->get(), toMlirStringRef(moduleAsm)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. + // TODO: Rework error reporting once diagnostic engine is + // exposed in C API. if (mlirModuleIsNull(module)) { throw SetPyError( PyExc_ValueError, @@ -1977,7 +1983,8 @@ return PyOpResult(operation.getRef(), mlirOperationGetResult(operation, 0)); }, - "Shortcut to get an op result if it has only one (throws an error " + "Shortcut to get an op result if it has only one (throws an " + "error " "otherwise).") .def("__iter__", [](PyOperationBase &self) { @@ -2060,7 +2067,8 @@ py::arg("operands") = py::none(), py::arg("attributes") = py::none(), py::arg("successors") = py::none(), py::arg("regions") = py::none(), py::arg("loc") = py::none(), py::arg("ip") = py::none(), - "Builds a specific, generated OpView based on class level attributes."); + "Builds a specific, generated OpView based on class level " + "attributes."); //---------------------------------------------------------------------------- // Mapping of PyRegion. @@ -2168,8 +2176,8 @@ [](std::string attrSpec, DefaultingPyMlirContext context) { MlirAttribute type = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. + // TODO: Rework error reporting once diagnostic engine is + // exposed in C API. if (mlirAttributeIsNull(type)) { throw SetPyError(PyExc_ValueError, Twine("Unable to parse attribute: '") + @@ -2270,8 +2278,8 @@ [](std::string typeSpec, DefaultingPyMlirContext context) { MlirType type = mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. + // TODO: Rework error reporting once diagnostic engine is + // exposed in C API. if (mlirTypeIsNull(type)) { throw SetPyError(PyExc_ValueError, Twine("Unable to parse type: '") + typeSpec + diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -60,6 +60,10 @@ return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); } +bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { + return unwrap(context)->isOperationRegistered(unwrap(name)); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py --- a/mlir/test/Bindings/Python/dialects.py +++ b/mlir/test/Bindings/Python/dialects.py @@ -3,14 +3,17 @@ import gc from mlir.ir import * + def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testDialectDescriptor +@run def testDialectDescriptor(): ctx = Context() d = ctx.get_dialect_descriptor("std") @@ -25,10 +28,9 @@ else: assert False, "Expected exception" -run(testDialectDescriptor) - # CHECK-LABEL: TEST: testUserDialectClass +@run def testUserDialectClass(): ctx = Context() # Access using attribute. @@ -60,14 +62,14 @@ # CHECK: print(d) -run(testUserDialectClass) - # CHECK-LABEL: TEST: testCustomOpView # This test uses the standard dialect AddFOp as an example of a user op. # TODO: Op creation and access is still quite verbose: simplify this test as # additional capabilities come online. +@run def testCustomOpView(): + def createInput(): op = Operation.create("pytest_dummy.intinput", results=[f32]) # TODO: Auto result cast from operation @@ -95,4 +97,12 @@ m.operation.print() -run(testCustomOpView) +# CHECK-LABEL: TEST: testIsRegisteredOperation +@run +def testIsRegisteredOperation(): + ctx = Context() + + # CHECK: std.cond_br: True + print(f"std.cond_br: {ctx.is_registered_operation('std.cond_br')}") + # CHECK: std.not_existing: False + print(f"std.not_existing: {ctx.is_registered_operation('std.not_existing')}") diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1442,6 +1442,22 @@ fprintf(stderr, "@registration\n"); // CHECK-LABEL: @registration + // CHECK: std.cond_br is_registered: 1 + fprintf(stderr, "std.cond_br is_registered: %d\n", + mlirContextIsRegisteredOperation( + ctx, mlirStringRefCreateFromCString("std.cond_br"))); + + // CHECK: std.not_existing_op is_registered: 0 + fprintf(stderr, "std.not_existing_op is_registered: %d\n", + mlirContextIsRegisteredOperation( + ctx, mlirStringRefCreateFromCString("std.not_existing_op"))); + + // CHECK: not_existing_dialect.not_existing_op is_registered: 0 + fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n", + mlirContextIsRegisteredOperation( + ctx, mlirStringRefCreateFromCString( + "not_existing_dialect.not_existing_op"))); + return 0; }