diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType +from .._mlir_libs._mlirPythonTest import * def register_python_test_dialect(context, load=True): diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -247,7 +247,6 @@ module = Module.create() with InsertionPoint(module.body): - op1 = test.OptionalOperandOp() # CHECK: op1.input is None: True print(f"op1.input is None: {op1.input is None}") @@ -487,3 +486,77 @@ two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero) # CHECK: f32 print(two_operands.result.type) + + +@run +def testTypeCasting(): + def try_print(x): + try: + test.print(x) + except Exception as ex: + print("error:", ex) + + # CHECK: bool: True + try_print(True) + + # CHECK: int: 42 + try_print(42) + + # CHECK: float: 4.25 + try_print(4.25) + + # CHECK: str: hello + try_print("hello") + + # CHECK: error: print(): incompatible function arguments. + try_print(object()) + + # CHECK: str: