diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -175,6 +175,9 @@ /** Gets the context that a module was created with. */ MlirContext mlirModuleGetContext(MlirModule module); +/** Gets the body of the module, i.e. the only block it contains. */ +MlirBlock mlirModuleGetBody(MlirModule module); + /** Checks whether a module is null. */ static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; } diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2234,6 +2234,16 @@ .releaseObject(); }, "Accesses the module as an operation") + .def_property_readonly( + "body", + [](PyModule &self) { + PyOperationRef module_op = PyOperation::forOperation( + self.getContext(), mlirModuleGetOperation(self.get()), + self.getRef().releaseObject()); + PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); + return returnBlock; + }, + "Return the block for this module") .def( "dump", [](PyModule &self) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -148,6 +148,10 @@ return wrap(unwrap(module).getContext()); } +MlirBlock mlirModuleGetBody(MlirModule module) { + return wrap(unwrap(module).getBody()); +} + void mlirModuleDestroy(MlirModule module) { // Transfer ownership to an OwningModuleRef so that its destructor is called. OwningModuleRef(unwrap(module)); diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py --- a/mlir/test/Bindings/Python/dialects.py +++ b/mlir/test/Bindings/Python/dialects.py @@ -73,7 +73,7 @@ f32 = mlir.ir.F32Type.get(ctx) loc = ctx.get_unknown_location() m = ctx.create_module(loc) - m_block = m.operation.regions[0].blocks[0] + m_block = m.body # TODO: Remove integer insertion in favor of InsertionPoint and/or op-based. ip = [0] def createInput(): diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -67,9 +67,7 @@ MlirModule makeAdd(MlirContext ctx, MlirLocation location) { MlirModule moduleOp = mlirModuleCreateEmpty(location); - MlirOperation module = mlirModuleGetOperation(moduleOp); - MlirRegion moduleBodyRegion = mlirOperationGetRegion(module, 0); - MlirBlock moduleBody = mlirRegionGetFirstBlock(moduleBodyRegion); + MlirBlock moduleBody = mlirModuleGetBody(moduleOp); MlirType memrefType = mlirTypeParseGet(ctx, "memref"); MlirType funcBodyArgTypes[] = {memrefType, memrefType};