diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(AVX512) add_subdirectory(Complex) add_subdirectory(GPU) +add_subdirectory(Math) add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(OpenACC) diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h @@ -10,6 +10,7 @@ #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/FoldUtils.h" @@ -33,6 +34,7 @@ Value value; }; +using folded_math_tanh = FoldedValueBuilder; using folded_std_constant_index = FoldedValueBuilder; using folded_std_constant_float = FoldedValueBuilder; using folded_std_constant_int = FoldedValueBuilder; @@ -55,7 +57,6 @@ using folded_std_load = FoldedValueBuilder; using folded_std_subi = FoldedValueBuilder; using folded_std_sub_view = FoldedValueBuilder; -using folded_std_tanh = FoldedValueBuilder; using folded_std_tensor_load = FoldedValueBuilder; using folded_std_view = FoldedValueBuilder; using folded_std_zero_extendi = FoldedValueBuilder; diff --git a/mlir/include/mlir/Dialect/Math/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Math/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Math/EDSC/Intrinsics.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/EDSC/Intrinsics.h @@ -0,0 +1,25 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for Math ops ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_MATH_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_MATH_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/EDSC/Builders.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { + +using math_rsqrt = ValueBuilder; +using math_tanh = ValueBuilder; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_MATH_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Math/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/IR/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(MathOps math) +add_mlir_doc(MathOps -gen-dialect-doc MathOps Dialects/) diff --git a/mlir/include/mlir/Dialect/Math/IR/Math.h b/mlir/include/mlir/Dialect/Math/IR/Math.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/IR/Math.h @@ -0,0 +1,32 @@ +//===- Math.h - Math dialect --------------------------------------*- C++-*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MATH_IR_MATH_H_ +#define MLIR_DIALECT_MATH_IR_MATH_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/VectorInterfaces.h" + +//===----------------------------------------------------------------------===// +// Math Dialect +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/MathOpsDialect.h.inc" + +//===----------------------------------------------------------------------===// +// Math Dialect Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Math/IR/MathOps.h.inc" + +#endif // MLIR_DIALECT_MATH_IR_MATH_H_ diff --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td @@ -0,0 +1,19 @@ +//===- MathBase.td - Base definitions for math dialect ------*- tablegen -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MATH_BASE +#define MATH_BASE +include "mlir/IR/OpBase.td" +def Math_Dialect : Dialect { + let name = "math"; + let cppNamespace = "::mlir::math"; + let description = [{ + The math dialect is intended to hold mathematical operations on integer and + floating type beyond simple arithmetics. + }]; +} +#endif // MATH_BASE diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -0,0 +1,430 @@ +//===- MathOps.td - Math op definitions --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MATH_OPS +#define MATH_OPS + +include "mlir/Dialect/Math/IR/MathBase.td" +include "mlir/Interfaces/VectorInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class MathOp traits = []> + : Op; + +class FloatUnaryOp traits = []> : + MathOp, + ElementwiseMappable, + SameOperandsAndResultType]> { + let arguments = (ins FloatLike:$operand); + + let results = (outs FloatLike:$result); + + let assemblyFormat = "$operand attr-dict `:` type($result)"; +} + +class FloatBinaryOp traits = []> : + MathOp, + ElementwiseMappable, + SameOperandsAndResultType]> { + let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); + let results = (outs FloatLike:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; +} + +//===----------------------------------------------------------------------===// +// AtanOp +//===----------------------------------------------------------------------===// + +def AtanOp : FloatUnaryOp<"atan">{ + let summary = "arcus tangent of the given value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.atan` ssa-use `:` type + ``` + + The `atan` operation computes the arcus tangent of a given value. It takes + one operand and returns one result of the same type. This type may be a + float scalar type, a vector whose element type is float, or a tensor of + floats. It has no standard attributes. + + Example: + + ```mlir + // Arcus tangent of scalar value. + %a = math.atan %b : f64 + + // SIMD vector element-wise arcus tangent. + %f = math.atan %g : vector<4xf32> + + // Tensor element-wise arcus tangent. + %x = math.atan %y : tensor<4x?xf8> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// Atan2Op +//===----------------------------------------------------------------------===// + +def Atan2Op : FloatBinaryOp<"atan2">{ + let summary = "2-argument arcus tangent of the given values"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.atan2` ssa-use `,` ssa-use `:` type + ``` + + The `atan2` operation takes two operands and returns one result, all of + which must be of the same type. This type may be a floating point scalar + type, a vector whose element type is a floating point type, or a floating + point tensor. + + The 2-argument arcus tangent `atan2(y, x)` returns the angle in the + Euclidian plane between the positive x-axis and the ray through the point + (x, y). It is a generalization of the 1-argument arcus tangent which + returns the angle on the basis of the ratio y/x. + + See also https://en.wikipedia.org/wiki/Atan2 + + Example: + + ```mlir + // Scalar variant. + %a = math.atan2 %b, %c : f32 + + // SIMD vector variant. + %f = math.atan2 %g, %h : vector<4xf32> + + // Tensor variant. + %x = math.atan2 %y, %z : tensor<4x?xf32> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// CosOp +//===----------------------------------------------------------------------===// + +def CosOp : FloatUnaryOp<"cos"> { + let summary = "cosine of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.cos` ssa-use `:` type + ``` + + The `cos` operation computes the cosine of a given value. It takes one + operand and returns one result of the same type. This type may be a float + scalar type, a vector whose element type is float, or a tensor of floats. + It has no standard attributes. + + Example: + + ```mlir + // Scalar cosine value. + %a = math.cos %b : f64 + + // SIMD vector element-wise cosine value. + %f = math.cos %g : vector<4xf32> + + // Tensor element-wise cosine value. + %x = math.cos %y : tensor<4x?xf8> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// SinOp +//===----------------------------------------------------------------------===// + +def SinOp : FloatUnaryOp<"sin"> { + let summary = "sine of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.sin` ssa-use `:` type + ``` + + The `sin` operation computes the sine of a given value. It takes one + operand and returns one result of the same type. This type may be a float + scalar type, a vector whose element type is float, or a tensor of floats. + It has no standard attributes. + + Example: + + ```mlir + // Scalar sine value. + %a = math.sin %b : f64 + + // SIMD vector element-wise sine value. + %f = math.sin %g : vector<4xf32> + + // Tensor element-wise sine value. + %x = math.sin %y : tensor<4x?xf8> + ``` + }]; +} + + +//===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// + +def ExpOp : FloatUnaryOp<"exp"> { + let summary = "base-e exponential of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.exp` ssa-use `:` type + ``` + + The `exp` operation takes one operand and returns one result of the same + type. This type may be a float scalar type, a vector whose element type is + float, or a tensor of floats. It has no standard attributes. + + Example: + + ```mlir + // Scalar natural exponential. + %a = math.exp %b : f64 + + // SIMD vector element-wise natural exponential. + %f = math.exp %g : vector<4xf32> + + // Tensor element-wise natural exponential. + %x = math.exp %y : tensor<4x?xf8> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// Exp2Op +//===----------------------------------------------------------------------===// + +def Exp2Op : FloatUnaryOp<"exp2"> { + let summary = "base-2 exponential of the specified value"; + + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.exp2` ssa-use `:` type + ``` + + The `exp` operation takes one operand and returns one result of the same + type. This type may be a float scalar type, a vector whose element type is + float, or a tensor of floats. It has no standard attributes. + + Example: + + ```mlir + // Scalar natural exponential. + %a = math.exp2 %b : f64 + + // SIMD vector element-wise natural exponential. + %f = math.exp2 %g : vector<4xf32> + + // Tensor element-wise natural exponential. + %x = math.exp2 %y : tensor<4x?xf8> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// LogOp +//===----------------------------------------------------------------------===// + +def LogOp : FloatUnaryOp<"log"> { + let summary = "base-e logarithm of the specified value"; + + let description = [{ + Computes the base-e logarithm of the given value. It takes one operand and + returns one result of the same type. + + Example: + + ```mlir + %y = math.log %x : f64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// Log10Op +//===----------------------------------------------------------------------===// + +def Log10Op : FloatUnaryOp<"log10"> { + let summary = "base-10 logarithm of the specified value"; + + let description = [{ + Computes the base-10 logarithm of the given value. It takes one operand and + returns one result of the same type. + + Example: + + ```mlir + %y = math.log10 %x : f64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// Log1pOp +//===----------------------------------------------------------------------===// + +def Log1pOp : FloatUnaryOp<"log1p"> { + let summary = "Computes the natural logarithm of one plus the given value"; + + let description = [{ + Computes the base-e logarithm of one plus the given value. It takes one + operand and returns one result of the same type. + + log1p(x) := log(1 + x) + + Example: + + ```mlir + %y = math.log1p %x : f64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// Log2Op +//===----------------------------------------------------------------------===// + +def Log2Op : FloatUnaryOp<"log2"> { + let summary = "base-2 logarithm of the specified value"; + + let description = [{ + Computes the base-2 logarithm of the given value. It takes one operand and + returns one result of the same type. + + Example: + + ```mlir + %y = math.log2 %x : f64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// PowFOp +//===----------------------------------------------------------------------===// + +def PowFOp : FloatBinaryOp<"powf"> { + let summary = "floating point raised to the power of operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.powf` ssa-use `,` ssa-use `:` type + ``` + + The `powf` operation takes two operands and returns one result, each of + these is required to be the same type. This type may be a floating point + scalar type, a vector whose element type is a floating point type, or a + floating point tensor. + + Example: + + ```mlir + // Scalar exponentiation. + %a = math.powf %b, %c : f64 + + // SIMD pointwise vector exponentiation + %f = math.powf %g, %h : vector<4xf32> + + // Tensor pointwise exponentiation. + %x = math.powf %y, %z : tensor<4x?xbf16> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// RsqrtOp +//===----------------------------------------------------------------------===// + +def RsqrtOp : FloatUnaryOp<"rsqrt"> { + let summary = "reciprocal of sqrt (1 / sqrt of the specified value)"; + let description = [{ + The `rsqrt` operation computes the reciprocal of the square root. It takes + one operand and returns one result of the same type. This type may be a + float scalar type, a vector whose element type is float, or a tensor of + floats. It has no standard attributes. + }]; +} + +//===----------------------------------------------------------------------===// +// SqrtOp +//===----------------------------------------------------------------------===// + +def SqrtOp : FloatUnaryOp<"sqrt"> { + let summary = "sqrt of the specified value"; + let description = [{ + The `sqrt` operation computes the square root. It takes one operand and + returns one result of the same type. This type may be a float scalar type, a + vector whose element type is float, or a tensor of floats. It has no standard + attributes. + + Example: + + ```mlir + // Scalar square root value. + %a = math.sqrt %b : f64 + // SIMD vector element-wise square root value. + %f = math.sqrt %g : vector<4xf32> + // Tensor element-wise square root value. + %x = math.sqrt %y : tensor<4x?xf32> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// TanhOp +//===----------------------------------------------------------------------===// + +def TanhOp : FloatUnaryOp<"tanh"> { + let summary = "hyperbolic tangent of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `std.tanh` ssa-use `:` type + ``` + + The `tanh` operation computes the hyperbolic tangent. It takes one operand + and returns one result of the same type. This type may be a float scalar + type, a vector whose element type is float, or a tensor of floats. It has + no standard attributes. + + Example: + + ```mlir + // Scalar hyperbolic tangent value. + %a = math.tanh %b : f64 + + // SIMD vector element-wise hyperbolic tangent value. + %f = math.tanh %g : vector<4xf32> + + // Tensor element-wise hyperbolic tangent value. + %x = math.tanh %y : tensor<4x?xf8> + ``` + }]; +} + +#endif // MATH_OPS diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -0,0 +1,24 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Bufferize.h" + +namespace mlir { + +class OwningRewritePatternList; + +void populateExpandTanhPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx); + +} // namespace mlir + +#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -35,7 +35,6 @@ using std_mulf = ValueBuilder; using std_memref_cast = ValueBuilder; using std_ret = OperationBuilder; -using std_rsqrt = ValueBuilder; using std_select = ValueBuilder; using std_load = ValueBuilder; using std_sign_extendi = ValueBuilder; @@ -44,7 +43,6 @@ using std_subf = ValueBuilder; using std_subi = ValueBuilder; using std_sub_view = ValueBuilder; -using std_tanh = ValueBuilder; using std_tensor_load = ValueBuilder; using std_tensor_store = OperationBuilder; using std_view = ValueBuilder; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -499,79 +499,6 @@ let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; } -//===----------------------------------------------------------------------===// -// AtanOp -//===----------------------------------------------------------------------===// - -def AtanOp : FloatUnaryOp<"atan", []>{ - let summary = "arcus tangent of the given value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.atan` ssa-use `:` type - ``` - - The `atan` operation computes the arcus tangent of a given value. It takes - one operand and returns one result of the same type. This type may be a - float scalar type, a vector whose element type is float, or a tensor of - floats. It has no standard attributes. - - Example: - - ```mlir - // Arcus tangent of scalar value. - %a = atan %b : f64 - - // SIMD vector element-wise arcus tangent. - %f = atan %g : vector<4xf32> - - // Tensor element-wise arcus tangent. - %x = atan %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// Atan2Op -//===----------------------------------------------------------------------===// - -def Atan2Op : FloatArithmeticOp<"atan2">{ - let summary = "2-argument arcus tangent of the given values"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.atan2` ssa-use `,` ssa-use `:` type - ``` - - The `atan2` operation takes two operands and returns one result, all of - which must be of the same type. This type may be a floating point scalar - type, a vector whose element type is a floating point type, or a floating - point tensor. - - The 2-argument arcus tangent `atan2(y, x)` returns the angle in the - Euclidian plane between the positive x-axis and the ray through the point - (x, y). It is a generalization of the 1-argument arcus tangent which - returns the angle on the basis of the ratio y/x. - - See also https://en.wikipedia.org/wiki/Atan2 - - Example: - - ```mlir - // Scalar variant. - %a = atan2 %b, %c : f32 - - // SIMD vector variant. - %f = atan2 %g, %h : vector<4xf32> - - // Tensor variant. - %x = atan2 %y, %z : tensor<4x?xf32> - ``` - }]; -} - //===----------------------------------------------------------------------===// // AtomicRMWOp //===----------------------------------------------------------------------===// @@ -1372,72 +1299,6 @@ }]; } -//===----------------------------------------------------------------------===// -// CosOp -//===----------------------------------------------------------------------===// - -def CosOp : FloatUnaryOp<"cos"> { - let summary = "cosine of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.cos` ssa-use `:` type - ``` - - The `cos` operation computes the cosine of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar cosine value. - %a = cos %b : f64 - - // SIMD vector element-wise cosine value. - %f = cos %g : vector<4xf32> - - // Tensor element-wise cosine value. - %x = cos %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// SinOp -//===----------------------------------------------------------------------===// - -def SinOp : FloatUnaryOp<"sin"> { - let summary = "sine of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.sin` ssa-use `:` type - ``` - - The `sin` operation computes the sine of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar sine value. - %a = sin %b : f64 - - // SIMD vector element-wise sine value. - %f = sin %g : vector<4xf32> - - // Tensor element-wise sine value. - %x = sin %y : tensor<4x?xf8> - ``` - }]; -} - //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// @@ -1528,46 +1389,6 @@ let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// ExpOp -//===----------------------------------------------------------------------===// - -def ExpOp : FloatUnaryOp<"exp"> { - let summary = "base-e exponential of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.exp` ssa-use `:` type - ``` - - The `exp` operation takes one operand and returns one result of the same - type. This type may be a float scalar type, a vector whose element type is - float, or a tensor of floats. It has no standard attributes. - - Example: - - ```mlir - // Scalar natural exponential. - %a = exp %b : f64 - - // SIMD vector element-wise natural exponential. - %f = exp %g : vector<4xf32> - - // Tensor element-wise natural exponential. - %x = exp %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// ExpOp -//===----------------------------------------------------------------------===// - -def Exp2Op : FloatUnaryOp<"exp2"> { - let summary = "base-2 exponential of the specified value"; -} - //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// @@ -1802,51 +1623,6 @@ let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; } -//===----------------------------------------------------------------------===// -// LogOp -//===----------------------------------------------------------------------===// - -def LogOp : FloatUnaryOp<"log"> { - let summary = "base-e logarithm of the specified value"; -} - -//===----------------------------------------------------------------------===// -// Log10Op -//===----------------------------------------------------------------------===// - -def Log10Op : FloatUnaryOp<"log10"> { - let summary = "base-10 logarithm of the specified value"; -} - -//===----------------------------------------------------------------------===// -// Log1pOp -//===----------------------------------------------------------------------===// - -def Log1pOp : FloatUnaryOp<"log1p"> { - let summary = "Computes the natural logarithm of one plus the given value"; - - let description = [{ - Computes the base-e logarithm of one plus the given value. It takes one - operand and returns one result of the same type. - - log1p(x) := log(1 + x) - - Example: - - ```mlir - %y = log1p %x : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// Log2Op -//===----------------------------------------------------------------------===// - -def Log2Op : FloatUnaryOp<"log2"> { - let summary = "base-2 logarithm of the specified value"; -} - //===----------------------------------------------------------------------===// // MemRefCastOp //===----------------------------------------------------------------------===// @@ -2187,39 +1963,6 @@ let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// PowFOp -//===----------------------------------------------------------------------===// - -def PowFOp : FloatArithmeticOp<"powf"> { - let summary = "floating point raised to the power of operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.powf` ssa-use `,` ssa-use `:` type - ``` - - The `powf` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. - - Example: - - ```mlir - // Scalar exponentiation. - %a = powf %b, %c : f64 - - // SIMD pointwise vector exponentiation - %f = powf %g, %h : vector<4xf32> - - // Tensor pointwise exponentiation. - %x = powf %y, %z : tensor<4x?xbf16> - ``` - }]; -} - //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// @@ -2333,20 +2076,6 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } -//===----------------------------------------------------------------------===// -// RsqrtOp -//===----------------------------------------------------------------------===// - -def RsqrtOp : FloatUnaryOp<"rsqrt"> { - let summary = "reciprocal of sqrt (1 / sqrt of the specified value)"; - let description = [{ - The `rsqrt` operation computes the reciprocal of the square root. It takes - one operand and returns one result of the same type. This type may be a - float scalar type, a vector whose element type is float, or a tensor of - floats. It has no standard attributes. - }]; -} - //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// @@ -2684,31 +2413,6 @@ let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } -//===----------------------------------------------------------------------===// -// SqrtOp -//===----------------------------------------------------------------------===// - -def SqrtOp : FloatUnaryOp<"sqrt"> { - let summary = "sqrt of the specified value"; - let description = [{ - The `sqrt` operation computes the square root. It takes one operand and - returns one result of the same type. This type may be a float scalar type, a - vector whose element type is float, or a tensor of floats. It has no standard - attributes. - - Example: - - ```mlir - // Scalar square root value. - %a = sqrt %b : f64 - // SIMD vector element-wise square root value. - %f = sqrt %g : vector<4xf32> - // Tensor element-wise square root value. - %x = sqrt %y : tensor<4x?xf32> - ``` - }]; -} - //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// @@ -3237,39 +2941,6 @@ }]; } -//===----------------------------------------------------------------------===// -// TanhOp -//===----------------------------------------------------------------------===// - -def TanhOp : FloatUnaryOp<"tanh"> { - let summary = "hyperbolic tangent of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.tanh` ssa-use `:` type - ``` - - The `tanh` operation computes the hyperbolic tangent. It takes one operand - and returns one result of the same type. This type may be a float scalar - type, a vector whose element type is float, or a tensor of floats. It has - no standard attributes. - - Example: - - ```mlir - // Scalar hyperbolic tangent value. - %a = tanh %b : f64 - - // SIMD vector element-wise hyperbolic tangent value. - %f = tanh %g : vector<4xf32> - - // Tensor element-wise hyperbolic tangent value. - %x = tanh %y : tensor<4x?xf8> - ``` - }]; -} - //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -1,4 +1,3 @@ - //===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. @@ -22,9 +21,6 @@ class OwningRewritePatternList; -void populateExpandTanhPattern(OwningRewritePatternList &patterns, - MLIRContext *ctx); - void populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -28,6 +28,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -60,6 +61,7 @@ LLVM::LLVMArmNeonDialect, LLVM::LLVMArmSVEDialect, linalg::LinalgDialect, + math::MathDialect, scf::SCFDialect, omp::OpenMPDialect, pdl::PDLDialect, diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -173,36 +174,36 @@ GPUFuncOpLowering<0>>(converter); patterns.insert>(converter, "__nv_fabsf", "__nv_fabs"); - patterns.insert>(converter, "__nv_atanf", - "__nv_atan"); - patterns.insert>(converter, "__nv_atan2f", - "__nv_atan2"); + patterns.insert>(converter, "__nv_atanf", + "__nv_atan"); + patterns.insert>(converter, "__nv_atan2f", + "__nv_atan2"); patterns.insert>(converter, "__nv_ceilf", "__nv_ceil"); - patterns.insert>(converter, "__nv_cosf", - "__nv_cos"); - patterns.insert>(converter, "__nv_expf", - "__nv_exp"); + patterns.insert>(converter, "__nv_cosf", + "__nv_cos"); + patterns.insert>(converter, "__nv_expf", + "__nv_exp"); patterns.insert>(converter, "__nv_floorf", "__nv_floor"); - patterns.insert>(converter, "__nv_logf", - "__nv_log"); - patterns.insert>(converter, "__nv_log1pf", - "__nv_log1p"); - patterns.insert>(converter, "__nv_log10f", - "__nv_log10"); - patterns.insert>(converter, "__nv_log2f", - "__nv_log2"); - patterns.insert>(converter, "__nv_powf", - "__nv_pow"); - patterns.insert>(converter, "__nv_rsqrtf", - "__nv_rsqrt"); - patterns.insert>(converter, "__nv_sinf", - "__nv_sin"); - patterns.insert>(converter, "__nv_sqrtf", - "__nv_sqrt"); - patterns.insert>(converter, "__nv_tanhf", - "__nv_tanh"); + patterns.insert>(converter, "__nv_logf", + "__nv_log"); + patterns.insert>(converter, "__nv_log1pf", + "__nv_log1p"); + patterns.insert>(converter, "__nv_log10f", + "__nv_log10"); + patterns.insert>(converter, "__nv_log2f", + "__nv_log2"); + patterns.insert>(converter, "__nv_powf", + "__nv_pow"); + patterns.insert>(converter, "__nv_rsqrtf", + "__nv_rsqrt"); + patterns.insert>(converter, "__nv_sinf", + "__nv_sin"); + patterns.insert>(converter, "__nv_sqrtf", + "__nv_sqrt"); + patterns.insert>(converter, "__nv_tanhf", + "__nv_tanh"); } std::unique_ptr> diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -105,36 +106,36 @@ GPUFuncOpLowering<5>, GPUReturnOpLowering>(converter); patterns.insert>(converter, "__ocml_fabs_f32", "__ocml_fabs_f64"); - patterns.insert>(converter, "__ocml_atan_f32", - "__ocml_atan_f64"); - patterns.insert>(converter, "__ocml_atan2_f32", - "__ocml_atan2_f64"); + patterns.insert>( + converter, "__ocml_atan_f32", "__ocml_atan_f64"); + patterns.insert>( + converter, "__ocml_atan2_f32", "__ocml_atan2_f64"); patterns.insert>(converter, "__ocml_ceil_f32", "__ocml_ceil_f64"); - patterns.insert>(converter, "__ocml_cos_f32", - "__ocml_cos_f64"); - patterns.insert>(converter, "__ocml_exp_f32", - "__ocml_exp_f64"); + patterns.insert>( + converter, "__ocml_cos_f32", "__ocml_cos_f64"); + patterns.insert>( + converter, "__ocml_exp_f32", "__ocml_exp_f64"); patterns.insert>(converter, "__ocml_floor_f32", "__ocml_floor_f64"); - patterns.insert>(converter, "__ocml_log_f32", - "__ocml_log_f64"); - patterns.insert>(converter, "__ocml_log10_f32", - "__ocml_log10_f64"); - patterns.insert>(converter, "__ocml_log1p_f32", - "__ocml_log1p_f64"); - patterns.insert>(converter, "__ocml_log2_f32", - "__ocml_log2_f64"); - patterns.insert>(converter, "__ocml_pow_f32", - "__ocml_pow_f64"); - patterns.insert>(converter, "__ocml_rsqrt_f32", - "__ocml_rsqrt_f64"); - patterns.insert>(converter, "__ocml_sin_f32", - "__ocml_sin_f64"); - patterns.insert>(converter, "__ocml_sqrt_f32", - "__ocml_sqrt_f64"); - patterns.insert>(converter, "__ocml_tanh_f32", - "__ocml_tanh_f64"); + patterns.insert>( + converter, "__ocml_log_f32", "__ocml_log_f64"); + patterns.insert>( + converter, "__ocml_log10_f32", "__ocml_log10_f64"); + patterns.insert>( + converter, "__ocml_log1p_f32", "__ocml_log1p_f64"); + patterns.insert>( + converter, "__ocml_log2_f32", "__ocml_log2_f64"); + patterns.insert>( + converter, "__ocml_pow_f32", "__ocml_pow_f64"); + patterns.insert>( + converter, "__ocml_rsqrt_f32", "__ocml_rsqrt_f64"); + patterns.insert>( + converter, "__ocml_sin_f32", "__ocml_sin_f64"); + patterns.insert>( + converter, "__ocml_sqrt_f32", "__ocml_sqrt_f64"); + patterns.insert>( + converter, "__ocml_tanh_f32", "__ocml_tanh_f64"); } std::unique_ptr> diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -13,5 +13,6 @@ LINK_LIBS PUBLIC MLIRLLVMIR + MLIRMath MLIRTransforms ) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -1655,19 +1656,20 @@ using CeilFOpLowering = VectorConvertToLLVMPattern; using CopySignOpLowering = VectorConvertToLLVMPattern; -using CosOpLowering = VectorConvertToLLVMPattern; +using CosOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; -using ExpOpLowering = VectorConvertToLLVMPattern; -using Exp2OpLowering = VectorConvertToLLVMPattern; +using ExpOpLowering = VectorConvertToLLVMPattern; +using Exp2OpLowering = VectorConvertToLLVMPattern; using FloorFOpLowering = VectorConvertToLLVMPattern; -using Log10OpLowering = VectorConvertToLLVMPattern; -using Log2OpLowering = VectorConvertToLLVMPattern; -using LogOpLowering = VectorConvertToLLVMPattern; +using Log10OpLowering = + VectorConvertToLLVMPattern; +using Log2OpLowering = VectorConvertToLLVMPattern; +using LogOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using OrOpLowering = VectorConvertToLLVMPattern; -using PowFOpLowering = VectorConvertToLLVMPattern; +using PowFOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = OneToOneConvertToLLVMPattern; using SignExtendIOpLowering = @@ -1680,8 +1682,8 @@ VectorConvertToLLVMPattern; using SignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; -using SinOpLowering = VectorConvertToLLVMPattern; -using SqrtOpLowering = VectorConvertToLLVMPattern; +using SinOpLowering = VectorConvertToLLVMPattern; +using SqrtOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using UnsignedDivIOpLowering = @@ -2335,13 +2337,13 @@ }; // A `rsqrt` is converted into `1 / sqrt`. -struct RsqrtOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct RsqrtOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(RsqrtOp op, ArrayRef operands, + matchAndRewrite(math::RsqrtOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - RsqrtOp::Adaptor transformed(operands); + math::RsqrtOp::Adaptor transformed(operands); auto operandType = transformed.operand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) @@ -4025,7 +4027,7 @@ : ConversionTarget(ctx) { this->addLegalDialect(); this->addIllegalOp(); - this->addIllegalOp(); + this->addIllegalOp(); } std::unique_ptr> diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRIR + MLIRMath MLIRPass MLIRSPIRV MLIRSPIRVConversion diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -1131,6 +1132,15 @@ SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert< + // Math dialect operations. + // TODO: Move to separate pass. + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, // Unary and binary patterns BitwiseOpPattern, BitwiseOpPattern, @@ -1138,25 +1148,18 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -13,6 +13,7 @@ MLIRIR MLIRLinalg MLIRLinalgUtils + MLIRMath MLIRPass MLIRTosa MLIRTosaTransforms diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" @@ -64,15 +65,15 @@ // tosa::PowOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ExpOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::SubOp if (isa(op) && elementTy.isa()) @@ -83,7 +84,7 @@ // tosa::TanhOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::GreaterOp if (isa(op) && elementTy.isa()) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -13,6 +13,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" @@ -30,7 +31,8 @@ : public TosaToLinalgOnTensorsBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnFunction() override { diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) +add_subdirectory(Math) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(PDL) diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Math/EDSC/Intrinsics.h" #include "mlir/Dialect/SCF/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -113,7 +114,7 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I, StructuredIndexed O) { - UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return std_tanh(a); }); + UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return math_tanh(a); }); return linalg_generic_pointwise(unOp, I, O); } diff --git a/mlir/lib/Dialect/Linalg/EDSC/CMakeLists.txt b/mlir/lib/Dialect/Linalg/EDSC/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/EDSC/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/EDSC/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRAffine MLIRAffineEDSC MLIRLinalg + MLIRMath MLIRSCF MLIRStandard ) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -293,7 +294,8 @@ BufferizeTypeConverter typeConverter; // Mark all Standard operations legal. - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. diff --git a/mlir/lib/Dialect/Math/CMakeLists.txt b/mlir/lib/Dialect/Math/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRMath + MathOps.cpp + MathDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math + + DEPENDS + MLIRMathOpsIncGen + + LINK_LIBS PUBLIC + MLIRDialect + MLIRIR + ) diff --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp @@ -0,0 +1,16 @@ +//===- MathDialect.cpp - MLIR dialect for Math implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" + +void mlir::math::MathDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Math/IR/MathOps.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -0,0 +1,19 @@ +//===- MathOps.cpp - MLIR operations for math implementation --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" + +using namespace mlir; +using namespace mlir::math; + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Math/IR/MathOps.cpp.inc" diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRMathTransforms + ExpandTanh.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRStandard + MLIRTransforms + ) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp rename from mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp rename to mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp @@ -10,27 +10,23 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; namespace { - /// Expands tanh op into /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 -struct TanhOpConverter : public OpRewritePattern { +struct TanhOpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TanhOp op, + LogicalResult matchAndRewrite(math::TanhOp op, PatternRewriter &rewriter) const final { auto floatType = op.operand().getType(); Location loc = op.getLoc(); @@ -42,13 +38,13 @@ // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} Value negDoubledX = rewriter.create(loc, doubledX); - Value exp2x = rewriter.create(loc, negDoubledX); + Value exp2x = rewriter.create(loc, negDoubledX); Value dividend = rewriter.create(loc, one, exp2x); Value divisor = rewriter.create(loc, one, exp2x); Value positiveRes = rewriter.create(loc, dividend, divisor); // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 - exp2x = rewriter.create(loc, doubledX); + exp2x = rewriter.create(loc, doubledX); dividend = rewriter.create(loc, exp2x, one); divisor = rewriter.create(loc, exp2x, one); Value negativeRes = rewriter.create(loc, dividend, divisor); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -2,7 +2,6 @@ Bufferize.cpp DecomposeCallGraphTypes.cpp ExpandOps.cpp - ExpandTanh.cpp FuncBufferize.cpp FuncConversions.cpp TensorConstantBufferize.cpp diff --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir --- a/mlir/test/Analysis/test-shape-fn-report.mlir +++ b/mlir/test/Analysis/test-shape-fn-report.mlir @@ -6,7 +6,7 @@ func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32> attributes {shape.function = @shape_lib::@same_result_shape} { // expected-remark@+1 {{no associated way}} - %0 = tanh %arg : tensor<10x20xf32> + %0 = math.tanh %arg : tensor<10x20xf32> // expected-remark@+1 {{associated shape function: same_result_shape}} %1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32> return %1 : tensor<10x20xf32> diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -192,9 +192,9 @@ // CHECK: llvm.func @__nv_cos(f64) -> f64 // CHECK-LABEL: func @gpu_cos func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.cos %arg_f32 : f32 + %result32 = math.cos %arg_f32 : f32 // CHECK: llvm.call @__nv_cosf(%{{.*}}) : (f32) -> f32 - %result64 = std.cos %arg_f64 : f64 + %result64 = math.cos %arg_f64 : f64 // CHECK: llvm.call @__nv_cos(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -206,9 +206,9 @@ // CHECK: llvm.func @__nv_exp(f64) -> f64 // CHECK-LABEL: func @gpu_exp func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.exp %arg_f32 : f32 + %result32 = math.exp %arg_f32 : f32 // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32 - %result64 = std.exp %arg_f64 : f64 + %result64 = math.exp %arg_f64 : f64 // CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -221,9 +221,9 @@ // CHECK: llvm.func @__nv_log(f64) -> f64 // CHECK-LABEL: func @gpu_log func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.log %arg_f32 : f32 + %result32 = math.log %arg_f32 : f32 // CHECK: llvm.call @__nv_logf(%{{.*}}) : (f32) -> f32 - %result64 = std.log %arg_f64 : f64 + %result64 = math.log %arg_f64 : f64 // CHECK: llvm.call @__nv_log(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -236,9 +236,9 @@ // CHECK: llvm.func @__nv_log10(f64) -> f64 // CHECK-LABEL: func @gpu_log10 func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.log10 %arg_f32 : f32 + %result32 = math.log10 %arg_f32 : f32 // CHECK: llvm.call @__nv_log10f(%{{.*}}) : (f32) -> f32 - %result64 = std.log10 %arg_f64 : f64 + %result64 = math.log10 %arg_f64 : f64 // CHECK: llvm.call @__nv_log10(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -251,9 +251,9 @@ // CHECK: llvm.func @__nv_log1p(f64) -> f64 // CHECK-LABEL: func @gpu_log1p func @gpu_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.log1p %arg_f32 : f32 + %result32 = math.log1p %arg_f32 : f32 // CHECK: llvm.call @__nv_log1pf(%{{.*}}) : (f32) -> f32 - %result64 = std.log1p %arg_f64 : f64 + %result64 = math.log1p %arg_f64 : f64 // CHECK: llvm.call @__nv_log1p(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -266,9 +266,9 @@ // CHECK: llvm.func @__nv_log2(f64) -> f64 // CHECK-LABEL: func @gpu_log2 func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.log2 %arg_f32 : f32 + %result32 = math.log2 %arg_f32 : f32 // CHECK: llvm.call @__nv_log2f(%{{.*}}) : (f32) -> f32 - %result64 = std.log2 %arg_f64 : f64 + %result64 = math.log2 %arg_f64 : f64 // CHECK: llvm.call @__nv_log2(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -281,9 +281,9 @@ // CHECK: llvm.func @__nv_sin(f64) -> f64 // CHECK-LABEL: func @gpu_sin func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.sin %arg_f32 : f32 + %result32 = math.sin %arg_f32 : f32 // CHECK: llvm.call @__nv_sinf(%{{.*}}) : (f32) -> f32 - %result64 = std.sin %arg_f64 : f64 + %result64 = math.sin %arg_f64 : f64 // CHECK: llvm.call @__nv_sin(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -296,13 +296,13 @@ // CHECK: llvm.func @__nv_tanh(f64) -> f64 // CHECK-LABEL: func @gpu_tanh func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = std.tanh %arg_f16 : f16 + %result16 = math.tanh %arg_f16 : f16 // CHECK: llvm.fpext %{{.*}} : f16 to f32 // CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (f32) -> f32 // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 - %result32 = std.tanh %arg_f32 : f32 + %result32 = math.tanh %arg_f32 : f32 // CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (f32) -> f32 - %result64 = std.tanh %arg_f64 : f64 + %result64 = math.tanh %arg_f64 : f64 // CHECK: llvm.call @__nv_tanh(%{{.*}}) : (f64) -> f64 std.return %result16, %result32, %result64 : f16, f32, f64 } @@ -316,13 +316,13 @@ // CHECK-LABEL: func @gpu_rsqrt func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = std.rsqrt %arg_f16 : f16 + %result16 = math.rsqrt %arg_f16 : f16 // CHECK: llvm.fpext %{{.*}} : f16 to f32 // CHECK-NEXT: llvm.call @__nv_rsqrtf(%{{.*}}) : (f32) -> f32 // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 - %result32 = std.rsqrt %arg_f32 : f32 + %result32 = math.rsqrt %arg_f32 : f32 // CHECK: llvm.call @__nv_rsqrtf(%{{.*}}) : (f32) -> f32 - %result64 = std.rsqrt %arg_f64 : f64 + %result64 = math.rsqrt %arg_f64 : f64 // CHECK: llvm.call @__nv_rsqrt(%{{.*}}) : (f64) -> f64 std.return %result16, %result32, %result64 : f16, f32, f64 } @@ -336,13 +336,13 @@ // CHECK-LABEL: func @gpu_sqrt func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = std.sqrt %arg_f16 : f16 + %result16 = math.sqrt %arg_f16 : f16 // CHECK: llvm.fpext %{{.*}} : f16 to f32 // CHECK-NEXT: llvm.call @__nv_sqrtf(%{{.*}}) : (f32) -> f32 // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 - %result32 = std.sqrt %arg_f32 : f32 + %result32 = math.sqrt %arg_f32 : f32 // CHECK: llvm.call @__nv_sqrtf(%{{.*}}) : (f32) -> f32 - %result64 = std.sqrt %arg_f64 : f64 + %result64 = math.sqrt %arg_f64 : f64 // CHECK: llvm.call @__nv_sqrt(%{{.*}}) : (f64) -> f64 std.return %result16, %result32, %result64 : f16, f32, f64 } @@ -356,13 +356,13 @@ // CHECK-LABEL: func @gpu_atan func @gpu_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = std.atan %arg_f16 : f16 + %result16 = math.atan %arg_f16 : f16 // CHECK: llvm.fpext %{{.*}} : f16 to f32 // CHECK-NEXT: llvm.call @__nv_atanf(%{{.*}}) : (f32) -> f32 // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 - %result32 = std.atan %arg_f32 : f32 + %result32 = math.atan %arg_f32 : f32 // CHECK: llvm.call @__nv_atanf(%{{.*}}) : (f32) -> f32 - %result64 = std.atan %arg_f64 : f64 + %result64 = math.atan %arg_f64 : f64 // CHECK: llvm.call @__nv_atan(%{{.*}}) : (f64) -> f64 std.return %result16, %result32, %result64 : f16, f32, f64 } @@ -376,14 +376,14 @@ // CHECK-LABEL: func @gpu_atan2 func @gpu_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = std.atan2 %arg_f16, %arg_f16 : f16 + %result16 = math.atan2 %arg_f16, %arg_f16 : f16 // CHECK: llvm.fpext %{{.*}} : f16 to f32 // CHECK: llvm.fpext %{{.*}} : f16 to f32 // CHECK-NEXT: llvm.call @__nv_atan2f(%{{.*}}) : (f32, f32) -> f32 // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 - %result32 = std.atan2 %arg_f32, %arg_f32 : f32 + %result32 = math.atan2 %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__nv_atan2f(%{{.*}}) : (f32, f32) -> f32 - %result64 = std.atan2 %arg_f64, %arg_f64 : f64 + %result64 = math.atan2 %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__nv_atan2(%{{.*}}) : (f64, f64) -> f64 std.return %result16, %result32, %result64 : f16, f32, f64 } @@ -399,9 +399,9 @@ // CHECK: llvm.func @__nv_exp(f64) -> f64 // CHECK-LABEL: func @gpu_exp func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.exp %arg_f32 : f32 + %result32 = math.exp %arg_f32 : f32 // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32 - %result64 = std.exp %arg_f64 : f64 + %result64 = math.exp %arg_f64 : f64 // CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -416,9 +416,9 @@ // CHECK: llvm.func @__nv_pow(f64, f64) -> f64 // CHECK-LABEL: func @gpu_pow func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.powf %arg_f32, %arg_f32 : f32 + %result32 = math.powf %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__nv_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 - %result64 = std.powf %arg_f64, %arg_f64 : f64 + %result64 = math.powf %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__nv_pow(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 std.return %result32, %result64 : f32, f64 } diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -134,9 +134,9 @@ // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64 // CHECK-LABEL: func @gpu_cos func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.cos %arg_f32 : f32 + %result32 = math.cos %arg_f32 : f32 // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.cos %arg_f64 : f64 + %result64 = math.cos %arg_f64 : f64 // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -148,11 +148,11 @@ // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 // CHECK-LABEL: func @gpu_exp func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %exp_f32 = std.exp %arg_f32 : f32 + %exp_f32 = math.exp %arg_f32 : f32 // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 - %result32 = std.exp %exp_f32 : f32 + %result32 = math.exp %exp_f32 : f32 // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.exp %arg_f64 : f64 + %result64 = math.exp %arg_f64 : f64 // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -169,11 +169,11 @@ // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 // CHECK-LABEL: func @gpu_exp func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %exp_f32 = std.exp %arg_f32 : f32 + %exp_f32 = math.exp %arg_f32 : f32 // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 - %result32 = std.exp %exp_f32 : f32 + %result32 = math.exp %exp_f32 : f32 // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.exp %arg_f64 : f64 + %result64 = math.exp %arg_f64 : f64 // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -188,9 +188,9 @@ // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.log %arg_f32 : f32 + %result32 = math.log %arg_f32 : f32 // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.log %arg_f64 : f64 + %result64 = math.log %arg_f64 : f64 // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -203,9 +203,9 @@ // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log1p func @gpu_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.log1p %arg_f32 : f32 + %result32 = math.log1p %arg_f32 : f32 // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.log1p %arg_f64 : f64 + %result64 = math.log1p %arg_f64 : f64 // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -218,9 +218,9 @@ // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log10 func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.log10 %arg_f32 : f32 + %result32 = math.log10 %arg_f32 : f32 // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.log10 %arg_f64 : f64 + %result64 = math.log10 %arg_f64 : f64 // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -233,9 +233,9 @@ // CHECK: llvm.func @__ocml_log2_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log2 func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.log2 %arg_f32 : f32 + %result32 = math.log2 %arg_f32 : f32 // CHECK: llvm.call @__ocml_log2_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.log2 %arg_f64 : f64 + %result64 = math.log2 %arg_f64 : f64 // CHECK: llvm.call @__ocml_log2_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -249,13 +249,13 @@ // CHECK-LABEL: func @gpu_rsqrt func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = std.rsqrt %arg_f16 : f16 + %result16 = math.rsqrt %arg_f16 : f16 // CHECK: llvm.fpext %{{.*}} : f16 to f32 // CHECK-NEXT: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 - %result32 = std.rsqrt %arg_f32 : f32 + %result32 = math.rsqrt %arg_f32 : f32 // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.rsqrt %arg_f64 : f64 + %result64 = math.rsqrt %arg_f64 : f64 // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64 std.return %result16, %result32, %result64 : f16, f32, f64 } @@ -269,13 +269,13 @@ // CHECK-LABEL: func @gpu_sqrt func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = std.sqrt %arg_f16 : f16 + %result16 = math.sqrt %arg_f16 : f16 // CHECK: llvm.fpext %{{.*}} : f16 to f32 // CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32 // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 - %result32 = std.sqrt %arg_f32 : f32 + %result32 = math.sqrt %arg_f32 : f32 // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.sqrt %arg_f64 : f64 + %result64 = math.sqrt %arg_f64 : f64 // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64 std.return %result16, %result32, %result64 : f16, f32, f64 } @@ -288,9 +288,9 @@ // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64 // CHECK-LABEL: func @gpu_tanh func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.tanh %arg_f32 : f32 + %result32 = math.tanh %arg_f32 : f32 // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.tanh %arg_f64 : f64 + %result64 = math.tanh %arg_f64 : f64 // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -303,9 +303,9 @@ // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64 // CHECK-LABEL: func @gpu_atan func @gpu_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.atan %arg_f32 : f32 + %result32 = math.atan %arg_f32 : f32 // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32 - %result64 = std.atan %arg_f64 : f64 + %result64 = math.atan %arg_f64 : f64 // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -318,9 +318,9 @@ // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64 // CHECK-LABEL: func @gpu_atan2 func @gpu_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.atan2 %arg_f32, %arg_f32 : f32 + %result32 = math.atan2 %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}) : (f32, f32) -> f32 - %result64 = std.atan2 %arg_f64, %arg_f64 : f64 + %result64 = math.atan2 %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}) : (f64, f64) -> f64 std.return %result32, %result64 : f32, f64 } @@ -333,9 +333,9 @@ // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64 // CHECK-LABEL: func @gpu_pow func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = std.powf %arg_f32, %arg_f32 : f32 + %result32 = math.powf %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 - %result64 = std.powf %arg_f64, %arg_f64 : f64 + %result64 = math.powf %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 std.return %result32, %result64 : f32, f64 } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -490,9 +490,9 @@ // CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : i32 %12 = xor %arg2, %arg3 : i32 // CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (f32) -> f32 - %13 = std.exp %arg0 : f32 + %13 = math.exp %arg0 : f32 // CHECK-NEXT: %14 = "llvm.intr.exp2"(%arg0) : (f32) -> f32 - %14 = std.exp2 %arg0 : f32 + %14 = math.exp2 %arg0 : f32 // CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : f64 %15 = constant 7.9e-01 : f64 // CHECK-NEXT: %16 = llvm.shl %arg2, %arg3 : i32 @@ -502,9 +502,9 @@ // CHECK-NEXT: %18 = llvm.lshr %arg2, %arg3 : i32 %18 = shift_right_unsigned %arg2, %arg3 : i32 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (f32) -> f32 - %19 = std.sqrt %arg0 : f32 + %19 = math.sqrt %arg0 : f32 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (f64) -> f64 - %20 = std.sqrt %arg4 : f64 + %20 = math.sqrt %arg4 : f64 return %0, %4 : f32, i32 } diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -18,7 +18,7 @@ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f32) -> f32 // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f32 - %0 = rsqrt %arg0 : f32 + %0 = math.rsqrt %arg0 : f32 std.return } @@ -28,7 +28,7 @@ // CHECK-SAME: f32 func @sine(%arg0 : f32) { // CHECK: "llvm.intr.sin"(%arg0) : (f32) -> f32 - %0 = sin %arg0 : f32 + %0 = math.sin %arg0 : f32 std.return } @@ -61,7 +61,7 @@ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64 // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f64) -> f64 // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f64 - %0 = rsqrt %arg0 : f64 + %0 = math.rsqrt %arg0 : f64 std.return } @@ -73,7 +73,7 @@ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (vector<4xf32>) -> vector<4xf32> // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<4xf32> - %0 = rsqrt %arg0 : vector<4xf32> + %0 = math.rsqrt %arg0 : vector<4xf32> std.return } @@ -87,7 +87,7 @@ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%[[EXTRACT]]) : (vector<3xf32>) -> vector<3xf32> // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<3xf32> // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %0[0] : !llvm.array<4 x vector<3xf32>> - %0 = rsqrt %arg0 : vector<4x3xf32> + %0 = math.rsqrt %arg0 : vector<4x3xf32> std.return } @@ -220,6 +220,6 @@ // CHECK-SAME: f64 func @powf(%arg0 : f64) { // CHECK: %[[POWF:.*]] = "llvm.intr.pow"(%arg0, %arg0) : (f64, f64) -> f64 - %0 = std.powf %arg0, %arg0 : f64 + %0 = math.powf %arg0, %arg0 : f64 std.return } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -48,21 +48,21 @@ // CHECK: spv.GLSL.Ceil %{{.*}}: f32 %1 = ceilf %arg0 : f32 // CHECK: spv.GLSL.Cos %{{.*}}: f32 - %2 = cos %arg0 : f32 + %2 = math.cos %arg0 : f32 // CHECK: spv.GLSL.Exp %{{.*}}: f32 - %3 = exp %arg0 : f32 + %3 = math.exp %arg0 : f32 // CHECK: spv.GLSL.Log %{{.*}}: f32 - %4 = log %arg0 : f32 + %4 = math.log %arg0 : f32 // CHECK: spv.FNegate %{{.*}}: f32 %5 = negf %arg0 : f32 // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32 - %6 = rsqrt %arg0 : f32 + %6 = math.rsqrt %arg0 : f32 // CHECK: spv.GLSL.Sqrt %{{.*}}: f32 - %7 = sqrt %arg0 : f32 + %7 = math.sqrt %arg0 : f32 // CHECK: spv.GLSL.Tanh %{{.*}}: f32 - %8 = tanh %arg0 : f32 + %8 = math.tanh %arg0 : f32 // CHECK: spv.GLSL.Sin %{{.*}}: f32 - %9 = sin %arg0 : f32 + %9 = math.sin %arg0 : f32 // CHECK: spv.GLSL.Floor %{{.*}}: f32 %10 = floorf %arg0 : f32 return diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir @@ -22,7 +22,7 @@ // CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[ARG2]] // CHECK: %[[T3:.*]] = addi %[[T2]], %[[C2]] // CHECK: %[[LOADVAL:.*]] = load %[[ARG0]][%[[T1]], %[[T3]]] - // CHECK: %[[STOREVAL:.*]] = sqrt %[[LOADVAL]] + // CHECK: %[[STOREVAL:.*]] = math.sqrt %[[LOADVAL]] // CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C3]] // CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]] // CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[ARG2]] @@ -30,7 +30,7 @@ // CHECK: store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]] %0 = subview %arg0[%arg1, 2][4, 4][3, %arg2] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [96, ?]> %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]> - %2 = sqrt %1 : f32 + %2 = math.sqrt %1 : f32 store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]> return } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -17,7 +17,7 @@ // CHECK-SAME: ins(%[[MEMREF]] : memref<4xf32>) // CHECK-SAME: outs(%[[RESULT_MEMREF]] : memref<4xf32>) { // CHECK: ^bb0(%[[RESULT1:.*]]: f32, %[[UNUSED:.*]]: f32): -// CHECK: %[[DIM1:.*]] = exp %[[RESULT1]] : f32 +// CHECK: %[[DIM1:.*]] = math.exp %[[RESULT1]] : f32 // CHECK: linalg.yield %[[DIM1]] : f32 // CHECK: } // CHECK: %[[RESULT:.*]] = tensor_load %[[RESULT_MEMREF]] : memref<4xf32> @@ -29,7 +29,7 @@ } ins(%arg0 : tensor<4xf32>) outs(%arg0 : tensor<4xf32>) { ^bb0(%gen_arg1: f32, %out: f32): - %tmp1 = exp %gen_arg1 : f32 + %tmp1 = math.exp %gen_arg1 : f32 linalg.yield %tmp1 : f32 } -> tensor<4xf32> return %0 : tensor<4xf32> @@ -58,7 +58,7 @@ } ins(%in : tensor) outs(%init : tensor) { ^bb0(%gen_arg1: f32, %out: f32): - %tmp1 = exp %gen_arg1 : f32 + %tmp1 = math.exp %gen_arg1 : f32 linalg.yield %tmp1 : f32 } -> tensor return %0 : tensor @@ -83,7 +83,7 @@ } ins(%arg0 : tensor<4xf32>) outs (%arg0, %arg0 : tensor<4xf32>, tensor<4xf32>) { ^bb0(%gen_arg1: f32, %out1: f32, %out2: f32): - %tmp1 = exp %gen_arg1 : f32 + %tmp1 = math.exp %gen_arg1 : f32 linalg.yield %tmp1, %tmp1 : f32, f32 } -> tensor<4xf32>, tensor<4xf32> return %0, %1 : tensor<4xf32>, tensor<4xf32> @@ -142,7 +142,7 @@ } ins(%arg0 : tensor) outs (%arg0, %arg0 : tensor, tensor) { ^bb0(%gen_arg1: f32, %out1: f32, %out2: f32): - %tmp1 = exp %gen_arg1 : f32 + %tmp1 = math.exp %gen_arg1 : f32 linalg.yield %tmp1, %tmp1 : f32, f32 } -> tensor, tensor return %0, %1 : tensor, tensor diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir --- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -44,9 +44,9 @@ // CHECK-SAME: ins(%[[ARG0]] // CHECK-SAME: outs(%[[ARG0]] // CHECK: ^bb0(%[[SCALAR:.*]]: f32, %{{.*}}: f32): - // CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32 + // CHECK: %[[YIELD:.*]] = math.exp %[[SCALAR]] : f32 // CHECK: linalg.yield %[[YIELD]] : f32 - %0 = exp %arg0 : tensor + %0 = math.exp %arg0 : tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -637,7 +637,7 @@ ins(%6 : memref) outs(%7 : memref) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %8 = exp %arg3 : f32 + %8 = math.exp %arg3 : f32 linalg.yield %8 : f32 } } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -223,19 +223,19 @@ %8 = constant 2.0 : f32 // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> %9 = divf %arg5, %i : f32 - // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> - %10 = exp2 %arg5 : f32 + // CHECK: %[[EXP:.*]] = math.exp2 %[[V3]] : vector<4x256xf32> + %10 = math.exp2 %arg5 : f32 // CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> %11 = mulf %arg5, %8 : f32 - // CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> - %12 = rsqrt %arg5 : f32 + // CHECK: %[[RSQRT:.*]] = math.rsqrt %[[V3]] : vector<4x256xf32> + %12 = math.rsqrt %arg5 : f32 // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> %13 = select %7, %arg5, %arg6 : f32 // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> %14 = subf %arg5, %arg4 : f32 - // CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> - %15 = tanh %arg5 : f32 + // CHECK: %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32> + %15 = math.tanh %arg5 : f32 // CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> // CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> // CHECK: vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32> @@ -304,19 +304,19 @@ %8 = constant 2.0 : f32 // CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32> %9 = divf %arg5, %i : f32 - // CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32> - %10 = exp2 %arg5 : f32 + // CHECK: %[[EXP:.*]] = math.exp2 %[[V3]] : vector<4x256xf32> + %10 = math.exp2 %arg5 : f32 // CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32> %11 = mulf %arg5, %8 : f32 - // CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32> - %12 = rsqrt %arg5 : f32 + // CHECK: %[[RSQRT:.*]] = math.rsqrt %[[V3]] : vector<4x256xf32> + %12 = math.rsqrt %arg5 : f32 // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> %13 = select %7, %arg5, %arg6 : f32 // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> %14 = subf %arg5, %arg4 : f32 - // CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32> - %15 = tanh %arg5 : f32 + // CHECK: %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32> + %15 = math.tanh %arg5 : f32 // CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> // CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> // CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Math/ops.mlir @@ -0,0 +1,172 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @atan( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @atan(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.atan %[[F]] : f32 + %0 = math.atan %f : f32 + // CHECK: %{{.*}} = math.atan %[[V]] : vector<4xf32> + %1 = math.atan %v : vector<4xf32> + // CHECK: %{{.*}} = math.atan %[[T]] : tensor<4x4x?xf32> + %2 = math.atan %t : tensor<4x4x?xf32> + return +} + + +// CHECK-LABEL: func @atan2( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @atan2(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.atan2 %[[F]], %[[F]] : f32 + %0 = math.atan2 %f, %f : f32 + // CHECK: %{{.*}} = math.atan2 %[[V]], %[[V]] : vector<4xf32> + %1 = math.atan2 %v, %v : vector<4xf32> + // CHECK: %{{.*}} = math.atan2 %[[T]], %[[T]] : tensor<4x4x?xf32> + %2 = math.atan2 %t, %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @cos( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @cos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.cos %[[F]] : f32 + %0 = math.cos %f : f32 + // CHECK: %{{.*}} = math.cos %[[V]] : vector<4xf32> + %1 = math.cos %v : vector<4xf32> + // CHECK: %{{.*}} = math.cos %[[T]] : tensor<4x4x?xf32> + %2 = math.cos %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @sin( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.sin %[[F]] : f32 + %0 = math.sin %f : f32 + // CHECK: %{{.*}} = math.sin %[[V]] : vector<4xf32> + %1 = math.sin %v : vector<4xf32> + // CHECK: %{{.*}} = math.sin %[[T]] : tensor<4x4x?xf32> + %2 = math.sin %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @exp( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @exp(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.exp %[[F]] : f32 + %0 = math.exp %f : f32 + // CHECK: %{{.*}} = math.exp %[[V]] : vector<4xf32> + %1 = math.exp %v : vector<4xf32> + // CHECK: %{{.*}} = math.exp %[[T]] : tensor<4x4x?xf32> + %2 = math.exp %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @exp2( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @exp2(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.exp2 %[[F]] : f32 + %0 = math.exp2 %f : f32 + // CHECK: %{{.*}} = math.exp2 %[[V]] : vector<4xf32> + %1 = math.exp2 %v : vector<4xf32> + // CHECK: %{{.*}} = math.exp2 %[[T]] : tensor<4x4x?xf32> + %2 = math.exp2 %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @log( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @log(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.log %[[F]] : f32 + %0 = math.log %f : f32 + // CHECK: %{{.*}} = math.log %[[V]] : vector<4xf32> + %1 = math.log %v : vector<4xf32> + // CHECK: %{{.*}} = math.log %[[T]] : tensor<4x4x?xf32> + %2 = math.log %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @log10( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @log10(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.log10 %[[F]] : f32 + %0 = math.log10 %f : f32 + // CHECK: %{{.*}} = math.log10 %[[V]] : vector<4xf32> + %1 = math.log10 %v : vector<4xf32> + // CHECK: %{{.*}} = math.log10 %[[T]] : tensor<4x4x?xf32> + %2 = math.log10 %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @log1p( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @log1p(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.log1p %[[F]] : f32 + %0 = math.log1p %f : f32 + // CHECK: %{{.*}} = math.log1p %[[V]] : vector<4xf32> + %1 = math.log1p %v : vector<4xf32> + // CHECK: %{{.*}} = math.log1p %[[T]] : tensor<4x4x?xf32> + %2 = math.log1p %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @log2( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @log2(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.log2 %[[F]] : f32 + %0 = math.log2 %f : f32 + // CHECK: %{{.*}} = math.log2 %[[V]] : vector<4xf32> + %1 = math.log2 %v : vector<4xf32> + // CHECK: %{{.*}} = math.log2 %[[T]] : tensor<4x4x?xf32> + %2 = math.log2 %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @powf( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @powf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.powf %[[F]], %[[F]] : f32 + %0 = math.powf %f, %f : f32 + // CHECK: %{{.*}} = math.powf %[[V]], %[[V]] : vector<4xf32> + %1 = math.powf %v, %v : vector<4xf32> + // CHECK: %{{.*}} = math.powf %[[T]], %[[T]] : tensor<4x4x?xf32> + %2 = math.powf %t, %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @rsqrt( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @rsqrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.rsqrt %[[F]] : f32 + %0 = math.rsqrt %f : f32 + // CHECK: %{{.*}} = math.rsqrt %[[V]] : vector<4xf32> + %1 = math.rsqrt %v : vector<4xf32> + // CHECK: %{{.*}} = math.rsqrt %[[T]] : tensor<4x4x?xf32> + %2 = math.rsqrt %t : tensor<4x4x?xf32> + return +} + + +// CHECK-LABEL: func @sqrt( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @sqrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.sqrt %[[F]] : f32 + %0 = math.sqrt %f : f32 + // CHECK: %{{.*}} = math.sqrt %[[V]] : vector<4xf32> + %1 = math.sqrt %v : vector<4xf32> + // CHECK: %{{.*}} = math.sqrt %[[T]] : tensor<4x4x?xf32> + %2 = math.sqrt %t : tensor<4x4x?xf32> + return +} + +// CHECK-LABEL: func @tanh( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @tanh(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.tanh %[[F]] : f32 + %0 = math.tanh %f : f32 + // CHECK: %{{.*}} = math.tanh %[[V]] : vector<4xf32> + %1 = math.tanh %v : vector<4xf32> + // CHECK: %{{.*}} = math.tanh %[[T]] : tensor<4x4x?xf32> + %2 = math.tanh %t : tensor<4x4x?xf32> + return +} diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -92,7 +92,7 @@ } scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { %diff_elem = load %diff[%i, %j] : memref<100x10xf32> - %exp_elem = exp %diff_elem : f32 + %exp_elem = math.exp %diff_elem : f32 store %exp_elem, %result[%i, %j] : memref<100x10xf32> scf.yield } @@ -118,7 +118,7 @@ // CHECK: [[DIFF_ELEM:%.*]] = subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]] // CHECK: store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]] // CHECK: [[DIFF_ELEM_:%.*]] = load [[DIFF]]{{\[}}[[I]], [[J]]] -// CHECK: [[EXP_ELEM:%.*]] = exp [[DIFF_ELEM_]] +// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]] // CHECK: store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] // CHECK: scf.yield // CHECK: } diff --git a/mlir/test/Dialect/Standard/expand-tanh.mlir b/mlir/test/Dialect/Standard/expand-tanh.mlir --- a/mlir/test/Dialect/Standard/expand-tanh.mlir +++ b/mlir/test/Dialect/Standard/expand-tanh.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @tanh func @tanh(%arg: f32) -> f32 { - %res = tanh %arg : f32 + %res = math.tanh %arg : f32 return %res : f32 } // CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32 @@ -10,11 +10,11 @@ // CHECK-DAG: %[[TWO:.+]] = constant 2.000000e+00 : f32 // CHECK: %[[DOUBLEDX:.+]] = mulf %arg0, %[[TWO]] : f32 // CHECK: %[[NEGDOUBLEDX:.+]] = negf %[[DOUBLEDX]] : f32 -// CHECK: %[[EXP1:.+]] = exp %[[NEGDOUBLEDX]] : f32 +// CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32 // CHECK: %[[DIVIDEND1:.+]] = subf %[[ONE]], %[[EXP1]] : f32 // CHECK: %[[DIVISOR1:.+]] = addf %[[ONE]], %[[EXP1]] : f32 // CHECK: %[[RES1:.+]] = divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32 -// CHECK: %[[EXP2:.+]] = exp %[[DOUBLEDX]] : f32 +// CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32 // CHECK: %[[DIVIDEND2:.+]] = subf %[[EXP2]], %[[ONE]] : f32 // CHECK: %[[DIVISOR2:.+]] = addf %[[EXP2]], %[[ONE]] : f32 // CHECK: %[[RES2:.+]] = divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32 diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -34,13 +34,13 @@ // CHECK-LABEL: @atan func @atan(%arg : f32) -> f32 { - %result = atan %arg : f32 + %result = math.atan %arg : f32 return %result : f32 } // CHECK-LABEL: @atan2 func @atan2(%arg0 : f32, %arg1 : f32) -> f32 { - %result = atan2 %arg0, %arg1 : f32 + %result = math.atan2 %arg0, %arg1 : f32 return %result : f32 } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" @@ -42,6 +43,7 @@ context.loadDialect(); // clang-format on diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -90,9 +90,6 @@ // CHECK: %[[I6:.*]] = muli %[[I2]], %[[I2]] : i32 %i6 = muli %i2, %i2 : i32 - // CHECK: %[[F7:.*]] = powf %[[F2]], %[[F2]] : f32 - %f7 = powf %f2, %f2 : f32 - // CHECK: %c42_i32 = constant 42 : i32 %x = "std.constant"(){value = 42 : i32} : () -> i32 @@ -372,18 +369,6 @@ // CHECK: = fptrunc {{.*}} : f32 to f16 %95 = fptrunc %f : f32 to f16 - // CHECK: %{{[0-9]+}} = exp %arg1 : f32 - %96 = "std.exp"(%f) : (f32) -> f32 - - // CHECK: %{{[0-9]+}} = exp %arg1 : f32 - %97 = exp %f : f32 - - // CHECK: %{{[0-9]+}} = exp %cst_8 : vector<4xf32> - %98 = exp %vcf32 : vector<4xf32> - - // CHECK: %{{[0-9]+}} = exp %arg0 : tensor<4x4x?xf32> - %99 = exp %t : tensor<4x4x?xf32> - // CHECK: %{{[0-9]+}} = absf %arg1 : f32 %100 = "std.absf"(%f) : (f32) -> f32 @@ -408,18 +393,6 @@ // CHECK: %{{[0-9]+}} = ceilf %arg0 : tensor<4x4x?xf32> %107 = ceilf %t : tensor<4x4x?xf32> - // CHECK: %{{[0-9]+}} = cos %arg1 : f32 - %108 = "std.cos"(%f) : (f32) -> f32 - - // CHECK: %{{[0-9]+}} = cos %arg1 : f32 - %109 = cos %f : f32 - - // CHECK: %{{[0-9]+}} = cos %cst_8 : vector<4xf32> - %110 = cos %vcf32 : vector<4xf32> - - // CHECK: %{{[0-9]+}} = cos %arg0 : tensor<4x4x?xf32> - %111 = cos %t : tensor<4x4x?xf32> - // CHECK: %{{[0-9]+}} = negf %arg1 : f32 %112 = "std.negf"(%f) : (f32) -> f32 @@ -444,18 +417,6 @@ // CHECK: %{{[0-9]+}} = copysign %arg0, %arg0 : tensor<4x4x?xf32> %119 = copysign %t, %t : tensor<4x4x?xf32> - // CHECK: %{{[0-9]+}} = tanh %arg1 : f32 - %120 = "std.tanh"(%f) : (f32) -> f32 - - // CHECK: %{{[0-9]+}} = tanh %arg1 : f32 - %121 = tanh %f : f32 - - // CHECK: %{{[0-9]+}} = tanh %cst_8 : vector<4xf32> - %122 = tanh %vcf32 : vector<4xf32> - - // CHECK: %{{[0-9]+}} = tanh %arg0 : tensor<4x4x?xf32> - %123 = tanh %t : tensor<4x4x?xf32> - // CHECK: %{{[0-9]+}} = shift_left %arg2, %arg2 : i32 %124 = "std.shift_left"(%i, %i) : (i32, i32) -> i32 @@ -501,38 +462,14 @@ // CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32> %138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32> - // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32 - %139 = "std.sqrt"(%f) : (f32) -> f32 - - // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32 - %140 = sqrt %f : f32 - - // CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32> - %141 = sqrt %vcf32 : vector<4xf32> - - // CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32> - %142 = sqrt %t : tensor<4x4x?xf32> - // CHECK: = fpext {{.*}} : vector<4xf32> to vector<4xf64> %143 = fpext %vcf32 : vector<4xf32> to vector<4xf64> // CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16> %144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16> - // CHECK: %{{[0-9]+}} = rsqrt %arg1 : f32 - %145 = rsqrt %f : f32 - - // CHECK: %{{[0-9]+}} = sin %arg1 : f32 - %146 = "std.sin"(%f) : (f32) -> f32 - - // CHECK: %{{[0-9]+}} = sin %arg1 : f32 - %147 = sin %f : f32 - - // CHECK: %{{[0-9]+}} = sin %cst_8 : vector<4xf32> - %148 = sin %vcf32 : vector<4xf32> - - // CHECK: %{{[0-9]+}} = sin %arg0 : tensor<4x4x?xf32> - %149 = sin %t : tensor<4x4x?xf32> + // CHECK: %{{[0-9]+}} = math.rsqrt %arg1 : f32 + %145 = math.rsqrt %f : f32 // CHECK: = fptosi {{.*}} : f32 to i32 %159 = fptosi %f : f32 to i32 @@ -582,9 +519,6 @@ // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_4, %cst_4 : tensor<42xi32> %174 = ceildivi_signed %tci32, %tci32 : tensor<42 x i32> - // CHECK: %{{[0-9]+}} = log1p %arg1 : f32 - %175 = log1p %f : f32 - return } diff --git a/mlir/test/Transforms/buffer-deallocation.mlir b/mlir/test/Transforms/buffer-deallocation.mlir --- a/mlir/test/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Transforms/buffer-deallocation.mlir @@ -535,7 +535,7 @@ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %1 = alloc() : memref<2xf32> test.buffer_based in(%arg1: memref<2xf32>) out(%1: memref<2xf32>) - %tmp1 = exp %gen1_arg0 : f32 + %tmp1 = math.exp %gen1_arg0 : f32 test.region_yield %tmp1 : f32 } br ^bb3(%0 : memref<2xf32>) @@ -553,7 +553,7 @@ // CHECK: %[[ALLOC2:.*]] = alloc() // CHECK-NEXT: test.buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOC2]] // CHECK: dealloc %[[ALLOC2]] -// CHECK-NEXT: %{{.*}} = exp +// CHECK-NEXT: %{{.*}} = math.exp // CHECK: %[[ALLOC3:.*]] = alloc() // CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[ALLOC3]]) // CHECK-NEXT: dealloc %[[ALLOC1]] @@ -812,7 +812,7 @@ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %1 = alloca() : memref<2xf32> test.buffer_based in(%arg1: memref<2xf32>) out(%1: memref<2xf32>) - %tmp1 = exp %gen1_arg0 : f32 + %tmp1 = math.exp %gen1_arg0 : f32 test.region_yield %tmp1 : f32 } br ^bb3(%0 : memref<2xf32>) @@ -830,7 +830,7 @@ // CHECK-NEXT: test.region_buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOC1]] // CHECK: %[[ALLOCA:.*]] = alloca() // CHECK-NEXT: test.buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOCA]] -// CHECK: %{{.*}} = exp +// CHECK: %{{.*}} = math.exp // CHECK: %[[ALLOC2:.*]] = alloc() // CHECK-NEXT: linalg.copy // CHECK-NEXT: dealloc %[[ALLOC1]] diff --git a/mlir/test/Transforms/buffer-hoisting.mlir b/mlir/test/Transforms/buffer-hoisting.mlir --- a/mlir/test/Transforms/buffer-hoisting.mlir +++ b/mlir/test/Transforms/buffer-hoisting.mlir @@ -360,7 +360,7 @@ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %1 = alloc() : memref<2xf32> test.buffer_based in(%arg1: memref<2xf32>) out(%1: memref<2xf32>) - %tmp1 = exp %gen1_arg0 : f32 + %tmp1 = math.exp %gen1_arg0 : f32 test.region_yield %tmp1 : f32 } br ^bb3(%0 : memref<2xf32>) @@ -592,7 +592,7 @@ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %1 = alloca() : memref<2xf32> test.buffer_based in(%arg1: memref<2xf32>) out(%1: memref<2xf32>) - %tmp1 = exp %gen1_arg0 : f32 + %tmp1 = math.exp %gen1_arg0 : f32 test.region_yield %tmp1 : f32 } br ^bb3(%0 : memref<2xf32>) diff --git a/mlir/test/Transforms/buffer-loop-hoisting.mlir b/mlir/test/Transforms/buffer-loop-hoisting.mlir --- a/mlir/test/Transforms/buffer-loop-hoisting.mlir +++ b/mlir/test/Transforms/buffer-loop-hoisting.mlir @@ -86,7 +86,7 @@ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %1 = alloc() : memref<2xf32> test.buffer_based in(%arg1: memref<2xf32>) out(%1: memref<2xf32>) - %tmp1 = exp %gen1_arg0 : f32 + %tmp1 = math.exp %gen1_arg0 : f32 test.region_yield %tmp1 : f32 } br ^bb3(%0 : memref<2xf32>) diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir --- a/mlir/test/Transforms/canonicalize-dce.mlir +++ b/mlir/test/Transforms/canonicalize-dce.mlir @@ -53,7 +53,7 @@ func @f(%arg0: f32) { br ^loop(%arg0: f32) ^loop(%0: f32): - %1 = "std.exp"(%0) : (f32) -> f32 + %1 = "math.exp"(%0) : (f32) -> f32 br ^loop(%1: f32) } @@ -65,7 +65,7 @@ // CHECK-NEXT: return func @f(%arg0: f32, %pred: i1) { - %exp = "std.exp"(%arg0) : (f32) -> f32 + %exp = "math.exp"(%arg0) : (f32) -> f32 cond_br %pred, ^true(%exp: f32), ^false(%exp: f32) ^true(%0: f32): return @@ -124,9 +124,9 @@ // CHECK-NEXT: "foo.return" func @f(%arg0: f32) { - %0 = "std.exp"(%arg0) : (f32) -> f32 + %0 = "math.exp"(%arg0) : (f32) -> f32 "foo.has_region"() ({ - %1 = "std.exp"(%0) : (f32) -> f32 + %1 = "math.exp"(%0) : (f32) -> f32 "foo.return"() : () -> () }) : () -> () return diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir --- a/mlir/test/Transforms/copy-removal.mlir +++ b/mlir/test/Transforms/copy-removal.mlir @@ -174,7 +174,7 @@ ins(%temp : memref<5xf32>) outs(%res : memref<5xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): - %tmp1 = exp %gen1_arg0 : f32 + %tmp1 = math.exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 } dealloc %ret : memref<5xf32> @@ -253,7 +253,7 @@ ins(%arg0 : memref<2xf32>) outs(%temp : memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): - %tmp2 = exp %gen2_arg0 : f32 + %tmp2 = math.exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 } "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> () @@ -279,7 +279,7 @@ ins(%arg0 : memref<2xf32>) outs(%temp : memref<2xf32>) { ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): - %tmp1 = exp %gen1_arg0 : f32 + %tmp1 = math.exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 } linalg.generic { @@ -288,7 +288,7 @@ ins(%arg0 : memref<2xf32>) outs(%to : memref<2xf32>) { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): - %tmp2 = exp %gen2_arg0 : f32 + %tmp2 = math.exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 } // CHECK: linalg.copy diff --git a/mlir/test/Transforms/promote-buffers-to-stack.mlir b/mlir/test/Transforms/promote-buffers-to-stack.mlir --- a/mlir/test/Transforms/promote-buffers-to-stack.mlir +++ b/mlir/test/Transforms/promote-buffers-to-stack.mlir @@ -367,7 +367,7 @@ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): %1 = alloc() : memref<2xf32> test.buffer_based in(%arg1: memref<2xf32>) out(%1: memref<2xf32>) - %tmp1 = exp %gen1_arg0 : f32 + %tmp1 = math.exp %gen1_arg0 : f32 test.region_yield %tmp1 : f32 } br ^bb3(%0 : memref<2xf32>) diff --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp --- a/mlir/test/lib/Transforms/TestExpandTanh.cpp +++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp @@ -1,4 +1,4 @@ -//===- TestExpandTanh.cpp - Test expand tanh op into exp form ------===// +//===- TestExpandTanh.cpp - Test expand tanh op into exp form -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -13,6 +13,7 @@ // CHECK-NEXT: llvm_arm_neon // CHECK-NEXT: llvm_arm_sve // CHECK-NEXT: llvm_avx512 +// CHECK-NEXT: math // CHECK-NEXT: nvvm // CHECK-NEXT: omp // CHECK-NEXT: pdl