diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -272,8 +272,12 @@ class UnaryFn: """Unary function namespace.""" + abs = UnaryFnType("abs") + ceil = UnaryFnType("ceil") exp = UnaryFnType("exp") + floor = UnaryFnType("floor") log = UnaryFnType("log") + neg = UnaryFnType("neg") class BinaryFnType: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -380,16 +380,36 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, True) + def _unary_abs(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.AbsOp(x).result + raise NotImplementedError("Unsupported 'abs' operand: {x}") + + def _unary_ceil(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.CeilOp(x).result + raise NotImplementedError("Unsupported 'ceil' operand: {x}") + def _unary_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") + def _unary_floor(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.FloorOp(x).result + raise NotImplementedError("Unsupported 'floor' operand: {x}") + def _unary_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") + def _unary_neg(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return arith.NegFOp(x).result + raise NotImplementedError("Unsupported 'neg' operand: {x}") + def _binary_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.AddFOp(lhs, rhs).result diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -11,7 +11,7 @@ # fill, matmul, convolution, or pooling tests. The features include: # - constant defined in the body # - fix/predefined types -# - exponential functions +# - some math/arith functions, including abs, ceil, exp, floor, log, and neg # - custom op names. @@ -69,6 +69,26 @@ def test_i32_index(init_result): return test_index(outs=[init_result]) + # CHECK-LABEL: @test_f32_elemwise_abs + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.abs %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_elemwise_abs(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs) + + # CHECK-LABEL: @test_f32_elemwise_ceil + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_elemwise_ceil(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil) + # CHECK-LABEL: @test_f32_elemwise_exp # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 @@ -79,6 +99,16 @@ def test_f32_elemwise_exp(input, init_result): return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp) + # CHECK-LABEL: @test_f32_elemwise_floor + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_elemwise_floor(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor) + # CHECK-LABEL: @test_f32_elemwise_log # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) # CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32 @@ -89,6 +119,16 @@ def test_f32_elemwise_log(input, init_result): return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log) + # CHECK-LABEL: @test_f32_elemwise_neg + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_elemwise_neg(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.neg) + # Just check that we don't assert out on name mismatch. # CHECK-LABEL: @test_non_default_op_name @builtin.FuncOp.from_py_func(