diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -10,6 +10,7 @@ from .... import linalg from .... import math from .... import arith +from .... import complex from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values from .scalar_expr import * @@ -408,6 +409,8 @@ def _unary_negf(self, x: Value) -> Value: if _is_floating_point_type(x.type): return arith.NegFOp(x).result + if _is_complex_type(x.type): + return complex.NegOp(x).result raise NotImplementedError("Unsupported 'negf' operand: {x}") def _binary_add(self, lhs: Value, rhs: Value) -> Value: @@ -415,6 +418,8 @@ return arith.AddFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.AddIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.AddOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") def _binary_sub(self, lhs: Value, rhs: Value) -> Value: @@ -422,6 +427,8 @@ return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.SubIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.SubOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") def _binary_mul(self, lhs: Value, rhs: Value) -> Value: @@ -429,6 +436,8 @@ return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MulIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.MulOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: @@ -512,6 +521,10 @@ block_arg_types.append(element_or_self_type) +def _is_complex_type(t: Type) -> bool: + return ComplexType.isinstance(t) + + def _is_floating_point_type(t: Type) -> bool: # TODO: Create a FloatType in the Python API and implement the switch # there. diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -44,6 +44,7 @@ with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() + c32 = ComplexType.get(f32) i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): @@ -129,6 +130,16 @@ def test_f32_elemwise_neg(input, init_result): return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) + # CHECK-LABEL: @test_c32_elemwise_neg + # CHECK: ^{{.*}}(%[[IN:.+]]: complex, %[[OUT:.+]]: complex) + # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex + # CHECK-NEXT: linalg.yield %[[EXP]] : complex + # CHECK-NEXT: -> tensor<4x16xcomplex> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)) + def test_c32_elemwise_neg(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) + # Just check that we don't assert out on name mismatch. # CHECK-LABEL: @test_non_default_op_name @func.FuncOp.from_py_func(