diff --git a/mlir/include/mlir/Dialect/Math/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Math/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Math/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Math) +add_public_tablegen_target(MLIRMathPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc MathPasses ./) diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -19,6 +19,16 @@ void populateExpandTanhPattern(OwningRewritePatternList &patterns, MLIRContext *ctx); +std::unique_ptr> createFastMathExpansionPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" + } // namespace mlir #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -0,0 +1,21 @@ +//===-- Passes.td - Math pass definition file --------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MATH_TRANSORMS_PASSES +#define MLIR_DIALECT_MATH_TRANSORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def FastMathExpansion : FunctionPass<"fast-math-expansion"> { + let summary = "Expands math operations to fast approximations that do not " + "rely on any library functions."; + let constructor = "mlir::createFastMathExpansionPass()"; + let dependentDialects = ["math::MathDialect", "vector::VectorDialect"]; +} + +#endif // MLIR_DIALECT_MATH_TRANSORMS_PASSES diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -20,6 +20,7 @@ #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" @@ -54,6 +55,7 @@ registerLinalgPasses(); LLVM::registerLLVMPasses(); quant::registerQuantPasses(); + registerMathPasses(); registerSCFPasses(); registerShapePasses(); spirv::registerSPIRVPasses(); diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,9 +1,13 @@ add_mlir_dialect_library(MLIRMathTransforms ExpandTanh.cpp + FastMathExpansion.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms + DEPENDS + MLIRMathPassIncGen + LINK_LIBS PUBLIC MLIRIR MLIRMath diff --git a/mlir/lib/Dialect/Math/Transforms/FastMathExpansion.cpp b/mlir/lib/Dialect/Math/Transforms/FastMathExpansion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/FastMathExpansion.cpp @@ -0,0 +1,220 @@ +//===- FastMathExpansion.cpp - Expand math ops to fast approximations -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements expansion of math operations to fast approximations +// that do not rely on any of the library functions. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; + +static bool isValidFloatType(Type type) { + if (auto vectorType = type.dyn_cast()) + return vectorType.getElementType().isa(); + return type.isa(); +} + +//----------------------------------------------------------------------------// +// A PatternRewriter wrapper that provides concise API for building expansions +// for operations on float scalars or vectors. +//----------------------------------------------------------------------------// + +namespace { +class FloatExpansionBuilder { +public: + FloatExpansionBuilder(Location loc, Type type, PatternRewriter &rewriter); + + Value constant(double value) const; + + Value abs(Value a) const; // abs(a) + Value min(Value a, Value b) const; // min(a, b) + Value max(Value a, Value b) const; // max(a, b) + Value mul(Value a, Value b) const; // a * b + Value div(Value a, Value b) const; // a / b + Value madd(Value a, Value b, Value c) const; // a * b + c + + // Compares values `a` and `b` with the given `predicate`. + Value cmp(CmpFPredicate predicate, Value a, Value b) const; + + // Selects values from `a` or `b` based on the `predicate`. + Value select(Value predicate, Value a, Value b) const; + +private: + Location loc; + PatternRewriter &rewriter; + VectorType vectorType; // can be null for scalar type + FloatType elementType; +}; +} // namespace + +FloatExpansionBuilder::FloatExpansionBuilder(Location loc, Type type, + PatternRewriter &rewriter) + : loc(loc), rewriter(rewriter) { + vectorType = type.dyn_cast(); + + if (vectorType) + elementType = vectorType.getElementType().cast(); + else + elementType = type.cast(); +} + +Value FloatExpansionBuilder::constant(double value) const { + auto attr = rewriter.getFloatAttr(elementType, value); + Value scalar = rewriter.create(loc, attr); + + if (vectorType) + return rewriter.create(loc, vectorType, scalar); + return scalar; +} + +Value FloatExpansionBuilder::abs(Value a) const { + return rewriter.create(loc, a); +} + +Value FloatExpansionBuilder::min(Value a, Value b) const { + return select(cmp(CmpFPredicate::OLT, a, b), a, b); +} +Value FloatExpansionBuilder::max(Value a, Value b) const { + return select(cmp(CmpFPredicate::OGT, a, b), a, b); +} +Value FloatExpansionBuilder::mul(Value a, Value b) const { + return rewriter.create(loc, a, b); +} + +Value FloatExpansionBuilder::div(Value a, Value b) const { + return rewriter.create(loc, a, b); +} + +Value FloatExpansionBuilder::madd(Value a, Value b, Value c) const { + if (vectorType) + return rewriter.create(loc, a, b, c); + + Value mul = rewriter.create(loc, a, b); + return rewriter.create(loc, mul, c); +} + +Value FloatExpansionBuilder::cmp(CmpFPredicate predicate, Value a, + Value b) const { + return rewriter.create(loc, predicate, a, b); +} + +Value FloatExpansionBuilder::select(Value predicate, Value a, Value b) const { + return rewriter.create(loc, predicate, a, b); +} + +//----------------------------------------------------------------------------// +// TanhOp expansion. +//----------------------------------------------------------------------------// + +namespace { +struct TanhExpansion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::TanhOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult TanhExpansion::matchAndRewrite(math::TanhOp op, + PatternRewriter &rewriter) const { + if (!isValidFloatType(op.operand().getType())) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + Value operand = op.operand(); + FloatExpansionBuilder builder(op->getLoc(), operand.getType(), rewriter); + + // Clamp operand into [plusClamp, minusClamp] range. + Value plusClamp = builder.constant(7.90531110763549805); + Value minusClamp = builder.constant(-7.9053111076354980); + Value x = builder.max(builder.min(operand, plusClamp), minusClamp); + + // The monomial coefficients of the numerator polynomial (odd). + Value alpha_1 = builder.constant(4.89352455891786e-03); + Value alpha_3 = builder.constant(6.37261928875436e-04); + Value alpha_5 = builder.constant(1.48572235717979e-05); + Value alpha_7 = builder.constant(5.12229709037114e-08); + Value alpha_9 = builder.constant(-8.60467152213735e-11); + Value alpha_11 = builder.constant(2.00018790482477e-13); + Value alpha_13 = builder.constant(-2.76076847742355e-16); + + // The monomial coefficients of the denominator polynomial (even). + Value beta_0 = builder.constant(4.89352518554385e-03); + Value beta_2 = builder.constant(2.26843463243900e-03); + Value beta_4 = builder.constant(1.18534705686654e-04); + Value beta_6 = builder.constant(1.19825839466702e-06); + + // Since the polynomials are odd/even, we need x^2. + Value x2 = builder.mul(x, x); + + // Evaluate the numerator polynomial p. + Value p = builder.madd(x2, alpha_13, alpha_11); + p = builder.madd(x2, p, alpha_9); + p = builder.madd(x2, p, alpha_7); + p = builder.madd(x2, p, alpha_5); + p = builder.madd(x2, p, alpha_3); + p = builder.madd(x2, p, alpha_1); + p = builder.mul(x, p); + + // Evaluate the denominator polynomial q. + Value q = builder.madd(x2, beta_6, beta_4); + q = builder.madd(x2, q, beta_2); + q = builder.madd(x2, q, beta_0); + + // Divide the numerator by the denominator. + Value tiny = builder.constant(0.0004f); + Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny); + Value res = builder.select(tinyMask, x, builder.div(p, q)); + + rewriter.replaceOp(op, res); + + return success(); +} + +//----------------------------------------------------------------------------// +// Function pass that expands math operations to fast approximations. +//----------------------------------------------------------------------------// + +namespace { +class FastMathExpansionPass + : public FastMathExpansionBase { +public: + FastMathExpansionPass() = default; + void runOnFunction() override; +}; +} // namespace + +void FastMathExpansionPass::runOnFunction() { + mlir::MLIRContext *ctx = &getContext(); + mlir::OwningRewritePatternList patterns; + patterns.insert(ctx); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalOp(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + return; + } +} + +std::unique_ptr> mlir::createFastMathExpansionPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Math/Transforms/PassDetail.h b/mlir/lib/Dialect/Math/Transforms/PassDetail.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/PassDetail.h @@ -0,0 +1,30 @@ +//===- PassDetail.h - Math Pass class details -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_MATH_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_MATH_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace math { +class MathDialect; +} // namespace math + +namespace vector { +class VectorDialect; +} // namespace vector + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif // DIALECT_MATH_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/test/Dialect/Math/fast-math.mlir b/mlir/test/Dialect/Math/fast-math.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Math/fast-math.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -fast-math-expansion | FileCheck %s + +// CHECK-LABEL: @tanh_scalar +func @tanh_scalar(%arg0: f32) -> f32 { + // CHECK-NOT: tanh + %0 = math.tanh %arg0 : f32 + return %0 : f32 +} + +// CHECK-LABEL: @tanh_vector +func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> { + // CHECK-NOT: tanh + %0 = math.tanh %arg0 : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/mlir-cpu-runner/fast-math.mlir b/mlir/test/mlir-cpu-runner/fast-math.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/fast-math.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -fast-math-expansion \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s --dump-input=always + + +func @main() { + // ------------------------------------------------------------------------ // + // Tanh. + // ------------------------------------------------------------------------ // + + // CHECK: 0.848284 + %0 = constant 1.25 : f32 + %1 = math.tanh %0 : f32 + vector.print %1 : f32 + + // CHECK: 0.244919, 0.635149, 0.761594, 0.848284 + %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32> + %3 = math.tanh %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: 0.099668, 0.197375, 0.291313, 0.379949, 0.462117, 0.53705, 0.604368, 0.664037 + %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32> + %5 = math.tanh %4 : vector<8xf32> + vector.print %5 : vector<8xf32> + + return +}