diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -110,47 +110,6 @@ from _mlir import * ``` -### Limited use of globals - -For normal operations, parent-child constructor relationships are realized with -constructor methods on a parent class as opposed to requiring -invocation/creation from a global symbol. - -For example, consider two code fragments: - -```python - -op = build_my_op() - -region = mlir.Region(op) - -``` - -vs - -```python - -op = build_my_op() - -region = op.new_region() - -``` - -For tightly coupled data structures like `Operation`, the latter is generally -preferred because: - -* It is syntactically less possible to create something that is going to access - illegal memory (less error handling in the bindings, less testing, etc). - -* It reduces the global-API surface area for creating related entities. This - makes it more likely that if constructing IR based on an Operation instance of - unknown providence, receiving code can just call methods on it to do what they - want versus needing to reach back into the global namespace and find the right - `Region` class. - -* It leaks fewer things that are in place for C++ convenience (i.e. default - constructors to invalid instances). - ### Use the C-API The Python APIs should seek to layer on top of the C-API to the degree possible. @@ -171,6 +130,20 @@ All other objects are dependent. All objects maintain a back-reference (keep-alive) to their closest containing top-level object. Further, dependent objects fall into two categories: a) uniqued (which live for the life-time of the context) and b) mutable. Mutable objects need additional machinery for keeping track of when the C++ instance that backs their Python object is no longer valid (typically due to some specific mutation of the IR, deletion, or bulk operation). +### Optionality and argument ordering in the Core IR + +The following types support being bound to the current thread as a context manager: + +* `PyLocation` (`loc: mlir.ir.Location = None`) +* `PyInsertionPoint` (`ip: mlir.ir.InsertionPoint = None`) +* `PyMlirContext` (`context: mlir.ir.Context = None`) + +In order to support composability of function arguments, when these types appear as arguments, they should always be the last and appear in the above order and with the given names (which is generally the order in which they are expected to need to be expressed explicitly in special cases) as necessary. Each should carry a default value of `py::none()` and use either a manual or automatic conversion for resolving either with the explicit value or a value from the thread context manager (i.e. `DefaultingPyMlirContext` or `DefaultingPyLocation`). + +The rationale for this is that unlike C++, where arguments to the *left* are the most composable via templating, in Python, trailing keyword arguments to the *right* are the most composable, enabling a variety of strategies such as kwarg passthrough, default values, etc. Keeping function signatures composable increases the chances that interesting DSLs and higher level APIs can be constructed without a lot of exotic boilerplate. + +Used consistently, this enables a style of IR construction that rarely needs to use explicit contexts, locations, or insertion points but is free to do so when extra control is needed. + #### Operation hierarchy As mentioned above, `PyOperation` is special because it can exist in either a top-level or dependent state. The life-cycle is unidirectional: operations can be created detached (top-level) and once added to another operation, they are then dependent for the remainder of their lifetime. The situation is more complicated when considering construction scenarios where an operation is added to a transitive parent that is still detached, necessitating further accounting at such transition points (i.e. all such added children are initially added to the IR with a parent of their outer-most detached operation, but then once it is added to an attached operation, they need to be re-parented to the containing module). diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -11,7 +11,7 @@ #include -#include +#include "PybindUtils.h" #include "mlir-c/IR.h" #include "llvm/ADT/DenseMap.h" @@ -22,7 +22,9 @@ class PyBlock; class PyInsertionPoint; class PyLocation; +class DefaultingPyLocation; class PyMlirContext; +class DefaultingPyMlirContext; class PyModule; class PyOperation; class PyType; @@ -81,43 +83,65 @@ }; /// Tracks an entry in the thread context stack. New entries are pushed onto -/// here for each with block that activates a new InsertionPoint or Context. -/// Pushing either a context or an insertion point resets the other: -/// - a new context activates a new entry with a null insertion point. -/// - a new insertion point activates a new entry with the context that the -/// insertion point is bound to. +/// here for each with block that activates a new InsertionPoint, Context or +/// Location. +/// +/// Pushing either a Location or InsertionPoint also pushes its associated +/// Context. Pushing a Context will not modify the Location or InsertionPoint +/// unless if they are from a different context, in which case, they are +/// cleared. class PyThreadContextEntry { public: - PyThreadContextEntry(pybind11::object context, - pybind11::object insertionPoint) - : context(std::move(context)), insertionPoint(std::move(insertionPoint)) { - } + enum class FrameKind { + kContext, + kInsertionPoint, + kLocation, + }; + + PyThreadContextEntry(FrameKind frameKind, pybind11::object context, + pybind11::object insertionPoint, + pybind11::object location) + : context(std::move(context)), insertionPoint(std::move(insertionPoint)), + location(std::move(location)), frameKind(frameKind) {} /// Gets the top of stack context and return nullptr if not defined. - /// If required is true and there is no default, a nice user-facing exception - /// is raised. - static PyMlirContext *getDefaultContext(bool required); + static PyMlirContext *getDefaultContext(); /// Gets the top of stack insertion point and return nullptr if not defined. - /// If required is true and there is no default, a nice user-facing exception - /// is raised. - static PyInsertionPoint *getDefaultInsertionPoint(bool required); + static PyInsertionPoint *getDefaultInsertionPoint(); + + /// Gets the top of stack location and returns nullptr if not defined. + static PyLocation *getDefaultLocation(); PyMlirContext *getContext(); PyInsertionPoint *getInsertionPoint(); + PyLocation *getLocation(); + FrameKind getFrameKind() { return frameKind; } /// Stack management. static PyThreadContextEntry *getTos(); - static void push(pybind11::object context, pybind11::object insertionPoint); + static pybind11::object pushContext(PyMlirContext &context); + static void popContext(PyMlirContext &context); + static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); + static void popInsertionPoint(PyInsertionPoint &insertionPoint); + static pybind11::object pushLocation(PyLocation &location); + static void popLocation(PyLocation &location); /// Gets the thread local stack. static std::vector &getStack(); private: + static void push(FrameKind frameKind, pybind11::object context, + pybind11::object insertionPoint, pybind11::object location); + /// An object reference to the PyContext. pybind11::object context; /// An object reference to the current insertion point. pybind11::object insertionPoint; + /// An object reference to the current location. + pybind11::object location; + // The kind of push that was performed. + FrameKind frameKind; }; /// Wrapper around MlirContext. @@ -172,14 +196,10 @@ /// Used for testing. size_t getLiveModuleCount(); - /// Creates an operation. See corresponding python docstring. - pybind11::object - createOperation(std::string name, PyLocation location, - llvm::Optional> operands, - llvm::Optional> results, - llvm::Optional attributes, - llvm::Optional> successors, - int regions); + /// Enter and exit the context manager. + pybind11::object contextEnter(); + void contextExit(pybind11::object excType, pybind11::object excVal, + pybind11::object excTb); private: PyMlirContext(MlirContext context); @@ -213,6 +233,17 @@ friend class PyOperation; }; +/// Used in function arguments when None should resolve to the current context +/// manager set instance. +class DefaultingPyMlirContext + : public Defaulting { +public: + using Defaulting::Defaulting; + static constexpr const char kTypeDescription[] = + "[ThreadContextAware] mlir.ir.Context"; + static PyMlirContext &resolve(); +}; + /// Base class for all objects that directly or indirectly depend on an /// MlirContext. The lifetime of the context will extend at least to the /// lifetime of these instances. @@ -275,9 +306,26 @@ public: PyLocation(PyMlirContextRef contextRef, MlirLocation loc) : BaseContextObject(std::move(contextRef)), loc(loc) {} + + /// Enter and exit the context manager. + pybind11::object contextEnter(); + void contextExit(pybind11::object excType, pybind11::object excVal, + pybind11::object excTb); + MlirLocation loc; }; +/// Used in function arguments when None should resolve to the current context +/// manager set instance. +class DefaultingPyLocation + : public Defaulting { +public: + using Defaulting::Defaulting; + static constexpr const char kTypeDescription[] = + "[ThreadContextAware] mlir.ir.Location"; + static PyLocation &resolve(); +}; + /// Wrapper around MlirModule. /// This is the top-level, user-owned object that contains regions/ops/blocks. class PyModule; @@ -376,6 +424,14 @@ /// no parent. PyOperationRef getParentOperation(); + /// Creates an operation. See corresponding python docstring. + static pybind11::object + create(std::string name, llvm::Optional> operands, + llvm::Optional> results, + llvm::Optional attributes, + llvm::Optional> successors, int regions, + DefaultingPyLocation location, pybind11::object ip); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, @@ -478,6 +534,8 @@ void contextExit(pybind11::object excType, pybind11::object excVal, pybind11::object excTb); + PyBlock &getBlock() { return block; } + private: // Trampoline constructor that avoids null initializing members while // looking up parents. @@ -560,4 +618,17 @@ } // namespace python } // namespace mlir +namespace pybind11 { +namespace detail { + +template <> +struct type_caster + : MlirDefaultingCaster {}; +template <> +struct type_caster + : MlirDefaultingCaster {}; + +} // namespace detail +} // namespace pybind11 + #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -28,23 +28,18 @@ // Docstrings (trivial, non-duplicated docstrings are included inline). //------------------------------------------------------------------------------ -static const char kContextCreateOperationDocstring[] = - R"(Creates a new operation. +static const char kContextParseTypeDocstring[] = + R"(Parses the assembly form of a type. -Args: - name: Operation name (e.g. "dialect.operation"). - location: A Location object. - results: Sequence of Type representing op result types. - attributes: Dict of str:Attribute. - successors: List of Block for the operation's successors. - regions: Number of regions to create. +Returns a Type object or raises a ValueError if the type cannot be parsed. -Returns: - A new "detached" Operation object. Detached operations can be added - to blocks, which causes them to become "attached." +See also: https://mlir.llvm.org/docs/LangRef/#type-system )"; -static const char kContextParseDocstring[] = +static const char kContextGetFileLocationDocstring[] = + R"(Gets a Location representing a file, line and column)"; + +static const char kModuleParseDocstring[] = R"(Parses a module's assembly format from a string. Returns a new MlirModule or raises a ValueError if the parsing fails. @@ -52,20 +47,24 @@ See also: https://mlir.llvm.org/docs/LangRef/ )"; -static const char kContextParseTypeDocstring[] = - R"(Parses the assembly form of a type. - -Returns a Type object or raises a ValueError if the type cannot be parsed. +static const char kOperationCreateDocstring[] = + R"(Creates a new operation. -See also: https://mlir.llvm.org/docs/LangRef/#type-system +Args: + name: Operation name (e.g. "dialect.operation"). + results: Sequence of Type representing op result types. + attributes: Dict of str:Attribute. + successors: List of Block for the operation's successors. + regions: Number of regions to create. + location: A Location object (defaults to resolve from context manager). + ip: An InsertionPoint (defaults to resolve from context manager or set to + False to disable insertion, even with an insertion point set in the + context manager). +Returns: + A new "detached" Operation object. Detached operations can be added + to blocks, which causes them to become "attached." )"; -static const char kContextGetUnknownLocationDocstring[] = - R"(Gets a Location representing an unknown location)"; - -static const char kContextGetFileLocationDocstring[] = - R"(Gets a Location representing a file, line and column)"; - static const char kOperationPrintDocstring[] = R"(Prints the assembly form of the operation to a file like object. @@ -545,108 +544,26 @@ size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } -py::object PyMlirContext::createOperation( - std::string name, PyLocation location, - llvm::Optional> operands, - llvm::Optional> results, - llvm::Optional attributes, - llvm::Optional> successors, int regions) { - llvm::SmallVector mlirOperands; - llvm::SmallVector mlirResults; - llvm::SmallVector mlirSuccessors; - llvm::SmallVector, 4> mlirAttributes; - - // General parameter validation. - if (regions < 0) - throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); - - // Unpack/validate operands. - if (operands) { - mlirOperands.reserve(operands->size()); - for (PyValue *operand : *operands) { - if (!operand) - throw SetPyError(PyExc_ValueError, "operand value cannot be None"); - mlirOperands.push_back(operand->get()); - } - } - - // Unpack/validate results. - if (results) { - mlirResults.reserve(results->size()); - for (PyType *result : *results) { - // TODO: Verify result type originate from the same context. - if (!result) - throw SetPyError(PyExc_ValueError, "result type cannot be None"); - mlirResults.push_back(result->type); - } - } - // Unpack/validate attributes. - if (attributes) { - mlirAttributes.reserve(attributes->size()); - for (auto &it : *attributes) { +pybind11::object PyMlirContext::contextEnter() { + return PyThreadContextEntry::pushContext(*this); +} - auto name = it.first.cast(); - auto &attribute = it.second.cast(); - // TODO: Verify attribute originates from the same context. - mlirAttributes.emplace_back(std::move(name), attribute.attr); - } - } - // Unpack/validate successors. - if (successors) { - llvm::SmallVector mlirSuccessors; - mlirSuccessors.reserve(successors->size()); - for (auto *successor : *successors) { - // TODO: Verify successor originate from the same context. - if (!successor) - throw SetPyError(PyExc_ValueError, "successor block cannot be None"); - mlirSuccessors.push_back(successor->get()); - } - } +void PyMlirContext::contextExit(pybind11::object excType, + pybind11::object excVal, + pybind11::object excTb) { + PyThreadContextEntry::popContext(*this); +} - // Apply unpacked/validated to the operation state. Beyond this - // point, exceptions cannot be thrown or else the state will leak. - MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc); - if (!mlirOperands.empty()) - mlirOperationStateAddOperands(&state, mlirOperands.size(), - mlirOperands.data()); - if (!mlirResults.empty()) - mlirOperationStateAddResults(&state, mlirResults.size(), - mlirResults.data()); - if (!mlirAttributes.empty()) { - // Note that the attribute names directly reference bytes in - // mlirAttributes, so that vector must not be changed from here - // on. - llvm::SmallVector mlirNamedAttributes; - mlirNamedAttributes.reserve(mlirAttributes.size()); - for (auto &it : mlirAttributes) - mlirNamedAttributes.push_back( - mlirNamedAttributeGet(it.first.c_str(), it.second)); - mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), - mlirNamedAttributes.data()); - } - if (!mlirSuccessors.empty()) - mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), - mlirSuccessors.data()); - if (regions) { - llvm::SmallVector mlirRegions; - mlirRegions.resize(regions); - for (int i = 0; i < regions; ++i) - mlirRegions[i] = mlirRegionCreate(); - mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), - mlirRegions.data()); +PyMlirContext &DefaultingPyMlirContext::resolve() { + PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); + if (!context) { + throw SetPyError( + PyExc_RuntimeError, + "An MLIR function requires a Context but none was provided in the call " + "or from the surrounding environment. Either pass to the function with " + "a 'context=' argument or establish a default using 'with Context():'"); } - - // Construct the operation. - MlirOperation operation = mlirOperationCreate(&state); - PyOperationRef created = PyOperation::createDetached(getRef(), operation); - - // InsertPoint active? - PyInsertionPoint *ip = - PyThreadContextEntry::getDefaultInsertionPoint(/*required=*/false); - if (ip) - ip->insert(*created.get()); - - return created.releaseObject(); + return *context; } //------------------------------------------------------------------------------ @@ -665,10 +582,24 @@ return &stack.back(); } -void PyThreadContextEntry::push(pybind11::object context, - pybind11::object insertionPoint) { +void PyThreadContextEntry::push(FrameKind frameKind, py::object context, + py::object insertionPoint, + py::object location) { auto &stack = getStack(); - stack.emplace_back(std::move(context), std::move(insertionPoint)); + stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), + std::move(location)); + + if (stack.size() > 1) { + auto &prev = *(stack.rbegin() + 1); + auto ¤t = stack.back(); + if (current.context.is(prev.context)) { + // Default non-context objects from the previous entry. + if (!current.insertionPoint) + current.insertionPoint = prev.insertionPoint; + if (!current.location) + current.location = prev.location; + } + } } PyMlirContext *PyThreadContextEntry::getContext() { @@ -683,30 +614,87 @@ return py::cast(insertionPoint); } -PyMlirContext *PyThreadContextEntry::getDefaultContext(bool required) { +PyLocation *PyThreadContextEntry::getLocation() { + if (!location) + return nullptr; + return py::cast(location); +} + +PyMlirContext *PyThreadContextEntry::getDefaultContext() { auto *tos = getTos(); - PyMlirContext *context = tos ? tos->getContext() : nullptr; - if (required && !context) { - throw SetPyError( - PyExc_RuntimeError, - "A default context is required for this call but is not provided. " - "Establish a default by surrounding the code with " - "'with context:'"); - } - return context; + return tos ? tos->getContext() : nullptr; +} + +PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { + auto *tos = getTos(); + return tos ? tos->getInsertionPoint() : nullptr; } -PyInsertionPoint * -PyThreadContextEntry::getDefaultInsertionPoint(bool required) { +PyLocation *PyThreadContextEntry::getDefaultLocation() { auto *tos = getTos(); - PyInsertionPoint *ip = tos ? tos->getInsertionPoint() : nullptr; - if (required && !ip) + return tos ? tos->getLocation() : nullptr; +} + +py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { + py::object contextObj = py::cast(context); + push(FrameKind::kContext, /*context=*/contextObj, + /*insertionPoint=*/py::object(), + /*location=*/py::object()); + return contextObj; +} + +void PyThreadContextEntry::popContext(PyMlirContext &context) { + auto &stack = getStack(); + if (stack.empty()) + throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); + auto &tos = stack.back(); + if (tos.frameKind != FrameKind::kContext && tos.getContext() != &context) + throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); + stack.pop_back(); +} + +py::object +PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { + py::object contextObj = + insertionPoint.getBlock().getParentOperation()->getContext().getObject(); + py::object insertionPointObj = py::cast(insertionPoint); + push(FrameKind::kInsertionPoint, + /*context=*/contextObj, + /*insertionPoint=*/insertionPointObj, + /*location=*/py::object()); + return insertionPointObj; +} + +void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { + auto &stack = getStack(); + if (stack.empty()) + throw SetPyError(PyExc_RuntimeError, + "Unbalanced InsertionPoint enter/exit"); + auto &tos = stack.back(); + if (tos.frameKind != FrameKind::kInsertionPoint && + tos.getInsertionPoint() != &insertionPoint) throw SetPyError(PyExc_RuntimeError, - "A default insertion point is required for this call but " - "is not provided. " - "Establish a default by surrounding the code with " - "'with InsertionPoint(...):'"); - return ip; + "Unbalanced InsertionPoint enter/exit"); + stack.pop_back(); +} + +py::object PyThreadContextEntry::pushLocation(PyLocation &location) { + py::object contextObj = location.getContext().getObject(); + py::object locationObj = py::cast(location); + push(FrameKind::kLocation, /*context=*/contextObj, + /*insertionPoint=*/py::object(), + /*location=*/locationObj); + return locationObj; +} + +void PyThreadContextEntry::popLocation(PyLocation &location) { + auto &stack = getStack(); + if (stack.empty()) + throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); + auto &tos = stack.back(); + if (tos.frameKind != FrameKind::kLocation && tos.getLocation() != &location) + throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); + stack.pop_back(); } //------------------------------------------------------------------------------ @@ -727,6 +715,32 @@ return dialect; } +//------------------------------------------------------------------------------ +// PyLocation +//------------------------------------------------------------------------------ + +py::object PyLocation::contextEnter() { + return PyThreadContextEntry::pushLocation(*this); +} + +void PyLocation::contextExit(py::object excType, py::object excVal, + py::object excTb) { + PyThreadContextEntry::popLocation(*this); +} + +PyLocation &DefaultingPyLocation::resolve() { + auto *location = PyThreadContextEntry::getDefaultLocation(); + if (!location) { + throw SetPyError( + PyExc_RuntimeError, + "An MLIR function requires a Location but none was provided in the " + + "call or from the surrounding environment. Either pass to the function " + "with a 'loc=' argument or establish a default using 'with loc:'"); + } + return *location; +} + //------------------------------------------------------------------------------ // PyModule //------------------------------------------------------------------------------ @@ -911,6 +925,117 @@ return PyBlock{std::move(parentOperation), block}; } +py::object PyOperation::create( + std::string name, llvm::Optional> operands, + llvm::Optional> results, + llvm::Optional attributes, + llvm::Optional> successors, int regions, + DefaultingPyLocation location, py::object maybeIp) { + llvm::SmallVector mlirOperands; + llvm::SmallVector mlirResults; + llvm::SmallVector mlirSuccessors; + llvm::SmallVector, 4> mlirAttributes; + + // General parameter validation. + if (regions < 0) + throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); + + // Unpack/validate operands. + if (operands) { + mlirOperands.reserve(operands->size()); + for (PyValue *operand : *operands) { + if (!operand) + throw SetPyError(PyExc_ValueError, "operand value cannot be None"); + mlirOperands.push_back(operand->get()); + } + } + + // Unpack/validate results. + if (results) { + mlirResults.reserve(results->size()); + for (PyType *result : *results) { + // TODO: Verify result type originate from the same context. + if (!result) + throw SetPyError(PyExc_ValueError, "result type cannot be None"); + mlirResults.push_back(result->type); + } + } + // Unpack/validate attributes. + if (attributes) { + mlirAttributes.reserve(attributes->size()); + for (auto &it : *attributes) { + + auto name = it.first.cast(); + auto &attribute = it.second.cast(); + // TODO: Verify attribute originates from the same context. + mlirAttributes.emplace_back(std::move(name), attribute.attr); + } + } + // Unpack/validate successors. + if (successors) { + llvm::SmallVector mlirSuccessors; + mlirSuccessors.reserve(successors->size()); + for (auto *successor : *successors) { + // TODO: Verify successor originate from the same context. + if (!successor) + throw SetPyError(PyExc_ValueError, "successor block cannot be None"); + mlirSuccessors.push_back(successor->get()); + } + } + + // Apply unpacked/validated to the operation state. Beyond this + // point, exceptions cannot be thrown or else the state will leak. + MlirOperationState state = mlirOperationStateGet(name.c_str(), location->loc); + if (!mlirOperands.empty()) + mlirOperationStateAddOperands(&state, mlirOperands.size(), + mlirOperands.data()); + if (!mlirResults.empty()) + mlirOperationStateAddResults(&state, mlirResults.size(), + mlirResults.data()); + if (!mlirAttributes.empty()) { + // Note that the attribute names directly reference bytes in + // mlirAttributes, so that vector must not be changed from here + // on. + llvm::SmallVector mlirNamedAttributes; + mlirNamedAttributes.reserve(mlirAttributes.size()); + for (auto &it : mlirAttributes) + mlirNamedAttributes.push_back( + mlirNamedAttributeGet(it.first.c_str(), it.second)); + mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + } + if (!mlirSuccessors.empty()) + mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), + mlirSuccessors.data()); + if (regions) { + llvm::SmallVector mlirRegions; + mlirRegions.resize(regions); + for (int i = 0; i < regions; ++i) + mlirRegions[i] = mlirRegionCreate(); + mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), + mlirRegions.data()); + } + + // Construct the operation. + MlirOperation operation = mlirOperationCreate(&state); + PyOperationRef created = + PyOperation::createDetached(location->getContext(), operation); + + // InsertPoint active? + if (!maybeIp.is(py::cast(false))) { + PyInsertionPoint *ip; + if (maybeIp.is_none()) { + ip = PyThreadContextEntry::getDefaultInsertionPoint(); + } else { + ip = py::cast(maybeIp); + } + if (ip) + ip->insert(*created.get()); + } + + return created.releaseObject(); +} + PyOpView::PyOpView(py::object operation) : operationObject(std::move(operation)), operation(py::cast(this->operationObject)) {} @@ -998,26 +1123,13 @@ } py::object PyInsertionPoint::contextEnter() { - auto context = block.getParentOperation()->getContext().getObject(); - py::object self = py::cast(this); - PyThreadContextEntry::push(/*context=*/std::move(context), - /*insertionPoint=*/self); - return self; + return PyThreadContextEntry::pushInsertionPoint(*this); } void PyInsertionPoint::contextExit(pybind11::object excType, pybind11::object excVal, pybind11::object excTb) { - auto &stack = PyThreadContextEntry::getStack(); - if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, - "Unbalanced insertion point enter/exit"); - auto &tos = stack.back(); - PyInsertionPoint *current = tos.getInsertionPoint(); - if (current != this) - throw SetPyError(PyExc_RuntimeError, - "Unbalanced insertion point enter/exit"); - stack.pop_back(); + PyThreadContextEntry::popInsertionPoint(*this); } //------------------------------------------------------------------------------ @@ -1299,10 +1411,9 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - // TODO: Make the location optional and create a default location. - [](PyType &type, double value, PyLocation &loc) { + [](PyType &type, double value, DefaultingPyLocation loc) { MlirAttribute attr = - mlirFloatAttrDoubleGetChecked(type.type, value, loc.loc); + mlirFloatAttrDoubleGetChecked(type.type, value, loc->loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirAttributeIsNull(attr)) { @@ -1313,25 +1424,25 @@ } return PyFloatAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), py::arg("loc"), + py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), "Gets an uniqued float point attribute associated to a type"); c.def_static( "get_f32", - [](PyMlirContext &context, double value) { + [](double value, DefaultingPyMlirContext context) { MlirAttribute attr = mlirFloatAttrDoubleGet( - context.get(), mlirF32TypeGet(context.get()), value); - return PyFloatAttribute(context.getRef(), attr); + context->get(), mlirF32TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); }, - py::arg("context"), py::arg("value"), + py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued float point attribute associated to a f32 type"); c.def_static( "get_f64", - [](PyMlirContext &context, double value) { + [](double value, DefaultingPyMlirContext context) { MlirAttribute attr = mlirFloatAttrDoubleGet( - context.get(), mlirF64TypeGet(context.get()), value); - return PyFloatAttribute(context.getRef(), attr); + context->get(), mlirF64TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); }, - py::arg("context"), py::arg("value"), + py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued float point attribute associated to a f64 type"); c.def_property_readonly( "value", @@ -1377,11 +1488,12 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context, bool value) { - MlirAttribute attr = mlirBoolAttrGet(context.get(), value); - return PyBoolAttribute(context.getRef(), attr); + [](bool value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirBoolAttrGet(context->get(), value); + return PyBoolAttribute(context->getRef(), attr); }, - py::arg("context"), py::arg("value"), "Gets an uniqued bool attribute"); + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued bool attribute"); c.def_property_readonly( "value", [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); }, @@ -1398,11 +1510,12 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context, std::string value) { + [](std::string value, DefaultingPyMlirContext context) { MlirAttribute attr = - mlirStringAttrGet(context.get(), value.size(), &value[0]); - return PyStringAttribute(context.getRef(), attr); + mlirStringAttrGet(context->get(), value.size(), &value[0]); + return PyStringAttribute(context->getRef(), attr); }, + py::arg("value"), py::arg("context") = py::none(), "Gets a uniqued string attribute"); c.def_static( "get_typed", @@ -1432,9 +1545,9 @@ static constexpr const char *pyClassName = "DenseElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - static PyDenseElementsAttribute getFromBuffer(PyMlirContext &contextWrapper, - py::buffer array, - bool signless) { + static PyDenseElementsAttribute + getFromBuffer(py::buffer array, bool signless, + DefaultingPyMlirContext contextWrapper) { // Request a contiguous view. In exotic cases, this will cause a copy. int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; Py_buffer *view = new Py_buffer(); @@ -1444,21 +1557,21 @@ } py::buffer_info arrayInfo(view); - MlirContext context = contextWrapper.get(); + MlirContext context = contextWrapper->get(); // Switch on the types that can be bulk loaded between the Python and // MLIR-C APIs. if (arrayInfo.format == "f") { // f32 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); return PyDenseElementsAttribute( - contextWrapper.getRef(), + contextWrapper->getRef(), bulkLoad(context, mlirDenseElementsAttrFloatGet, mlirF32TypeGet(context), arrayInfo)); } else if (arrayInfo.format == "d") { // f64 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); return PyDenseElementsAttribute( - contextWrapper.getRef(), + contextWrapper->getRef(), bulkLoad(context, mlirDenseElementsAttrDoubleGet, mlirF64TypeGet(context), arrayInfo)); } else if (arrayInfo.format == "i") { @@ -1466,7 +1579,7 @@ assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) : mlirIntegerTypeSignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper.getRef(), + return PyDenseElementsAttribute(contextWrapper->getRef(), bulkLoad(context, mlirDenseElementsAttrInt32Get, elementType, arrayInfo)); @@ -1475,7 +1588,7 @@ assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) : mlirIntegerTypeUnsignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper.getRef(), + return PyDenseElementsAttribute(contextWrapper->getRef(), bulkLoad(context, mlirDenseElementsAttrUInt32Get, elementType, arrayInfo)); @@ -1484,7 +1597,7 @@ assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) : mlirIntegerTypeSignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper.getRef(), + return PyDenseElementsAttribute(contextWrapper->getRef(), bulkLoad(context, mlirDenseElementsAttrInt64Get, elementType, arrayInfo)); @@ -1493,7 +1606,7 @@ assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) : mlirIntegerTypeUnsignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper.getRef(), + return PyDenseElementsAttribute(contextWrapper->getRef(), bulkLoad(context, mlirDenseElementsAttrUInt64Get, elementType, arrayInfo)); @@ -1540,8 +1653,9 @@ static void bindDerived(ClassTy &c) { c.def_static("get", PyDenseElementsAttribute::getFromBuffer, - py::arg("context"), py::arg("array"), - py::arg("signless") = true, "Gets from a buffer or ndarray") + py::arg("array"), py::arg("signless") = true, + py::arg("context") = py::none(), + "Gets from a buffer or ndarray") .def_static("get_splat", PyDenseElementsAttribute::getSplat, py::arg("shaped_type"), py::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") @@ -1624,24 +1738,27 @@ static void bindDerived(ClassTy &c) { c.def_static( "get_signless", - [](PyMlirContext &context, unsigned width) { - MlirType t = mlirIntegerTypeGet(context.get(), width); - return PyIntegerType(context.getRef(), t); + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeGet(context->get(), width); + return PyIntegerType(context->getRef(), t); }, + py::arg("width"), py::arg("context") = py::none(), "Create a signless integer type"); c.def_static( "get_signed", - [](PyMlirContext &context, unsigned width) { - MlirType t = mlirIntegerTypeSignedGet(context.get(), width); - return PyIntegerType(context.getRef(), t); + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeSignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); }, + py::arg("width"), py::arg("context") = py::none(), "Create a signed integer type"); c.def_static( "get_unsigned", - [](PyMlirContext &context, unsigned width) { - MlirType t = mlirIntegerTypeUnsignedGet(context.get(), width); - return PyIntegerType(context.getRef(), t); + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); }, + py::arg("width"), py::arg("context") = py::none(), "Create an unsigned integer type"); c.def_property_readonly( "width", @@ -1678,11 +1795,11 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context) { - MlirType t = mlirIndexTypeGet(context.get()); - return PyIndexType(context.getRef(), t); + [](DefaultingPyMlirContext context) { + MlirType t = mlirIndexTypeGet(context->get()); + return PyIndexType(context->getRef(), t); }, - "Create a index type."); + py::arg("context") = py::none(), "Create a index type."); } }; @@ -1696,11 +1813,11 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context) { - MlirType t = mlirBF16TypeGet(context.get()); - return PyBF16Type(context.getRef(), t); + [](DefaultingPyMlirContext context) { + MlirType t = mlirBF16TypeGet(context->get()); + return PyBF16Type(context->getRef(), t); }, - "Create a bf16 type."); + py::arg("context") = py::none(), "Create a bf16 type."); } }; @@ -1714,11 +1831,11 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context) { - MlirType t = mlirF16TypeGet(context.get()); - return PyF16Type(context.getRef(), t); + [](DefaultingPyMlirContext context) { + MlirType t = mlirF16TypeGet(context->get()); + return PyF16Type(context->getRef(), t); }, - "Create a f16 type."); + py::arg("context") = py::none(), "Create a f16 type."); } }; @@ -1732,11 +1849,11 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context) { - MlirType t = mlirF32TypeGet(context.get()); - return PyF32Type(context.getRef(), t); + [](DefaultingPyMlirContext context) { + MlirType t = mlirF32TypeGet(context->get()); + return PyF32Type(context->getRef(), t); }, - "Create a f32 type."); + py::arg("context") = py::none(), "Create a f32 type."); } }; @@ -1750,11 +1867,11 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context) { - MlirType t = mlirF64TypeGet(context.get()); - return PyF64Type(context.getRef(), t); + [](DefaultingPyMlirContext context) { + MlirType t = mlirF64TypeGet(context->get()); + return PyF64Type(context->getRef(), t); }, - "Create a f64 type."); + py::arg("context") = py::none(), "Create a f64 type."); } }; @@ -1768,11 +1885,11 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context) { - MlirType t = mlirNoneTypeGet(context.get()); - return PyNoneType(context.getRef(), t); + [](DefaultingPyMlirContext context) { + MlirType t = mlirNoneTypeGet(context->get()); + return PyNoneType(context->getRef(), t); }, - "Create a none type."); + py::arg("context") = py::none(), "Create a none type."); } }; @@ -1892,10 +2009,10 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - // TODO: Make the location optional and create a default location. - [](std::vector shape, PyType &elementType, PyLocation &loc) { + [](std::vector shape, PyType &elementType, + DefaultingPyLocation loc) { MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(), - elementType.type, loc.loc); + elementType.type, loc->loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -1907,6 +2024,7 @@ } return PyVectorType(elementType.getContext(), t); }, + py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), "Create a vector type"); } }; @@ -1922,10 +2040,10 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - // TODO: Make the location optional and create a default location. - [](std::vector shape, PyType &elementType, PyLocation &loc) { + [](std::vector shape, PyType &elementType, + DefaultingPyLocation loc) { MlirType t = mlirRankedTensorTypeGetChecked( - shape.size(), shape.data(), elementType.type, loc.loc); + shape.size(), shape.data(), elementType.type, loc->loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -1939,6 +2057,7 @@ } return PyRankedTensorType(elementType.getContext(), t); }, + py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), "Create a ranked tensor type"); } }; @@ -1954,10 +2073,9 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - // TODO: Make the location optional and create a default location. - [](PyType &elementType, PyLocation &loc) { + [](PyType &elementType, DefaultingPyLocation loc) { MlirType t = - mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc); + mlirUnrankedTensorTypeGetChecked(elementType.type, loc->loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -1971,6 +2089,7 @@ } return PyUnrankedTensorType(elementType.getContext(), t); }, + py::arg("element_type"), py::arg("loc") = py::none(), "Create a unranked tensor type"); } }; @@ -1989,10 +2108,10 @@ "get_contiguous_memref", // TODO: Make the location optional and create a default location. [](PyType &elementType, std::vector shape, - unsigned memorySpace, PyLocation &loc) { + unsigned memorySpace, DefaultingPyLocation loc) { MlirType t = mlirMemRefTypeContiguousGetChecked( elementType.type, shape.size(), shape.data(), memorySpace, - loc.loc); + loc->loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2006,7 +2125,8 @@ } return PyMemRefType(elementType.getContext(), t); }, - "Create a memref type") + py::arg("element_type"), py::arg("shape"), py::arg("memory_space"), + py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly( "num_affine_maps", [](PyMemRefType &self) -> intptr_t { @@ -2034,10 +2154,10 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - // TODO: Make the location optional and create a default location. - [](PyType &elementType, unsigned memorySpace, PyLocation &loc) { + [](PyType &elementType, unsigned memorySpace, + DefaultingPyLocation loc) { MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type, - memorySpace, loc.loc); + memorySpace, loc->loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2051,7 +2171,8 @@ } return PyUnrankedMemRefType(elementType.getContext(), t); }, - "Create a unranked memref type") + py::arg("element_type"), py::arg("memory_space"), + py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", [](PyUnrankedMemRefType &self) -> unsigned { @@ -2071,15 +2192,16 @@ static void bindDerived(ClassTy &c) { c.def_static( "get_tuple", - [](PyMlirContext &context, py::list elementList) { + [](py::list elementList, DefaultingPyMlirContext context) { intptr_t num = py::len(elementList); // Mapping py::list to SmallVector. SmallVector elements; for (auto element : elementList) elements.push_back(element.cast().type); - MlirType t = mlirTupleTypeGet(context.get(), num, elements.data()); - return PyTupleType(context.getRef(), t); + MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); + return PyTupleType(context->getRef(), t); }, + py::arg("elements"), py::arg("context") = py::none(), "Create a tuple type"); c.def( "get_type", @@ -2107,16 +2229,16 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyMlirContext &context, std::vector inputs, - std::vector results) { + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { SmallVector inputsRaw(inputs.begin(), inputs.end()); SmallVector resultsRaw(results.begin(), results.end()); - MlirType t = mlirFunctionTypeGet(context.get(), inputsRaw.size(), + MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), inputsRaw.data(), resultsRaw.size(), resultsRaw.data()); - return PyFunctionType(context.getRef(), t); + return PyFunctionType(context->getRef(), t); }, - py::arg("context"), py::arg("inputs"), py::arg("results"), + py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), "Gets a FunctionType from a list of input and result types"); c.def_property_readonly( "inputs", @@ -2170,6 +2292,17 @@ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def("__enter__", &PyMlirContext::contextEnter) + .def("__exit__", &PyMlirContext::contextExit) + .def_property_readonly_static( + "current", + [](py::object & /*class*/) { + auto *context = PyThreadContextEntry::getDefaultContext(); + if (!context) + throw SetPyError(PyExc_ValueError, "No current Context"); + return context; + }, + "Gets the Context bound to the current thread or raises ValueError") .def_property_readonly( "dialects", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, @@ -2188,7 +2321,8 @@ } return PyDialectDescriptor(self.getRef(), dialect); }, - "Gets or loads a dialect by name, returning its descriptor object") + "Gets or loads a dialect by name, returning its descriptor " + "object") .def_property( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { @@ -2196,79 +2330,7 @@ }, [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); - }) - .def("create_operation", &PyMlirContext::createOperation, py::arg("name"), - py::arg("location"), py::arg("operands") = py::none(), - py::arg("results") = py::none(), py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = 0, - kContextCreateOperationDocstring) - .def( - "parse_module", - [](PyMlirContext &self, const std::string moduleAsm) { - MlirModule module = - mlirModuleCreateParse(self.get(), moduleAsm.c_str()); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirModuleIsNull(module)) { - throw SetPyError( - PyExc_ValueError, - "Unable to parse module assembly (see diagnostics)"); - } - return PyModule::forModule(module).releaseObject(); - }, - kContextParseDocstring) - .def( - "create_module", - [](PyMlirContext &self, PyLocation &loc) { - MlirModule module = mlirModuleCreateEmpty(loc.loc); - return PyModule::forModule(module).releaseObject(); - }, - py::arg("loc"), "Creates an empty module") - .def( - "parse_attr", - [](PyMlirContext &self, std::string attrSpec) { - MlirAttribute type = - mlirAttributeParseGet(self.get(), attrSpec.c_str()); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - llvm::Twine("Unable to parse attribute: '") + - attrSpec + "'"); - } - return PyAttribute(self.getRef(), type); - }, - py::keep_alive<0, 1>()) - .def( - "parse_type", - [](PyMlirContext &self, std::string typeSpec) { - MlirType type = mlirTypeParseGet(self.get(), typeSpec.c_str()); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - llvm::Twine("Unable to parse type: '") + - typeSpec + "'"); - } - return PyType(self.getRef(), type); - }, - kContextParseTypeDocstring) - .def( - "get_unknown_location", - [](PyMlirContext &self) { - return PyLocation(self.getRef(), - mlirLocationUnknownGet(self.get())); - }, - kContextGetUnknownLocationDocstring) - .def( - "get_file_location", - [](PyMlirContext &self, std::string filename, int line, int col) { - return PyLocation(self.getRef(), - mlirLocationFileLineColGet( - self.get(), filename.c_str(), line, col)); - }, - kContextGetFileLocationDocstring, py::arg("filename"), - py::arg("line"), py::arg("col")); + }); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor @@ -2327,6 +2389,35 @@ // Mapping of Location //---------------------------------------------------------------------------- py::class_(m, "Location") + .def("__enter__", &PyLocation::contextEnter) + .def("__exit__", &PyLocation::contextExit) + .def_property_readonly_static( + "current", + [](py::object & /*class*/) { + auto *loc = PyThreadContextEntry::getDefaultLocation(); + if (!loc) + throw SetPyError(PyExc_ValueError, "No current Location"); + return loc; + }, + "Gets the Location bound to the current thread or raises ValueError") + .def_static( + "unknown", + [](DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationUnknownGet(context->get())); + }, + py::arg("context") = py::none(), + "Gets a Location representing an unknown location") + .def_static( + "file", + [](std::string filename, int line, int col, + DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationFileLineColGet( + context->get(), filename.c_str(), line, col)); + }, + py::arg("filename"), py::arg("line"), py::arg("col"), + py::arg("context") = py::none(), kContextGetFileLocationDocstring) .def_property_readonly( "context", [](PyLocation &self) { return self.getContext().getObject(); }, @@ -2344,6 +2435,29 @@ py::class_(m, "Module") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def_static( + "parse", + [](const std::string moduleAsm, DefaultingPyMlirContext context) { + MlirModule module = + mlirModuleCreateParse(context->get(), moduleAsm.c_str()); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirModuleIsNull(module)) { + throw SetPyError( + PyExc_ValueError, + "Unable to parse module assembly (see diagnostics)"); + } + return PyModule::forModule(module).releaseObject(); + }, + py::arg("asm"), py::arg("context") = py::none(), + kModuleParseDocstring) + .def_static( + "create", + [](DefaultingPyLocation loc) { + MlirModule module = mlirModuleCreateEmpty(loc->loc); + return PyModule::forModule(module).releaseObject(); + }, + py::arg("loc") = py::none(), "Creates an empty module") .def_property_readonly( "context", [](PyModule &self) { return self.getContext().getObject(); }, @@ -2388,6 +2502,13 @@ // Mapping of Operation. //---------------------------------------------------------------------------- py::class_(m, "Operation") + .def_static("create", &PyOperation::create, py::arg("name"), + py::arg("operands") = py::none(), + py::arg("results") = py::none(), + py::arg("attributes") = py::none(), + py::arg("successors") = py::none(), py::arg("regions") = 0, + py::arg("loc") = py::none(), py::arg("ip") = py::none(), + kOperationCreateDocstring) .def_property_readonly( "context", [](PyOperation &self) { return self.getContext().getObject(); }, @@ -2520,6 +2641,16 @@ "Inserts after the last operation but still inside the block.") .def("__enter__", &PyInsertionPoint::contextEnter) .def("__exit__", &PyInsertionPoint::contextExit) + .def_property_readonly_static( + "current", + [](py::object & /*class*/) { + auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); + if (!ip) + throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); + return ip; + }, + "Gets the InsertionPoint bound to the current thread or raises " + "ValueError") .def(py::init(), py::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, @@ -2533,6 +2664,22 @@ // Mapping of PyAttribute. //---------------------------------------------------------------------------- py::class_(m, "Attribute") + .def_static( + "parse", + [](std::string attrSpec, DefaultingPyMlirContext context) { + MlirAttribute type = + mlirAttributeParseGet(context->get(), attrSpec.c_str()); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirAttributeIsNull(type)) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Unable to parse attribute: '") + + attrSpec + "'"); + } + return PyAttribute(context->getRef(), type); + }, + py::arg("asm"), py::arg("context") = py::none(), + "Parses an attribute from an assembly form") .def_property_readonly( "context", [](PyAttribute &self) { return self.getContext().getObject(); }, @@ -2628,6 +2775,21 @@ // Mapping of PyType. //---------------------------------------------------------------------------- py::class_(m, "Type") + .def_static( + "parse", + [](std::string typeSpec, DefaultingPyMlirContext context) { + MlirType type = mlirTypeParseGet(context->get(), typeSpec.c_str()); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(type)) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Unable to parse type: '") + + typeSpec + "'"); + } + return PyType(context->getRef(), type); + }, + py::arg("asm"), py::arg("context") = py::none(), + kContextParseTypeDocstring) .def_property_readonly( "context", [](PyType &self) { return self.getContext().getObject(); }, "Context that owns the Type") diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -15,13 +15,6 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/Twine.h" -namespace pybind11 { -namespace detail { -template -struct type_caster> : optional_caster> {}; -} // namespace detail -} // namespace pybind11 - namespace mlir { namespace python { @@ -32,7 +25,78 @@ pybind11::error_already_set SetPyError(PyObject *excClass, const llvm::Twine &message); +/// CRTP template for special wrapper types that are allowed to be passed in as +/// 'None' function arguments and can be resolved by some global mechanic if +/// so. Such types will raise an error if this global resolution fails, and +/// it is actually illegal for them to ever be unresolved. From a user +/// perspective, they behave like a smart ptr to the underlying type (i.e. +/// 'get' method and operator-> overloaded). +/// +/// Derived types must provide a method, which is called when an environmental +/// resolution is required. It must raise an exception if resolution fails: +/// static ReferrentTy &resolve() +/// +/// They must also provide a parameter description that will be used in +/// error messages about mismatched types: +/// static constexpr const char kTypeDescription[] = ""; + +template +class Defaulting { +public: + using ReferrentTy = T; + /// Type casters require the type to be default constructible, but using + /// such an instance is illegal. + Defaulting() = default; + Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} + + ReferrentTy *get() { return referrent; } + ReferrentTy *operator->() { return referrent; } + +private: + ReferrentTy *referrent = nullptr; +}; + } // namespace python } // namespace mlir +namespace pybind11 { +namespace detail { + +template +struct MlirDefaultingCaster { + PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); + + bool load(pybind11::handle src, bool) { + if (src.is_none()) { + // Note that we do want an exception to propagate from here as it will be + // the most informative. + value = DefaultingTy{DefaultingTy::resolve()}; + return true; + } + + // Unlike many casters that chain, these casters are expected to always + // succeed, so instead of doing an isinstance check followed by a cast, + // just cast in one step and handle the exception. Returning false (vs + // letting the exception propagate) causes higher level signature parsing + // code to produce nice error messages (other than "Cannot cast..."). + try { + value = DefaultingTy{ + pybind11::cast(src)}; + return true; + } catch (std::exception &e) { + return false; + } + } + + static handle cast(DefaultingTy src, return_value_policy policy, + handle parent) { + return pybind11::cast(src, policy); + } +}; + +template +struct type_caster> : optional_caster> {}; +} // namespace detail +} // namespace pybind11 + #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H diff --git a/mlir/lib/Bindings/Python/mlir/dialects/std.py b/mlir/lib/Bindings/Python/mlir/dialects/std.py --- a/mlir/lib/Bindings/Python/mlir/dialects/std.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/std.py @@ -5,20 +5,22 @@ # TODO: This file should be auto-generated. from . import _cext +_ir = _cext.ir @_cext.register_dialect -class _Dialect(_cext.ir.Dialect): +class _Dialect(_ir.Dialect): # Special case: 'std' namespace aliases to the empty namespace. DIALECT_NAMESPACE = "std" pass @_cext.register_operation(_Dialect) -class AddFOp(_cext.ir.OpView): +class AddFOp(_ir.OpView): OPERATION_NAME = "std.addf" - def __init__(self, loc, lhs, rhs): - super().__init__(loc.context.create_operation( - "std.addf", loc, operands=[lhs, rhs], results=[lhs.type])) + def __init__(self, lhs, rhs, loc=None, ip=None): + super().__init__(_ir.Operation.create( + "std.addf", operands=[lhs, rhs], results=[lhs.type], + loc=loc, ip=ip)) @property def lhs(self): diff --git a/mlir/test/Bindings/Python/context_managers.py b/mlir/test/Bindings/Python/context_managers.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/context_managers.py @@ -0,0 +1,99 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +# CHECK-LABEL: TEST: testContextEnterExit +def testContextEnterExit(): + with Context() as ctx: + assert Context.current is ctx + try: + _ = Context.current + except ValueError as e: + # CHECK: No current Context + print(e) + else: assert False, "Expected exception" + +run(testContextEnterExit) + + +# CHECK-LABEL: TEST: testLocationEnterExit +def testLocationEnterExit(): + ctx1 = Context() + with Location.unknown(ctx1) as loc1: + assert Context.current is ctx1 + assert Location.current is loc1 + + # Re-asserting the same context should not change the location. + with ctx1: + assert Context.current is ctx1 + assert Location.current is loc1 + # Asserting a different context should clear it. + with Context() as ctx2: + assert Context.current is ctx2 + try: + _ = Location.current + except ValueError: pass + else: assert False, "Expected exception" + + # And should restore. + assert Context.current is ctx1 + assert Location.current is loc1 + + # All should clear. + try: + _ = Location.current + except ValueError as e: + # CHECK: No current Location + print(e) + else: assert False, "Expected exception" + +run(testLocationEnterExit) + + +# CHECK-LABEL: TEST: testInsertionPointEnterExit +def testInsertionPointEnterExit(): + ctx1 = Context() + m = Module.create(Location.unknown(ctx1)) + ip = InsertionPoint.at_block_terminator(m.body) + + with ip: + assert InsertionPoint.current is ip + # Asserting a location from the same context should preserve. + with Location.unknown(ctx1) as loc1: + assert InsertionPoint.current is ip + assert Location.current is loc1 + # Location should clear. + try: + _ = Location.current + except ValueError: pass + else: assert False, "Expected exception" + + # Asserting the same Context should preserve. + with ctx1: + assert InsertionPoint.current is ip + + # Asserting a different context should clear it. + with Context() as ctx2: + assert Context.current is ctx2 + try: + _ = InsertionPoint.current + except ValueError: pass + else: assert False, "Expected exception" + + # All should clear. + try: + _ = InsertionPoint.current + except ValueError as e: + # CHECK: No current InsertionPoint + print(e) + else: assert False, "Expected exception" + +run(testInsertionPointEnterExit) diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py --- a/mlir/test/Bindings/Python/dialects.py +++ b/mlir/test/Bindings/Python/dialects.py @@ -1,18 +1,18 @@ # RUN: %PYTHON %s | FileCheck %s import gc -import mlir +from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() - assert mlir.ir.Context._get_live_count() == 0 + assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testDialectDescriptor def testDialectDescriptor(): - ctx = mlir.ir.Context() + ctx = Context() d = ctx.get_dialect_descriptor("std") # CHECK: print(d) @@ -30,7 +30,7 @@ # CHECK-LABEL: TEST: testUserDialectClass def testUserDialectClass(): - ctx = mlir.ir.Context() + ctx = Context() # Access using attribute. d = ctx.dialects.std # Note that the standard dialect namespace prints as ''. Others will print @@ -68,26 +68,25 @@ # TODO: Op creation and access is still quite verbose: simplify this test as # additional capabilities come online. def testCustomOpView(): - ctx = mlir.ir.Context() - ctx.allow_unregistered_dialects = True - f32 = mlir.ir.F32Type.get(ctx) - loc = ctx.get_unknown_location() - m = ctx.create_module(loc) - def createInput(): - op = ctx.create_operation("pytest_dummy.intinput", loc, results=[f32]) + op = Operation.create("pytest_dummy.intinput", results=[f32]) # TODO: Auto result cast from operation return op.results[0] - with mlir.ir.InsertionPoint.at_block_terminator(m.body): - # Create via dialects context collection. - input1 = createInput() - input2 = createInput() - op1 = ctx.dialects.std.AddFOp(loc, input1, input2) + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + + with InsertionPoint.at_block_terminator(m.body): + f32 = F32Type.get() + # Create via dialects context collection. + input1 = createInput() + input2 = createInput() + op1 = ctx.dialects.std.AddFOp(input1, input2) - # Create via an import - from mlir.dialects.std import AddFOp - AddFOp(loc, input1, op1.result) + # Create via an import + from mlir.dialects.std import AddFOp + AddFOp(input1, op1.result) # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" diff --git a/mlir/test/Bindings/Python/insertion_point.py b/mlir/test/Bindings/Python/insertion_point.py --- a/mlir/test/Bindings/Python/insertion_point.py +++ b/mlir/test/Bindings/Python/insertion_point.py @@ -16,18 +16,18 @@ def test_insert_at_block_end(): ctx = Context() ctx.allow_unregistered_dialects = True - loc = ctx.get_unknown_location() - module = ctx.parse_module(r""" - func @foo() -> () { - "custom.op1"() : () -> () - } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - ip = InsertionPoint(entry_block) - ip.insert(ctx.create_operation("custom.op2", loc)) - # CHECK: "custom.op1" - # CHECK: "custom.op2" - module.operation.print() + with Location.unknown(ctx): + module = Module.parse(r""" + func @foo() -> () { + "custom.op1"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint(entry_block) + ip.insert(Operation.create("custom.op2")) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() run(test_insert_at_block_end) @@ -36,20 +36,20 @@ def test_insert_before_operation(): ctx = Context() ctx.allow_unregistered_dialects = True - loc = ctx.get_unknown_location() - module = ctx.parse_module(r""" - func @foo() -> () { - "custom.op1"() : () -> () - "custom.op2"() : () -> () - } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - ip = InsertionPoint(entry_block.operations[1]) - ip.insert(ctx.create_operation("custom.op3", loc)) - # CHECK: "custom.op1" - # CHECK: "custom.op3" - # CHECK: "custom.op2" - module.operation.print() + with Location.unknown(ctx): + module = Module.parse(r""" + func @foo() -> () { + "custom.op1"() : () -> () + "custom.op2"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint(entry_block.operations[1]) + ip.insert(Operation.create("custom.op3")) + # CHECK: "custom.op1" + # CHECK: "custom.op3" + # CHECK: "custom.op2" + module.operation.print() run(test_insert_before_operation) @@ -58,18 +58,18 @@ def test_insert_at_block_begin(): ctx = Context() ctx.allow_unregistered_dialects = True - loc = ctx.get_unknown_location() - module = ctx.parse_module(r""" - func @foo() -> () { - "custom.op2"() : () -> () - } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - ip = InsertionPoint.at_block_begin(entry_block) - ip.insert(ctx.create_operation("custom.op1", loc)) - # CHECK: "custom.op1" - # CHECK: "custom.op2" - module.operation.print() + with Location.unknown(ctx): + module = Module.parse(r""" + func @foo() -> () { + "custom.op2"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint.at_block_begin(entry_block) + ip.insert(Operation.create("custom.op1")) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() run(test_insert_at_block_begin) @@ -86,19 +86,19 @@ def test_insert_at_terminator(): ctx = Context() ctx.allow_unregistered_dialects = True - loc = ctx.get_unknown_location() - module = ctx.parse_module(r""" - func @foo() -> () { - "custom.op1"() : () -> () - return - } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - ip = InsertionPoint.at_block_terminator(entry_block) - ip.insert(ctx.create_operation("custom.op2", loc)) - # CHECK: "custom.op1" - # CHECK: "custom.op2" - module.operation.print() + with Location.unknown(ctx): + module = Module.parse(r""" + func @foo() -> () { + "custom.op1"() : () -> () + return + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint.at_block_terminator(entry_block) + ip.insert(Operation.create("custom.op2")) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() run(test_insert_at_terminator) @@ -107,20 +107,20 @@ def test_insert_at_block_terminator_missing(): ctx = Context() ctx.allow_unregistered_dialects = True - loc = ctx.get_unknown_location() - module = ctx.parse_module(r""" - func @foo() -> () { - "custom.op1"() : () -> () - } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - try: - ip = InsertionPoint.at_block_terminator(entry_block) - except ValueError as e: - # CHECK: Block has no terminator - print(e) - else: - assert False, "Expected exception" + with ctx: + module = Module.parse(r""" + func @foo() -> () { + "custom.op1"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + try: + ip = InsertionPoint.at_block_terminator(entry_block) + except ValueError as e: + # CHECK: Block has no terminator + print(e) + else: + assert False, "Expected exception" run(test_insert_at_block_terminator_missing) @@ -129,24 +129,24 @@ def test_insertion_point_context(): ctx = Context() ctx.allow_unregistered_dialects = True - loc = ctx.get_unknown_location() - module = ctx.parse_module(r""" - func @foo() -> () { - "custom.op1"() : () -> () - } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - with InsertionPoint(entry_block): - ctx.create_operation("custom.op2", loc) - with InsertionPoint.at_block_begin(entry_block): - ctx.create_operation("custom.opa", loc) - ctx.create_operation("custom.opb", loc) - ctx.create_operation("custom.op3", loc) - # CHECK: "custom.opa" - # CHECK: "custom.opb" - # CHECK: "custom.op1" - # CHECK: "custom.op2" - # CHECK: "custom.op3" - module.operation.print() + with Location.unknown(ctx): + module = Module.parse(r""" + func @foo() -> () { + "custom.op1"() : () -> () + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + with InsertionPoint(entry_block): + Operation.create("custom.op2") + with InsertionPoint.at_block_begin(entry_block): + Operation.create("custom.opa") + Operation.create("custom.opb") + Operation.create("custom.op3") + # CHECK: "custom.opa" + # CHECK: "custom.opb" + # CHECK: "custom.op1" + # CHECK: "custom.op2" + # CHECK: "custom.op3" + module.operation.print() run(test_insertion_point_context) diff --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py --- a/mlir/test/Bindings/Python/ir_array_attributes.py +++ b/mlir/test/Bindings/Python/ir_array_attributes.py @@ -3,27 +3,27 @@ # and we may want to disable if not available. import gc -import mlir +from mlir.ir import * import numpy as np def run(f): print("\nTEST:", f.__name__) f() gc.collect() - assert mlir.ir.Context._get_live_count() == 0 + assert Context._get_live_count() == 0 ################################################################################ # Tests of the array/buffer .get() factory method on unsupported dtype. ################################################################################ def testGetDenseElementsUnsupported(): - ctx = mlir.ir.Context() - array = np.array([["hello", "goodbye"]]) - try: - attr = mlir.ir.DenseElementsAttr.get(ctx, array) - except ValueError as e: - # CHECK: unimplemented array format conversion from format: - print(e) + with Context(): + array = np.array([["hello", "goodbye"]]) + try: + attr = DenseElementsAttr.get(array) + except ValueError as e: + # CHECK: unimplemented array format conversion from format: + print(e) run(testGetDenseElementsUnsupported) @@ -33,63 +33,60 @@ # CHECK-LABEL: TEST: testGetDenseElementsSplatInt def testGetDenseElementsSplatInt(): - ctx = mlir.ir.Context() - loc = ctx.get_unknown_location() - t = mlir.ir.IntegerType.get_signless(ctx, 32) - element = mlir.ir.IntegerAttr.get(t, 555) - shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc) - attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element) - # CHECK: dense<555> : tensor<2x3x4xi32> - print(attr) - # CHECK: is_splat: True - print("is_splat:", attr.is_splat) + with Context(), Location.unknown(): + t = IntegerType.get_signless(32) + element = IntegerAttr.get(t, 555) + shaped_type = RankedTensorType.get((2, 3, 4), t) + attr = DenseElementsAttr.get_splat(shaped_type, element) + # CHECK: dense<555> : tensor<2x3x4xi32> + print(attr) + # CHECK: is_splat: True + print("is_splat:", attr.is_splat) run(testGetDenseElementsSplatInt) # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat def testGetDenseElementsSplatFloat(): - ctx = mlir.ir.Context() - loc = ctx.get_unknown_location() - t = mlir.ir.F32Type.get(ctx) - element = mlir.ir.FloatAttr.get(t, 1.2, loc) - shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc) - attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element) - # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> - print(attr) + with Context(), Location.unknown(): + t = F32Type.get() + element = FloatAttr.get(t, 1.2) + shaped_type = RankedTensorType.get((2, 3, 4), t) + attr = DenseElementsAttr.get_splat(shaped_type, element) + # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> + print(attr) run(testGetDenseElementsSplatFloat) # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors def testGetDenseElementsSplatErrors(): - ctx = mlir.ir.Context() - loc = ctx.get_unknown_location() - t = mlir.ir.F32Type.get(ctx) - other_t = mlir.ir.F64Type.get(ctx) - element = mlir.ir.FloatAttr.get(t, 1.2, loc) - other_element = mlir.ir.FloatAttr.get(other_t, 1.2, loc) - shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc) - dynamic_shaped_type = mlir.ir.UnrankedTensorType.get(t, loc) - non_shaped_type = t - - try: - attr = mlir.ir.DenseElementsAttr.get_splat(non_shaped_type, element) - except ValueError as e: - # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) - print(e) - - try: - attr = mlir.ir.DenseElementsAttr.get_splat(dynamic_shaped_type, element) - except ValueError as e: - # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) - print(e) - - try: - attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, other_element) - except ValueError as e: - # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) - print(e) + with Context(), Location.unknown(): + t = F32Type.get() + other_t = F64Type.get() + element = FloatAttr.get(t, 1.2) + other_element = FloatAttr.get(other_t, 1.2) + shaped_type = RankedTensorType.get((2, 3, 4), t) + dynamic_shaped_type = UnrankedTensorType.get(t) + non_shaped_type = t + + try: + attr = DenseElementsAttr.get_splat(non_shaped_type, element) + except ValueError as e: + # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) + print(e) + + try: + attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) + except ValueError as e: + # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) + print(e) + + try: + attr = DenseElementsAttr.get_splat(shaped_type, other_element) + except ValueError as e: + # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) + print(e) run(testGetDenseElementsSplatErrors) @@ -102,24 +99,24 @@ # CHECK-LABEL: TEST: testGetDenseElementsF32 def testGetDenseElementsF32(): - ctx = mlir.ir.Context() - array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) - attr = mlir.ir.DenseElementsAttr.get(ctx, array) - # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> - print(attr) - # CHECK: is_splat: False - print("is_splat:", attr.is_splat) + with Context(): + array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> + print(attr) + # CHECK: is_splat: False + print("is_splat:", attr.is_splat) run(testGetDenseElementsF32) # CHECK-LABEL: TEST: testGetDenseElementsF64 def testGetDenseElementsF64(): - ctx = mlir.ir.Context() - array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) - attr = mlir.ir.DenseElementsAttr.get(ctx, array) - # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> - print(attr) + with Context(): + array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> + print(attr) run(testGetDenseElementsF64) @@ -127,43 +124,43 @@ ### 32 bit integer arrays # CHECK-LABEL: TEST: testGetDenseElementsI32Signless def testGetDenseElementsI32Signless(): - ctx = mlir.ir.Context() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) - attr = mlir.ir.DenseElementsAttr.get(ctx, array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - print(attr) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> + print(attr) run(testGetDenseElementsI32Signless) # CHECK-LABEL: TEST: testGetDenseElementsUI32Signless def testGetDenseElementsUI32Signless(): - ctx = mlir.ir.Context() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) - attr = mlir.ir.DenseElementsAttr.get(ctx, array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - print(attr) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> + print(attr) run(testGetDenseElementsUI32Signless) # CHECK-LABEL: TEST: testGetDenseElementsI32 def testGetDenseElementsI32(): - ctx = mlir.ir.Context() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) - attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> - print(attr) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> + print(attr) run(testGetDenseElementsI32) # CHECK-LABEL: TEST: testGetDenseElementsUI32 def testGetDenseElementsUI32(): - ctx = mlir.ir.Context() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) - attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> - print(attr) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> + print(attr) run(testGetDenseElementsUI32) @@ -171,43 +168,43 @@ ## 64bit integer arrays # CHECK-LABEL: TEST: testGetDenseElementsI64Signless def testGetDenseElementsI64Signless(): - ctx = mlir.ir.Context() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) - attr = mlir.ir.DenseElementsAttr.get(ctx, array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> - print(attr) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> + print(attr) run(testGetDenseElementsI64Signless) # CHECK-LABEL: TEST: testGetDenseElementsUI64Signless def testGetDenseElementsUI64Signless(): - ctx = mlir.ir.Context() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) - attr = mlir.ir.DenseElementsAttr.get(ctx, array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> - print(attr) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> + print(attr) run(testGetDenseElementsUI64Signless) # CHECK-LABEL: TEST: testGetDenseElementsI64 def testGetDenseElementsI64(): - ctx = mlir.ir.Context() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) - attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> - print(attr) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> + print(attr) run(testGetDenseElementsI64) # CHECK-LABEL: TEST: testGetDenseElementsUI64 def testGetDenseElementsUI64(): - ctx = mlir.ir.Context() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) - attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> - print(attr) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> + print(attr) run(testGetDenseElementsUI64) diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -1,19 +1,19 @@ # RUN: %PYTHON %s | FileCheck %s import gc -import mlir +from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() - assert mlir.ir.Context._get_live_count() == 0 + assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testParsePrint def testParsePrint(): - ctx = mlir.ir.Context() - t = ctx.parse_attr('"hello"') + with Context() as ctx: + t = Attribute.parse('"hello"') assert t.context is ctx ctx = None gc.collect() @@ -29,156 +29,155 @@ # TODO: Hook the diagnostic manager to capture a more meaningful error # message. def testParseError(): - ctx = mlir.ir.Context() - try: - t = ctx.parse_attr("BAD_ATTR_DOES_NOT_EXIST") - except ValueError as e: - # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' - print("testParseError:", e) - else: - print("Exception not produced") + with Context(): + try: + t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") + except ValueError as e: + # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' + print("testParseError:", e) + else: + print("Exception not produced") run(testParseError) # CHECK-LABEL: TEST: testAttrEq def testAttrEq(): - ctx = mlir.ir.Context() - a1 = ctx.parse_attr('"attr1"') - a2 = ctx.parse_attr('"attr2"') - a3 = ctx.parse_attr('"attr1"') - # CHECK: a1 == a1: True - print("a1 == a1:", a1 == a1) - # CHECK: a1 == a2: False - print("a1 == a2:", a1 == a2) - # CHECK: a1 == a3: True - print("a1 == a3:", a1 == a3) - # CHECK: a1 == None: False - print("a1 == None:", a1 == None) + with Context(): + a1 = Attribute.parse('"attr1"') + a2 = Attribute.parse('"attr2"') + a3 = Attribute.parse('"attr1"') + # CHECK: a1 == a1: True + print("a1 == a1:", a1 == a1) + # CHECK: a1 == a2: False + print("a1 == a2:", a1 == a2) + # CHECK: a1 == a3: True + print("a1 == a3:", a1 == a3) + # CHECK: a1 == None: False + print("a1 == None:", a1 == None) run(testAttrEq) # CHECK-LABEL: TEST: testAttrEqDoesNotRaise def testAttrEqDoesNotRaise(): - ctx = mlir.ir.Context() - a1 = ctx.parse_attr('"attr1"') - not_an_attr = "foo" - # CHECK: False - print(a1 == not_an_attr) - # CHECK: False - print(a1 == None) - # CHECK: True - print(a1 != None) + with Context(): + a1 = Attribute.parse('"attr1"') + not_an_attr = "foo" + # CHECK: False + print(a1 == not_an_attr) + # CHECK: False + print(a1 == None) + # CHECK: True + print(a1 != None) run(testAttrEqDoesNotRaise) # CHECK-LABEL: TEST: testStandardAttrCasts def testStandardAttrCasts(): - ctx = mlir.ir.Context() - a1 = ctx.parse_attr('"attr1"') - astr = mlir.ir.StringAttr(a1) - aself = mlir.ir.StringAttr(astr) - # CHECK: Attribute("attr1") - print(repr(astr)) - try: - tillegal = mlir.ir.StringAttr(ctx.parse_attr("1.0")) - except ValueError as e: - # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) - print("ValueError:", e) - else: - print("Exception not produced") + with Context(): + a1 = Attribute.parse('"attr1"') + astr = StringAttr(a1) + aself = StringAttr(astr) + # CHECK: Attribute("attr1") + print(repr(astr)) + try: + tillegal = StringAttr(Attribute.parse("1.0")) + except ValueError as e: + # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) + print("ValueError:", e) + else: + print("Exception not produced") run(testStandardAttrCasts) # CHECK-LABEL: TEST: testFloatAttr def testFloatAttr(): - ctx = mlir.ir.Context() - fattr = mlir.ir.FloatAttr(ctx.parse_attr("42.0 : f32")) - # CHECK: fattr value: 42.0 - print("fattr value:", fattr.value) - - # Test factory methods. - loc = ctx.get_unknown_location() - # CHECK: default_get: 4.200000e+01 : f32 - print("default_get:", mlir.ir.FloatAttr.get( - mlir.ir.F32Type.get(ctx), 42.0, loc)) - # CHECK: f32_get: 4.200000e+01 : f32 - print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0)) - # CHECK: f64_get: 4.200000e+01 : f64 - print("f64_get:", mlir.ir.FloatAttr.get_f64(ctx, 42.0)) - try: - fattr_invalid = mlir.ir.FloatAttr.get( - mlir.ir.IntegerType.get_signless(ctx, 32), 42, loc) - except ValueError as e: - # CHECK: invalid 'Type(i32)' and expected floating point type. - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + fattr = FloatAttr(Attribute.parse("42.0 : f32")) + # CHECK: fattr value: 42.0 + print("fattr value:", fattr.value) + + # Test factory methods. + # CHECK: default_get: 4.200000e+01 : f32 + print("default_get:", FloatAttr.get( + F32Type.get(), 42.0)) + # CHECK: f32_get: 4.200000e+01 : f32 + print("f32_get:", FloatAttr.get_f32(42.0)) + # CHECK: f64_get: 4.200000e+01 : f64 + print("f64_get:", FloatAttr.get_f64(42.0)) + try: + fattr_invalid = FloatAttr.get( + IntegerType.get_signless(32), 42) + except ValueError as e: + # CHECK: invalid 'Type(i32)' and expected floating point type. + print(e) + else: + print("Exception not produced") run(testFloatAttr) # CHECK-LABEL: TEST: testIntegerAttr def testIntegerAttr(): - ctx = mlir.ir.Context() - iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42")) - # CHECK: iattr value: 42 - print("iattr value:", iattr.value) - # CHECK: iattr type: i64 - print("iattr type:", iattr.type) - - # Test factory methods. - # CHECK: default_get: 42 : i32 - print("default_get:", mlir.ir.IntegerAttr.get( - mlir.ir.IntegerType.get_signless(ctx, 32), 42)) + with Context() as ctx: + iattr = IntegerAttr(Attribute.parse("42")) + # CHECK: iattr value: 42 + print("iattr value:", iattr.value) + # CHECK: iattr type: i64 + print("iattr type:", iattr.type) + + # Test factory methods. + # CHECK: default_get: 42 : i32 + print("default_get:", IntegerAttr.get( + IntegerType.get_signless(32), 42)) run(testIntegerAttr) # CHECK-LABEL: TEST: testBoolAttr def testBoolAttr(): - ctx = mlir.ir.Context() - battr = mlir.ir.BoolAttr(ctx.parse_attr("true")) - # CHECK: iattr value: 1 - print("iattr value:", battr.value) + with Context() as ctx: + battr = BoolAttr(Attribute.parse("true")) + # CHECK: iattr value: 1 + print("iattr value:", battr.value) - # Test factory methods. - # CHECK: default_get: true - print("default_get:", mlir.ir.BoolAttr.get(ctx, True)) + # Test factory methods. + # CHECK: default_get: true + print("default_get:", BoolAttr.get(True)) run(testBoolAttr) # CHECK-LABEL: TEST: testStringAttr def testStringAttr(): - ctx = mlir.ir.Context() - sattr = mlir.ir.StringAttr(ctx.parse_attr('"stringattr"')) - # CHECK: sattr value: stringattr - print("sattr value:", sattr.value) - - # Test factory methods. - # CHECK: default_get: "foobar" - print("default_get:", mlir.ir.StringAttr.get(ctx, "foobar")) - # CHECK: typed_get: "12345" : i32 - print("typed_get:", mlir.ir.StringAttr.get_typed( - mlir.ir.IntegerType.get_signless(ctx, 32), "12345")) + with Context() as ctx: + sattr = StringAttr(Attribute.parse('"stringattr"')) + # CHECK: sattr value: stringattr + print("sattr value:", sattr.value) + + # Test factory methods. + # CHECK: default_get: "foobar" + print("default_get:", StringAttr.get("foobar")) + # CHECK: typed_get: "12345" : i32 + print("typed_get:", StringAttr.get_typed( + IntegerType.get_signless(32), "12345")) run(testStringAttr) # CHECK-LABEL: TEST: testNamedAttr def testNamedAttr(): - ctx = mlir.ir.Context() - a = ctx.parse_attr('"stringattr"') - named = a.get_named("foobar") # Note: under the small object threshold - # CHECK: attr: "stringattr" - print("attr:", named.attr) - # CHECK: name: foobar - print("name:", named.name) - # CHECK: named: NamedAttribute(foobar="stringattr") - print("named:", named) + with Context(): + a = Attribute.parse('"stringattr"') + named = a.get_named("foobar") # Note: under the small object threshold + # CHECK: attr: "stringattr" + print("attr:", named.attr) + # CHECK: name: foobar + print("name:", named.name) + # CHECK: named: NamedAttribute(foobar="stringattr") + print("named:", named) run(testNamedAttr) diff --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py --- a/mlir/test/Bindings/Python/ir_location.py +++ b/mlir/test/Bindings/Python/ir_location.py @@ -1,19 +1,19 @@ # RUN: %PYTHON %s | FileCheck %s import gc -import mlir +from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() - assert mlir.ir.Context._get_live_count() == 0 + assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testUnknown def testUnknown(): - ctx = mlir.ir.Context() - loc = ctx.get_unknown_location() + with Context() as ctx: + loc = Location.unknown() assert loc.context is ctx ctx = None gc.collect() @@ -27,8 +27,8 @@ # CHECK-LABEL: TEST: testFileLineCol def testFileLineCol(): - ctx = mlir.ir.Context() - loc = ctx.get_file_location("foo.txt", 123, 56) + with Context() as ctx: + loc = Location.file("foo.txt", 123, 56) ctx = None gc.collect() # CHECK: file str: loc("foo.txt":123:56) diff --git a/mlir/test/Bindings/Python/ir_module.py b/mlir/test/Bindings/Python/ir_module.py --- a/mlir/test/Bindings/Python/ir_module.py +++ b/mlir/test/Bindings/Python/ir_module.py @@ -1,21 +1,21 @@ # RUN: %PYTHON %s | FileCheck %s import gc -import mlir +from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() - assert mlir.ir.Context._get_live_count() == 0 + assert Context._get_live_count() == 0 # Verify successful parse. # CHECK-LABEL: TEST: testParseSuccess # CHECK: module @successfulParse def testParseSuccess(): - ctx = mlir.ir.Context() - module = ctx.parse_module(r"""module @successfulParse {}""") + ctx = Context() + module = Module.parse(r"""module @successfulParse {}""", ctx) assert module.context is ctx print("CLEAR CONTEXT") ctx = None # Ensure that module captures the context. @@ -30,9 +30,9 @@ # CHECK-LABEL: TEST: testParseError # CHECK: testParseError: Unable to parse module assembly (see diagnostics) def testParseError(): - ctx = mlir.ir.Context() + ctx = Context() try: - module = ctx.parse_module(r"""}SYNTAX ERROR{""") + module = Module.parse(r"""}SYNTAX ERROR{""", ctx) except ValueError as e: print("testParseError:", e) else: @@ -45,9 +45,9 @@ # CHECK-LABEL: TEST: testCreateEmpty # CHECK: module { def testCreateEmpty(): - ctx = mlir.ir.Context() - loc = ctx.get_unknown_location() - module = ctx.create_module(loc) + ctx = Context() + loc = Location.unknown(ctx) + module = Module.create(loc) print("CLEAR CONTEXT") ctx = None # Ensure that module captures the context. gc.collect() @@ -63,10 +63,10 @@ # CHECK: func @roundtripUnicode() # CHECK: foo = "\F0\9F\98\8A" def testRoundtripUnicode(): - ctx = mlir.ir.Context() - module = ctx.parse_module(r""" + ctx = Context() + module = Module.parse(r""" func @roundtripUnicode() attributes { foo = "😊" } - """) + """, ctx) print(str(module)) run(testRoundtripUnicode) @@ -75,8 +75,8 @@ # Tests that module.operation works and correctly interns instances. # CHECK-LABEL: TEST: testModuleOperation def testModuleOperation(): - ctx = mlir.ir.Context() - module = ctx.parse_module(r"""module @successfulParse {}""") + ctx = Context() + module = Module.parse(r"""module @successfulParse {}""", ctx) assert ctx._get_live_module_count() == 1 op1 = module.operation assert ctx._get_live_operation_count() == 1 @@ -106,13 +106,13 @@ # CHECK-LABEL: TEST: testModuleCapsule def testModuleCapsule(): - ctx = mlir.ir.Context() - module = ctx.parse_module(r"""module @successfulParse {}""") + ctx = Context() + module = Module.parse(r"""module @successfulParse {}""", ctx) assert ctx._get_live_module_count() == 1 # CHECK: "mlir.ir.Module._CAPIPtr" module_capsule = module._CAPIPtr print(module_capsule) - module_dup = mlir.ir.Module._CAPICreate(module_capsule) + module_dup = Module._CAPICreate(module_capsule) assert module is module_dup assert module_dup.context is ctx # Gc and verify destructed. diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -3,26 +3,26 @@ import gc import io import itertools -import mlir +from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() - assert mlir.ir.Context._get_live_count() == 0 + assert Context._get_live_count() == 0 # Verify iterator based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators def testTraverseOpRegionBlockIterators(): - ctx = mlir.ir.Context() + ctx = Context() ctx.allow_unregistered_dialects = True - module = ctx.parse_module(r""" + module = Module.parse(r""" func @f1(%arg0: i32) -> i32 { %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 return %1 : i32 } - """) + """, ctx) op = module.operation assert op.context is ctx # Get the block using iterators off of the named collections. @@ -69,14 +69,14 @@ # Verify index based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices def testTraverseOpRegionBlockIndices(): - ctx = mlir.ir.Context() + ctx = Context() ctx.allow_unregistered_dialects = True - module = ctx.parse_module(r""" + module = Module.parse(r""" func @f1(%arg0: i32) -> i32 { %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 return %1 : i32 } - """) + """, ctx) def walk_operations(indent, op): for i in range(len(op.regions)): @@ -105,28 +105,28 @@ # CHECK-LABEL: TEST: testBlockArgumentList def testBlockArgumentList(): - ctx = mlir.ir.Context() - module = ctx.parse_module(r""" - func @f1(%arg0: i32, %arg1: f64, %arg2: index) { - return - } - """) - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - assert len(entry_block.arguments) == 3 - # CHECK: Argument 0, type i32 - # CHECK: Argument 1, type f64 - # CHECK: Argument 2, type index - for arg in entry_block.arguments: - print(f"Argument {arg.arg_number}, type {arg.type}") - new_type = mlir.ir.IntegerType.get_signless(ctx, 8 * (arg.arg_number + 1)) - arg.set_type(new_type) - - # CHECK: Argument 0, type i8 - # CHECK: Argument 1, type i16 - # CHECK: Argument 2, type i24 - for arg in entry_block.arguments: - print(f"Argument {arg.arg_number}, type {arg.type}") + with Context() as ctx: + module = Module.parse(r""" + func @f1(%arg0: i32, %arg1: f64, %arg2: index) { + return + } + """, ctx) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + assert len(entry_block.arguments) == 3 + # CHECK: Argument 0, type i32 + # CHECK: Argument 1, type f64 + # CHECK: Argument 2, type index + for arg in entry_block.arguments: + print(f"Argument {arg.arg_number}, type {arg.type}") + new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) + arg.set_type(new_type) + + # CHECK: Argument 0, type i8 + # CHECK: Argument 1, type i16 + # CHECK: Argument 2, type i24 + for arg in entry_block.arguments: + print(f"Argument {arg.arg_number}, type {arg.type}") run(testBlockArgumentList) @@ -134,18 +134,18 @@ # CHECK-LABEL: TEST: testDetachedOperation def testDetachedOperation(): - ctx = mlir.ir.Context() + ctx = Context() ctx.allow_unregistered_dialects = True - loc = ctx.get_unknown_location() - i32 = mlir.ir.IntegerType.get_signed(ctx, 32) - op1 = ctx.create_operation( - "custom.op1", loc, results=[i32, i32], regions=1, attributes={ - "foo": mlir.ir.StringAttr.get(ctx, "foo_value"), - "bar": mlir.ir.StringAttr.get(ctx, "bar_value"), - }) - # CHECK: %0:2 = "custom.op1"() ( { - # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) - print(op1) + with Location.unknown(ctx): + i32 = IntegerType.get_signed(32) + op1 = Operation.create( + "custom.op1", results=[i32, i32], regions=1, attributes={ + "foo": StringAttr.get("foo_value"), + "bar": StringAttr.get("bar_value"), + }) + # CHECK: %0:2 = "custom.op1"() ( { + # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) + print(op1) # TODO: Check successors once enough infra exists to do it properly. @@ -154,30 +154,30 @@ # CHECK-LABEL: TEST: testOperationInsertionPoint def testOperationInsertionPoint(): - ctx = mlir.ir.Context() + ctx = Context() ctx.allow_unregistered_dialects = True - module = ctx.parse_module(r""" + module = Module.parse(r""" func @f1(%arg0: i32) -> i32 { %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 return %1 : i32 } - """) + """, ctx) # Create test op. - loc = ctx.get_unknown_location() - op1 = ctx.create_operation("custom.op1", loc) - op2 = ctx.create_operation("custom.op2", loc) - - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - ip = mlir.ir.InsertionPoint.at_block_begin(entry_block) - ip.insert(op1) - ip.insert(op2) - # CHECK: func @f1 - # CHECK: "custom.op1"() - # CHECK: "custom.op2"() - # CHECK: %0 = "custom.addi" - print(module) + with Location.unknown(ctx): + op1 = Operation.create("custom.op1") + op2 = Operation.create("custom.op2") + + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + ip = InsertionPoint.at_block_begin(entry_block) + ip.insert(op1) + ip.insert(op2) + # CHECK: func @f1 + # CHECK: "custom.op1"() + # CHECK: "custom.op2"() + # CHECK: %0 = "custom.addi" + print(module) # Trying to add a previously added op should raise. try: @@ -192,55 +192,55 @@ # CHECK-LABEL: TEST: testOperationWithRegion def testOperationWithRegion(): - ctx = mlir.ir.Context() + ctx = Context() ctx.allow_unregistered_dialects = True - loc = ctx.get_unknown_location() - i32 = mlir.ir.IntegerType.get_signed(ctx, 32) - op1 = ctx.create_operation("custom.op1", loc, regions=1) - block = op1.regions[0].blocks.append(i32, i32) - # CHECK: "custom.op1"() ( { - # CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors - # CHECK: "custom.terminator"() : () -> () - # CHECK: }) : () -> () - terminator = ctx.create_operation("custom.terminator", loc) - ip = mlir.ir.InsertionPoint(block) - ip.insert(terminator) - print(op1) - - # Now add the whole operation to another op. - # TODO: Verify lifetime hazard by nulling out the new owning module and - # accessing op1. - # TODO: Also verify accessing the terminator once both parents are nulled - # out. - module = ctx.parse_module(r""" - func @f1(%arg0: i32) -> i32 { - %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 - return %1 : i32 - } - """) - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - ip = mlir.ir.InsertionPoint.at_block_begin(entry_block) - ip.insert(op1) - # CHECK: func @f1 - # CHECK: "custom.op1"() - # CHECK: "custom.terminator" - # CHECK: %0 = "custom.addi" - print(module) + with Location.unknown(ctx): + i32 = IntegerType.get_signed(32) + op1 = Operation.create("custom.op1", regions=1) + block = op1.regions[0].blocks.append(i32, i32) + # CHECK: "custom.op1"() ( { + # CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors + # CHECK: "custom.terminator"() : () -> () + # CHECK: }) : () -> () + terminator = Operation.create("custom.terminator") + ip = InsertionPoint(block) + ip.insert(terminator) + print(op1) + + # Now add the whole operation to another op. + # TODO: Verify lifetime hazard by nulling out the new owning module and + # accessing op1. + # TODO: Also verify accessing the terminator once both parents are nulled + # out. + module = Module.parse(r""" + func @f1(%arg0: i32) -> i32 { + %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 + return %1 : i32 + } + """) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + ip = InsertionPoint.at_block_begin(entry_block) + ip.insert(op1) + # CHECK: func @f1 + # CHECK: "custom.op1"() + # CHECK: "custom.terminator" + # CHECK: %0 = "custom.addi" + print(module) run(testOperationWithRegion) # CHECK-LABEL: TEST: testOperationResultList def testOperationResultList(): - ctx = mlir.ir.Context() - module = ctx.parse_module(r""" + ctx = Context() + module = Module.parse(r""" func @f1() { %0:3 = call @f2() : () -> (i32, f64, index) return } func @f2() -> (i32, f64, index) - """) + """, ctx) caller = module.body.operations[0] call = caller.regions[0].blocks[0].operations[0] assert len(call.results) == 3 @@ -256,13 +256,13 @@ # CHECK-LABEL: TEST: testOperationPrint def testOperationPrint(): - ctx = mlir.ir.Context() - module = ctx.parse_module(r""" + ctx = Context() + module = Module.parse(r""" func @f1(%arg0: i32) -> i32 { %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32> return %arg0 : i32 } - """) + """, ctx) # Test print to stdout. # CHECK: return %arg0 : i32 diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -1,19 +1,19 @@ # RUN: %PYTHON %s | FileCheck %s import gc -import mlir +from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() - assert mlir.ir.Context._get_live_count() == 0 + assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testParsePrint def testParsePrint(): - ctx = mlir.ir.Context() - t = ctx.parse_type("i32") + ctx = Context() + t = Type.parse("i32", ctx) assert t.context is ctx ctx = None gc.collect() @@ -29,9 +29,9 @@ # TODO: Hook the diagnostic manager to capture a more meaningful error # message. def testParseError(): - ctx = mlir.ir.Context() + ctx = Context() try: - t = ctx.parse_type("BAD_TYPE_DOES_NOT_EXIST") + t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) except ValueError as e: # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST' print("testParseError:", e) @@ -43,10 +43,10 @@ # CHECK-LABEL: TEST: testTypeEq def testTypeEq(): - ctx = mlir.ir.Context() - t1 = ctx.parse_type("i32") - t2 = ctx.parse_type("f32") - t3 = ctx.parse_type("i32") + ctx = Context() + t1 = Type.parse("i32", ctx) + t2 = Type.parse("f32", ctx) + t3 = Type.parse("i32", ctx) # CHECK: t1 == t1: True print("t1 == t1:", t1 == t1) # CHECK: t1 == t2: False @@ -61,8 +61,8 @@ # CHECK-LABEL: TEST: testTypeEqDoesNotRaise def testTypeEqDoesNotRaise(): - ctx = mlir.ir.Context() - t1 = ctx.parse_type("i32") + ctx = Context() + t1 = Type.parse("i32", ctx) not_a_type = "foo" # CHECK: False print(t1 == not_a_type) @@ -76,14 +76,14 @@ # CHECK-LABEL: TEST: testStandardTypeCasts def testStandardTypeCasts(): - ctx = mlir.ir.Context() - t1 = ctx.parse_type("i32") - tint = mlir.ir.IntegerType(t1) - tself = mlir.ir.IntegerType(tint) + ctx = Context() + t1 = Type.parse("i32", ctx) + tint = IntegerType(t1) + tself = IntegerType(tint) # CHECK: Type(i32) print(repr(tint)) try: - tillegal = mlir.ir.IntegerType(ctx.parse_type("f32")) + tillegal = IntegerType(Type.parse("f32", ctx)) except ValueError as e: # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) print("ValueError:", e) @@ -95,91 +95,91 @@ # CHECK-LABEL: TEST: testIntegerType def testIntegerType(): - ctx = mlir.ir.Context() - i32 = mlir.ir.IntegerType(ctx.parse_type("i32")) - # CHECK: i32 width: 32 - print("i32 width:", i32.width) - # CHECK: i32 signless: True - print("i32 signless:", i32.is_signless) - # CHECK: i32 signed: False - print("i32 signed:", i32.is_signed) - # CHECK: i32 unsigned: False - print("i32 unsigned:", i32.is_unsigned) - - s32 = mlir.ir.IntegerType(ctx.parse_type("si32")) - # CHECK: s32 signless: False - print("s32 signless:", s32.is_signless) - # CHECK: s32 signed: True - print("s32 signed:", s32.is_signed) - # CHECK: s32 unsigned: False - print("s32 unsigned:", s32.is_unsigned) - - u32 = mlir.ir.IntegerType(ctx.parse_type("ui32")) - # CHECK: u32 signless: False - print("u32 signless:", u32.is_signless) - # CHECK: u32 signed: False - print("u32 signed:", u32.is_signed) - # CHECK: u32 unsigned: True - print("u32 unsigned:", u32.is_unsigned) - - # CHECK: signless: i16 - print("signless:", mlir.ir.IntegerType.get_signless(ctx, 16)) - # CHECK: signed: si8 - print("signed:", mlir.ir.IntegerType.get_signed(ctx, 8)) - # CHECK: unsigned: ui64 - print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64)) + with Context() as ctx: + i32 = IntegerType(Type.parse("i32")) + # CHECK: i32 width: 32 + print("i32 width:", i32.width) + # CHECK: i32 signless: True + print("i32 signless:", i32.is_signless) + # CHECK: i32 signed: False + print("i32 signed:", i32.is_signed) + # CHECK: i32 unsigned: False + print("i32 unsigned:", i32.is_unsigned) + + s32 = IntegerType(Type.parse("si32")) + # CHECK: s32 signless: False + print("s32 signless:", s32.is_signless) + # CHECK: s32 signed: True + print("s32 signed:", s32.is_signed) + # CHECK: s32 unsigned: False + print("s32 unsigned:", s32.is_unsigned) + + u32 = IntegerType(Type.parse("ui32")) + # CHECK: u32 signless: False + print("u32 signless:", u32.is_signless) + # CHECK: u32 signed: False + print("u32 signed:", u32.is_signed) + # CHECK: u32 unsigned: True + print("u32 unsigned:", u32.is_unsigned) + + # CHECK: signless: i16 + print("signless:", IntegerType.get_signless(16)) + # CHECK: signed: si8 + print("signed:", IntegerType.get_signed(8)) + # CHECK: unsigned: ui64 + print("unsigned:", IntegerType.get_unsigned(64)) run(testIntegerType) # CHECK-LABEL: TEST: testIndexType def testIndexType(): - ctx = mlir.ir.Context() - # CHECK: index type: index - print("index type:", mlir.ir.IndexType.get(ctx)) + with Context() as ctx: + # CHECK: index type: index + print("index type:", IndexType.get()) run(testIndexType) # CHECK-LABEL: TEST: testFloatType def testFloatType(): - ctx = mlir.ir.Context() - # CHECK: float: bf16 - print("float:", mlir.ir.BF16Type.get(ctx)) - # CHECK: float: f16 - print("float:", mlir.ir.F16Type.get(ctx)) - # CHECK: float: f32 - print("float:", mlir.ir.F32Type.get(ctx)) - # CHECK: float: f64 - print("float:", mlir.ir.F64Type.get(ctx)) + with Context(): + # CHECK: float: bf16 + print("float:", BF16Type.get()) + # CHECK: float: f16 + print("float:", F16Type.get()) + # CHECK: float: f32 + print("float:", F32Type.get()) + # CHECK: float: f64 + print("float:", F64Type.get()) run(testFloatType) # CHECK-LABEL: TEST: testNoneType def testNoneType(): - ctx = mlir.ir.Context() - # CHECK: none type: none - print("none type:", mlir.ir.NoneType.get(ctx)) + with Context(): + # CHECK: none type: none + print("none type:", NoneType.get()) run(testNoneType) # CHECK-LABEL: TEST: testComplexType def testComplexType(): - ctx = mlir.ir.Context() - complex_i32 = mlir.ir.ComplexType(ctx.parse_type("complex")) - # CHECK: complex type element: i32 - print("complex type element:", complex_i32.element_type) - - f32 = mlir.ir.F32Type.get(ctx) - # CHECK: complex type: complex - print("complex type:", mlir.ir.ComplexType.get(f32)) - - index = mlir.ir.IndexType.get(ctx) - try: - complex_invalid = mlir.ir.ComplexType.get(index) - except ValueError as e: - # CHECK: invalid 'Type(index)' and expected floating point or integer type. - print(e) - else: - print("Exception not produced") + with Context() as ctx: + complex_i32 = ComplexType(Type.parse("complex")) + # CHECK: complex type element: i32 + print("complex type element:", complex_i32.element_type) + + f32 = F32Type.get() + # CHECK: complex type: complex + print("complex type:", ComplexType.get(f32)) + + index = IndexType.get() + try: + complex_invalid = ComplexType.get(index) + except ValueError as e: + # CHECK: invalid 'Type(index)' and expected floating point or integer type. + print(e) + else: + print("Exception not produced") run(testComplexType) @@ -188,26 +188,26 @@ # vectors, memrefs and tensors, so this test case uses an instance of vector # to test the shaped type. The class hierarchy is preserved on the python side. def testConcreteShapedType(): - ctx = mlir.ir.Context() - vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>")) - # CHECK: element type: f32 - print("element type:", vector.element_type) - # CHECK: whether the given shaped type is ranked: True - print("whether the given shaped type is ranked:", vector.has_rank) - # CHECK: rank: 2 - print("rank:", vector.rank) - # CHECK: whether the shaped type has a static shape: True - print("whether the shaped type has a static shape:", vector.has_static_shape) - # CHECK: whether the dim-th dimension is dynamic: False - print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) - # CHECK: dim size: 3 - print("dim size:", vector.get_dim_size(1)) - # CHECK: is_dynamic_size: False - print("is_dynamic_size:", vector.is_dynamic_size(3)) - # CHECK: is_dynamic_stride_or_offset: False - print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) - # CHECK: isinstance(ShapedType): True - print("isinstance(ShapedType):", isinstance(vector, mlir.ir.ShapedType)) + with Context() as ctx: + vector = VectorType(Type.parse("vector<2x3xf32>")) + # CHECK: element type: f32 + print("element type:", vector.element_type) + # CHECK: whether the given shaped type is ranked: True + print("whether the given shaped type is ranked:", vector.has_rank) + # CHECK: rank: 2 + print("rank:", vector.rank) + # CHECK: whether the shaped type has a static shape: True + print("whether the shaped type has a static shape:", vector.has_static_shape) + # CHECK: whether the dim-th dimension is dynamic: False + print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) + # CHECK: dim size: 3 + print("dim size:", vector.get_dim_size(1)) + # CHECK: is_dynamic_size: False + print("is_dynamic_size:", vector.is_dynamic_size(3)) + # CHECK: is_dynamic_stride_or_offset: False + print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) + # CHECK: isinstance(ShapedType): True + print("isinstance(ShapedType):", isinstance(vector, ShapedType)) run(testConcreteShapedType) @@ -215,8 +215,8 @@ # Tests that ShapedType operates as an abstract base class of a concrete # shaped type (using vector as an example). def testAbstractShapedType(): - ctx = mlir.ir.Context() - vector = mlir.ir.ShapedType(ctx.parse_type("vector<2x3xf32>")) + ctx = Context() + vector = ShapedType(Type.parse("vector<2x3xf32>", ctx)) # CHECK: element type: f32 print("element type:", vector.element_type) @@ -224,186 +224,184 @@ # CHECK-LABEL: TEST: testVectorType def testVectorType(): - ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type.get(ctx) - shape = [2, 3] - loc = ctx.get_unknown_location() - # CHECK: vector type: vector<2x3xf32> - print("vector type:", mlir.ir.VectorType.get(shape, f32, loc)) - - none = mlir.ir.NoneType.get(ctx) - try: - vector_invalid = mlir.ir.VectorType.get(shape, none, loc) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point or integer type. - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + f32 = F32Type.get() + shape = [2, 3] + # CHECK: vector type: vector<2x3xf32> + print("vector type:", VectorType.get(shape, f32)) + + none = NoneType.get() + try: + vector_invalid = VectorType.get(shape, none) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point or integer type. + print(e) + else: + print("Exception not produced") run(testVectorType) # CHECK-LABEL: TEST: testRankedTensorType def testRankedTensorType(): - ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type.get(ctx) - shape = [2, 3] - loc = ctx.get_unknown_location() - # CHECK: ranked tensor type: tensor<2x3xf32> - print("ranked tensor type:", - mlir.ir.RankedTensorType.get(shape, f32, loc)) - - none = mlir.ir.NoneType.get(ctx) - try: - tensor_invalid = mlir.ir.RankedTensorType.get(shape, none, loc) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + f32 = F32Type.get() + shape = [2, 3] + loc = Location.unknown() + # CHECK: ranked tensor type: tensor<2x3xf32> + print("ranked tensor type:", + RankedTensorType.get(shape, f32)) + + none = NoneType.get() + try: + tensor_invalid = RankedTensorType.get(shape, none) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") run(testRankedTensorType) # CHECK-LABEL: TEST: testUnrankedTensorType def testUnrankedTensorType(): - ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type.get(ctx) - loc = ctx.get_unknown_location() - unranked_tensor = mlir.ir.UnrankedTensorType.get(f32, loc) - # CHECK: unranked tensor type: tensor<*xf32> - print("unranked tensor type:", unranked_tensor) - try: - invalid_rank = unranked_tensor.rank - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_get_dim_size = unranked_tensor.get_dim_size(1) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - - none = mlir.ir.NoneType.get(ctx) - try: - tensor_invalid = mlir.ir.UnrankedTensorType.get(none, loc) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + f32 = F32Type.get() + loc = Location.unknown() + unranked_tensor = UnrankedTensorType.get(f32) + # CHECK: unranked tensor type: tensor<*xf32> + print("unranked tensor type:", unranked_tensor) + try: + invalid_rank = unranked_tensor.rank + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_get_dim_size = unranked_tensor.get_dim_size(1) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + + none = NoneType.get() + try: + tensor_invalid = UnrankedTensorType.get(none) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") run(testUnrankedTensorType) # CHECK-LABEL: TEST: testMemRefType def testMemRefType(): - ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type.get(ctx) - shape = [2, 3] - loc = ctx.get_unknown_location() - memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2, loc) - # CHECK: memref type: memref<2x3xf32, 2> - print("memref type:", memref) - # CHECK: number of affine layout maps: 0 - print("number of affine layout maps:", memref.num_affine_maps) - # CHECK: memory space: 2 - print("memory space:", memref.memory_space) - - none = mlir.ir.NoneType.get(ctx) - try: - memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(none, shape, 2, - loc) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + f32 = F32Type.get() + shape = [2, 3] + loc = Location.unknown() + memref = MemRefType.get_contiguous_memref(f32, shape, 2) + # CHECK: memref type: memref<2x3xf32, 2> + print("memref type:", memref) + # CHECK: number of affine layout maps: 0 + print("number of affine layout maps:", memref.num_affine_maps) + # CHECK: memory space: 2 + print("memory space:", memref.memory_space) + + none = NoneType.get() + try: + memref_invalid = MemRefType.get_contiguous_memref(none, shape, 2) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") run(testMemRefType) # CHECK-LABEL: TEST: testUnrankedMemRefType def testUnrankedMemRefType(): - ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type.get(ctx) - loc = ctx.get_unknown_location() - unranked_memref = mlir.ir.UnrankedMemRefType.get(f32, 2, loc) - # CHECK: unranked memref type: memref<*xf32, 2> - print("unranked memref type:", unranked_memref) - try: - invalid_rank = unranked_memref.rank - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_get_dim_size = unranked_memref.get_dim_size(1) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - - none = mlir.ir.NoneType.get(ctx) - try: - memref_invalid = mlir.ir.UnrankedMemRefType.get(none, 2, loc) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + f32 = F32Type.get() + loc = Location.unknown() + unranked_memref = UnrankedMemRefType.get(f32, 2) + # CHECK: unranked memref type: memref<*xf32, 2> + print("unranked memref type:", unranked_memref) + try: + invalid_rank = unranked_memref.rank + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_get_dim_size = unranked_memref.get_dim_size(1) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + + none = NoneType.get() + try: + memref_invalid = UnrankedMemRefType.get(none, 2) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") run(testUnrankedMemRefType) # CHECK-LABEL: TEST: testTupleType def testTupleType(): - ctx = mlir.ir.Context() - i32 = mlir.ir.IntegerType(ctx.parse_type("i32")) - f32 = mlir.ir.F32Type.get(ctx) - vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>")) - l = [i32, f32, vector] - tuple_type = mlir.ir.TupleType.get_tuple(ctx, l) - # CHECK: tuple type: tuple> - print("tuple type:", tuple_type) - # CHECK: number of types: 3 - print("number of types:", tuple_type.num_types) - # CHECK: pos-th type in the tuple type: f32 - print("pos-th type in the tuple type:", tuple_type.get_type(1)) + with Context() as ctx: + i32 = IntegerType(Type.parse("i32")) + f32 = F32Type.get() + vector = VectorType(Type.parse("vector<2x3xf32>")) + l = [i32, f32, vector] + tuple_type = TupleType.get_tuple(l) + # CHECK: tuple type: tuple> + print("tuple type:", tuple_type) + # CHECK: number of types: 3 + print("number of types:", tuple_type.num_types) + # CHECK: pos-th type in the tuple type: f32 + print("pos-th type in the tuple type:", tuple_type.get_type(1)) run(testTupleType) # CHECK-LABEL: TEST: testFunctionType def testFunctionType(): - ctx = mlir.ir.Context() - input_types = [mlir.ir.IntegerType.get_signless(ctx, 32), - mlir.ir.IntegerType.get_signless(ctx, 16)] - result_types = [mlir.ir.IndexType.get(ctx)] - func = mlir.ir.FunctionType.get(ctx, input_types, result_types) - # CHECK: INPUTS: [Type(i32), Type(i16)] - print("INPUTS:", func.inputs) - # CHECK: RESULTS: [Type(index)] - print("RESULTS:", func.results) + with Context() as ctx: + input_types = [IntegerType.get_signless(32), + IntegerType.get_signless(16)] + result_types = [IndexType.get()] + func = FunctionType.get(input_types, result_types) + # CHECK: INPUTS: [Type(i32), Type(i16)] + print("INPUTS:", func.inputs) + # CHECK: RESULTS: [Type(index)] + print("RESULTS:", func.results) run(testFunctionType)