diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -2,13 +2,19 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Sequence +from typing import Any, Sequence import os _this_dir = os.path.dirname(__file__) +# These submodules have no type stubs and are thus opaque to the type checker. +_mlirConversions: Any +_mlirTransforms: Any +_mlirAllPassesRegistration: Any + + def get_lib_dirs() -> Sequence[str]: """Gets the lib directory for linking to shared libraries. diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -7,7 +7,10 @@ # * Local edits to signatures and types that MyPy did not auto detect (or # detected incorrectly). -from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple +from typing import ( + Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, + Type as _Type, TypeVar +) from typing import overload @@ -121,6 +124,8 @@ @property def results(self) -> OpResultList: ... +_TOperation = TypeVar("_TOperation", bound=_OperationBase) + # TODO: Auto-generated. Audit and fix. class AffineExpr: def __init__(self, *args, **kwargs) -> None: ... @@ -379,7 +384,7 @@ def isinstance(arg: Any) -> bool: ... class Block: - __hash__: ClassVar[None] = ... + __hash__: ClassVar[None] = ... # type: ignore def append(self, operation: _OperationBase) -> None: ... def create_after(self, *args: Type) -> Block: ... @staticmethod @@ -406,7 +411,7 @@ @property def arg_number(self) -> int: ... @property - def owner(self) -> Block: ... + def owner(self) -> Block: ... # type: ignore[override] class BlockArgumentList: def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... @@ -463,7 +468,7 @@ def _get_live_operation_count(self) -> int: ... def attach_diagnostic_handler(self, callback: Callable[[Diagnostic], bool]) -> DiagnosticHandler: ... def enable_multithreading(self, enable: bool) -> None: ... - def get_dialect_descriptor(dialect_name: str) -> DialectDescriptor: ... + def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor: ... def is_registered_operation(self, operation_name: str) -> bool: ... def __enter__(self) -> Context: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @@ -748,7 +753,7 @@ class Location: current: ClassVar[Location] = ... # read-only - __hash__: ClassVar[None] = ... + __hash__: ClassVar[None] = ... # type: ignore def _CAPICreate(self) -> Location: ... @staticmethod def callsite(callee: Location, frames: Sequence[Location], context: Optional[Context] = None) -> Location: ... @@ -787,6 +792,7 @@ class Module: def _CAPICreate(self) -> object: ... + @staticmethod def create(loc: Optional[Location] = None) -> Module: ... def dump(self) -> None: ... @staticmethod @@ -858,17 +864,19 @@ _ODS_RESULT_SEGMENTS: ClassVar[None] = ... def __init__(self, operation: _OperationBase) -> None: ... @classmethod - def build_generic(cls, results: Optional[Sequence[Type]] = None, + def build_generic( + cls: _Type[_TOperation], + results: Optional[Sequence[Type]] = None, operands: Optional[Sequence[Value]] = None, attributes: Optional[Dict[str, Attribute]] = None, successors: Optional[Sequence[Block]] = None, regions: Optional[int] = None, loc: Optional[Location] = None, - ip: Optional[InsertionPoint] = None) -> _OperationBase: ... + ip: Optional[InsertionPoint] = None) -> _TOperation: ... @property def context(self) -> Context: ... @property - def operation(self) -> _OperationBase: ... + def operation(self) -> Operation: ... class Operation(_OperationBase): def _CAPICreate(self) -> object: ... @@ -912,7 +920,7 @@ def encoding(self) -> Optional[Attribute]: ... class Region: - __hash__: ClassVar[None] = ... + __hash__: ClassVar[None] = ... # type: ignore @overload def __eq__(self, arg0: Region) -> bool: ... @overload