diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -61,7 +61,11 @@ // Define the function attribute enums matching the OpDSL functions. def UnaryFn : I32EnumAttr<"UnaryFn", "", [ I32EnumAttrCase<"exp", 0>, - I32EnumAttrCase<"log", 1> + I32EnumAttrCase<"log", 1>, + I32EnumAttrCase<"abs", 2>, + I32EnumAttrCase<"ceil", 3>, + I32EnumAttrCase<"floor", 4>, + I32EnumAttrCase<"negf", 5> ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::linalg"; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -144,6 +144,14 @@ return builder.create(arg.getLoc(), arg); case UnaryFn::log: return builder.create(arg.getLoc(), arg); + case UnaryFn::abs: + return builder.create(arg.getLoc(), arg); + case UnaryFn::ceil: + return builder.create(arg.getLoc(), arg); + case UnaryFn::floor: + return builder.create(arg.getLoc(), arg); + case UnaryFn::negf: + return builder.create(arg.getLoc(), arg); } llvm_unreachable("unsupported unary function"); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -274,6 +274,10 @@ """Unary function namespace.""" exp = UnaryFnType("exp") log = UnaryFnType("log") + abs = UnaryFnType("abs") + ceil = UnaryFnType("ceil") + floor = UnaryFnType("floor") + negf = UnaryFnType("negf") class BinaryFnType: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -390,6 +390,26 @@ return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") + def _unary_abs(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.AbsOp(x).result + raise NotImplementedError("Unsupported 'abs' operand: {x}") + + def _unary_ceil(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.CeilOp(x).result + raise NotImplementedError("Unsupported 'ceil' operand: {x}") + + def _unary_floor(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.FloorOp(x).result + raise NotImplementedError("Unsupported 'floor' operand: {x}") + + def _unary_negf(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return arith.NegFOp(x).result + raise NotImplementedError("Unsupported 'negf' operand: {x}") + def _binary_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.AddFOp(lhs, rhs).result diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -298,6 +298,54 @@ // ----- +// Verifies the fun attribute controls the unary function used. +func @generalize_elemwise_abs(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_unary {fun = #linalg.unary_fn} + ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_abs +// CHECK: = math.abs + +// ----- + +// Verifies the fun attribute controls the unary function used. +func @generalize_elemwise_ceil(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_unary {fun = #linalg.unary_fn} + ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_ceil +// CHECK: = math.ceil + +// ----- + +// Verifies the fun attribute controls the unary function used. +func @generalize_elemwise_floor(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_unary {fun = #linalg.unary_fn} + ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_floor +// CHECK: = math.floor + +// ----- + +// Verifies the fun attribute controls the unary function used. +func @generalize_elemwise_negf(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_unary {fun = #linalg.unary_fn} + ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_negf +// CHECK: = arith.negf + +// ----- + // Verifies the default value of the fun attribute is an add op. func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -11,7 +11,7 @@ # fill, matmul, convolution, or pooling tests. The features include: # - constant defined in the body # - fix/predefined types -# - exponential functions +# - some math/arith functions, including abs, ceil, exp, floor, log, and negf # - custom op names. @@ -89,6 +89,46 @@ def test_f32_elemwise_log(input, init_result): return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log) + # CHECK-LABEL: @test_f32_elemwise_abs + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.abs %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_elemwise_abs(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs) + + # CHECK-LABEL: @test_f32_elemwise_ceil + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_elemwise_ceil(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil) + + # CHECK-LABEL: @test_f32_elemwise_floor + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_elemwise_floor(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor) + + # CHECK-LABEL: @test_f32_elemwise_neg + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_elemwise_neg(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) + # Just check that we don't assert out on name mismatch. # CHECK-LABEL: @test_non_default_op_name @builtin.FuncOp.from_py_func(