diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -136,7 +136,9 @@ ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/StandardOps.td - SOURCES dialects/std.py + SOURCES + dialects/std.py + dialects/_std_ops_ext.py DIALECT_NAME std) declare_mlir_dialect_python_bindings( diff --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_std_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_std_ops_ext.py @@ -0,0 +1,71 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from .builtin import FuncOp + from ._ods_common import get_default_loc_context as _get_default_loc_context + + from typing import Any, List, Optional, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +def _isa(obj: Any, cls: type): + try: + cls(obj) + except ValueError: + return False + return True + + +def _is_any_of(obj: Any, classes: List[type]): + return any(_isa(obj, cls) for cls in classes) + + +def _is_integer_like_type(type: Type): + return _is_any_of(type, [IntegerType, IndexType]) + + +def _is_float_type(type: Type): + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + + +class ConstantOp: + """Specialization for the constant op class.""" + + def __init__(self, + result: Type, + value: Union[int, float, Attribute], + *, + loc=None, + ip=None): + if isinstance(value, int): + super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip) + else: + super().__init__(result, value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), + value, + loc=loc, + ip=ip) + + @property + def type(self): + return self.results[0].type + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") diff --git a/mlir/test/python/dialects/std.py b/mlir/test/python/dialects/std.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/std.py @@ -0,0 +1,64 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import std + + +def constructAndPrintInModule(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f + +# CHECK-LABEL: TEST: testConstantOp + +@constructAndPrintInModule +def testConstantOp(): + c1 = std.ConstantOp(IntegerType.get_signless(32), 42) + c2 = std.ConstantOp(IntegerType.get_signless(64), 100) + c3 = std.ConstantOp(F32Type.get(), 3.14) + c4 = std.ConstantOp(F64Type.get(), 1.23) + # CHECK: 42 + print(c1.literal_value) + + # CHECK: 100 + print(c2.literal_value) + + # CHECK: 3.140000104904175 + print(c3.literal_value) + + # CHECK: 1.23 + print(c4.literal_value) + +# CHECK: = constant 42 : i32 +# CHECK: = constant 100 : i64 +# CHECK: = constant 3.140000e+00 : f32 +# CHECK: = constant 1.230000e+00 : f64 + +# CHECK-LABEL: TEST: testVectorConstantOp +@constructAndPrintInModule +def testVectorConstantOp(): + int_type = IntegerType.get_signless(32) + vec_type = VectorType.get([2, 2], int_type) + c1 = std.ConstantOp(vec_type, + DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))) + try: + print(c1.literal_value) + except ValueError as e: + assert "only integer and float constants have literal values" in str(e) + else: + assert False + +# CHECK: = constant dense<42> : vector<2x2xi32> + +# CHECK-LABEL: TEST: testConstantIndexOp +@constructAndPrintInModule +def testConstantIndexOp(): + c1 = std.ConstantOp.create_index(10) + # CHECK: 10 + print(c1.literal_value) + +# CHECK: = constant 10 : index