diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -115,6 +115,13 @@ return rewriter.create(loc, resultTypes, args); } + // tosa::ReciprocalOp + if (isa(op) && elementTy.isa()) { + auto one = + rewriter.create(loc, FloatAttr::get(elementTy, 1)); + return rewriter.create(loc, resultTypes, one, args[0]); + } + if (isa(op) && elementTy.isa()) { Value a = args[0]; Value b = args[1]; @@ -325,6 +332,16 @@ rewriter); } + // tosa::SigmoidOp + if (isa(op) && elementTy.isa()) { + auto one = + rewriter.create(loc, FloatAttr::get(elementTy, 1)); + auto negate = rewriter.create(loc, resultTypes, args[0]); + auto exp = rewriter.create(loc, resultTypes, negate); + auto added = rewriter.create(loc, resultTypes, exp, one); + return rewriter.create(loc, resultTypes, one, added); + } + // tosa::CastOp if (isa(op)) { Type srcTy = elementTy; @@ -1382,11 +1399,11 @@ RewritePatternSet *patterns) { patterns->add< PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, @@ -1401,11 +1418,11 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - IdentityNConverter, + PointwiseConverter, IdentityNConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, - ReduceConverter, ArgMaxConverter, ConcatConverter, PadConverter, - ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter, - TransposeConverter, MatMulConverter, FullyConnectedConverter>( - patterns->getContext()); + ReduceConverter, ArgMaxConverter, ConcatConverter, + PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter, + TileConverter, TransposeConverter, MatMulConverter, + FullyConnectedConverter>(patterns->getContext()); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -180,22 +180,33 @@ // CHECK: select %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: negf + // CHECK: exp + // CHECK: addf + // CHECK: divf + %19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32> + // CHECK: linalg.generic // CHECK: fptosi - %19 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32> + %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: constant 0 // CHECK: cmpf - %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1> + %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: fptrunc - %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16> + %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16> // CHECK: linalg.generic // CHECK: yield - %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: divf + %24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32> return }