diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h --- a/mlir/include/mlir-c/Interfaces.h +++ b/mlir/include/mlir-c/Interfaces.h @@ -22,6 +22,29 @@ extern "C" { #endif +//===----------------------------------------------------------------------===// +// Opaque type declarations. +// +// Types are exposed to C bindings as structs containing opaque pointers. They +// are not supposed to be inspected from C. This allows the underlying +// representation to change without affecting the API users. The use of structs +// instead of typedefs enables some type safety as structs are not implicitly +// convertible to each other. +// +// Instances of these types may or may not own the underlying object. The +// ownership semantics is defined by how an instance of the type was obtained. +//===----------------------------------------------------------------------===// + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirShapedTypeComponents, void); + +#undef DEFINE_C_API_STRUCT + /// Returns `true` if the given operation implements an interface identified by /// its TypeID. MLIR_CAPI_EXPORTED bool @@ -60,6 +83,69 @@ intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData); +//===----------------------------------------------------------------------===// +// ShapedTypeComponents. +//===----------------------------------------------------------------------===// + +static inline bool +mlirShapedTypeComponentsIsNull(MlirShapedTypeComponents shapedTypeComponents) { + return !shapedTypeComponents.ptr; +} + +/// Create a ShapedTypeComponents from an element type +MLIR_CAPI_EXPORTED MlirShapedTypeComponents +mlirShapedTypeComponentsGet(MlirType elementType); + +/// Create a ShapedTypeComponents from shape and element type +MLIR_CAPI_EXPORTED MlirShapedTypeComponents mlirShapedTypeComponentsGetRanked( + intptr_t rank, const int64_t *shape, MlirType elementType); + +/// Return whether the ShapedTypeComponents has a rank. +MLIR_CAPI_EXPORTED bool +mlirShapedTypeComponentsHasRank(MlirShapedTypeComponents shapedTypeComponents); + +/// Return the rank of the ShapedTypeComponents if it has one. +MLIR_CAPI_EXPORTED int64_t +mlirShapedTypeComponentsGetRank(MlirShapedTypeComponents shapedTypeComponents); + +/// Returns the dim-th dimension of the given ShapedTypeComponents +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeComponentsGetDimSize( + MlirShapedTypeComponents shapedTypeComponents, intptr_t dim); + +/// Return the element type component. +MLIR_CAPI_EXPORTED MlirType mlirShapedTypeComponentsGetElementType( + MlirShapedTypeComponents shapedTypeComponents); + +// Destruct a MlirShapedTypeComponents +MLIR_CAPI_EXPORTED void +mlirShapedTypeComponentsDestroy(MlirShapedTypeComponents shapedTypeComponents); + +//===----------------------------------------------------------------------===// +// InferShapedTypeOpInterface. +//===----------------------------------------------------------------------===// + +/// Returns the interface TypeID of the InferShapedTypeOpInterface. +MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(); + +/// These callbacks are used to return multiple ShapedTypeComponents from +/// functions while transferring ownership to the caller. The first argument is +/// the number of consecutive elements pointed to by the second argument. The +/// third argument is an opaque pointer forwarded to the callback by the caller. +typedef void (*MlirShapedTypeComponentsCallback)(intptr_t, + MlirShapedTypeComponents *, + void *); + +/// Infers the return types of the operation identified by its canonical given +/// the arguments that will be supplied to its generic builder. Calls `callback` +/// with the types of inferred arguments, potentially several times, on success. +/// Returns failure otherwise. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirInferShapedTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, + MlirShapedTypeComponentsCallback callback, void *userData); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h --- a/mlir/include/mlir/CAPI/Interfaces.h +++ b/mlir/include/mlir/CAPI/Interfaces.h @@ -15,4 +15,9 @@ #ifndef MLIR_CAPI_INTERFACES_H #define MLIR_CAPI_INTERFACES_H +#include "mlir/CAPI/Wrap.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +DEFINE_C_API_PTR_METHODS(MlirShapedTypeComponents, mlir::ShapedTypeComponents) + #endif // MLIR_CAPI_INTERFACES_H diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" @@ -35,6 +35,10 @@ R"(Given the arguments required to build an operation, attempts to infer its return types. Raises ValueError on failure.)"; +constexpr static const char *inferReturnTypeComponentsDoc = + R"(Given the arguments required to build an operation, attempts to infer +its return shaped type components. Raises ValueError on failure.)"; + /// CRTP base class for Python classes representing MLIR Op interfaces. /// Interface hierarchies are flat so no base class is expected here. The /// derived class is expected to define the following static fields: @@ -104,7 +108,7 @@ /// Creates the Python bindings for this class in the given module. static void bind(py::module &m) { - py::class_ cls(m, "InferTypeOpInterface", + py::class_ cls(m, ConcreteIface::pyClassName, py::module_local()); cls.def(py::init(), py::arg("object"), py::arg("context") = py::none(), constructorDoc) @@ -155,7 +159,7 @@ py::object obj; }; -/// Python wrapper for InterTypeOpInterface. This interface has only static +/// Python wrapper for InferTypeOpInterface. This interface has only static /// methods. class PyInferTypeOpInterface : public PyConcreteOpInterface { @@ -197,8 +201,8 @@ if (operandList && !operandList->empty()) { // Note: as the list may contain other lists this may not be final size. mlirOperands.reserve(operandList->size()); - for (const auto& it : llvm::enumerate(*operandList)) { - PyValue* val; + for (const auto &it : llvm::enumerate(*operandList)) { + PyValue *val; try { val = py::cast(it.value()); if (!val) @@ -274,7 +278,217 @@ } }; -void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } +/// Wrapper around an MlirShapedTypeComponents. +/// Upon construction, the Python wrapper takes ownership of the +/// underlying MlirShapedTypeComponents. +class PyShapedTypeComponents { +public: + PyShapedTypeComponents(MlirShapedTypeComponents shapedTypeComponents) + : shapedTypeComponents(shapedTypeComponents) {} + ~PyShapedTypeComponents() { + if (!mlirShapedTypeComponentsIsNull(shapedTypeComponents)) + mlirShapedTypeComponentsDestroy(shapedTypeComponents); + } + PyShapedTypeComponents(PyShapedTypeComponents &) = delete; + PyShapedTypeComponents(PyShapedTypeComponents &&other) + : shapedTypeComponents(other.shapedTypeComponents) { + other.shapedTypeComponents = {nullptr}; + } + + static void bind(py::module &m) { + py::class_(m, "ShapedTypeComponents", + py::module_local()) + .def_property_readonly( + "element_type", + [](PyShapedTypeComponents &self) { + MlirType t = mlirShapedTypeComponentsGetElementType(self); + return PyType(PyMlirContext::forContext(mlirTypeGetContext(t)), + t); + }, + "Returns the element type of the shaped type components.") + .def_static( + "get", + [](PyType &elementType) { + MlirShapedTypeComponents shapedTypeComponents = + mlirShapedTypeComponentsGet(elementType); + return PyShapedTypeComponents(shapedTypeComponents); + }, + py::arg("element_type"), + "Create a ranked shaped type components object") + .def_static( + "get", + [](std::vector shape, PyType &elementType) { + MlirShapedTypeComponents shapedTypeComponents = + mlirShapedTypeComponentsGetRanked(shape.size(), shape.data(), + elementType); + return PyShapedTypeComponents(shapedTypeComponents); + }, + py::arg("shape"), py::arg("element_type"), + "Create a ranked shaped type components object") + .def_property_readonly( + "has_rank", + [](PyShapedTypeComponents &self) -> bool { + return mlirShapedTypeComponentsHasRank(self); + }, + "Returns whether the given shaped type component is ranked.") + .def_property_readonly( + "rank", + [](PyShapedTypeComponents &self) { + return mlirShapedTypeComponentsGetRank(self); + }, + "Returns the rank of the given ranked shaped type components.") + .def_property_readonly( + "shape", + [](PyShapedTypeComponents &self) { + std::vector shape; + int64_t rank = mlirShapedTypeComponentsGetRank(self); + shape.reserve(rank); + for (int64_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeComponentsGetDimSize(self, i)); + return shape; + }, + "Returns the shape of the ranked shaped type components as a list " + "of integers."); + } + + operator MlirShapedTypeComponents() const { return shapedTypeComponents; } + MlirShapedTypeComponents get() const { return shapedTypeComponents; } + + pybind11::object getCapsule(); + static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); + +private: + MlirShapedTypeComponents shapedTypeComponents; +}; + +/// Python wrapper for InferShapedTypeOpInterface. This interface has only +/// static methods. +class PyInferShapedTypeOpInterface + : public PyConcreteOpInterface { +public: + using PyConcreteOpInterface< + PyInferShapedTypeOpInterface>::PyConcreteOpInterface; + + constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; + constexpr static GetTypeIDFunctionTy getInterfaceID = + &mlirInferShapedTypeOpInterfaceTypeID; + + /// C-style user-data structure for type appending callback. + struct AppendResultsCallbackData { + std::vector &inferredShapedTypeComponents; + }; + + /// Appends the types provided as the two first arguments to the user-data + /// structure (expects AppendResultsCallbackData). + static void + appendResultsCallback(intptr_t nTypeComponents, + MlirShapedTypeComponents *shapedTypeComponents, + void *userData) { + auto *data = static_cast(userData); + data->inferredShapedTypeComponents.reserve( + data->inferredShapedTypeComponents.size() + nTypeComponents); + for (intptr_t i = 0; i < nTypeComponents; ++i) { + data->inferredShapedTypeComponents.emplace_back(shapedTypeComponents[i]); + } + } + + /// Given the arguments required to build an operation, attempts to infer its + /// return types. Throws value_error on failure. + std::vector + inferReturnTypeComponents(std::optional operandList, + std::optional attributes, + std::optional> regions, + DefaultingPyMlirContext context, + DefaultingPyLocation location) { + llvm::SmallVector mlirOperands; + llvm::SmallVector mlirRegions; + + if (operandList && !operandList->empty()) { + // Note: as the list may contain other lists this may not be final size. + mlirOperands.reserve(operandList->size()); + for (const auto &it : llvm::enumerate(*operandList)) { + PyValue *val; + try { + val = py::cast(it.value()); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + continue; + } catch (py::cast_error &err) { + // Intentionally unhandled to try sequence below first. + (void)err; + } + + try { + auto vals = py::cast(it.value()); + for (py::object v : vals) { + try { + val = py::cast(v); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + + ")") + .str()); + } + } + continue; + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + ")") + .str()); + } + + throw py::cast_error(); + } + } + + if (regions) { + mlirRegions.reserve(regions->size()); + for (PyRegion ®ion : *regions) { + mlirRegions.push_back(region); + } + } + + std::vector inferredShapedTypeComponents; + PyMlirContext &pyContext = context.resolve(); + AppendResultsCallbackData data{inferredShapedTypeComponents}; + MlirStringRef opNameRef = + mlirStringRefCreate(getOpName().data(), getOpName().length()); + MlirAttribute attributeDict = + attributes ? attributes->get() : mlirAttributeGetNull(); + + MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( + opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), + mlirOperands.data(), attributeDict, mlirRegions.size(), + mlirRegions.data(), &appendResultsCallback, &data); + + if (mlirLogicalResultIsFailure(result)) { + throw py::value_error("Failed to infer result types"); + } + + return inferredShapedTypeComponents; + } + + static void bindDerived(ClassTy &cls) { + cls.def("inferReturnTypeComponents", + &PyInferShapedTypeOpInterface::inferReturnTypeComponents, + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), py::arg("regions") = py::none(), + py::arg("context") = py::none(), py::arg("loc") = py::none(), + inferReturnTypeComponentsDoc); + } +}; + +void populateIRInterfaces(py::module &m) { + PyInferTypeOpInterface::bind(m); + PyShapedTypeComponents::bind(m); + PyInferShapedTypeOpInterface::bind(m); +} } // namespace python } // namespace mlir diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -9,8 +9,10 @@ #include "mlir-c/Interfaces.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Interfaces.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/ScopeExit.h" #include @@ -82,3 +84,96 @@ callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); return mlirLogicalResultSuccess(); } + +MlirShapedTypeComponents mlirShapedTypeComponentsGet(MlirType elementType) { + ShapedTypeComponents *shapedTypeComponents = + new ShapedTypeComponents(unwrap(elementType)); + return wrap(shapedTypeComponents); +} + +MlirShapedTypeComponents +mlirShapedTypeComponentsGetRanked(intptr_t rank, const int64_t *shape, + MlirType elementType) { + ShapedTypeComponents *shapedTypeComponents = new ShapedTypeComponents( + llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType)); + return wrap(shapedTypeComponents); +} + +bool mlirShapedTypeComponentsHasRank( + MlirShapedTypeComponents shapedTypeComponents) { + return unwrap(shapedTypeComponents)->hasRank(); +} + +int64_t +mlirShapedTypeComponentsGetRank(MlirShapedTypeComponents shapedTypeComponents) { + return unwrap(shapedTypeComponents)->getDims().size(); +} + +int64_t mlirShapedTypeComponentsGetDimSize( + MlirShapedTypeComponents shapedTypeComponents, intptr_t dim) { + return unwrap(shapedTypeComponents)->getDims()[static_cast(dim)]; +} + +MlirType mlirShapedTypeComponentsGetElementType( + MlirShapedTypeComponents shapedTypeComponents) { + return wrap(unwrap(shapedTypeComponents)->getElementType()); +} + +void mlirShapedTypeComponentsDestroy( + MlirShapedTypeComponents shapedTypeComponents) { + delete unwrap(shapedTypeComponents); +} + +MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() { + return wrap(InferShapedTypeOpInterface::getInterfaceID()); +} + +MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, + MlirShapedTypeComponentsCallback callback, void *userData) { + StringRef name(opName.data, opName.length); + std::optional info = + RegisteredOperationName::lookup(name, unwrap(context)); + if (!info) + return mlirLogicalResultFailure(); + + std::optional maybeLocation; + if (!mlirLocationIsNull(location)) + maybeLocation = unwrap(location); + SmallVector unwrappedOperands; + (void)unwrapList(nOperands, operands, unwrappedOperands); + DictionaryAttr attributeDict; + if (!mlirAttributeIsNull(attributes)) + attributeDict = unwrap(attributes).cast(); + + // Create a vector of unique pointers to regions and make sure they are not + // deleted when exiting the scope. This is a hack caused by C++ API expecting + // an list of unique pointers to regions (without ownership transfer + // semantics) and C API making ownership transfer explicit. + SmallVector> unwrappedRegions; + unwrappedRegions.reserve(nRegions); + for (intptr_t i = 0; i < nRegions; ++i) + unwrappedRegions.emplace_back(unwrap(*(regions + i))); + auto cleaner = llvm::make_scope_exit([&]() { + for (auto ®ion : unwrappedRegions) + region.release(); + }); + + SmallVector inferredTypeComponents; + if (failed(info->getInterface() + ->inferReturnTypeComponents( + unwrap(context), maybeLocation, + mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)), + attributeDict, unwrappedRegions, inferredTypeComponents))) + return mlirLogicalResultFailure(); + + SmallVector wrappedInferredTypeComponents; + wrappedInferredTypeComponents.reserve(inferredTypeComponents.size()); + for (ShapedTypeComponents t : inferredTypeComponents) + wrappedInferredTypeComponents.push_back(wrap(new ShapedTypeComponents(t))); + callback(wrappedInferredTypeComponents.size(), + wrappedInferredTypeComponents.data(), userData); + return mlirLogicalResultSuccess(); +} diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -63,6 +63,7 @@ "FunctionType", "IndexType", "InferTypeOpInterface", + "InferShapedTypeOpInterface", "InsertionPoint", "IntegerAttr", "IntegerSet", @@ -88,6 +89,7 @@ "RegionIterator", "RegionSequence", "ShapedType", + "ShapedTypeComponents", "StringAttr", "SymbolTable", "TupleType", @@ -690,6 +692,14 @@ def isinstance(arg: Any) -> bool: ... class InferTypeOpInterface: + def __init__(self, object: object, context: Optional[Context] = None) -> None: ... + def inferReturnTypeComponents(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[ShapedTypeComponents]: ... + @property + def operation(self) -> Operation: ... + @property + def opview(self) -> OpView: ... + +class InferShapedTypeOpInterface: def __init__(self, object: object, context: Optional[Context] = None) -> None: ... def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... @property @@ -1016,6 +1026,18 @@ @property def shape(self) -> List[int]: ... +class ShapedTypeComponents: + @property + def element_type(self) -> Type: ... + @staticmethod + def get(*args, **kwargs) -> ShapedTypeComponents: ... + @property + def has_rank(self) -> bool: ... + @property + def rank(self) -> int: ... + @property + def shape(self) -> List[int]: ... + # TODO: Auto-generated. Audit and fix. class StringAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,6 +1,7 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * +import mlir.dialects.func as func import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor @@ -330,3 +331,31 @@ # CHECK: False print(tt.is_null()) + + +# CHECK-LABEL: TEST: inferReturnTypeComponents +@run +def inferReturnTypeComponents(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + resultType = UnrankedTensorType.get(i32) + operandType = RankedTensorType.get([1, 3, 10, 10], i32) + f = func.FuncOp("test_inferReturnTypeComponents", ([operandType], [resultType])) + entry_block = Block.create_at_start(f.operation.regions[0], [operandType]) + + with InsertionPoint(entry_block): + op = test.InferShapedTypeComponentsOp(resultType, entry_block.arguments[0]) + + # CHECK: has rank: True + # CHECK: rank: 4 + # CHECK: element type: i32 + # CHECK: shape: [1, 3, 10, 10] + iface = InferShapedTypeOpInterface(op) + shaped_type_components = iface.inferReturnTypeComponents(operands=[op.operand])[0] + print("has rank:", shaped_type_components.has_rank) + print("rank:", shaped_type_components.rank) + print("element type:", shaped_type_components.element_type) + print("shape:", shaped_type_components.shape) diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -89,6 +89,33 @@ let results = (outs I32:$integer, F64:$flt, Index:$index); } +def InferShapedTypeComponentsOp : TestOp<"infer_shaped_type_components_op", + [DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTensor:$operand); + let results = (outs AnyTensor:$result); + + let extraClassDefinition = [{ + ::mlir::LogicalResult $cppClass::inferReturnTypeComponents( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl< + ::mlir::ShapedTypeComponents>& inferredShapedTypeComponents) { + $cppClass::Adaptor adaptor(operands, attributes, regions); + auto operandType = + adaptor.getOperand().getType().cast<::mlir::ShapedType>(); + if (operandType.hasRank()) { + inferredShapedTypeComponents.emplace_back(operandType.getShape(), + operandType.getElementType()); + } else{ + inferredShapedTypeComponents.emplace_back(operandType.getElementType()); + } + return ::mlir::success(); + } + }]; +} + def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op", [SameOperandsAndResultType]> { let arguments = (ins Variadic);