diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -285,6 +285,14 @@ * not perform deep comparison. */ int mlirOperationEqual(MlirOperation op, MlirOperation other); +/** Gets the block that owns this operation, returning null if the operation is + * not owned. */ +MlirBlock mlirOperationGetBlock(MlirOperation op); + +/** Gets the operation that owns this operation, returning null if the operation + * is not owned. */ +MlirOperation mlirOperationGetParentOperation(MlirOperation op); + /** Returns the number of regions attached to the given operation. */ intptr_t mlirOperationGetNumRegions(MlirOperation op); @@ -408,6 +416,9 @@ /** Returns the first operation in the block. */ MlirOperation mlirBlockGetFirstOperation(MlirBlock block); +/** Returns the terminator operation in the block or null if no terminator. */ +MlirOperation mlirBlockGetTerminator(MlirBlock block); + /** Takes an operation owned by the caller and appends it to the block. */ void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation); diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -9,6 +9,8 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H +#include + #include #include "mlir-c/IR.h" @@ -18,6 +20,7 @@ namespace python { class PyBlock; +class PyInsertionPoint; class PyLocation; class PyMlirContext; class PyModule; @@ -61,6 +64,7 @@ return stolen; } + T *get() { return referrent; } T *operator->() { assert(referrent && object); return referrent; @@ -76,9 +80,48 @@ pybind11::object object; }; -using PyMlirContextRef = PyObjectRef; +/// Tracks an entry in the thread context stack. New entries are pushed onto +/// here for each with block that activates a new InsertionPoint or Context. +/// Pushing either a context or an insertion point resets the other: +/// - a new context activates a new entry with a null insertion point. +/// - a new insertion point activates a new entry with the context that the +/// insertion point is bound to. +class PyThreadContextEntry { +public: + PyThreadContextEntry(pybind11::object context, + pybind11::object insertionPoint) + : context(std::move(context)), insertionPoint(std::move(insertionPoint)) { + } + + /// Gets the top of stack context and return nullptr if not defined. + /// If required is true and there is no default, a nice user-facing exception + /// is raised. + static PyMlirContext *getDefaultContext(bool required); + + /// Gets the top of stack insertion point and return nullptr if not defined. + /// If required is true and there is no default, a nice user-facing exception + /// is raised. + static PyInsertionPoint *getDefaultInsertionPoint(bool required); + + PyMlirContext *getContext(); + PyInsertionPoint *getInsertionPoint(); + + /// Stack management. + static PyThreadContextEntry *getTos(); + static void push(pybind11::object context, pybind11::object insertionPoint); + + /// Gets the thread local stack. + static std::vector &getStack(); + +private: + /// An object reference to the PyContext. + pybind11::object context; + /// An object reference to the current insertion point. + pybind11::object insertionPoint; +}; /// Wrapper around MlirContext. +using PyMlirContextRef = PyObjectRef; class PyMlirContext { public: PyMlirContext() = delete; @@ -287,8 +330,7 @@ public: ~PyOperation(); /// Returns a PyOperation for the given MlirOperation, optionally associating - /// it with a parentKeepAlive (which must match on all such calls for the - /// same operation). + /// it with a parentKeepAlive. static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); @@ -326,6 +368,14 @@ bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); + /// Gets the owning block or raises an exception if the operation has no + /// owning block. + PyBlock getBlock(); + + /// Gets the parent operation or raises an exception if the operation has + /// no parent. + PyOperationRef getParentOperation(); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, @@ -403,6 +453,41 @@ MlirBlock block; }; +/// An insertion point maintains a pointer to a Block and a reference operation. +/// Calls to insert() will insert a new operation before the +/// reference operation. If the reference operation is null, then appends to +/// the end of the block. +class PyInsertionPoint { +public: + /// Creates an insertion point positioned after the last operation in the + /// block, but still inside the block. + PyInsertionPoint(PyBlock &block); + /// Creates an insertion point positioned before a reference operation. + PyInsertionPoint(PyOperation &beforeOperation); + + /// Shortcut to create an insertion point at the beginning of the block. + static PyInsertionPoint atBlockBegin(PyBlock &block); + /// Shortcut to create an insertion point before the block terminator. + static PyInsertionPoint atBlockTerminator(PyBlock &block); + + /// Inserts an operation. + void insert(PyOperation &operation); + + /// Enter and exit the context manager. + pybind11::object contextEnter(); + void contextExit(pybind11::object excType, pybind11::object excVal, + pybind11::object excTb); + +private: + // Trampoline constructor that avoids null initializing members while + // looking up parents. + PyInsertionPoint(PyBlock block, llvm::Optional refOperation) + : block(std::move(block)), refOperation(std::move(refOperation)) {} + + PyBlock block; + llvm::Optional refOperation; +}; + /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. class PyAttribute : public BaseContextObject { diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -467,41 +467,11 @@ "attempt to access out of bounds operation"); } - void insert(int index, PyOperation &newOperation) { - parentOperation->checkValid(); - newOperation.checkValid(); - if (index < 0) { - throw SetPyError( - PyExc_IndexError, - "only positive insertion indices are supported for operations"); - } - if (newOperation.isAttached()) { - throw SetPyError( - PyExc_ValueError, - "attempt to insert an operation that has already been inserted"); - } - // TODO: Needing to do this check is unfortunate, especially since it will - // be a forward-scan, just like the following call to - // mlirBlockInsertOwnedOperation. Switch to insert before/after once - // D88148 lands. - if (index > dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to insert operation past end"); - } - mlirBlockInsertOwnedOperation(block, index, newOperation.get()); - newOperation.setAttached(); - // TODO: Rework the parentKeepAlive so as to avoid ownership hazards under - // the new ownership. - } - static void bind(py::module &m) { py::class_(m, "OperationList") .def("__getitem__", &PyOperationList::dunderGetItem) .def("__iter__", &PyOperationList::dunderIter) - .def("__len__", &PyOperationList::dunderLen) - .def("insert", &PyOperationList::insert, py::arg("index"), - py::arg("operation"), - "Inserts an operation at an indexed position"); + .def("__len__", &PyOperationList::dunderLen); } private: @@ -668,7 +638,75 @@ // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); - return PyOperation::createDetached(getRef(), operation).releaseObject(); + PyOperationRef created = PyOperation::createDetached(getRef(), operation); + + // InsertPoint active? + PyInsertionPoint *ip = + PyThreadContextEntry::getDefaultInsertionPoint(/*required=*/false); + if (ip) + ip->insert(*created.get()); + + return created.releaseObject(); +} + +//------------------------------------------------------------------------------ +// PyThreadContextEntry management +//------------------------------------------------------------------------------ + +std::vector &PyThreadContextEntry::getStack() { + static thread_local std::vector stack; + return stack; +} + +PyThreadContextEntry *PyThreadContextEntry::getTos() { + auto &stack = getStack(); + if (stack.empty()) + return nullptr; + return &stack.back(); +} + +void PyThreadContextEntry::push(pybind11::object context, + pybind11::object insertionPoint) { + auto &stack = getStack(); + stack.emplace_back(std::move(context), std::move(insertionPoint)); +} + +PyMlirContext *PyThreadContextEntry::getContext() { + if (!context) + return nullptr; + return py::cast(context); +} + +PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { + if (!insertionPoint) + return nullptr; + return py::cast(insertionPoint); +} + +PyMlirContext *PyThreadContextEntry::getDefaultContext(bool required) { + auto *tos = getTos(); + PyMlirContext *context = tos ? tos->getContext() : nullptr; + if (required && !context) { + throw SetPyError( + PyExc_RuntimeError, + "A default context is required for this call but is not provided. " + "Establish a default by surrounding the code with " + "'with context:'"); + } + return context; +} + +PyInsertionPoint * +PyThreadContextEntry::getDefaultInsertionPoint(bool required) { + auto *tos = getTos(); + PyInsertionPoint *ip = tos ? tos->getInsertionPoint() : nullptr; + if (required && !ip) + throw SetPyError(PyExc_RuntimeError, + "A default insertion point is required for this call but " + "is not provided. " + "Establish a default by surrounding the code with " + "'with InsertionPoint(...):'"); + return ip; } //------------------------------------------------------------------------------ @@ -791,7 +829,6 @@ } // Use existing. PyOperation *existing = it->second.second; - assert(existing->parentKeepAlive.is(parentKeepAlive)); py::object pyRef = py::reinterpret_borrow(it->second.first); return PyOperationRef(existing, std::move(pyRef)); } @@ -858,6 +895,22 @@ return fileObject.attr("getvalue")(); } +PyOperationRef PyOperation::getParentOperation() { + if (!isAttached()) + throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); + MlirOperation operation = mlirOperationGetParentOperation(get()); + if (mlirOperationIsNull(operation)) + throw SetPyError(PyExc_ValueError, "Operation has no parent."); + return PyOperation::forOperation(getContext(), operation); +} + +PyBlock PyOperation::getBlock() { + PyOperationRef parentOperation = getParentOperation(); + MlirBlock block = mlirOperationGetBlock(get()); + assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); + return PyBlock{std::move(parentOperation), block}; +} + PyOpView::PyOpView(py::object operation) : operationObject(std::move(operation)), operation(py::cast(this->operationObject)) {} @@ -897,6 +950,76 @@ return parentMetaclass(newName, py::make_tuple(userClass), attributes); } +//------------------------------------------------------------------------------ +// PyInsertionPoint. +//------------------------------------------------------------------------------ + +PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} + +PyInsertionPoint::PyInsertionPoint(PyOperation &beforeOperation) + : block(beforeOperation.getBlock()), + refOperation(beforeOperation.getRef()) {} + +void PyInsertionPoint::insert(PyOperation &operation) { + if (operation.isAttached()) + throw SetPyError(PyExc_ValueError, + "Attempt to insert operation that is already attached"); + block.getParentOperation()->checkValid(); + MlirOperation beforeOp = {nullptr}; + if (refOperation) { + // Insert before operation. + (*refOperation)->checkValid(); + beforeOp = (*refOperation)->get(); + } + mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get()); + operation.setAttached(); +} + +PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { + MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); + if (mlirOperationIsNull(firstOp)) { + // Just insert at end. + return PyInsertionPoint(block); + } + + // Insert before first op. + PyOperationRef firstOpRef = PyOperation::forOperation( + block.getParentOperation()->getContext(), firstOp); + return PyInsertionPoint{block, std::move(firstOpRef)}; +} + +PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { + MlirOperation terminator = mlirBlockGetTerminator(block.get()); + if (mlirOperationIsNull(terminator)) + throw SetPyError(PyExc_ValueError, "Block has no terminator"); + PyOperationRef terminatorOpRef = PyOperation::forOperation( + block.getParentOperation()->getContext(), terminator); + return PyInsertionPoint{block, std::move(terminatorOpRef)}; +} + +py::object PyInsertionPoint::contextEnter() { + auto context = block.getParentOperation()->getContext().getObject(); + py::object self = py::cast(this); + PyThreadContextEntry::push(/*context=*/std::move(context), + /*insertionPoint=*/self); + return self; +} + +void PyInsertionPoint::contextExit(pybind11::object excType, + pybind11::object excVal, + pybind11::object excTb) { + auto &stack = PyThreadContextEntry::getStack(); + if (stack.empty()) + throw SetPyError(PyExc_RuntimeError, + "Unbalanced insertion point enter/exit"); + auto &tos = stack.back(); + PyInsertionPoint *current = tos.getInsertionPoint(); + if (current != this) + throw SetPyError(PyExc_RuntimeError, + "Unbalanced insertion point enter/exit"); + stack.pop_back(); +} + //------------------------------------------------------------------------------ // PyAttribute. //------------------------------------------------------------------------------ @@ -2388,6 +2511,24 @@ }, "Returns the assembly form of the block."); + //---------------------------------------------------------------------------- + // Mapping of PyInsertionPoint. + //---------------------------------------------------------------------------- + + py::class_(m, "InsertionPoint") + .def(py::init(), py::arg("block"), + "Inserts after the last operation but still inside the block.") + .def("__enter__", &PyInsertionPoint::contextEnter) + .def("__exit__", &PyInsertionPoint::contextExit) + .def(py::init(), py::arg("beforeOperation"), + "Inserts before a referenced operation.") + .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, + py::arg("block"), "Inserts at the beginning of the block.") + .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, + py::arg("block"), "Inserts before the block terminator.") + .def("insert", &PyInsertionPoint::insert, py::arg("operation"), + "Inserts an operation."); + //---------------------------------------------------------------------------- // Mapping of PyAttribute. //---------------------------------------------------------------------------- diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -249,6 +249,14 @@ return unwrap(op) == unwrap(other); } +MlirBlock mlirOperationGetBlock(MlirOperation op) { + return wrap(unwrap(op)->getBlock()); +} + +MlirOperation mlirOperationGetParentOperation(MlirOperation op) { + return wrap(unwrap(op)->getParentOp()); +} + intptr_t mlirOperationGetNumRegions(MlirOperation op) { return static_cast(unwrap(op)->getNumRegions()); } @@ -403,6 +411,16 @@ return wrap(&cppBlock->front()); } +MlirOperation mlirBlockGetTerminator(MlirBlock block) { + Block *cppBlock = unwrap(block); + if (cppBlock->empty()) + return wrap(static_cast(nullptr)); + Operation &back = cppBlock->back(); + if (!back.isKnownTerminator()) + return wrap(static_cast(nullptr)); + return wrap(&back); +} + void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { unwrap(block)->push_back(unwrap(operation)); } diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py --- a/mlir/test/Bindings/Python/dialects.py +++ b/mlir/test/Bindings/Python/dialects.py @@ -73,30 +73,21 @@ f32 = mlir.ir.F32Type.get(ctx) loc = ctx.get_unknown_location() m = ctx.create_module(loc) - m_block = m.body - # TODO: Remove integer insertion in favor of InsertionPoint and/or op-based. - ip = [0] + def createInput(): op = ctx.create_operation("pytest_dummy.intinput", loc, results=[f32]) - m_block.operations.insert(ip[0], op) - ip[0] += 1 # TODO: Auto result cast from operation return op.results[0] - # Create via dialects context collection. - input1 = createInput() - input2 = createInput() - op1 = ctx.dialects.std.AddFOp(loc, input1, input2) - # TODO: Auto operation cast from OpView - # TODO: Context manager insertion point - m_block.operations.insert(ip[0], op1.operation) - ip[0] += 1 - - # Create via an import - from mlir.dialects.std import AddFOp - op2 = AddFOp(loc, input1, op1.result) - m_block.operations.insert(ip[0], op2.operation) - ip[0] += 1 + with mlir.ir.InsertionPoint.at_block_terminator(m.body): + # Create via dialects context collection. + input1 = createInput() + input2 = createInput() + op1 = ctx.dialects.std.AddFOp(loc, input1, input2) + + # Create via an import + from mlir.dialects.std import AddFOp + AddFOp(loc, input1, op1.result) # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" diff --git a/mlir/test/Bindings/Python/insertion_point.py b/mlir/test/Bindings/Python/insertion_point.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/insertion_point.py @@ -0,0 +1,152 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +import io +import itertools +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +# CHECK-LABEL: TEST: test_insert_at_block_end +def test_insert_at_block_end(): + ctx = Context() + ctx.allow_unregistered_dialects = True + loc = ctx.get_unknown_location() + module = ctx.parse_module(r""" + func @foo() -> () { + "custom.op1"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint(entry_block) + ip.insert(ctx.create_operation("custom.op2", loc)) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() + +run(test_insert_at_block_end) + + +# CHECK-LABEL: TEST: test_insert_before_operation +def test_insert_before_operation(): + ctx = Context() + ctx.allow_unregistered_dialects = True + loc = ctx.get_unknown_location() + module = ctx.parse_module(r""" + func @foo() -> () { + "custom.op1"() : () -> () + "custom.op2"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint(entry_block.operations[1]) + ip.insert(ctx.create_operation("custom.op3", loc)) + # CHECK: "custom.op1" + # CHECK: "custom.op3" + # CHECK: "custom.op2" + module.operation.print() + +run(test_insert_before_operation) + + +# CHECK-LABEL: TEST: test_insert_at_block_begin +def test_insert_at_block_begin(): + ctx = Context() + ctx.allow_unregistered_dialects = True + loc = ctx.get_unknown_location() + module = ctx.parse_module(r""" + func @foo() -> () { + "custom.op2"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint.at_block_begin(entry_block) + ip.insert(ctx.create_operation("custom.op1", loc)) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() + +run(test_insert_at_block_begin) + + +# CHECK-LABEL: TEST: test_insert_at_block_begin_empty +def test_insert_at_block_begin_empty(): + # TODO: Write this test case when we can create such a situation. + pass + +run(test_insert_at_block_begin_empty) + + +# CHECK-LABEL: TEST: test_insert_at_terminator +def test_insert_at_terminator(): + ctx = Context() + ctx.allow_unregistered_dialects = True + loc = ctx.get_unknown_location() + module = ctx.parse_module(r""" + func @foo() -> () { + "custom.op1"() : () -> () + return + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint.at_block_terminator(entry_block) + ip.insert(ctx.create_operation("custom.op2", loc)) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() + +run(test_insert_at_terminator) + + +# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing +def test_insert_at_block_terminator_missing(): + ctx = Context() + ctx.allow_unregistered_dialects = True + loc = ctx.get_unknown_location() + module = ctx.parse_module(r""" + func @foo() -> () { + "custom.op1"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + try: + ip = InsertionPoint.at_block_terminator(entry_block) + except ValueError as e: + # CHECK: Block has no terminator + print(e) + else: + assert False, "Expected exception" + +run(test_insert_at_block_terminator_missing) + + +# CHECK-LABEL: TEST: test_insertion_point_context +def test_insertion_point_context(): + ctx = Context() + ctx.allow_unregistered_dialects = True + loc = ctx.get_unknown_location() + module = ctx.parse_module(r""" + func @foo() -> () { + "custom.op1"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + with InsertionPoint(entry_block): + ctx.create_operation("custom.op2", loc) + with InsertionPoint.at_block_begin(entry_block): + ctx.create_operation("custom.opa", loc) + ctx.create_operation("custom.opb", loc) + ctx.create_operation("custom.op3", loc) + # CHECK: "custom.opa" + # CHECK: "custom.opb" + # CHECK: "custom.op1" + # CHECK: "custom.op2" + # CHECK: "custom.op3" + module.operation.print() + +run(test_insertion_point_context) diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -111,7 +111,7 @@ return } """) - func = module.operation.regions[0].blocks[0].operations[0] + func = module.body.operations[0] entry_block = func.regions[0].blocks[0] assert len(entry_block.arguments) == 3 # CHECK: Argument 0, type i32 @@ -152,8 +152,8 @@ run(testDetachedOperation) -# CHECK-LABEL: TEST: testOperationInsert -def testOperationInsert(): +# CHECK-LABEL: TEST: testOperationInsertionPoint +def testOperationInsertionPoint(): ctx = mlir.ir.Context() ctx.allow_unregistered_dialects = True module = ctx.parse_module(r""" @@ -168,10 +168,11 @@ op1 = ctx.create_operation("custom.op1", loc) op2 = ctx.create_operation("custom.op2", loc) - func = module.operation.regions[0].blocks[0].operations[0] + func = module.body.operations[0] entry_block = func.regions[0].blocks[0] - entry_block.operations.insert(0, op1) - entry_block.operations.insert(1, op2) + ip = mlir.ir.InsertionPoint.at_block_begin(entry_block) + ip.insert(op1) + ip.insert(op2) # CHECK: func @f1 # CHECK: "custom.op1"() # CHECK: "custom.op2"() @@ -180,13 +181,13 @@ # Trying to add a previously added op should raise. try: - entry_block.operations.insert(0, op1) + ip.insert(op1) except ValueError: pass else: assert False, "expected insert of attached op to raise" -run(testOperationInsert) +run(testOperationInsertionPoint) # CHECK-LABEL: TEST: testOperationWithRegion @@ -202,7 +203,8 @@ # CHECK: "custom.terminator"() : () -> () # CHECK: }) : () -> () terminator = ctx.create_operation("custom.terminator", loc) - block.operations.insert(0, terminator) + ip = mlir.ir.InsertionPoint(block) + ip.insert(terminator) print(op1) # Now add the whole operation to another op. @@ -216,9 +218,10 @@ return %1 : i32 } """) - func = module.operation.regions[0].blocks[0].operations[0] + func = module.body.operations[0] entry_block = func.regions[0].blocks[0] - entry_block.operations.insert(0, op1) + ip = mlir.ir.InsertionPoint.at_block_begin(entry_block) + ip.insert(op1) # CHECK: func @f1 # CHECK: "custom.op1"() # CHECK: "custom.terminator" @@ -238,7 +241,7 @@ } func @f2() -> (i32, f64, index) """) - caller = module.operation.regions[0].blocks[0].operations[0] + caller = module.body.operations[0] call = caller.regions[0].blocks[0].operations[0] assert len(call.results) == 3 # CHECK: Result 0, type i32 diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -261,18 +261,32 @@ MlirBlock block = mlirRegionGetFirstBlock(region); operation = mlirBlockGetFirstOperation(block); region = mlirOperationGetRegion(operation, 0); + MlirOperation parentOperation = operation; block = mlirRegionGetFirstBlock(region); operation = mlirBlockGetFirstOperation(block); - // In the module we created, the first operation of the first function is an - // "std.dim", which has an attribute and a single result that we can use to - // test the printing mechanism. + // Verify that parent operation and block report correctly. + fprintf(stderr, "Parent operation eq: %d\n", + mlirOperationEqual(mlirOperationGetParentOperation(operation), + parentOperation)); + fprintf(stderr, "Block eq: %d\n", + mlirBlockEqual(mlirOperationGetBlock(operation), block)); + + // In the module we created, the first operation of the first function is + // an "std.dim", which has an attribute and a single result that we can + // use to test the printing mechanism. mlirBlockPrint(block, printToStderr, NULL); fprintf(stderr, "\n"); fprintf(stderr, "First operation: "); mlirOperationPrint(operation, printToStderr, NULL); fprintf(stderr, "\n"); + // Get the block terminator and print it. + MlirOperation terminator = mlirBlockGetTerminator(block); + fprintf(stderr, "Terminator: "); + mlirOperationPrint(terminator, printToStderr, NULL); + fprintf(stderr, "\n"); + // Get the attribute by index. MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0); fprintf(stderr, "Get attr 0: "); @@ -1100,6 +1114,8 @@ printFirstOfEach(ctx, module); // clang-format off + // CHECK: Parent operation eq: 1 + // CHECK: Block eq: 1 // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref // CHECK: %[[C1:.*]] = constant 1 : index @@ -1111,6 +1127,7 @@ // CHECK: } // CHECK: return // CHECK: First operation: {{.*}} = constant 0 : index + // CHECK: Terminator: return // CHECK: Get attr 0: 0 : index // CHECK: Get attr 0 by name: 0 : index // CHECK: does_not_exist is null: 1