diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h --- a/mlir/include/mlir-c/Interfaces.h +++ b/mlir/include/mlir-c/Interfaces.h @@ -60,6 +60,32 @@ intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData); +//===----------------------------------------------------------------------===// +// InferShapedTypeOpInterface. +//===----------------------------------------------------------------------===// + +/// Returns the interface TypeID of the InferShapedTypeOpInterface. +MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(); + +/// These callbacks are used to return multiple ShapedTypeComponents from +/// functions while transferring ownership to the caller. The first argument is +/// the rank followed by a pointer to the shape (if applicable), then the +/// element type, then an attribute. The last argument is an opaque pointer +/// forwarded to the callback by the caller. This callback will be called +/// potentially multiple times for each ShapedTypeComponent. +typedef void (*MlirShapedTypeComponentsCallback)(intptr_t, const int64_t *, + MlirType, MlirAttribute, + void *); + +/// Infers the return shaped type components of the operation. Calls `callback` +/// with the types of inferred arguments on success. Returns failure otherwise. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirInferShapedTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, + MlirShapedTypeComponentsCallback callback, void *userData); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" @@ -35,6 +35,76 @@ R"(Given the arguments required to build an operation, attempts to infer its return types. Raises ValueError on failure.)"; +constexpr static const char *inferReturnTypeComponentsDoc = + R"(Given the arguments required to build an operation, attempts to infer +its return shaped type components. Raises ValueError on failure.)"; + +namespace { + +llvm::SmallVector wrapOperands(std::optional operandList) { + llvm::SmallVector mlirOperands; + + if (operandList && !operandList->empty()) { + // Note: as the list may contain other lists this may not be final size. + mlirOperands.reserve(operandList->size()); + for (const auto &it : llvm::enumerate(*operandList)) { + PyValue *val; + try { + val = py::cast(it.value()); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + continue; + } catch (py::cast_error &err) { + // Intentionally unhandled to try sequence below first. + (void)err; + } + + try { + auto vals = py::cast(it.value()); + for (py::object v : vals) { + try { + val = py::cast(v); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + ")") + .str()); + } + } + continue; + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + ")") + .str()); + } + + throw py::cast_error(); + } + } + + return mlirOperands; +} +llvm::SmallVector +wrapRegions(std::optional> regions) { + llvm::SmallVector mlirRegions; + + if (regions) { + mlirRegions.reserve(regions->size()); + for (PyRegion ®ion : *regions) { + mlirRegions.push_back(region); + } + } + + return mlirRegions; +} + +} // namespace + /// CRTP base class for Python classes representing MLIR Op interfaces. /// Interface hierarchies are flat so no base class is expected here. The /// derived class is expected to define the following static fields: @@ -104,7 +174,7 @@ /// Creates the Python bindings for this class in the given module. static void bind(py::module &m) { - py::class_ cls(m, "InferTypeOpInterface", + py::class_ cls(m, ConcreteIface::pyClassName, py::module_local()); cls.def(py::init(), py::arg("object"), py::arg("context") = py::none(), constructorDoc) @@ -155,7 +225,7 @@ py::object obj; }; -/// Python wrapper for InterTypeOpInterface. This interface has only static +/// Python wrapper for InferTypeOpInterface. This interface has only static /// methods. class PyInferTypeOpInterface : public PyConcreteOpInterface { @@ -191,59 +261,8 @@ std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { - llvm::SmallVector mlirOperands; - llvm::SmallVector mlirRegions; - - if (operandList && !operandList->empty()) { - // Note: as the list may contain other lists this may not be final size. - mlirOperands.reserve(operandList->size()); - for (const auto& it : llvm::enumerate(*operandList)) { - PyValue* val; - try { - val = py::cast(it.value()); - if (!val) - throw py::cast_error(); - mlirOperands.push_back(val->get()); - continue; - } catch (py::cast_error &err) { - // Intentionally unhandled to try sequence below first. - (void)err; - } - - try { - auto vals = py::cast(it.value()); - for (py::object v : vals) { - try { - val = py::cast(v); - if (!val) - throw py::cast_error(); - mlirOperands.push_back(val->get()); - } catch (py::cast_error &err) { - throw py::value_error( - (llvm::Twine("Operand ") + llvm::Twine(it.index()) + - " must be a Value or Sequence of Values (" + err.what() + - ")") - .str()); - } - } - continue; - } catch (py::cast_error &err) { - throw py::value_error( - (llvm::Twine("Operand ") + llvm::Twine(it.index()) + - " must be a Value or Sequence of Values (" + err.what() + ")") - .str()); - } - - throw py::cast_error(); - } - } - - if (regions) { - mlirRegions.reserve(regions->size()); - for (PyRegion ®ion : *regions) { - mlirRegions.push_back(region); - } - } + llvm::SmallVector mlirOperands = wrapOperands(operandList); + llvm::SmallVector mlirRegions = wrapRegions(regions); std::vector inferredTypes; PyMlirContext &pyContext = context.resolve(); @@ -274,7 +293,157 @@ } }; -void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } +/// Wrapper around an ShapedTypeComponents. +/// Upon construction, the Python wrapper takes ownership of the +/// underlying ShapedTypeComponents. +class PyShapedTypeComponents { +public: + PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} + PyShapedTypeComponents(py::list shape, MlirType elementType) + : shape(shape), elementType(elementType), ranked(true) {} + PyShapedTypeComponents(py::list shape, MlirType elementType, + MlirAttribute attribute) + : shape(shape), elementType(elementType), attribute(attribute), + ranked(true) {} + PyShapedTypeComponents(PyShapedTypeComponents &) = delete; + PyShapedTypeComponents(PyShapedTypeComponents &&other) + : shape(other.shape), elementType(other.elementType), + attribute(other.attribute), ranked(other.ranked) {} + + static void bind(py::module &m) { + py::class_(m, "ShapedTypeComponents", + py::module_local()) + .def_property_readonly( + "element_type", + [](PyShapedTypeComponents &self) { + return PyType(PyMlirContext::forContext( + mlirTypeGetContext(self.elementType)), + self.elementType); + }, + "Returns the element type of the shaped type components.") + .def_static( + "get", + [](PyType &elementType) { + return PyShapedTypeComponents(elementType); + }, + py::arg("element_type"), + "Create an shaped type components object with only the element " + "type.") + .def_static( + "get", + [](py::list shape, PyType &elementType) { + return PyShapedTypeComponents(shape, elementType); + }, + py::arg("shape"), py::arg("element_type"), + "Create a ranked shaped type components object.") + .def_static( + "get", + [](py::list shape, PyType &elementType, PyAttribute &attribute) { + return PyShapedTypeComponents(shape, elementType, attribute); + }, + py::arg("shape"), py::arg("element_type"), py::arg("attribute"), + "Create a ranked shaped type components object with attribute.") + .def_property_readonly( + "has_rank", + [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, + "Returns whether the given shaped type component is ranked.") + .def_property_readonly( + "rank", + [](PyShapedTypeComponents &self) { return self.shape.size(); }, + "Returns the rank of the given ranked shaped type components.") + .def_property_readonly( + "shape", [](PyShapedTypeComponents &self) { return self.shape; }, + "Returns the shape of the ranked shaped type components as a list " + "of integers."); + } + + pybind11::object getCapsule(); + static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); + +private: + py::list shape; + MlirType elementType; + MlirAttribute attribute; + bool ranked{false}; +}; + +/// Python wrapper for InferShapedTypeOpInterface. This interface has only +/// static methods. +class PyInferShapedTypeOpInterface + : public PyConcreteOpInterface { +public: + using PyConcreteOpInterface< + PyInferShapedTypeOpInterface>::PyConcreteOpInterface; + + constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; + constexpr static GetTypeIDFunctionTy getInterfaceID = + &mlirInferShapedTypeOpInterfaceTypeID; + + /// C-style user-data structure for type appending callback. + struct AppendResultsCallbackData { + std::vector &inferredShapedTypeComponents; + }; + + /// Appends the types provided as the two first arguments to the user-data + /// structure (expects AppendResultsCallbackData). + static void appendResultsCallback(intptr_t rank, const int64_t *shape, + MlirType elementType, + MlirAttribute attribute, void *userData) { + auto *data = static_cast(userData); + py::list shapeList; + for (intptr_t i = 0; i < rank; ++i) { + shapeList.append(shape[i]); + } + data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, + attribute); + } + + /// Given the arguments required to build an operation, attempts to infer its + /// return types. Throws value_error on failure. + std::vector + inferReturnTypeComponents(std::optional operandList, + std::optional attributes, + std::optional> regions, + DefaultingPyMlirContext context, + DefaultingPyLocation location) { + llvm::SmallVector mlirOperands = wrapOperands(operandList); + llvm::SmallVector mlirRegions = wrapRegions(regions); + + std::vector inferredShapedTypeComponents; + PyMlirContext &pyContext = context.resolve(); + AppendResultsCallbackData data{inferredShapedTypeComponents}; + MlirStringRef opNameRef = + mlirStringRefCreate(getOpName().data(), getOpName().length()); + MlirAttribute attributeDict = + attributes ? attributes->get() : mlirAttributeGetNull(); + + MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( + opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), + mlirOperands.data(), attributeDict, mlirRegions.size(), + mlirRegions.data(), &appendResultsCallback, &data); + + if (mlirLogicalResultIsFailure(result)) { + throw py::value_error("Failed to infer result shape type components"); + } + + return inferredShapedTypeComponents; + } + + static void bindDerived(ClassTy &cls) { + cls.def("inferReturnTypeComponents", + &PyInferShapedTypeOpInterface::inferReturnTypeComponents, + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), py::arg("regions") = py::none(), + py::arg("context") = py::none(), py::arg("loc") = py::none(), + inferReturnTypeComponentsDoc); + } +}; + +void populateIRInterfaces(py::module &m) { + PyInferTypeOpInterface::bind(m); + PyShapedTypeComponents::bind(m); + PyInferShapedTypeOpInterface::bind(m); +} } // namespace python } // namespace mlir diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -9,14 +9,65 @@ #include "mlir-c/Interfaces.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Interfaces.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/ScopeExit.h" #include using namespace mlir; +namespace { + +std::optional +getRegisteredOperationName(MlirContext context, MlirStringRef opName) { + StringRef name(opName.data, opName.length); + std::optional info = + RegisteredOperationName::lookup(name, unwrap(context)); + return info; +} + +std::optional maybeGetLocation(MlirLocation location) { + std::optional maybeLocation; + if (!mlirLocationIsNull(location)) + maybeLocation = unwrap(location); + return maybeLocation; +} + +SmallVector unwrapOperands(intptr_t nOperands, MlirValue *operands) { + SmallVector unwrappedOperands; + (void)unwrapList(nOperands, operands, unwrappedOperands); + return unwrappedOperands; +} + +DictionaryAttr unwrapAttributes(MlirAttribute attributes) { + DictionaryAttr attributeDict; + if (!mlirAttributeIsNull(attributes)) + attributeDict = unwrap(attributes).cast(); + return attributeDict; +} + +SmallVector> unwrapRegions(intptr_t nRegions, + MlirRegion *regions) { + // Create a vector of unique pointers to regions and make sure they are not + // deleted when exiting the scope. This is a hack caused by C++ API expecting + // an list of unique pointers to regions (without ownership transfer + // semantics) and C API making ownership transfer explicit. + SmallVector> unwrappedRegions; + unwrappedRegions.reserve(nRegions); + for (intptr_t i = 0; i < nRegions; ++i) + unwrappedRegions.emplace_back(unwrap(*(regions + i))); + auto cleaner = llvm::make_scope_exit([&]() { + for (auto ®ion : unwrappedRegions) + region.release(); + }); + return unwrappedRegions; +} + +} // namespace + bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID) { std::optional info = @@ -41,33 +92,16 @@ intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData) { - StringRef name(opName.data, opName.length); std::optional info = - RegisteredOperationName::lookup(name, unwrap(context)); + getRegisteredOperationName(context, opName); if (!info) return mlirLogicalResultFailure(); - std::optional maybeLocation; - if (!mlirLocationIsNull(location)) - maybeLocation = unwrap(location); - SmallVector unwrappedOperands; - (void)unwrapList(nOperands, operands, unwrappedOperands); - DictionaryAttr attributeDict; - if (!mlirAttributeIsNull(attributes)) - attributeDict = unwrap(attributes).cast(); - - // Create a vector of unique pointers to regions and make sure they are not - // deleted when exiting the scope. This is a hack caused by C++ API expecting - // an list of unique pointers to regions (without ownership transfer - // semantics) and C API making ownership transfer explicit. - SmallVector> unwrappedRegions; - unwrappedRegions.reserve(nRegions); - for (intptr_t i = 0; i < nRegions; ++i) - unwrappedRegions.emplace_back(unwrap(*(regions + i))); - auto cleaner = llvm::make_scope_exit([&]() { - for (auto ®ion : unwrappedRegions) - region.release(); - }); + std::optional maybeLocation = maybeGetLocation(location); + SmallVector unwrappedOperands = unwrapOperands(nOperands, operands); + DictionaryAttr attributeDict = unwrapAttributes(attributes); + SmallVector> unwrappedRegions = + unwrapRegions(nRegions, regions); SmallVector inferredTypes; if (failed(info->getInterface()->inferReturnTypes( @@ -82,3 +116,47 @@ callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); return mlirLogicalResultSuccess(); } + +MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() { + return wrap(InferShapedTypeOpInterface::getInterfaceID()); +} + +MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, + MlirShapedTypeComponentsCallback callback, void *userData) { + std::optional info = + getRegisteredOperationName(context, opName); + if (!info) + return mlirLogicalResultFailure(); + + std::optional maybeLocation = maybeGetLocation(location); + SmallVector unwrappedOperands = unwrapOperands(nOperands, operands); + DictionaryAttr attributeDict = unwrapAttributes(attributes); + SmallVector> unwrappedRegions = + unwrapRegions(nRegions, regions); + + SmallVector inferredTypeComponents; + if (failed(info->getInterface() + ->inferReturnTypeComponents( + unwrap(context), maybeLocation, + mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)), + attributeDict, unwrappedRegions, inferredTypeComponents))) + return mlirLogicalResultFailure(); + + intptr_t rank; + const int64_t *shapeData; + for (ShapedTypeComponents t : inferredTypeComponents) { + if (t.hasRank()) { + rank = t.getDims().size(); + shapeData = t.getDims().data(); + } else { + rank = 0; + shapeData = nullptr; + } + callback(rank, shapeData, wrap(t.getElementType()), wrap(t.getAttribute()), + userData); + } + return mlirLogicalResultSuccess(); +} diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -62,6 +62,7 @@ "FloatAttr", "FunctionType", "IndexType", + "InferShapedTypeOpInterface", "InferTypeOpInterface", "InsertionPoint", "IntegerAttr", @@ -88,6 +89,7 @@ "RegionIterator", "RegionSequence", "ShapedType", + "ShapedTypeComponents", "StringAttr", "SymbolTable", "TupleType", @@ -689,6 +691,14 @@ @staticmethod def isinstance(arg: Any) -> bool: ... +class InferShapedTypeOpInterface: + def __init__(self, object: object, context: Optional[Context] = None) -> None: ... + def inferReturnTypeComponents(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[ShapedTypeComponents]: ... + @property + def operation(self) -> Operation: ... + @property + def opview(self) -> OpView: ... + class InferTypeOpInterface: def __init__(self, object: object, context: Optional[Context] = None) -> None: ... def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... @@ -1016,6 +1026,18 @@ @property def shape(self) -> List[int]: ... +class ShapedTypeComponents: + @property + def element_type(self) -> Type: ... + @staticmethod + def get(*args, **kwargs) -> ShapedTypeComponents: ... + @property + def has_rank(self) -> bool: ... + @property + def rank(self) -> int: ... + @property + def shape(self) -> List[int]: ... + # TODO: Auto-generated. Audit and fix. class StringAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,6 +1,7 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * +import mlir.dialects.func as func import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor @@ -330,3 +331,31 @@ # CHECK: False print(tt.is_null()) + + +# CHECK-LABEL: TEST: inferReturnTypeComponents +@run +def inferReturnTypeComponents(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + resultType = UnrankedTensorType.get(i32) + operandType = RankedTensorType.get([1, 3, 10, 10], i32) + f = func.FuncOp("test_inferReturnTypeComponents", ([operandType], [resultType])) + entry_block = Block.create_at_start(f.operation.regions[0], [operandType]) + + with InsertionPoint(entry_block): + op = test.InferShapedTypeComponentsOp(resultType, entry_block.arguments[0]) + + # CHECK: has rank: True + # CHECK: rank: 4 + # CHECK: element type: i32 + # CHECK: shape: [1, 3, 10, 10] + iface = InferShapedTypeOpInterface(op) + shaped_type_components = iface.inferReturnTypeComponents(operands=[op.operand])[0] + print("has rank:", shaped_type_components.has_rank) + print("rank:", shaped_type_components.rank) + print("element type:", shaped_type_components.element_type) + print("shape:", shaped_type_components.shape) diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -89,6 +89,33 @@ let results = (outs I32:$integer, F64:$flt, Index:$index); } +def InferShapedTypeComponentsOp : TestOp<"infer_shaped_type_components_op", + [DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTensor:$operand); + let results = (outs AnyTensor:$result); + + let extraClassDefinition = [{ + ::mlir::LogicalResult $cppClass::inferReturnTypeComponents( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl< + ::mlir::ShapedTypeComponents>& inferredShapedTypeComponents) { + $cppClass::Adaptor adaptor(operands, attributes, regions); + auto operandType = + adaptor.getOperand().getType().cast<::mlir::ShapedType>(); + if (operandType.hasRank()) { + inferredShapedTypeComponents.emplace_back(operandType.getShape(), + operandType.getElementType()); + } else { + inferredShapedTypeComponents.emplace_back(operandType.getElementType()); + } + return ::mlir::success(); + } + }]; +} + def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op", [SameOperandsAndResultType]> { let arguments = (ins Variadic);