diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -447,6 +447,10 @@ /// Checks whether a region is null. static inline bool mlirRegionIsNull(MlirRegion region) { return !region.ptr; } +/// Checks whether two region handles point to the same region. This does not +/// perform deep comparison. +MLIR_CAPI_EXPORTED bool mlirRegionEqual(MlirRegion region, MlirRegion other); + /// Gets the first block in the region. MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region); @@ -496,6 +500,9 @@ /// Returns the closest surrounding operation that contains this block. MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock); +/// Returns the region that contains this block. +MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block); + /// Returns the block immediately following the given block in its parent /// region. MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -969,7 +969,6 @@ } // Unpack/validate successors. if (successors) { - llvm::SmallVector mlirSuccessors; mlirSuccessors.reserve(successors->size()); for (auto *successor : *successors) { // TODO: Verify successor originate from the same context. @@ -1003,9 +1002,10 @@ mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), mlirNamedAttributes.data()); } - if (!mlirSuccessors.empty()) + if (!mlirSuccessors.empty()) { mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), mlirSuccessors.data()); + } if (regions) { llvm::SmallVector mlirRegions; mlirRegions.resize(regions); @@ -2206,6 +2206,12 @@ return self.getParentOperation()->createOpView(); }, "Returns the owning operation of this block.") + .def_property_readonly( + "region", + [](PyBlock &self) { + MlirRegion region = mlirBlockGetParentRegion(self.get()); + return PyRegion(self.getParentOperation(), region); + }) .def_property_readonly( "arguments", [](PyBlock &self) { @@ -2218,6 +2224,36 @@ return PyOperationList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of operations.") + .def("create_before", + [](PyBlock &self, py::args pyArgTypes) { + self.checkValid(); + llvm::SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + } + + MlirBlock block = + mlirBlockCreate(argTypes.size(), argTypes.data()); + MlirRegion region = mlirBlockGetParentRegion(self.get()); + mlirRegionInsertOwnedBlockBefore(region, self.get(), block); + return PyBlock(self.getParentOperation(), block); + }) + .def("create_after", + [](PyBlock &self, py::args pyArgTypes) { + self.checkValid(); + llvm::SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + } + + MlirBlock block = + mlirBlockCreate(argTypes.size(), argTypes.data()); + MlirRegion region = mlirBlockGetParentRegion(self.get()); + mlirRegionInsertOwnedBlockAfter(region, self.get(), block); + return PyBlock(self.getParentOperation(), block); + }) .def( "__iter__", [](PyBlock &self) { @@ -2270,7 +2306,9 @@ .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, py::arg("block"), "Inserts before the block terminator.") .def("insert", &PyInsertionPoint::insert, py::arg("operation"), - "Inserts an operation."); + "Inserts an operation.") + .def_property_readonly( + "block", [](PyInsertionPoint &self) { return self.getBlock(); }); //---------------------------------------------------------------------------- // Mapping of PyAttribute. diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -427,6 +427,10 @@ MlirRegion mlirRegionCreate() { return wrap(new Region); } +bool mlirRegionEqual(MlirRegion region, MlirRegion other) { + return unwrap(region) == unwrap(other); +} + MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { Region *cppRegion = unwrap(region); if (cppRegion->empty()) @@ -492,6 +496,10 @@ return wrap(unwrap(block)->getParentOp()); } +MlirRegion mlirBlockGetParentRegion(MlirBlock block) { + return wrap(unwrap(block)->getParent()); +} + MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { return wrap(unwrap(block)->getNextNode()); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -323,13 +323,20 @@ assert(mlirModuleIsNull(mlirModuleFromOperation(operation))); // Verify that parent operation and block report correctly. + // CHECK: Parent operation eq: 1 fprintf(stderr, "Parent operation eq: %d\n", mlirOperationEqual(mlirOperationGetParentOperation(operation), parentOperation)); + // CHECK: Block eq: 1 fprintf(stderr, "Block eq: %d\n", mlirBlockEqual(mlirOperationGetBlock(operation), block)); - // CHECK: Parent operation eq: 1 - // CHECK: Block eq: 1 + // CHECK: Block parent operation eq: 1 + fprintf( + stderr, "Block parent operation eq: %d\n", + mlirOperationEqual(mlirBlockGetParentOperation(block), parentOperation)); + // CHECK: Block parent region eq: 1 + fprintf(stderr, "Block parent region eq: %d\n", + mlirRegionEqual(mlirBlockGetParentRegion(block), region)); // In the module we created, the first operation of the first function is // an "memref.dim", which has an attribute and a single result that we can @@ -441,7 +448,8 @@ operation, mlirStringRefCreateFromCString("elts"), mlirDenseElementsAttrInt32Get( mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32), - mlirAttributeGetNull()), 4, eltsData)); + mlirAttributeGetNull()), + 4, eltsData)); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2); mlirOpPrintingFlagsPrintGenericOpForm(flags); @@ -909,25 +917,25 @@ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding), 2, ints8); MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( - mlirRankedTensorTypeGet(2, shape, - mlirIntegerTypeUnsignedGet(ctx, 32), encoding), + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32), + encoding), 2, uints32); MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding), 2, ints32); MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get( - mlirRankedTensorTypeGet(2, shape, - mlirIntegerTypeUnsignedGet(ctx, 64), encoding), + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64), + encoding), 2, uints64); MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get( mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding), 2, ints64); MlirAttribute floatElements = mlirDenseElementsAttrFloatGet( - mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), - 2, floats); + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 2, + floats); MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet( - mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), - 2, doubles); + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2, + doubles); if (!mlirAttributeIsADenseElements(boolElements) || !mlirAttributeIsADenseElements(uint8Elements) || @@ -1084,8 +1092,8 @@ mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding), 2, indices); MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet( - mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding), - 2, floats); + mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding), 2, + floats); MlirAttribute sparseAttr = mlirSparseElementsAttribute( mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), indicesAttr, valuesAttr); @@ -1635,11 +1643,12 @@ mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("std")); MlirLocation loc = mlirLocationUnknownGet(ctx); MlirType indexType = mlirIndexTypeGet(ctx); - MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value"); + MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value"); MlirAttribute indexZeroLiteral = mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index")); - MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral); + MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( + mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral); MlirOperationState constZeroState = mlirOperationStateGet( mlirStringRefCreateFromCString("std.constant"), loc); mlirOperationStateAddResults(&constZeroState, 1, &indexType); diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -27,9 +27,10 @@ // CHECK: operands.append(variadic1) // CHECK: operands.append(non_variadic) // CHECK: if variadic2 is not None: operands.append(variadic2) + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def variadic1(self): @@ -68,9 +69,10 @@ // CHECK: if variadic1 is not None: results.append(variadic1) // CHECK: results.append(non_variadic) // CHECK: if variadic2 is not None: results.append(variadic2) + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def variadic1(self): @@ -112,9 +114,10 @@ // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: attributes["in"] = in_ + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def i32attr(self): @@ -152,9 +155,10 @@ // CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: if is_ is not None: attributes["is"] = is_ + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def in_(self): @@ -177,9 +181,10 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): @@ -195,9 +200,10 @@ // CHECK: operands.append(_gen_arg_0) // CHECK: operands.append(f32) // CHECK: operands.append(_gen_arg_2) + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def f32(self): @@ -226,9 +232,10 @@ // CHECK: attributes = {} // CHECK: operands.append(non_variadic) // CHECK: operands.extend(variadic) + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def non_variadic(self): @@ -253,9 +260,10 @@ // CHECK: attributes = {} // CHECK: results.extend(variadic) // CHECK: results.append(non_variadic) + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def variadic(self): @@ -278,9 +286,10 @@ // CHECK: results = [] // CHECK: attributes = {} // CHECK: operands.append(in_) + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def in_(self): @@ -346,9 +355,10 @@ // CHECK: results.append(f64) // CHECK: operands.append(i32) // CHECK: operands.append(f32) + // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, - // CHECK: loc=loc, ip=ip)) + // CHECK: successors=_ods_successors, loc=loc, ip=ip)) // CHECK: @builtins.property // CHECK: def i32(self): @@ -368,3 +378,15 @@ // CHECK: return self.operation.results[1] let results = (outs I64:$i64, F64:$f64); } + +// CHECK: @_ods_cext.register_operation(_Dialect) +// CHECK: class WithSuccessorsOp(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.with_successors" +def WithSuccessorsOp : TestOp<"with_successors"> { + // CHECK-NOT: _ods_successors = None + // CHECK: _ods_successors = [] + // CHECK-NEXT: _ods_successors.append(successor) + // CHECK-NEXT: _ods_successors.extend(successors) + let successors = (successor AnySuccessor:$successor, + VariadicSuccessor:$successors); +} diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/ir/blocks.py @@ -0,0 +1,50 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +import io +import itertools +from mlir.ir import * +from mlir.dialects import builtin +# Note: std dialect needed for terminators. +from mlir.dialects import std + + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f + + +# CHECK-LABEL: TEST: testBlockCreation +# CHECK: func @test(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16) +# CHECK: br ^bb1(%[[ARG1]] : i16) +# CHECK: ^bb1(%[[PHI0:.*]]: i16): +# CHECK: br ^bb2(%[[ARG0]] : i32) +# CHECK: ^bb2(%[[PHI1:.*]]: i32): +# CHECK: return +@run +def testBlockCreation(): + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f_type = FunctionType.get( + [IntegerType.get_signless(32), + IntegerType.get_signless(16)], []) + f_op = builtin.FuncOp("test", f_type) + entry_block = f_op.add_entry_block() + i32_arg, i16_arg = entry_block.arguments + successor_block = entry_block.create_after(i32_arg.type) + with InsertionPoint(successor_block): + std.ReturnOp([]) + middle_block = successor_block.create_before(i16_arg.type) + + with InsertionPoint(entry_block): + std.BranchOp([i16_arg], dest=middle_block) + + with InsertionPoint(middle_block): + std.BranchOp([i32_arg], dest=successor_block) + print(module.operation) + # Ensure region back references are coherent. + assert entry_block.region == middle_block.region == successor_block.region diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -475,7 +475,8 @@ /// Template for the default auto-generated builder. /// {0} is a comma-separated list of builder arguments, including the trailing /// `loc` and `ip`; -/// {1} is the code populating `operands`, `results` and `attributes` fields. +/// {1} is the code populating `operands`, `results` and `attributes`, +/// `successors` fields. constexpr const char *initTemplate = R"Py( def __init__(self, {0}): operands = [] @@ -484,7 +485,7 @@ {1} super().__init__(self.build_generic( attributes=attributes, results=results, operands=operands, - loc=loc, ip=ip)) + successors=_ods_successors, loc=loc, ip=ip)) )Py"; /// Template for appending a single element to the operand/result list. @@ -518,6 +519,16 @@ R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( _ods_get_default_loc_context(loc)))Py"; +/// Template to initialize the successors list in the builder if there are any +/// successors. +/// {0} is the value to initialize the successors list to. +constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; + +/// Template to append or extend the list of successors in the builder. +/// {0} is the list method ('append' or 'extend'); +/// {1} is the value to add. +constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; + /// Populates `builderArgs` with the Python-compatible names of builder function /// arguments, first the results, then the intermixed attributes and operands in /// the same order as they appear in the `arguments` field of the op definition. @@ -526,7 +537,8 @@ static void populateBuilderArgs(const Operator &op, llvm::SmallVectorImpl &builderArgs, - llvm::SmallVectorImpl &operandNames) { + llvm::SmallVectorImpl &operandNames, + llvm::SmallVectorImpl &successorArgNames) { for (int i = 0, e = op.getNumResults(); i < e; ++i) { std::string name = op.getResultName(i).str(); if (name.empty()) { @@ -550,6 +562,16 @@ if (!op.getArg(i).is()) operandNames.push_back(name); } + + for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { + NamedSuccessor successor = op.getSuccessor(i); + std::string name = std::string(successor.name); + if (name.empty()) + name = llvm::formatv("_gen_successor_{0}", i); + name = sanitizeName(name); + builderArgs.push_back(name); + successorArgNames.push_back(name); + } } /// Populates `builderLines` with additional lines that are required in the @@ -581,6 +603,27 @@ } } +/// Populates `builderLines` with additional lines that are required in the +/// builder to set up successors. successorArgNames is expected to correspond +/// to the Python argument name for each successor on the op. +static void populateBuilderLinesSuccessors( + const Operator &op, llvm::ArrayRef successorArgNames, + llvm::SmallVectorImpl &builderLines) { + if (successorArgNames.empty()) { + builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); + return; + } + + builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); + for (int i = 0, e = successorArgNames.size(); i < e; ++i) { + auto &argName = successorArgNames[i]; + const NamedSuccessor &successor = op.getSuccessor(i); + builderLines.push_back( + llvm::formatv(addSuccessorTemplate, + successor.isVariadic() ? "extend" : "append", argName)); + } +} + /// Populates `builderLines` with additional lines that are required in the /// builder. `kind` must be either "operand" or "result". `names` contains the /// names of init arguments that correspond to the elements. @@ -629,12 +672,14 @@ if (op.skipDefaultBuilders()) return; - llvm::SmallVector builderArgs; - llvm::SmallVector builderLines; - llvm::SmallVector operandArgNames; + llvm::SmallVector builderArgs; + llvm::SmallVector builderLines; + llvm::SmallVector operandArgNames; + llvm::SmallVector successorArgNames; builderArgs.reserve(op.getNumOperands() + op.getNumResults() + - op.getNumNativeAttributes()); - populateBuilderArgs(op, builderArgs, operandArgNames); + op.getNumNativeAttributes() + op.getNumSuccessors()); + populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); + populateBuilderLines( op, "result", llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), @@ -644,6 +689,7 @@ populateBuilderLinesAttr( op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), builderLines); + populateBuilderLinesSuccessors(op, successorArgNames, builderLines); builderArgs.push_back("*"); builderArgs.push_back("loc=None");