depends on D150839
This diff uses MlirTypeID to register TypeCasters (i.e., [](PyType pyType) -> DerivedTy { return pyType; }) for all concrete types (i.e., PyConcrete<...>) that are then queried for (by MlirTypeID) and called in struct type_caster<MlirType>::cast. The result is that anywhere an MlirType mlirType is returned from a python binding, that mlirType is automatically cast to the correct concrete type. For example:
c0 = arith.ConstantOp(f32, 0.0) # CHECK: F32Type(f32) print(repr(c0.result.type)) unranked_tensor_type = UnrankedTensorType.get(f32) unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result # CHECK: UnrankedTensorType print(type(unranked_tensor.type).__name__) # CHECK: UnrankedTensorType(tensor<*xf32>) print(repr(unranked_tensor.type))
This functionality immediately extends to typed attributes (i.e., attr.type).
The diff also implements similar functionality for mlir_type_subclasses but in a slightly different way - for such types (which have no cpp corresponding class or struct) the user must provide a type caster in python (similar to how AttrBuilder works) or in cpp as a py::cpp_function.
nit, wording here seems a bit awkward