diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -93,9 +93,10 @@ static MlirAffineExpr castFrom(PyAffineExpr &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, - Twine("Cannot cast affine expression to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); + throw py::value_error((Twine("Cannot cast affine expression to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig; } diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "IRModule.h" @@ -666,14 +666,14 @@ !mlirAttributeIsAFloat(elementAttr)) { std::string message = "Illegal element type for DenseElementsAttr: "; message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); + throw py::value_error(message); } if (!mlirTypeIsAShaped(shapedType) || !mlirShapedTypeHasStaticShape(shapedType)) { std::string message = "Expected a static ShapedType for the shaped_type parameter: "; message.append(py::repr(py::cast(shapedType))); - throw SetPyError(PyExc_ValueError, message); + throw py::value_error(message); } MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); MlirType attrType = mlirAttributeGetType(elementAttr); @@ -683,7 +683,7 @@ message.append(py::repr(py::cast(shapedType))); message.append(", element="); message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); + throw py::value_error(message); } MlirAttribute elements = @@ -783,8 +783,7 @@ .def("get_splat_value", [](PyDenseElementsAttribute &self) -> PyAttribute { if (!mlirDenseElementsAttrIsSplat(self)) { - throw SetPyError( - PyExc_ValueError, + throw py::value_error( "get_splat_value called on a non-splat attribute"); } return PyAttribute(self.getContext(), @@ -861,8 +860,7 @@ /// out of range. py::int_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); + throw py::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); @@ -909,7 +907,7 @@ return mlirDenseElementsAttrGetInt64Value(*this, pos); } } - throw SetPyError(PyExc_TypeError, "Unsupported integer type"); + throw py::type_error("Unsupported integer type"); } static void bindDerived(ClassTy &c) { @@ -957,15 +955,13 @@ MlirAttribute attr = mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_KeyError, - "attempt to access a non-existent attribute"); + throw py::key_error("attempt to access a non-existent attribute"); } return PyAttribute(self.getContext(), attr); }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds attribute"); + throw py::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); return PyNamedAttribute( @@ -987,8 +983,7 @@ py::float_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); + throw py::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); @@ -1004,7 +999,7 @@ if (mlirTypeIsAF64(type)) { return mlirDenseElementsAttrGetDoubleValue(*this, pos); } - throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); + throw py::type_error("Unsupported floating-point type"); } static void bindDerived(ClassTy &c) { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -278,8 +278,7 @@ PyRegion dunderGetItem(intptr_t index) { // dunderLen checks validity. if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds region"); + throw py::index_error("attempt to access out of bounds region"); } MlirRegion region = mlirOperationGetRegion(operation->get(), index); return PyRegion(operation, region); @@ -351,8 +350,7 @@ PyBlock dunderGetItem(intptr_t index) { operation->checkValid(); if (index < 0) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds block"); + throw py::index_error("attempt to access out of bounds block"); } MlirBlock block = mlirRegionGetFirstBlock(region); while (!mlirBlockIsNull(block)) { @@ -362,7 +360,7 @@ block = mlirBlockGetNextInRegion(block); index -= 1; } - throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); + throw py::index_error("attempt to access out of bounds block"); } PyBlock appendBlock(const py::args &pyArgTypes) { @@ -456,8 +454,7 @@ py::object dunderGetItem(intptr_t index) { parentOperation->checkValid(); if (index < 0) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds operation"); + throw py::index_error("attempt to access out of bounds operation"); } MlirOperation childOp = mlirBlockGetFirstOperation(block); while (!mlirOperationIsNull(childOp)) { @@ -468,8 +465,7 @@ childOp = mlirOperationGetNextInBlock(childOp); index -= 1; } - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds operation"); + throw py::index_error("attempt to access out of bounds operation"); } static void bind(py::module &m) { @@ -684,8 +680,7 @@ PyMlirContext &DefaultingPyMlirContext::resolve() { PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); if (!context) { - throw SetPyError( - PyExc_RuntimeError, + throw std::runtime_error( "An MLIR function requires a Context but none was provided in the call " "or from the surrounding environment. Either pass to the function with " "a 'context=' argument or establish a default using 'with Context():'"); @@ -775,10 +770,10 @@ void PyThreadContextEntry::popContext(PyMlirContext &context) { auto &stack = getStack(); if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); + throw std::runtime_error("Unbalanced Context enter/exit"); auto &tos = stack.back(); if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); + throw std::runtime_error("Unbalanced Context enter/exit"); stack.pop_back(); } @@ -797,13 +792,11 @@ void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { auto &stack = getStack(); if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, - "Unbalanced InsertionPoint enter/exit"); + throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); auto &tos = stack.back(); if (tos.frameKind != FrameKind::InsertionPoint && tos.getInsertionPoint() != &insertionPoint) - throw SetPyError(PyExc_RuntimeError, - "Unbalanced InsertionPoint enter/exit"); + throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); stack.pop_back(); } @@ -819,10 +812,10 @@ void PyThreadContextEntry::popLocation(PyLocation &location) { auto &stack = getStack(); if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); + throw std::runtime_error("Unbalanced Location enter/exit"); auto &tos = stack.back(); if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); + throw std::runtime_error("Unbalanced Location enter/exit"); stack.pop_back(); } @@ -913,8 +906,11 @@ MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), {key.data(), key.size()}); if (mlirDialectIsNull(dialect)) { - throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, - Twine("Dialect '") + key + "' not found"); + std::string msg = (Twine("Dialect '") + key + "' not found").str(); + if (attrError) + throw py::attribute_error(msg); + else + throw py::index_error(msg); } return dialect; } @@ -961,8 +957,7 @@ PyLocation &DefaultingPyLocation::resolve() { auto *location = PyThreadContextEntry::getDefaultLocation(); if (!location) { - throw SetPyError( - PyExc_RuntimeError, + throw std::runtime_error( "An MLIR function requires a Location but none was provided in the " "call or from the surrounding environment. Either pass to the function " "with a 'loc=' argument or establish a default using 'with loc:'"); @@ -1107,7 +1102,7 @@ void PyOperation::checkValid() const { if (!valid) { - throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); + throw std::runtime_error("the operation has been invalidated"); } } @@ -1211,7 +1206,7 @@ std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) - throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); + throw py::value_error("Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); if (mlirOperationIsNull(operation)) return {}; @@ -1270,14 +1265,14 @@ // General parameter validation. if (regions < 0) - throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); + throw py::value_error("number of regions must be >= 0"); // Unpack/validate operands. if (operands) { mlirOperands.reserve(operands->size()); for (PyValue *operand : *operands) { if (!operand) - throw SetPyError(PyExc_ValueError, "operand value cannot be None"); + throw py::value_error("operand value cannot be None"); mlirOperands.push_back(operand->get()); } } @@ -1288,7 +1283,7 @@ for (PyType *result : *results) { // TODO: Verify result type originate from the same context. if (!result) - throw SetPyError(PyExc_ValueError, "result type cannot be None"); + throw py::value_error("result type cannot be None"); mlirResults.push_back(*result); } } @@ -1329,7 +1324,7 @@ for (auto *successor : *successors) { // TODO: Verify successor originate from the same context. if (!successor) - throw SetPyError(PyExc_ValueError, "successor block cannot be None"); + throw py::value_error("successor block cannot be None"); mlirSuccessors.push_back(successor->get()); } } @@ -1701,8 +1696,8 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) { PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) - throw SetPyError(PyExc_ValueError, - "Attempt to insert operation that is already attached"); + throw py::value_error( + "Attempt to insert operation that is already attached"); block.getParentOperation()->checkValid(); MlirOperation beforeOp = {nullptr}; if (refOperation) { @@ -1740,7 +1735,7 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { MlirOperation terminator = mlirBlockGetTerminator(block.get()); if (mlirOperationIsNull(terminator)) - throw SetPyError(PyExc_ValueError, "Block has no terminator"); + throw py::value_error("Block has no terminator"); PyOperationRef terminatorOpRef = PyOperation::forOperation( block.getParentOperation()->getContext(), terminator); return PyInsertionPoint{block, std::move(terminatorOpRef)}; @@ -2033,9 +2028,10 @@ static MlirValue castFrom(PyValue &orig) { if (!DerivedTy::isaFunction(orig.get())) { auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); + throw py::value_error((Twine("Cannot cast value to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig.get(); } @@ -2273,16 +2269,14 @@ MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_KeyError, - "attempt to access a non-existent attribute"); + throw py::key_error("attempt to access a non-existent attribute"); } return PyAttribute(operation->getContext(), attr); } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds attribute"); + throw py::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation->get(), index); @@ -2301,8 +2295,7 @@ int removed = mlirOperationRemoveAttributeByName(operation->get(), toMlirStringRef(name)); if (!removed) - throw SetPyError(PyExc_KeyError, - "attempt to delete a non-existent attribute"); + throw py::key_error("attempt to delete a non-existent attribute"); } intptr_t dunderLen() { @@ -2402,7 +2395,7 @@ [](py::object & /*class*/) { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - throw SetPyError(PyExc_ValueError, "No current Context"); + throw py::value_error("No current Context"); return context; }, "Gets the Context bound to the current thread or raises ValueError") @@ -2419,8 +2412,8 @@ MlirDialect dialect = mlirContextGetOrLoadDialect( self.get(), {name.data(), name.size()}); if (mlirDialectIsNull(dialect)) { - throw SetPyError(PyExc_ValueError, - Twine("Dialect '") + name + "' not found"); + throw py::value_error( + (Twine("Dialect '") + name + "' not found").str()); } return PyDialectDescriptor(self.getRef(), dialect); }, @@ -2545,7 +2538,7 @@ [](py::object & /*class*/) { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw SetPyError(PyExc_ValueError, "No current Location"); + throw py::value_error("No current Location"); return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -2752,13 +2745,13 @@ auto numResults = mlirOperationGetNumResults(operation); if (numResults != 1) { auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw SetPyError( - PyExc_ValueError, - Twine("Cannot call .result on operation ") + - StringRef(name.data, name.length) + " which has " + - Twine(numResults) + - " results (it is only valid for operations with a " - "single result)"); + throw py::value_error( + (Twine("Cannot call .result on operation ") + + StringRef(name.data, name.length) + " which has " + + Twine(numResults) + + " results (it is only valid for operations with a " + "single result)") + .str()); } return PyOpResult(operation.getRef(), mlirOperationGetResult(operation, 0)); @@ -3119,7 +3112,7 @@ [](py::object & /*class*/) { auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); if (!ip) - throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); + throw py::value_error("No current InsertionPoint"); return ip; }, "Gets the InsertionPoint bound to the current thread or raises " diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -877,9 +877,10 @@ static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); + throw py::value_error((llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig; } @@ -898,9 +899,8 @@ "static_typeid", [](py::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw SetPyError(PyExc_AttributeError, - DerivedTy::pyClassName + - llvm::Twine(" has no typeid.")); + throw py::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); }); cls.def_property_readonly("typeid", [](PyType &self) { return py::cast(self).attr("typeid").cast(); @@ -990,9 +990,10 @@ static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, - llvm::Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); + throw py::value_error((llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig; } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -10,8 +10,8 @@ #include "Globals.h" #include "PybindUtils.h" -#include #include +#include #include "mlir-c/Bindings/Python/Interop.h" @@ -76,9 +76,9 @@ py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + - dialectNamespace + - "' is already registered."); + throw std::runtime_error((llvm::Twine("Dialect namespace '") + + dialectNamespace + "' is already registered.") + .str()); } found = std::move(pyClass); } @@ -87,9 +87,9 @@ py::object pyClass) { py::object &found = operationClassMap[operationName]; if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + - operationName + - "' is already registered."); + throw std::runtime_error((llvm::Twine("Operation '") + operationName + + "' is already registered.") + .str()); } found = std::move(pyClass); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -325,11 +325,11 @@ MlirType t = mlirComplexTypeGet(elementType); return PyComplexType(elementType.getContext(), t); } - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); + throw py::value_error( + (Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type.") + .str()); }, "Create a complex type"); c.def_property_readonly( @@ -432,8 +432,7 @@ private: void requireHasRank() { if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, + throw py::value_error( "calling this method requires that the type has a rank."); } } diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -93,7 +93,7 @@ mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); + throw py::value_error(std::string(errorMsg.join())); return new PyPassManager(passManager); }, py::arg("pipeline"), py::arg("context") = py::none(), @@ -109,7 +109,7 @@ mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); + throw py::value_error(std::string(errorMsg.join())); }, py::arg("pipeline"), "Add textual pipeline elements to the pass manager. Throws a " diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -19,13 +19,6 @@ namespace mlir { namespace python { -// Sets a python error, ready to be thrown to return control back to the -// python runtime. -// Correct usage: -// throw SetPyError(PyExc_ValueError, "Foobar'd"); -pybind11::error_already_set SetPyError(PyObject *excClass, - const llvm::Twine &message); - /// CRTP template for special wrapper types that are allowed to be passed in as /// 'None' function arguments and can be resolved by some global mechanic if /// so. Such types will raise an error if this global resolution fails, and diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp deleted file mode 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.cpp +++ /dev/null @@ -1,16 +0,0 @@ -//===- PybindUtils.cpp - Utilities for interop with pybind11 --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PybindUtils.h" - -pybind11::error_already_set -mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) { - auto messageStr = message.str(); - PyErr_SetString(excClass, messageStr.c_str()); - return pybind11::error_already_set(); -} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -263,7 +263,6 @@ IRInterfaces.cpp IRModule.cpp IRTypes.cpp - PybindUtils.cpp Pass.cpp # Headers must be included explicitly so they are installed.