diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -455,6 +455,32 @@ ); } + +//===----------------------------------------------------------------------===// +// Operator: erf +//===----------------------------------------------------------------------===// +def Tosa_ErfOp : Tosa_Op<"erf", [ + DeclareOpInterfaceMethods, + Pure]> { + let summary = "Computes gauss error function of input"; + + let description = [{ + Gauss error function: $ erf(x) = \frac{2}{\sqrt(\pi)} \int_{0}^{x} e^{-t^2} \,dt $ + For quantized integer data types, the TABLE operator should be used instead + with the following definition. The erf_table has 513 entries each of + 16-bit/8-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output + ); +} + //===----------------------------------------------------------------------===// // TOSA Spec Section 2.4 // Operator Class: Elementwise unary/binary/ternary operators. diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -299,6 +299,10 @@ if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); + // tosa::ErfOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + // tosa::GreaterOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OGT, @@ -2044,6 +2048,7 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1058,6 +1058,7 @@ NARY_SHAPE_INFER(tosa::SelectOp) NARY_SHAPE_INFER(tosa::SubOp) NARY_SHAPE_INFER(tosa::TanhOp) +NARY_SHAPE_INFER(tosa::ErfOp) NARY_SHAPE_INFER(tosa::SigmoidOp) #undef PRED_SHAPE_INFER diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -258,6 +258,10 @@ // CHECK: arith.divf %23 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: math.erf + %24 = "tosa.erf"(%0) : (tensor<1xf32>) -> tensor<1xf32> + return } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -114,6 +114,13 @@ return %0 : tensor<13x21x3xf32> } +// ----- +// CHECK-LABEL: erf +func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.erf"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + // ----- // CHECK-LABEL: add func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -65,6 +65,9 @@ // CHECK: "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<4xi32> %12 = "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<*xi32> + + // CHECK: "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %13 = "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> return }