diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -418,6 +418,14 @@ /// - Result type inference is enabled and cannot be performed. MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state); +/// Parses an operation, giving ownership to the caller. If parsing fails a null +/// operation will be returned, and an error diagnostic emitted. +/// +/// `sourceStr` may be either the text assembly format, or binary bitcode +/// format. `sourceName` is used when filling in missing location information. +MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse( + MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName); + /// Creates a deep copy of an operation. The operation is not inserted and /// ownership is transferred to the caller. MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op); diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -140,10 +140,12 @@ /// error message is emitted through the error handler registered in the /// context, and failure is returned. If `sourceFileLoc` is non-null, it is /// populated with a file location representing the start of the source file -/// that is being parsed. +/// that is being parsed. `sourceName` is used when filling in missing location +/// information. LogicalResult parseSourceString(llvm::StringRef sourceStr, Block *block, const ParserConfig &config, - LocationAttr *sourceFileLoc = nullptr); + LocationAttr *sourceFileLoc = nullptr, + StringRef sourceName = ""); namespace detail { /// The internal implementation of the templated `parseSourceFile` methods @@ -234,13 +236,16 @@ /// message is emitted through the error handler registered in the context, and /// failure is returned. `ContainerOpT` is required to have a single region /// containing a single block, and must implement the -/// `SingleBlockImplicitTerminator` trait. +/// `SingleBlockImplicitTerminator` trait. `sourceName` is used when filling in +/// missing location information. template inline OwningOpRef parseSourceString(llvm::StringRef sourceStr, - const ParserConfig &config) { + const ParserConfig &config, + StringRef sourceName = "") { LocationAttr sourceFileLoc; Block block; - if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc))) + if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc, + sourceName))) return OwningOpRef(); return detail::constructContainerOpForParserIfNecessary( &block, config.getContext(), sourceFileLoc); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -20,8 +20,8 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include #include +#include namespace py = pybind11; using namespace mlir; @@ -1059,6 +1059,20 @@ return created; } +PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, + const std::string &sourceStr, + const std::string &sourceName) { + MlirOperation op = + mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), + toMlirStringRef(sourceName)); + // TODO: Rework error reporting once diagnostic engine is + // exposed in C API. + if (mlirOperationIsNull(op)) + throw SetPyError(PyExc_ValueError, + "Unable to parse operation assembly (see diagnostics)"); + return PyOperation::createDetached(std::move(contextRef), op); +} + void PyOperation::checkValid() const { if (!valid) { throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); @@ -2778,6 +2792,17 @@ py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) + .def_static( + "parse", + [](const std::string &sourceStr, const std::string &sourceName, + DefaultingPyMlirContext context) { + return PyOperation::parse(context->getRef(), sourceStr, sourceName) + ->createOpView(); + }, + py::arg("source"), py::arg("source_name") = "", + py::arg("context") = py::none(), + "Parses an operation. Supports both text assembly format and binary " + "bitcode format.") .def_property_readonly("parent", [](PyOperation &self) -> py::object { auto parent = self.getParentOperation(); @@ -2829,6 +2854,31 @@ py::arg("successors") = py::none(), py::arg("regions") = py::none(), py::arg("loc") = py::none(), py::arg("ip") = py::none(), "Builds a specific, generated OpView based on class level attributes."); + opViewClass.attr("parse") = classmethod( + [](const py::object &cls, const std::string &sourceStr, + const std::string &sourceName, DefaultingPyMlirContext context) { + PyOperationRef parsed = + PyOperation::parse(context->getRef(), sourceStr, sourceName); + + // Check if the expected operation was parsed, and cast to to the + // appropriate `OpView` subclass if successful. + // NOTE: This accesses attributes that have been automatically added to + // `OpView` subclasses, and is not intended to be used on `OpView` + // directly. + std::string clsOpName = + py::cast(cls.attr("OPERATION_NAME")); + MlirStringRef parsedOpName = + mlirIdentifierStr(mlirOperationGetName(*parsed.get())); + if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName))) + throw SetPyError( + PyExc_ValueError, + Twine("Expected a '") + clsOpName + "' op, got: '" + + std::string(parsedOpName.data, parsedOpName.length) + "'"); + return cls.attr("_Raw")(parsed.getObject()); + }, + py::arg("cls"), py::arg("source"), py::arg("source_name") = "", + py::arg("context") = py::none(), + "Parses a specific, generated OpView based on class level attributes"); //---------------------------------------------------------------------------- // Mapping of PyRegion. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -9,9 +9,9 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H +#include #include #include -#include #include "PybindUtils.h" @@ -548,6 +548,12 @@ createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); + /// Parses a source string (either text assembly or bitcode), creating a + /// detached operation. + static PyOperationRef parse(PyMlirContextRef contextRef, + const std::string &sourceStr, + const std::string &sourceName); + /// Detaches the operation from its parent block and updates its state /// accordingly. void detachFromParent() { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -364,6 +364,15 @@ return result; } +MlirOperation mlirOperationCreateParse(MlirContext context, + MlirStringRef sourceStr, + MlirStringRef sourceName) { + + return wrap( + parseSourceString(unwrap(sourceStr), unwrap(context), unwrap(sourceName)) + .release()); +} + MlirOperation mlirOperationClone(MlirOperation op) { return wrap(unwrap(op)->clone()); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -89,8 +89,9 @@ LogicalResult mlir::parseSourceString(llvm::StringRef sourceStr, Block *block, const ParserConfig &config, - LocationAttr *sourceFileLoc) { - auto memBuffer = llvm::MemoryBuffer::getMemBuffer(sourceStr); + LocationAttr *sourceFileLoc, + StringRef sourceName) { + auto memBuffer = llvm::MemoryBuffer::getMemBuffer(sourceStr, sourceName); if (!memBuffer) return failure(); diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -4,6 +4,7 @@ import io import itertools from mlir.ir import * +from mlir.dialects.builtin import ModuleOp def run(f): @@ -900,3 +901,31 @@ with ctx, Location.unknown(): op = Operation.create("custom.op1") assert hash(op) == hash(op.operation) + + +# CHECK-LABEL: TEST: testOperationParse +@run +def testOperationParse(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + + # Generic operation parsing. + m = Operation.parse('module {}') + o = Operation.parse('"test.foo"() : () -> ()') + assert isinstance(m, ModuleOp) + assert type(o) is OpView + + # Parsing specific operation. + m = ModuleOp.parse('module {}') + assert isinstance(m, ModuleOp) + try: + ModuleOp.parse('"test.foo"() : () -> ()') + except ValueError as e: + # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo' + print(f"error: {e}") + else: + assert False, "expected error" + + o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string") + # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1) + print(f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}")