diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLLVMIR MLIRMath MLIRPass MLIRStandard diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -21,97 +23,129 @@ using namespace mlir; using namespace mlir::vector; -static bool isValidFloatType(Type type) { - if (auto vectorType = type.dyn_cast()) - return vectorType.getElementType().isa(); - return type.isa(); -} +// Returns a vector length if the type is a f32 vector type, 1 if the type is +// a f32 scalar, and None otherwise. +static Optional isValidF32Type(Type type) { + if (type.isF32()) + return 1; -//----------------------------------------------------------------------------// -// A PatternRewriter wrapper that provides concise API for building expansions -// for operations on float scalars or vectors. -//----------------------------------------------------------------------------// - -namespace { -class FloatApproximationBuilder { -public: - FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter); + if (auto vectorType = type.dyn_cast()) { + auto elementType = vectorType.getElementType(); + if (vectorType.getRank() != 1 || !elementType.isF32()) + return llvm::None; + return vectorType.getDimSize(0); + } - Value constant(double value) const; - - Value abs(Value a) const; - Value min(Value a, Value b) const; - Value max(Value a, Value b) const; - Value mul(Value a, Value b) const; - Value div(Value a, Value b) const; - - // Fused multiple-add operation: a * b + c. - Value madd(Value a, Value b, Value c) const; - - // Compares values `a` and `b` with the given `predicate`. - Value cmp(CmpFPredicate predicate, Value a, Value b) const; - - // Selects values from `a` or `b` based on the `predicate`. - Value select(Value predicate, Value a, Value b) const; - -private: - Location loc; - PatternRewriter &rewriter; - VectorType vectorType; // can be null for scalar type - FloatType elementType; -}; -} // namespace - -FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type, - PatternRewriter &rewriter) - : loc(loc), rewriter(rewriter) { - vectorType = type.dyn_cast(); - - if (vectorType) - elementType = vectorType.getElementType().cast(); - else - elementType = type.cast(); + return llvm::None; } -Value FloatApproximationBuilder::constant(double value) const { - auto attr = rewriter.getFloatAttr(elementType, value); - Value scalar = rewriter.create(loc, attr); - - if (vectorType) - return rewriter.create(loc, vectorType, scalar); - return scalar; +// Broadcasts scalar type into vector type (or return scalar type itself if the +// vector size is `1`). +static Type broadcast(int vectorSize, Type type) { + assert(!type.isa() && "must be scalar type"); + return vectorSize > 1 ? VectorType::get({vectorSize}, type) : type; } -Value FloatApproximationBuilder::abs(Value a) const { - return rewriter.create(loc, a); +// Broadcasts scalar value into the vector value (or returnes the value itself +// if the vector size is `1`). +static Value broadcast(int vectorSize, Location loc, PatternRewriter &rewriter, + Value value) { + assert(!value.getType().isa() && "must be scalar value"); + return vectorSize > 1 + ? rewriter.create( + loc, broadcast(vectorSize, value.getType()), value) + : value; } -Value FloatApproximationBuilder::min(Value a, Value b) const { - return select(cmp(CmpFPredicate::OLT, a, b), a, b); -} -Value FloatApproximationBuilder::max(Value a, Value b) const { - return select(cmp(CmpFPredicate::OGT, a, b), a, b); +static int vectorSize(Type type) { + if (auto vectorType = type.dyn_cast()) { + assert(vectorType.getRank() == 1); + return vectorType.getDimSize(0); + } + return 1; } -Value FloatApproximationBuilder::mul(Value a, Value b) const { - return rewriter.create(loc, a, b); + +static Type elementType(Type type) { + if (auto vectorType = type.dyn_cast()) + return vectorType.getElementType(); + return type; } -Value FloatApproximationBuilder::div(Value a, Value b) const { - return rewriter.create(loc, a, b); +//----------------------------------------------------------------------------// +// Helper functions to build math functions approximations. +//----------------------------------------------------------------------------// + +static Value min(Location loc, PatternRewriter &rewriter, Value a, Value b) { + return rewriter.create( + loc, rewriter.create(loc, CmpFPredicate::OLT, a, b), a, b); } -Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const { - return rewriter.create(loc, a, b, c); +static Value max(Location loc, PatternRewriter &rewriter, Value a, Value b) { + return rewriter.create( + loc, rewriter.create(loc, CmpFPredicate::OGT, a, b), a, b); } -Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a, - Value b) const { - return rewriter.create(loc, predicate, a, b); +static Value clamp(Location loc, PatternRewriter &rewriter, Value value, + Value lowerBound, Value upperBound) { + return max(loc, rewriter, min(loc, rewriter, value, upperBound), lowerBound); } -Value FloatApproximationBuilder::select(Value predicate, Value a, - Value b) const { - return rewriter.create(loc, predicate, a, b); +// Decomposes given floating point value `arg` into a normalized fraction and +// an integral power of two (see std::frexp). Returned values have float type. +static std::pair frexp(Location loc, PatternRewriter &rewriter, + Value arg) { + assert(elementType(arg.getType()).isF32() && "argument must be f32 type"); + + int size = vectorSize(arg.getType()); + auto i32 = broadcast(size, rewriter.getIntegerType(32)); + auto f32 = broadcast(size, rewriter.getF32Type()); + + auto bcast = [&](Value value) -> Value { + return broadcast(size, loc, rewriter, value); + }; + + auto f32Value = [&](float value) -> Value { + return rewriter.create(loc, rewriter.getF32Type(), + rewriter.getF32FloatAttr(value)); + }; + + auto f32FromBits = [&](uint32_t bits) -> Value { + auto i32Value = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(static_cast(bits))); + return rewriter.create(loc, rewriter.getF32Type(), + i32Value); + }; + + Value cst126f = bcast(f32Value(126.0f)); + Value cstHalf = bcast(f32Value(0.5f)); + Value cstInvMantMask = bcast(f32FromBits(~0x7f800000u)); + + // Cast to i32 for bitwise operations. + Value i32Half = rewriter.create(loc, i32, cstHalf); + Value i32InvMantMask = + rewriter.create(loc, i32, cstInvMantMask); + Value i32Arg = rewriter.create(loc, i32, arg); + + // Compute normalized fraction. + Value tmp0 = rewriter.create( + loc, rewriter.create(loc, i32Arg, i32InvMantMask), i32Half); + Value normalized_fraction = rewriter.create(loc, f32, tmp0); + + // Compute exponent. + Value shiftBy = bcast(rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(23))); + + Value shifted = rewriter.create( + loc, + rewriter.create(loc, i32, + rewriter.create(loc, arg)), + shiftBy); + + Value tmp1 = rewriter.create(loc, f32, shifted); + Value exponent = rewriter.create(loc, tmp1, cst126f); + + return {normalized_fraction, exponent}; } //----------------------------------------------------------------------------// @@ -131,64 +165,216 @@ LogicalResult TanhApproximation::matchAndRewrite(math::TanhOp op, PatternRewriter &rewriter) const { - if (!isValidFloatType(op.operand().getType())) + Value operand = op.operand(); + + auto vectorSize = isValidF32Type(operand.getType()); + if (!vectorSize.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - Value operand = op.operand(); - FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter); + Location loc = op->getLoc(); + + auto bcast = [&](Value value) -> Value { + return broadcast(*vectorSize, loc, rewriter, value); + }; + + auto f32Value = [&](float value) -> Value { + return rewriter.create(loc, rewriter.getF32Type(), + rewriter.getF32FloatAttr(value)); + }; // Clamp operand into [plusClamp, minusClamp] range. - Value plusClamp = builder.constant(7.90531110763549805); - Value minusClamp = builder.constant(-7.9053111076354980); - Value x = builder.max(builder.min(operand, plusClamp), minusClamp); + Value minusClamp = bcast(f32Value(-7.9053111076354980f)); + Value plusClamp = bcast(f32Value(7.90531110763549805f)); + Value x = clamp(loc, rewriter, operand, minusClamp, plusClamp); // Mask for tiny values that are approximated with `operand`. - Value tiny = builder.constant(0.0004f); - Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny); + Value tiny = bcast(f32Value(0.0004f)); + Value tinyMask = rewriter.create( + loc, CmpFPredicate::OLT, rewriter.create(loc, operand), tiny); // The monomial coefficients of the numerator polynomial (odd). - Value alpha1 = builder.constant(4.89352455891786e-03); - Value alpha3 = builder.constant(6.37261928875436e-04); - Value alpha5 = builder.constant(1.48572235717979e-05); - Value alpha7 = builder.constant(5.12229709037114e-08); - Value alpha9 = builder.constant(-8.60467152213735e-11); - Value alpha11 = builder.constant(2.00018790482477e-13); - Value alpha13 = builder.constant(-2.76076847742355e-16); + Value alpha1 = bcast(f32Value(4.89352455891786e-03f)); + Value alpha3 = bcast(f32Value(6.37261928875436e-04f)); + Value alpha5 = bcast(f32Value(1.48572235717979e-05f)); + Value alpha7 = bcast(f32Value(5.12229709037114e-08f)); + Value alpha9 = bcast(f32Value(-8.60467152213735e-11f)); + Value alpha11 = bcast(f32Value(2.00018790482477e-13f)); + Value alpha13 = bcast(f32Value(-2.76076847742355e-16f)); // The monomial coefficients of the denominator polynomial (even). - Value beta0 = builder.constant(4.89352518554385e-03); - Value beta2 = builder.constant(2.26843463243900e-03); - Value beta4 = builder.constant(1.18534705686654e-04); - Value beta6 = builder.constant(1.19825839466702e-06); + Value beta0 = bcast(f32Value(4.89352518554385e-03f)); + Value beta2 = bcast(f32Value(2.26843463243900e-03f)); + Value beta4 = bcast(f32Value(1.18534705686654e-04f)); + Value beta6 = bcast(f32Value(1.19825839466702e-06f)); // Since the polynomials are odd/even, we need x^2. - Value x2 = builder.mul(x, x); + Value x2 = rewriter.create(loc, x, x); // Evaluate the numerator polynomial p. - Value p = builder.madd(x2, alpha13, alpha11); - p = builder.madd(x2, p, alpha9); - p = builder.madd(x2, p, alpha7); - p = builder.madd(x2, p, alpha5); - p = builder.madd(x2, p, alpha3); - p = builder.madd(x2, p, alpha1); - p = builder.mul(x, p); + Value p = rewriter.create(loc, x2, alpha13, alpha11); + p = rewriter.create(loc, x2, p, alpha9); + p = rewriter.create(loc, x2, p, alpha7); + p = rewriter.create(loc, x2, p, alpha5); + p = rewriter.create(loc, x2, p, alpha3); + p = rewriter.create(loc, x2, p, alpha1); + p = rewriter.create(loc, x, p); // Evaluate the denominator polynomial q. - Value q = builder.madd(x2, beta6, beta4); - q = builder.madd(x2, q, beta2); - q = builder.madd(x2, q, beta0); + Value q = rewriter.create(loc, x2, beta6, beta4); + q = rewriter.create(loc, x2, q, beta2); + q = rewriter.create(loc, x2, q, beta0); // Divide the numerator by the denominator. - Value res = builder.select(tinyMask, x, builder.div(p, q)); + Value res = rewriter.create(loc, tinyMask, x, + rewriter.create(loc, p, q)); rewriter.replaceOp(op, res); return success(); } +//----------------------------------------------------------------------------// +// LogOp approximation. +//----------------------------------------------------------------------------// + +namespace { + +// This approximations comes from the Julien Pommier's SSE math library. +// Link: http://gruntthepeon.free.fr/ssemath +struct LogApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::LogOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +#define LN2_VALUE \ + 0.693147180559945309417232121458176568075500134360255254120680009493393621L + +LogicalResult +LogApproximation::matchAndRewrite(math::LogOp op, + PatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value operand = op.operand(); + + auto vectorSize = isValidF32Type(operand.getType()); + if (!vectorSize.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + auto bcast = [&](Value value) -> Value { + return broadcast(*vectorSize, loc, rewriter, value); + }; + + auto f32Value = [&](float value) -> Value { + return rewriter.create(loc, rewriter.getF32Type(), + rewriter.getF32FloatAttr(value)); + }; + + auto f32FromBits = [&](uint32_t bits) -> Value { + auto i32Value = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(static_cast(bits))); + return rewriter.create(loc, rewriter.getF32Type(), + i32Value); + }; + + Value cstZero = bcast(f32Value(0.0f)); + Value cstOne = bcast(f32Value(1.0f)); + Value cstNegHalf = bcast(f32Value(-0.5f)); + + // The smallest non denormalized float number. + Value cstMinNormPos = bcast(f32FromBits(0x00800000u)); + Value cstMinusInf = bcast(f32FromBits(0xff800000u)); + Value cstPosInf = bcast(f32FromBits(0x7f800000u)); + Value cstNan = bcast(f32FromBits(0x7fc00000)); + + // Polynomial coefficients. + Value cstCephesSQRTHF = bcast(f32Value(0.707106781186547524f)); + Value cstCephesLogP0 = bcast(f32Value(7.0376836292E-2f)); + Value cstCephesLogP1 = bcast(f32Value(-1.1514610310E-1f)); + Value cstCephesLogP2 = bcast(f32Value(1.1676998740E-1f)); + Value cstCephesLogP3 = bcast(f32Value(-1.2420140846E-1f)); + Value cstCephesLogP4 = bcast(f32Value(+1.4249322787E-1f)); + Value cstCephesLogP5 = bcast(f32Value(-1.6668057665E-1f)); + Value cstCephesLogP6 = bcast(f32Value(+2.0000714765E-1f)); + Value cstCephesLogP7 = bcast(f32Value(-2.4999993993E-1f)); + Value cstCephesLogP8 = bcast(f32Value(+3.3333331174E-1f)); + + Value x = op.operand(); + + // Truncate input values to the minimum positive normal. + x = max(loc, rewriter, x, cstMinNormPos); + + // Extract significant in the range [0.5,1) and exponent. + std::pair pair = frexp(loc, rewriter, x); + x = pair.first; + Value e = pair.second; + + // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift + // by -1.0. The values are then centered around 0, which improves the + // stability of the polynomial evaluation: + // + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + Value mask = + rewriter.create(loc, CmpFPredicate::OLT, x, cstCephesSQRTHF); + Value tmp = rewriter.create(loc, mask, x, cstZero); + + x = rewriter.create(loc, x, cstOne); + e = rewriter.create( + loc, e, rewriter.create(loc, mask, cstOne, cstZero)); + x = rewriter.create(loc, x, tmp); + + Value x2 = rewriter.create(loc, x, x); + Value x3 = rewriter.create(loc, x2, x); + + // Evaluate the polynomial approximant of degree 8 in three parts. + Value y0, y1, y2; + y0 = rewriter.create(loc, cstCephesLogP0, x, cstCephesLogP1); + y1 = rewriter.create(loc, cstCephesLogP3, x, cstCephesLogP4); + y2 = rewriter.create(loc, cstCephesLogP6, x, cstCephesLogP7); + y0 = rewriter.create(loc, y0, x, cstCephesLogP2); + y1 = rewriter.create(loc, y1, x, cstCephesLogP5); + y2 = rewriter.create(loc, y2, x, cstCephesLogP8); + y0 = rewriter.create(loc, y0, x3, y1); + y0 = rewriter.create(loc, y0, x3, y2); + y0 = rewriter.create(loc, y0, x3); + + y0 = rewriter.create(loc, cstNegHalf, x2, y0); + x = rewriter.create(loc, x, y0); + + Value cstLn2 = bcast(f32Value(static_cast(LN2_VALUE))); + x = rewriter.create(loc, e, cstLn2, x); + + Value invalidMask = + rewriter.create(loc, CmpFPredicate::ULT, op.operand(), cstZero); + Value zeroMask = + rewriter.create(loc, CmpFPredicate::OEQ, op.operand(), cstZero); + Value posInfMask = + rewriter.create(loc, CmpFPredicate::OEQ, op.operand(), cstPosInf); + + // Filter out invalid values: + // • x == 0 -> -INF + // • x < 0 -> NAN + // • x == +INF -> +INF + Value aproximation = rewriter.create( + loc, zeroMask, cstMinusInf, + rewriter.create( + loc, invalidMask, cstNan, + rewriter.create(loc, posInfMask, cstPosInf, x))); + + rewriter.replaceOp(op, aproximation); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -1,15 +1,22 @@ // RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s -// CHECK-LABEL: @tanh_scalar -func @tanh_scalar(%arg0: f32) -> f32 { +// Check that all math functions lowered to approximations built from +// standard operations (add, mul, fma, shift, etc...). + +// CHECK-LABEL: @scalar +func @scalar(%arg0: f32) -> f32 { // CHECK-NOT: tanh %0 = math.tanh %arg0 : f32 - return %0 : f32 + // CHECK-NOT: log + %1 = math.log %0 : f32 + return %1 : f32 } -// CHECK-LABEL: @tanh_vector -func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> { +// CHECK-LABEL: @vector +func @vector(%arg0: vector<8xf32>) -> vector<8xf32> { // CHECK-NOT: tanh %0 = math.tanh %arg0 : vector<8xf32> - return %0 : vector<8xf32> + // CHECK-NOT: log + %1 = math.log %0 : vector<8xf32> + return %1 : vector<8xf32> } diff --git a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp --- a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp +++ b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -24,7 +25,8 @@ : public PassWrapper { void runOnFunction() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } }; } // end anonymous namespace diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -7,12 +7,10 @@ // RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s - -func @main() { - // ------------------------------------------------------------------------ // - // Tanh. - // ------------------------------------------------------------------------ // - +// ------------------------------------------------------------------------ // +// Tanh. +// ------------------------------------------------------------------------ // +func @tanh() { // CHECK: 0.848284 %0 = constant 1.25 : f32 %1 = math.tanh %0 : f32 @@ -30,3 +28,51 @@ return } + +// ------------------------------------------------------------------------ // +// Log. +// ------------------------------------------------------------------------ // +func @log() { + // CHECK: 2.64704 + %0 = constant 14.112233 : f32 + %1 = math.log %0 : f32 + vector.print %1 : f32 + + // CHECK: -1.38629, -0.287682, 0, 0.223144 + %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32> + %3 = math.log %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: -2.30259, -1.60944, -1.20397, -0.916291, -0.693147, -0.510826, -0.356675, -0.223144 + %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32> + %5 = math.log %4 : vector<8xf32> + vector.print %5 : vector<8xf32> + + // CHECK: -inf + %zero = constant 0.0 : f32 + %log_zero = math.log %zero : f32 + vector.print %log_zero : f32 + + // CHECK: nan + %neg_one = constant -1.0 : f32 + %log_neg_one = math.log %neg_one : f32 + vector.print %log_neg_one : f32 + + // CHECK: inf + %inf = constant 0x7f800000 : f32 + %log_inf = math.log %inf : f32 + vector.print %log_inf : f32 + + // CHECK: -inf, nan, inf, 0.693147 + %special_vec = constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32> + %log_special_vec = math.log %special_vec : vector<4xf32> + vector.print %log_special_vec : vector<4xf32> + + return +} + +func @main() { + call @tanh(): () -> () + call @log(): () -> () + return +}