diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -790,6 +790,27 @@ return *this; } + // Add a matchAndRewrite style pattern represented as a C function pointer. + template + OwningRewritePatternList & + insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) { + struct Folder final : public OpRewritePattern { + Folder(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), + MLIRContext *context) + : OpRewritePattern(context), implFn(implFn) {} + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + return implFn(op, rewriter); + } + + private: + LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); + }; + insert(std::make_unique(std::move(implFn), getContext())); + return *this; + } + private: /// Add an instance of the pattern type 'T'. Return a reference to `this` for /// chaining insertions. diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp @@ -15,51 +15,42 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Transforms/DialectConversion.h" - using namespace mlir; -namespace { /// Expands tanh op into /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 -struct TanhOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(math::TanhOp op, - PatternRewriter &rewriter) const final { - auto floatType = op.operand().getType(); - Location loc = op.getLoc(); - auto floatOne = rewriter.getFloatAttr(floatType, 1.0); - auto floatTwo = rewriter.getFloatAttr(floatType, 2.0); - Value one = rewriter.create(loc, floatOne); - Value two = rewriter.create(loc, floatTwo); - Value doubledX = rewriter.create(loc, op.operand(), two); - - // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} - Value negDoubledX = rewriter.create(loc, doubledX); - Value exp2x = rewriter.create(loc, negDoubledX); - Value dividend = rewriter.create(loc, one, exp2x); - Value divisor = rewriter.create(loc, one, exp2x); - Value positiveRes = rewriter.create(loc, dividend, divisor); - - // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 - exp2x = rewriter.create(loc, doubledX); - dividend = rewriter.create(loc, exp2x, one); - divisor = rewriter.create(loc, exp2x, one); - Value negativeRes = rewriter.create(loc, dividend, divisor); - - // tanh(x) = x >= 0 ? positiveRes : negativeRes - auto floatZero = rewriter.getFloatAttr(floatType, 0.0); - Value zero = rewriter.create(loc, floatZero); - Value cmpRes = - rewriter.create(loc, CmpFPredicate::OGE, op.operand(), zero); - rewriter.replaceOpWithNewOp(op, cmpRes, positiveRes, negativeRes); - return success(); - } -}; -} // namespace +static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { + auto floatType = op.operand().getType(); + Location loc = op.getLoc(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + auto floatTwo = rewriter.getFloatAttr(floatType, 2.0); + Value one = rewriter.create(loc, floatOne); + Value two = rewriter.create(loc, floatTwo); + Value doubledX = rewriter.create(loc, op.operand(), two); + + // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} + Value negDoubledX = rewriter.create(loc, doubledX); + Value exp2x = rewriter.create(loc, negDoubledX); + Value dividend = rewriter.create(loc, one, exp2x); + Value divisor = rewriter.create(loc, one, exp2x); + Value positiveRes = rewriter.create(loc, dividend, divisor); + + // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 + exp2x = rewriter.create(loc, doubledX); + dividend = rewriter.create(loc, exp2x, one); + divisor = rewriter.create(loc, exp2x, one); + Value negativeRes = rewriter.create(loc, dividend, divisor); + + // tanh(x) = x >= 0 ? positiveRes : negativeRes + auto floatZero = rewriter.getFloatAttr(floatType, 0.0); + Value zero = rewriter.create(loc, floatZero); + Value cmpRes = + rewriter.create(loc, CmpFPredicate::OGE, op.operand(), zero); + rewriter.replaceOpWithNewOp(op, cmpRes, positiveRes, negativeRes); + return success(); +} void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns) { - patterns.insert(patterns.getContext()); + patterns.insert(convertTanhOp); }