diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -20,10 +20,15 @@ namespace mlir { class Pass; +class MLIRContext; +class OwningRewritePatternList; /// Creates an instance of the ExpandAtomic pass. std::unique_ptr createExpandAtomicPass(); +void populateExpandTanhPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx); + } // end namespace mlir #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms ExpandAtomic.cpp + ExpandTanh.cpp FuncConversions.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp @@ -0,0 +1,70 @@ +//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements expansion of tanh op. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Expands tanh op into +/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 +/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 +struct TanhOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TanhOp op, + PatternRewriter &rewriter) const final { + auto floatType = op.operand().getType(); + Location loc = op.getLoc(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + auto floatTwo = rewriter.getFloatAttr(floatType, 2.0); + Value one = rewriter.create(loc, floatOne); + Value two = rewriter.create(loc, floatTwo); + Value doubledX = rewriter.create(loc, op.operand(), two); + + // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} + Value negDoubledX = rewriter.create(loc, doubledX); + Value exp2x = rewriter.create(loc, negDoubledX); + Value dividend = rewriter.create(loc, one, exp2x); + Value divisor = rewriter.create(loc, one, exp2x); + Value positiveRes = rewriter.create(loc, dividend, divisor); + + // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 + exp2x = rewriter.create(loc, doubledX); + dividend = rewriter.create(loc, exp2x, one); + divisor = rewriter.create(loc, exp2x, one); + Value negativeRes = rewriter.create(loc, dividend, divisor); + + // tanh(x) = x >= 0 ? positiveRes : negativeRes + auto floatZero = rewriter.getFloatAttr(floatType, 0.0); + Value zero = rewriter.create(loc, floatZero); + Value cmpRes = + rewriter.create(loc, CmpFPredicate::OGE, op.operand(), zero); + rewriter.replaceOpWithNewOp(op, cmpRes, positiveRes, negativeRes); + return success(); + } +}; +} // namespace + +void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/mlir/test/Dialect/Standard/expand-tanh.mlir b/mlir/test/Dialect/Standard/expand-tanh.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/expand-tanh.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -test-expand-tanh | FileCheck %s + +// CHECK-LABEL: func @tanh +func @tanh(%arg: f32) -> f32 { + %res = tanh %arg : f32 + return %res : f32 +} +// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK-DAG: %[[ONE:.+]] = constant 1.000000e+00 : f32 +// CHECK-DAG: %[[TWO:.+]] = constant 2.000000e+00 : f32 +// CHECK: %[[DOUBLEDX:.+]] = mulf %arg0, %[[TWO]] : f32 +// CHECK: %[[NEGDOUBLEDX:.+]] = negf %[[DOUBLEDX]] : f32 +// CHECK: %[[EXP1:.+]] = exp %[[NEGDOUBLEDX]] : f32 +// CHECK: %[[DIVIDEND1:.+]] = subf %[[ONE]], %[[EXP1]] : f32 +// CHECK: %[[DIVISOR1:.+]] = addf %[[ONE]], %[[EXP1]] : f32 +// CHECK: %[[RES1:.+]] = divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32 +// CHECK: %[[EXP2:.+]] = exp %[[DOUBLEDX]] : f32 +// CHECK: %[[DIVIDEND2:.+]] = subf %[[EXP2]], %[[ONE]] : f32 +// CHECK: %[[DIVISOR2:.+]] = addf %[[EXP2]], %[[ONE]] : f32 +// CHECK: %[[RES2:.+]] = divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32 +// CHECK: %[[COND:.+]] = cmpf "oge", %arg0, %[[ZERO]] : f32 +// CHECK: %[[RESULT:.+]] = select %[[COND]], %[[RES1]], %[[RES2]] : f32 +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRTestTransforms TestAllReduceLowering.cpp TestBufferPlacement.cpp + TestExpandTanh.cpp TestCallGraph.cpp TestConstantFold.cpp TestConvertGPUKernelToCubin.cpp diff --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp @@ -0,0 +1,37 @@ +//===- TestExpandTanh.cpp - Test expand tanh op into exp form ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for expanding tanh. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestExpandTanhPass + : public PassWrapper { + void runOnFunction() override; +}; +} // end anonymous namespace + +void TestExpandTanhPass::runOnFunction() { + OwningRewritePatternList patterns; + populateExpandTanhPattern(patterns, &getContext()); + applyPatternsAndFoldGreedily(getOperation(), patterns); +} + +namespace mlir { +void registerTestExpandTanhPass() { + PassRegistration pass("test-expand-tanh", + "Test expanding tanh"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -48,6 +48,7 @@ void registerTestConvertGPUKernelToCubinPass(); void registerTestConvertGPUKernelToHsacoPass(); void registerTestDominancePass(); +void registerTestExpandTanhPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); void registerTestLinalgHoisting(); @@ -122,6 +123,7 @@ registerTestBufferPlacementPreparationPass(); registerTestDominancePass(); registerTestFunc(); + registerTestExpandTanhPass(); registerTestGpuMemoryPromotionPass(); registerTestLinalgHoisting(); registerTestLinalgTransforms();