diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -285,6 +285,39 @@ }]; } +//===----------------------------------------------------------------------===// +// ErfOp +//===----------------------------------------------------------------------===// + +def Math_ErfOp : Math_FloatUnaryOp<"erf"> { + let summary = "error function of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.erf` ssa-use `:` type + ``` + + The `erf` operation computes the error function. It takes one operand + and returns one result of the same type. This type may be a float scalar + type, a vector whose element type is float, or a tensor of floats. It has + no standard attributes. + + Example: + + ```mlir + // Scalar error function value. + %a = math.erf %b : f64 + + // SIMD vector element-wise error function value. + %f = math.erf %g : vector<4xf32> + + // Tensor element-wise error function value. + %x = math.erf %y : tensor<4x?xf8> + ``` + }]; +} + //===----------------------------------------------------------------------===// // ExpOp diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h @@ -0,0 +1,29 @@ +//===- Approximation.h - Math dialect -----------------------------*- C++-*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_ +#define MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_ + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace math { + +struct ErfPolynomialApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::ErfOp op, + PatternRewriter &rewriter) const final; +}; + +} // namespace math +} // namespace mlir + +#endif // MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_ diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -116,6 +116,8 @@ VecOpToScalarOp>(patterns.getContext(), benefit); patterns.add>(patterns.getContext(), "atan2f", "atan2", benefit); + patterns.add>(patterns.getContext(), "erff", + "erf", benefit); patterns.add>(patterns.getContext(), "expm1f", "expm1", benefit); patterns.add>(patterns.getContext(), "tanhf", diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" @@ -21,9 +22,12 @@ #include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ArrayRef.h" #include +#include using namespace mlir; +using namespace mlir::math; using namespace mlir::vector; using TypePredicate = llvm::function_ref; @@ -183,6 +187,24 @@ return exp2ValueF32; } +namespace { +Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, + llvm::ArrayRef coeffs, Value x) { + auto width = vectorWidth(x.getType(), isF32); + if (coeffs.size() == 0) { + return broadcast(builder, f32Cst(builder, 0.0f), *width); + } else if (coeffs.size() == 1) { + return coeffs[0]; + } + Value res = builder.create(x, coeffs[coeffs.size() - 1], + coeffs[coeffs.size() - 2]); + for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { + res = builder.create(x, res, coeffs[i]); + } + return res; +} +} // namespace + //----------------------------------------------------------------------------// // TanhOp approximation. //----------------------------------------------------------------------------// @@ -465,6 +487,122 @@ return success(); } +//----------------------------------------------------------------------------// +// Erf approximation. +//----------------------------------------------------------------------------// + +// Approximates erf(x) with +// a - P(x)/Q(x) +// where P and Q are polynomials of degree 4. +// Different coefficients are chosen based on the value of x. +// The approximation error is ~2.5e-07. +// Boost's minimax tool that utilizes the Remez method was used to find the +// coefficients. +LogicalResult +ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, + PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + if (!width.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + + const int intervalsCount = 3; + const int polyDegree = 4; + + Value zero = bcast(f32Cst(builder, 0)); + Value one = bcast(f32Cst(builder, 1)); + Value pp[intervalsCount][polyDegree + 1]; + pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00)); + pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00)); + pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01)); + pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01)); + pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02)); + pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00)); + pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00)); + pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01)); + pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01)); + pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02)); + pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03)); + pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03)); + pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03)); + pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04)); + pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05)); + + Value qq[intervalsCount][polyDegree + 1]; + qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00)); + qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01)); + qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01)); + qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01)); + qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02)); + qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00)); + qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01)); + qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01)); + qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02)); + qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02)); + qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00)); + qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00)); + qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00)); + qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01)); + qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02)); + + Value offsets[intervalsCount]; + offsets[0] = bcast(f32Cst(builder, 0)); + offsets[1] = bcast(f32Cst(builder, 0)); + offsets[2] = bcast(f32Cst(builder, 1)); + + Value bounds[intervalsCount]; + bounds[0] = bcast(f32Cst(builder, 0.8)); + bounds[1] = bcast(f32Cst(builder, 2)); + bounds[2] = bcast(f32Cst(builder, 3.75)); + + Value isNegativeArg = builder.create(arith::CmpFPredicate::OLT, + op.operand(), zero); + Value negArg = builder.create(op.operand()); + Value x = builder.create(isNegativeArg, negArg, op.operand()); + + Value offset = offsets[0]; + Value p[polyDegree + 1]; + Value q[polyDegree + 1]; + for (int i = 0; i <= polyDegree; ++i) { + p[i] = pp[0][i]; + q[i] = qq[0][i]; + } + + // TODO: maybe use vector stacking to reduce the number of selects. + Value isLessThanBound[intervalsCount]; + for (int j = 0; j < intervalsCount - 1; ++j) { + isLessThanBound[j] = + builder.create(arith::CmpFPredicate::OLT, x, bounds[j]); + for (int i = 0; i <= polyDegree; ++i) { + p[i] = builder.create(isLessThanBound[j], p[i], pp[j + 1][i]); + q[i] = builder.create(isLessThanBound[j], q[i], qq[j + 1][i]); + } + offset = + builder.create(isLessThanBound[j], offset, offsets[j + 1]); + } + isLessThanBound[intervalsCount - 1] = builder.create( + arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); + + Value pPoly = makePolynomialCalculation(builder, p, x); + Value qPoly = makePolynomialCalculation(builder, q, x); + Value rationalPoly = builder.create(pPoly, qPoly); + Value formula = builder.create(offset, rationalPoly); + formula = builder.create(isLessThanBound[intervalsCount - 1], + formula, one); + + // erf is odd function: erf(x) = -erf(-x). + Value negFormula = builder.create(formula); + Value res = builder.create(isNegativeArg, negFormula, formula); + + rewriter.replaceOp(op, res); + + return success(); +} + //----------------------------------------------------------------------------// // Exp approximation. //----------------------------------------------------------------------------// @@ -848,8 +986,8 @@ RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options) { patterns.add, + Log1pApproximation, ErfPolynomialApproximation, ExpApproximation, + ExpM1Approximation, SinAndCosApproximation, SinAndCosApproximation>( patterns.getContext()); if (options.enableAvx2) diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s +// CHECK-DAG: @erf(f64) -> f64 +// CHECK-DAG: @erff(f32) -> f32 // CHECK-DAG: @expm1(f64) -> f64 // CHECK-DAG: @expm1f(f32) -> f32 // CHECK-DAG: @atan2(f64, f64) -> f64 @@ -32,6 +34,18 @@ return %float_result, %double_result : f32, f64 } +// CHECK-LABEL: func @erf_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func @erf_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @erff(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.erf %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @erf(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.erf %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + // CHECK-LABEL: func @expm1_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -50,6 +50,18 @@ return } +// CHECK-LABEL: func @erf( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.erf %[[F]] : f32 + %0 = math.erf %f : f32 + // CHECK: %{{.*}} = math.erf %[[V]] : vector<4xf32> + %1 = math.erf %v : vector<4xf32> + // CHECK: %{{.*}} = math.erf %[[T]] : tensor<4x4x?xf32> + %2 = math.erf %t : tensor<4x4x?xf32> + return +} + // CHECK-LABEL: func @exp( // CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) func @exp(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -5,6 +5,95 @@ // Check that all math functions lowered to approximations built from // standard operations (add, mul, fma, shift, etc...). +// CHECK-LABEL: func @erf_scalar( +// CHECK-SAME: %[[val_arg0:.*]]: f32) -> f32 { +// CHECK-DAG: %[[val_cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[val_cst_0:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[val_cst_1:.*]] = arith.constant 1.12837911 : f32 +// CHECK-DAG: %[[val_cst_2:.*]] = arith.constant -0.523018539 : f32 +// CHECK-DAG: %[[val_cst_3:.*]] = arith.constant 0.209741712 : f32 +// CHECK-DAG: %[[val_cst_4:.*]] = arith.constant 0.0258146804 : f32 +// CHECK-DAG: %[[val_cst_5:.*]] = arith.constant 1.12750685 : f32 +// CHECK-DAG: %[[val_cst_6:.*]] = arith.constant -0.364721417 : f32 +// CHECK-DAG: %[[val_cst_7:.*]] = arith.constant 0.118407398 : f32 +// CHECK-DAG: %[[val_cst_8:.*]] = arith.constant 0.0370645523 : f32 +// CHECK-DAG: %[[val_cst_9:.*]] = arith.constant -0.00330093061 : f32 +// CHECK-DAG: %[[val_cst_10:.*]] = arith.constant 0.00351961935 : f32 +// CHECK-DAG: %[[val_cst_11:.*]] = arith.constant -0.00141373626 : f32 +// CHECK-DAG: %[[val_cst_12:.*]] = arith.constant 2.53447099E-4 : f32 +// CHECK-DAG: %[[val_cst_13:.*]] = arith.constant -1.71048032E-5 : f32 +// CHECK-DAG: %[[val_cst_14:.*]] = arith.constant -0.463513821 : f32 +// CHECK-DAG: %[[val_cst_15:.*]] = arith.constant 0.519230127 : f32 +// CHECK-DAG: %[[val_cst_16:.*]] = arith.constant -0.131808966 : f32 +// CHECK-DAG: %[[val_cst_17:.*]] = arith.constant 0.0739796459 : f32 +// CHECK-DAG: %[[val_cst_18:.*]] = arith.constant -3.276070e-01 : f32 +// CHECK-DAG: %[[val_cst_19:.*]] = arith.constant 0.448369086 : f32 +// CHECK-DAG: %[[val_cst_20:.*]] = arith.constant -0.0883462652 : f32 +// CHECK-DAG: %[[val_cst_21:.*]] = arith.constant 0.0572442785 : f32 +// CHECK-DAG: %[[val_cst_22:.*]] = arith.constant -2.0606916 : f32 +// CHECK-DAG: %[[val_cst_23:.*]] = arith.constant 1.62705934 : f32 +// CHECK-DAG: %[[val_cst_24:.*]] = arith.constant -0.583389878 : f32 +// CHECK-DAG: %[[val_cst_25:.*]] = arith.constant 0.0821908935 : f32 +// CHECK-DAG: %[[val_cst_26:.*]] = arith.constant 8.000000e-01 : f32 +// CHECK-DAG: %[[val_cst_27:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[val_cst_28:.*]] = arith.constant 3.750000e+00 : f32 +// CHECK: %[[val_0:.*]] = arith.cmpf olt, %[[val_arg0]], %[[val_cst]] : f32 +// CHECK: %[[val_1:.*]] = arith.negf %[[val_arg0]] : f32 +// CHECK: %[[val_2:.*]] = select %[[val_0]], %[[val_1]], %[[val_arg0]] : f32 +// CHECK: %[[val_3:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_26]] : f32 +// CHECK: %[[val_4:.*]] = select %[[val_3]], %[[val_cst_1]], %[[val_cst_5]] : f32 +// CHECK: %[[val_5:.*]] = select %[[val_3]], %[[val_cst_14]], %[[val_cst_18]] : f32 +// CHECK: %[[val_6:.*]] = select %[[val_3]], %[[val_cst_2]], %[[val_cst_6]] : f32 +// CHECK: %[[val_7:.*]] = select %[[val_3]], %[[val_cst_15]], %[[val_cst_19]] : f32 +// CHECK: %[[val_8:.*]] = select %[[val_3]], %[[val_cst_3]], %[[val_cst_7]] : f32 +// CHECK: %[[val_9:.*]] = select %[[val_3]], %[[val_cst_16]], %[[val_cst_20]] : f32 +// CHECK: %[[val_10:.*]] = select %[[val_3]], %[[val_cst_4]], %[[val_cst_8]] : f32 +// CHECK: %[[val_11:.*]] = select %[[val_3]], %[[val_cst_17]], %[[val_cst_21]] : f32 +// CHECK: %[[val_12:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_27]] : f32 +// CHECK: %[[val_13:.*]] = select %[[val_12]], %[[val_cst]], %[[val_cst_9]] : f32 +// CHECK: %[[val_14:.*]] = select %[[val_12]], %[[val_4]], %[[val_cst_10]] : f32 +// CHECK: %[[val_15:.*]] = select %[[val_12]], %[[val_5]], %[[val_cst_22]] : f32 +// CHECK: %[[val_16:.*]] = select %[[val_12]], %[[val_6]], %[[val_cst_11]] : f32 +// CHECK: %[[val_17:.*]] = select %[[val_12]], %[[val_7]], %[[val_cst_23]] : f32 +// CHECK: %[[val_18:.*]] = select %[[val_12]], %[[val_8]], %[[val_cst_12]] : f32 +// CHECK: %[[val_19:.*]] = select %[[val_12]], %[[val_9]], %[[val_cst_24]] : f32 +// CHECK: %[[val_20:.*]] = select %[[val_12]], %[[val_10]], %[[val_cst_13]] : f32 +// CHECK: %[[val_21:.*]] = select %[[val_12]], %[[val_11]], %[[val_cst_25]] : f32 +// CHECK: %[[val_22:.*]] = select %[[val_12]], %[[val_cst]], %[[val_cst_0]] : f32 +// CHECK: %[[val_23:.*]] = arith.cmpf ult, %[[val_2]], %[[val_cst_28]] : f32 +// CHECK: %[[val_24:.*]] = math.fma %[[val_2]], %[[val_20]], %[[val_18]] : f32 +// CHECK: %[[val_25:.*]] = math.fma %[[val_2]], %[[val_24]], %[[val_16]] : f32 +// CHECK: %[[val_26:.*]] = math.fma %[[val_2]], %[[val_25]], %[[val_14]] : f32 +// CHECK: %[[val_27:.*]] = math.fma %[[val_2]], %[[val_26]], %[[val_13]] : f32 +// CHECK: %[[val_28:.*]] = math.fma %[[val_2]], %[[val_21]], %[[val_19]] : f32 +// CHECK: %[[val_29:.*]] = math.fma %[[val_2]], %[[val_28]], %[[val_17]] : f32 +// CHECK: %[[val_30:.*]] = math.fma %[[val_2]], %[[val_29]], %[[val_15]] : f32 +// CHECK: %[[val_31:.*]] = math.fma %[[val_2]], %[[val_30]], %[[val_cst_0]] : f32 +// CHECK: %[[val_32:.*]] = arith.divf %[[val_27]], %[[val_31]] : f32 +// CHECK: %[[val_33:.*]] = arith.addf %[[val_22]], %[[val_32]] : f32 +// CHECK: %[[val_34:.*]] = select %[[val_23]], %[[val_33]], %[[val_cst_0]] : f32 +// CHECK: %[[val_35:.*]] = arith.negf %[[val_34]] : f32 +// CHECK: %[[val_36:.*]] = select %[[val_0]], %[[val_35]], %[[val_34]] : f32 +// CHECK: return %[[val_36]] : f32 +// CHECK: } +func @erf_scalar(%arg0: f32) -> f32 { + %0 = math.erf %arg0 : f32 + return %0 : f32 +} + +// CHECK-LABEL: func @erf_vector( +// CHECK-SAME: %[[arg0:.*]]: vector<8xf32>) -> vector<8xf32> { +// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK-NOT: erf +// CHECK-COUNT-20: select +// CHECK: %[[res:.*]] = select +// CHECK: return %[[res]] : vector<8xf32> +// CHECK: } +func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> { + %0 = math.erf %arg0 : vector<8xf32> + return %0 : vector<8xf32> +} + // CHECK-LABEL: func @exp_scalar( // CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { // CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.693147182 : f32 diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -152,6 +152,78 @@ return } +// -------------------------------------------------------------------------- // +// Erf. +// -------------------------------------------------------------------------- // +func @erf() { + // CHECK: -0.000274406 + %val1 = arith.constant -2.431864e-4 : f32 + %erfVal1 = math.erf %val1 : f32 + vector.print %erfVal1 : f32 + + // CHECK: 0.742095 + %val2 = arith.constant 0.79999 : f32 + %erfVal2 = math.erf %val2 : f32 + vector.print %erfVal2 : f32 + + // CHECK: 0.742101 + %val3 = arith.constant 0.8 : f32 + %erfVal3 = math.erf %val3 : f32 + vector.print %erfVal3 : f32 + + // CHECK: 0.995322 + %val4 = arith.constant 1.99999 : f32 + %erfVal4 = math.erf %val4 : f32 + vector.print %erfVal4 : f32 + + // CHECK: 0.995322 + %val5 = arith.constant 2.0 : f32 + %erfVal5 = math.erf %val5 : f32 + vector.print %erfVal5 : f32 + + // CHECK: 1 + %val6 = arith.constant 3.74999 : f32 + %erfVal6 = math.erf %val6 : f32 + vector.print %erfVal6 : f32 + + // CHECK: 1 + %val7 = arith.constant 3.75 : f32 + %erfVal7 = math.erf %val7 : f32 + vector.print %erfVal7 : f32 + + // CHECK: -1 + %negativeInf = arith.constant 0xff800000 : f32 + %erfNegativeInf = math.erf %negativeInf : f32 + vector.print %erfNegativeInf : f32 + + // CHECK: -1, -1, -0.913759, -0.731446 + %vecVals1 = arith.constant dense<[-3.4028235e+38, -4.54318, -1.2130899, -7.8234202e-01]> : vector<4xf32> + %erfVecVals1 = math.erf %vecVals1 : vector<4xf32> + vector.print %erfVecVals1 : vector<4xf32> + + // CHECK: -1.3264e-38, 0, 1.3264e-38, 0.121319 + %vecVals2 = arith.constant dense<[-1.1754944e-38, 0.0, 1.1754944e-38, 1.0793410e-01]> : vector<4xf32> + %erfVecVals2 = math.erf %vecVals2 : vector<4xf32> + vector.print %erfVecVals2 : vector<4xf32> + + // CHECK: 0.919477, 0.999069, 1, 1 + %vecVals3 = arith.constant dense<[1.23578, 2.34093, 3.82342, 3.4028235e+38]> : vector<4xf32> + %erfVecVals3 = math.erf %vecVals3 : vector<4xf32> + vector.print %erfVecVals3 : vector<4xf32> + + // CHECK: 1 + %inf = arith.constant 0x7f800000 : f32 + %erfInf = math.erf %inf : f32 + vector.print %erfInf : f32 + + // CHECK: nan + %nan = arith.constant 0x7fc00000 : f32 + %erfNan = math.erf %nan : f32 + vector.print %erfNan : f32 + + return +} + // -------------------------------------------------------------------------- // // Exp. // -------------------------------------------------------------------------- // @@ -305,6 +377,7 @@ call @log(): () -> () call @log2(): () -> () call @log1p(): () -> () + call @erf(): () -> () call @exp(): () -> () call @expm1(): () -> () call @sin(): () -> () diff --git a/mlir/utils/vim/syntax/mlir.vim b/mlir/utils/vim/syntax/mlir.vim --- a/mlir/utils/vim/syntax/mlir.vim +++ b/mlir/utils/vim/syntax/mlir.vim @@ -43,6 +43,9 @@ syn keyword mlirOps splat store select sqrt subf subi subview tanh syn keyword mlirOps view +" Math ops. +syn match mlirOps /\/ + " Affine ops. syn match mlirOps /\/ syn match mlirOps /\/