diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -19,6 +19,9 @@ void populateExpandTanhPattern(OwningRewritePatternList &patterns, MLIRContext *ctx); +void populateMathPolynomialApproximationPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + } // namespace mlir #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRMathTransforms ExpandTanh.cpp + PolynomialApproximation.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -0,0 +1,194 @@ +//===- PolynomialApproximation.cpp - Approximate math operations ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements expansion of math operations to fast approximations +// that do not rely on any of the library functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; + +static bool isValidFloatType(Type type) { + if (auto vectorType = type.dyn_cast()) + return vectorType.getElementType().isa(); + return type.isa(); +} + +//----------------------------------------------------------------------------// +// A PatternRewriter wrapper that provides concise API for building expansions +// for operations on float scalars or vectors. +//----------------------------------------------------------------------------// + +namespace { +class FloatApproximationBuilder { +public: + FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter); + + Value constant(double value) const; + + Value abs(Value a) const; // abs(a) + Value min(Value a, Value b) const; // min(a, b) + Value max(Value a, Value b) const; // max(a, b) + Value mul(Value a, Value b) const; // a * b + Value div(Value a, Value b) const; // a / b + Value madd(Value a, Value b, Value c) const; // a * b + c + + // Compares values `a` and `b` with the given `predicate`. + Value cmp(CmpFPredicate predicate, Value a, Value b) const; + + // Selects values from `a` or `b` based on the `predicate`. + Value select(Value predicate, Value a, Value b) const; + +private: + Location loc; + PatternRewriter &rewriter; + VectorType vectorType; // can be null for scalar type + FloatType elementType; +}; +} // namespace + +FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type, + PatternRewriter &rewriter) + : loc(loc), rewriter(rewriter) { + vectorType = type.dyn_cast(); + + if (vectorType) + elementType = vectorType.getElementType().cast(); + else + elementType = type.cast(); +} + +Value FloatApproximationBuilder::constant(double value) const { + auto attr = rewriter.getFloatAttr(elementType, value); + Value scalar = rewriter.create(loc, attr); + + if (vectorType) + return rewriter.create(loc, vectorType, scalar); + return scalar; +} + +Value FloatApproximationBuilder::abs(Value a) const { + return rewriter.create(loc, a); +} + +Value FloatApproximationBuilder::min(Value a, Value b) const { + return select(cmp(CmpFPredicate::OLT, a, b), a, b); +} +Value FloatApproximationBuilder::max(Value a, Value b) const { + return select(cmp(CmpFPredicate::OGT, a, b), a, b); +} +Value FloatApproximationBuilder::mul(Value a, Value b) const { + return rewriter.create(loc, a, b); +} + +Value FloatApproximationBuilder::div(Value a, Value b) const { + return rewriter.create(loc, a, b); +} + +Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const { + if (vectorType) + return rewriter.create(loc, a, b, c); + + Value mul = rewriter.create(loc, a, b); + return rewriter.create(loc, mul, c); +} + +Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a, + Value b) const { + return rewriter.create(loc, predicate, a, b); +} + +Value FloatApproximationBuilder::select(Value predicate, Value a, + Value b) const { + return rewriter.create(loc, predicate, a, b); +} + +//----------------------------------------------------------------------------// +// TanhOp expansion. +//----------------------------------------------------------------------------// + +namespace { +struct TanhApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::TanhOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +TanhApproximation::matchAndRewrite(math::TanhOp op, + PatternRewriter &rewriter) const { + if (!isValidFloatType(op.operand().getType())) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + Value operand = op.operand(); + FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter); + + // Clamp operand into [plusClamp, minusClamp] range. + Value plusClamp = builder.constant(7.90531110763549805); + Value minusClamp = builder.constant(-7.9053111076354980); + Value x = builder.max(builder.min(operand, plusClamp), minusClamp); + + // The monomial coefficients of the numerator polynomial (odd). + Value alpha_1 = builder.constant(4.89352455891786e-03); + Value alpha_3 = builder.constant(6.37261928875436e-04); + Value alpha_5 = builder.constant(1.48572235717979e-05); + Value alpha_7 = builder.constant(5.12229709037114e-08); + Value alpha_9 = builder.constant(-8.60467152213735e-11); + Value alpha_11 = builder.constant(2.00018790482477e-13); + Value alpha_13 = builder.constant(-2.76076847742355e-16); + + // The monomial coefficients of the denominator polynomial (even). + Value beta_0 = builder.constant(4.89352518554385e-03); + Value beta_2 = builder.constant(2.26843463243900e-03); + Value beta_4 = builder.constant(1.18534705686654e-04); + Value beta_6 = builder.constant(1.19825839466702e-06); + + // Since the polynomials are odd/even, we need x^2. + Value x2 = builder.mul(x, x); + + // Evaluate the numerator polynomial p. + Value p = builder.madd(x2, alpha_13, alpha_11); + p = builder.madd(x2, p, alpha_9); + p = builder.madd(x2, p, alpha_7); + p = builder.madd(x2, p, alpha_5); + p = builder.madd(x2, p, alpha_3); + p = builder.madd(x2, p, alpha_1); + p = builder.mul(x, p); + + // Evaluate the denominator polynomial q. + Value q = builder.madd(x2, beta_6, beta_4); + q = builder.madd(x2, q, beta_2); + q = builder.madd(x2, q, beta_0); + + // Divide the numerator by the denominator. + Value tiny = builder.constant(0.0004f); + Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny); + Value res = builder.select(tinyMask, x, builder.div(p, q)); + + rewriter.replaceOp(op, res); + + return success(); +} + +//----------------------------------------------------------------------------// + +void mlir::populateMathPolynomialApproximationPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s + +// CHECK-LABEL: @tanh_scalar +func @tanh_scalar(%arg0: f32) -> f32 { + // CHECK-NOT: tanh + %0 = math.tanh %arg0 : f32 + return %0 : f32 +} + +// CHECK-LABEL: @tanh_vector +func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> { + // CHECK-NOT: tanh + %0 = math.tanh %arg0 : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -26,6 +26,7 @@ TestLoopUnrolling.cpp TestNumberOfExecutions.cpp TestOpaqueLoc.cpp + TestPolynomialApproximation.cpp TestMemRefBoundCheck.cpp TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp diff --git a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp @@ -0,0 +1,46 @@ +//===- TestPolynomialApproximation.cpp - Test math ops approximations -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for expanding math operations into +// polynomial approximations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestMathPolynomialApproximationPass + : public PassWrapper { + void runOnFunction() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // end anonymous namespace + +void TestMathPolynomialApproximationPass::runOnFunction() { + OwningRewritePatternList patterns; + populateMathPolynomialApproximationPatterns(patterns, &getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +namespace mlir { +namespace test { +void registerTestMathPolynomialApproximationPass() { + PassRegistration pass( + "test-math-polynomial-approximation", + "Test math polynomial approximations"); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -test-math-polynomial-approximation \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + + +func @main() { + // ------------------------------------------------------------------------ // + // Tanh. + // ------------------------------------------------------------------------ // + + // CHECK: 0.848284 + %0 = constant 1.25 : f32 + %1 = math.tanh %0 : f32 + vector.print %1 : f32 + + // CHECK: 0.244919, 0.635149, 0.761594, 0.848284 + %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32> + %3 = math.tanh %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: 0.099668, 0.197375, 0.291313, 0.379949, 0.462117, 0.53705, 0.604368, 0.664037 + %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32> + %5 = math.tanh %4 : vector<8xf32> + vector.print %5 : vector<8xf32> + + return +} diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -84,6 +84,7 @@ void registerTestLoopFusion(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); +void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); void registerTestNumberOfBlockExecutionsPass(); @@ -157,6 +158,7 @@ test::registerTestLoopFusion(); test::registerTestLoopMappingPass(); test::registerTestLoopUnrollingPass(); + test::registerTestMathPolynomialApproximationPass(); test::registerTestMemRefDependenceCheck(); test::registerTestMemRefStrideCalculation(); test::registerTestNumberOfBlockExecutionsPass();