diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -33,8 +33,8 @@ }]; let cppNamespace = "::mlir::linalg"; let dependentDialects = [ - "AffineDialect", "memref::MemRefDialect", "StandardOpsDialect", - "tensor::TensorDialect" + "AffineDialect", "math::MathDialect", "memref::MemRefDialect", + "StandardOpsDialect", "tensor::TensorDialect" ]; let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -887,3 +887,58 @@ scalar_const: '2.3283063999999999E-10 : f64' - !ScalarExpression scalar_arg: min +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: soft_plus_2d + cpp_class_name: SoftPlus2DOp + doc: |- + Implements the soft plus operator. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T + shape_map: affine_map<()[s0, s1] -> (s0, s1)> + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1] -> (s0, s1)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> + - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> + iterator_types: + - parallel + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: log + operands: + - !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_const: '1.000000e+00 : f64' + - !ScalarExpression + scalar_apply: + fn_name: exp + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LINALG_LINALGTYPES_H_ #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -20,6 +20,7 @@ MLIRSideEffectInterfaces MLIRViewLikeInterface MLIRStandard + MLIRMath MLIRMemRef MLIRTensor ) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -256,6 +256,20 @@ llvm_unreachable("unsupported non numeric type"); } + Value applyfn__exp(Value x) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(x)) + return builder.create(x.getLoc(), x); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__log(Value x) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(x)) + return builder.create(x.getLoc(), x); + llvm_unreachable("unsupported non numeric type"); + } + Value applyfn__sub(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -7,6 +7,7 @@ from mlir.ir import * from mlir.dialects import linalg from mlir.dialects import std +from mlir.dialects import math # TODO: resolve name collision for Linalg functionality that is injected inside # the _mlir.dialects.linalg directly via pybind. from _mlir.dialects.linalg import fill_builtin_region @@ -293,6 +294,16 @@ return std.AddIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") + def _eval_exp(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.ExpOp(x.type, x).result + raise NotImplementedError("Unsupported 'exp' operand: {x}") + + def _eval_log(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.LogOp(x.type, x).result + raise NotImplementedError("Unsupported 'log' operand: {x}") + def _eval_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return std.SubFOp(lhs.type, lhs, rhs).result diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -209,3 +209,16 @@ offset = cast(F64, const(2147483647)) scaling = (max - min) * inv_range O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) + + +@linalg_structured_op +def soft_plus_2d( + I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)): + """Implements the soft plus operator. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.m, D.n) + O[D.m, D.n] = \ + PrimFn.log(cast(U, const(1.0)) + PrimFn.exp(cast(U, I[D.m, D.n]))) diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -188,6 +188,23 @@ // CHECK-NEXT: linalg.yield %[[VAL6]] : i32 // CHECK-NEXT: -> tensor<16x32xi32> +// ----- + +func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK-LABEL: @generalize_soft_plus_2d_f32 +// CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64 +// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32 +// CHECK-NEXT: %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32 +// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 +// CHECK-NEXT: %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32 +// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32 +// CHECK-NEXT: linalg.yield %[[LOG]] : f32 +// CHECK-NEXT: -> tensor<16x32xf32> + // ----- // Verifies floating point to integer cast. func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -84,6 +84,13 @@ O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) +@linalg_structured_op +def soft_plus_poly( + I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)): + O[D.m, D.n] = \ + PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n]))) + + with Context() as ctx, Location.unknown(): module = Module.create() f16 = F16Type.get() @@ -299,5 +306,19 @@ def test_i32_fill_rng(min, max, seed, init_result): return fill_rng_poly(min, max, seed, outs=[init_result]) + # CHECK-LABEL: @test_f32_soft_plus + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[C1:.+]] = constant 1.000000e+00 : f64 + # CHECK-NEXT: %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32 + # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 + # CHECK-NEXT: %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32 + # CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32 + # CHECK-NEXT: linalg.yield %[[LOG]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) + def test_f32_soft_plus(input, init_result): + return soft_plus_poly(input, outs=[init_result]) + print(module)