diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -20,10 +20,18 @@ namespace mlir { class Pass; +class MLIRContext; +class OwningRewritePatternList; /// Creates an instance of the ExpandAtomic pass. std::unique_ptr createExpandAtomicPass(); +/// Creates an instance of the ExpandTanh pass. +std::unique_ptr createExpandTanhPass(); + +void populateExpandTanhPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx); + } // end namespace mlir #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -16,4 +16,9 @@ let constructor = "mlir::createExpandAtomicPass()"; } +def ExpandTanh : FunctionPass<"expand-tanh"> { + let summary = "Expands TanhOp into ExpOp form."; + let constructor = "mlir::createExpandTanhPass()"; +} + #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -377,6 +378,7 @@ populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns); + populateExpandTanhPattern(patterns, &getContext()); populateVectorToSCFConversionPatterns(patterns, &getContext()); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms ExpandAtomic.cpp + ExpandTanh.cpp FuncConversions.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/test/Dialect/Standard/expand-tanh.mlir b/mlir/test/Dialect/Standard/expand-tanh.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/expand-tanh.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -expand-tanh | FileCheck %s + +// CHECK-LABEL: func @tanh +func @tanh(%arg: f32) -> f32 { + %res = tanh %arg : f32 + return %res : f32 +} +// CHECK: %[[ONE:.+]] = constant 1.000000e+00 : f32 +// CHECK: %[[TWO:.+]] = constant 2.000000e+00 : f32 +// CHECK: %[[XX:.+]] = mulf %arg0, %[[TWO]] : f32 +// CHECK: %[[NEGXX:.+]] = negf %[[XX]] : f32 +// CHECK: %[[EXP1:.+]] = exp %[[NEGXX]] : f32 +// CHECK: %[[DIVIDEND1:.+]] = subf %[[ONE]], %[[EXP1]] : f32 +// CHECK: %[[DIVISOR1:.+]] = addf %[[ONE]], %[[EXP1]] : f32 +// CHECK: %[[RES1:.+]] = divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32 +// CHECK: %[[EXP2:.+]] = exp %[[XX]] : f32 +// CHECK: %[[DIVIDEND2:.+]] = subf %[[EXP2]], %[[ONE]] : f32 +// CHECK: %[[DIVISOR2:.+]] = addf %[[EXP2]], %[[ONE]] : f32 +// CHECK: %[[RES2:.+]] = divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32 +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[COND:.+]] = cmpf "oge", %arg0, %[[ZERO]] : f32 +// CHECK: %[[RESULT:.+]] = select %[[COND]], %[[RES1]], %[[RES2]] : f32 +// CHECK: return %[[RESULT]]