diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLLVMIR MLIRMath MLIRPass MLIRStandard diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -11,107 +11,146 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::vector; -static bool isValidFloatType(Type type) { - if (auto vectorType = type.dyn_cast()) - return vectorType.getElementType().isa(); - return type.isa(); -} - -//----------------------------------------------------------------------------// -// A PatternRewriter wrapper that provides concise API for building expansions -// for operations on float scalars or vectors. -//----------------------------------------------------------------------------// +using TypePredicate = llvm::function_ref; -namespace { -class FloatApproximationBuilder { -public: - FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter); +static bool isF32(Type type) { return type.isF32(); } - Value constant(double value) const; +// Returns vector width if the element type is matching the predicate (scalars +// that do match the predicate have width equal to `1`). +static Optional vectorWidth(Type type, TypePredicate pred) { + // If the type matches the predicate then its width is `1`. + if (pred(type)) + return 1; - Value abs(Value a) const; - Value min(Value a, Value b) const; - Value max(Value a, Value b) const; - Value mul(Value a, Value b) const; - Value div(Value a, Value b) const; + // Otherwise check if the type is a vector type. + auto vectorType = type.dyn_cast(); + if (vectorType && pred(vectorType.getElementType())) { + assert(vectorType.getRank() == 1 && "only 1d vectors are supported"); + return vectorType.getDimSize(0); + } - // Fused multiple-add operation: a * b + c. - Value madd(Value a, Value b, Value c) const; + return llvm::None; +} - // Compares values `a` and `b` with the given `predicate`. - Value cmp(CmpFPredicate predicate, Value a, Value b) const; +// Returns vector width of the type. If the type is a scalar returns `1`. +static int vectorWidth(Type type) { + auto vectorType = type.dyn_cast(); + return vectorType ? vectorType.getDimSize(0) : 1; +} - // Selects values from `a` or `b` based on the `predicate`. - Value select(Value predicate, Value a, Value b) const; +// Returns vector element type. If the type is a scalar returns the argument. +static Type elementType(Type type) { + auto vectorType = type.dyn_cast(); + return vectorType ? vectorType.getElementType() : type; +} -private: - Location loc; - PatternRewriter &rewriter; - VectorType vectorType; // can be null for scalar type - FloatType elementType; -}; -} // namespace +//----------------------------------------------------------------------------// +// Broadcast scalar types and values into vector types and values. +//----------------------------------------------------------------------------// -FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type, - PatternRewriter &rewriter) - : loc(loc), rewriter(rewriter) { - vectorType = type.dyn_cast(); +// Broadcasts scalar type into vector type (iff width is greater then 1). +static Type broadcast(Type type, int width) { + assert(!type.isa() && "must be scalar type"); + return width > 1 ? VectorType::get({width}, type) : type; +} - if (vectorType) - elementType = vectorType.getElementType().cast(); - else - elementType = type.cast(); +// Broadcasts scalar value into vector (iff width is greater then 1). +static Value broadcast(ImplicitLocOpBuilder &builder, Value value, int width) { + assert(!value.getType().isa() && "must be scalar value"); + auto type = broadcast(value.getType(), width); + return width > 1 ? builder.create(type, value) : value; } -Value FloatApproximationBuilder::constant(double value) const { - auto attr = rewriter.getFloatAttr(elementType, value); - Value scalar = rewriter.create(loc, attr); +//----------------------------------------------------------------------------// +// Helper functions to create constants. +//----------------------------------------------------------------------------// - if (vectorType) - return rewriter.create(loc, vectorType, scalar); - return scalar; +static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { + return builder.create(builder.getF32Type(), + builder.getF32FloatAttr(value)); } -Value FloatApproximationBuilder::abs(Value a) const { - return rewriter.create(loc, a); +static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { + return builder.create(builder.getI32Type(), + builder.getI32IntegerAttr(value)); } -Value FloatApproximationBuilder::min(Value a, Value b) const { - return select(cmp(CmpFPredicate::OLT, a, b), a, b); -} -Value FloatApproximationBuilder::max(Value a, Value b) const { - return select(cmp(CmpFPredicate::OGT, a, b), a, b); -} -Value FloatApproximationBuilder::mul(Value a, Value b) const { - return rewriter.create(loc, a, b); +static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { + Value i32Value = i32Cst(builder, static_cast(bits)); + return builder.create(builder.getF32Type(), i32Value); } -Value FloatApproximationBuilder::div(Value a, Value b) const { - return rewriter.create(loc, a, b); +//----------------------------------------------------------------------------// +// Helper functions to build math functions approximations. +//----------------------------------------------------------------------------// + +static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { + return builder.create( + builder.create(CmpFPredicate::OLT, a, b), a, b); } -Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const { - return rewriter.create(loc, a, b, c); +static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { + return builder.create( + builder.create(CmpFPredicate::OGT, a, b), a, b); } -Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a, - Value b) const { - return rewriter.create(loc, predicate, a, b); +static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, + Value upperBound) { + return max(builder, min(builder, value, upperBound), lowerBound); } -Value FloatApproximationBuilder::select(Value predicate, Value a, - Value b) const { - return rewriter.create(loc, predicate, a, b); +// Decomposes given floating point value `arg` into a normalized fraction and +// an integral power of two (see std::frexp). Returned values have float type. +static std::pair frexp(ImplicitLocOpBuilder &builder, Value arg, + bool is_positive = false) { + assert(isF32(elementType(arg.getType())) && "argument must be f32 type"); + + int width = vectorWidth(arg.getType()); + + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, width); + }; + + auto i32 = builder.getIntegerType(32); + auto i32Vec = broadcast(i32, width); + auto f32Vec = broadcast(builder.getF32Type(), width); + + Value cst126f = f32Cst(builder, 126.0f); + Value cstHalf = f32Cst(builder, 0.5f); + Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); + + // Bitcast to i32 for bitwise operations. + Value i32Half = builder.create(i32, cstHalf); + Value i32InvMantMask = builder.create(i32, cstInvMantMask); + Value i32Arg = builder.create(i32Vec, arg); + + // Compute normalized fraction. + Value tmp0 = builder.create(i32Arg, bcast(i32InvMantMask)); + Value tmp1 = builder.create(tmp0, bcast(i32Half)); + Value normalizedFraction = builder.create(f32Vec, tmp1); + + // Compute exponent. + Value arg0 = is_positive ? arg : builder.create(arg); + Value biasedExponentBits = builder.create( + builder.create(i32Vec, arg0), + bcast(i32Cst(builder, 23))); + Value biasedExponent = builder.create(f32Vec, biasedExponentBits); + Value exponent = builder.create(biasedExponent, bcast(cst126f)); + + return {normalizedFraction, exponent}; } //----------------------------------------------------------------------------// @@ -131,64 +170,192 @@ LogicalResult TanhApproximation::matchAndRewrite(math::TanhOp op, PatternRewriter &rewriter) const { - if (!isValidFloatType(op.operand().getType())) + auto width = vectorWidth(op.operand().getType(), isF32); + if (!width.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - Value operand = op.operand(); - FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; // Clamp operand into [plusClamp, minusClamp] range. - Value plusClamp = builder.constant(7.90531110763549805); - Value minusClamp = builder.constant(-7.9053111076354980); - Value x = builder.max(builder.min(operand, plusClamp), minusClamp); + Value minusClamp = bcast(f32Cst(builder, -7.9053111076354980f)); + Value plusClamp = bcast(f32Cst(builder, 7.90531110763549805f)); + Value x = clamp(builder, op.operand(), minusClamp, plusClamp); // Mask for tiny values that are approximated with `operand`. - Value tiny = builder.constant(0.0004f); - Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny); + Value tiny = bcast(f32Cst(builder, 0.0004f)); + Value tinyMask = builder.create( + CmpFPredicate::OLT, builder.create(op.operand()), tiny); // The monomial coefficients of the numerator polynomial (odd). - Value alpha1 = builder.constant(4.89352455891786e-03); - Value alpha3 = builder.constant(6.37261928875436e-04); - Value alpha5 = builder.constant(1.48572235717979e-05); - Value alpha7 = builder.constant(5.12229709037114e-08); - Value alpha9 = builder.constant(-8.60467152213735e-11); - Value alpha11 = builder.constant(2.00018790482477e-13); - Value alpha13 = builder.constant(-2.76076847742355e-16); + Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); + Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); + Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); + Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); + Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); + Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); + Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); // The monomial coefficients of the denominator polynomial (even). - Value beta0 = builder.constant(4.89352518554385e-03); - Value beta2 = builder.constant(2.26843463243900e-03); - Value beta4 = builder.constant(1.18534705686654e-04); - Value beta6 = builder.constant(1.19825839466702e-06); + Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); + Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); + Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); + Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); // Since the polynomials are odd/even, we need x^2. - Value x2 = builder.mul(x, x); + Value x2 = builder.create(x, x); // Evaluate the numerator polynomial p. - Value p = builder.madd(x2, alpha13, alpha11); - p = builder.madd(x2, p, alpha9); - p = builder.madd(x2, p, alpha7); - p = builder.madd(x2, p, alpha5); - p = builder.madd(x2, p, alpha3); - p = builder.madd(x2, p, alpha1); - p = builder.mul(x, p); + Value p = builder.create(x2, alpha13, alpha11); + p = builder.create(x2, p, alpha9); + p = builder.create(x2, p, alpha7); + p = builder.create(x2, p, alpha5); + p = builder.create(x2, p, alpha3); + p = builder.create(x2, p, alpha1); + p = builder.create(x, p); // Evaluate the denominator polynomial q. - Value q = builder.madd(x2, beta6, beta4); - q = builder.madd(x2, q, beta2); - q = builder.madd(x2, q, beta0); + Value q = builder.create(x2, beta6, beta4); + q = builder.create(x2, q, beta2); + q = builder.create(x2, q, beta0); // Divide the numerator by the denominator. - Value res = builder.select(tinyMask, x, builder.div(p, q)); + Value res = + builder.create(tinyMask, x, builder.create(p, q)); rewriter.replaceOp(op, res); return success(); } +//----------------------------------------------------------------------------// +// LogOp approximation. +//----------------------------------------------------------------------------// + +namespace { + +// This approximations comes from the Julien Pommier's SSE math library. +// Link: http://gruntthepeon.free.fr/ssemath +struct LogApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::LogOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +#define LN2_VALUE \ + 0.693147180559945309417232121458176568075500134360255254120680009493393621L + +LogicalResult +LogApproximation::matchAndRewrite(math::LogOp op, + PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + if (!width.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + + Value cstZero = bcast(f32Cst(builder, 0.0f)); + Value cstOne = bcast(f32Cst(builder, 1.0f)); + Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); + + // The smallest non denormalized float number. + Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); + Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); + Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); + Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); + + // Polynomial coefficients. + Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); + Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); + Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); + Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); + Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); + Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); + Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); + Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); + Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); + Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); + + Value x = op.operand(); + + // Truncate input values to the minimum positive normal. + x = max(builder, x, cstMinNormPos); + + // Extract significant in the range [0.5,1) and exponent. + std::pair pair = frexp(builder, x, /*is_positive=*/true); + x = pair.first; + Value e = pair.second; + + // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift + // by -1.0. The values are then centered around 0, which improves the + // stability of the polynomial evaluation: + // + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + Value mask = builder.create(CmpFPredicate::OLT, x, cstCephesSQRTHF); + Value tmp = builder.create(mask, x, cstZero); + + x = builder.create(x, cstOne); + e = builder.create(e, + builder.create(mask, cstOne, cstZero)); + x = builder.create(x, tmp); + + Value x2 = builder.create(x, x); + Value x3 = builder.create(x2, x); + + // Evaluate the polynomial approximant of degree 8 in three parts. + Value y0, y1, y2; + y0 = builder.create(cstCephesLogP0, x, cstCephesLogP1); + y1 = builder.create(cstCephesLogP3, x, cstCephesLogP4); + y2 = builder.create(cstCephesLogP6, x, cstCephesLogP7); + y0 = builder.create(y0, x, cstCephesLogP2); + y1 = builder.create(y1, x, cstCephesLogP5); + y2 = builder.create(y2, x, cstCephesLogP8); + y0 = builder.create(y0, x3, y1); + y0 = builder.create(y0, x3, y2); + y0 = builder.create(y0, x3); + + y0 = builder.create(cstNegHalf, x2, y0); + x = builder.create(x, y0); + + Value cstLn2 = bcast(f32Cst(builder, static_cast(LN2_VALUE))); + x = builder.create(e, cstLn2, x); + + Value invalidMask = + builder.create(CmpFPredicate::ULT, op.operand(), cstZero); + Value zeroMask = + builder.create(CmpFPredicate::OEQ, op.operand(), cstZero); + Value posInfMask = + builder.create(CmpFPredicate::OEQ, op.operand(), cstPosInf); + + // Filter out invalid values: + // • x == 0 -> -INF + // • x < 0 -> NAN + // • x == +INF -> +INF + Value aproximation = builder.create( + zeroMask, cstMinusInf, + builder.create( + invalidMask, cstNan, + builder.create(posInfMask, cstPosInf, x))); + + rewriter.replaceOp(op, aproximation); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -1,15 +1,22 @@ // RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s -// CHECK-LABEL: @tanh_scalar -func @tanh_scalar(%arg0: f32) -> f32 { +// Check that all math functions lowered to approximations built from +// standard operations (add, mul, fma, shift, etc...). + +// CHECK-LABEL: @scalar +func @scalar(%arg0: f32) -> f32 { // CHECK-NOT: tanh %0 = math.tanh %arg0 : f32 - return %0 : f32 + // CHECK-NOT: log + %1 = math.log %0 : f32 + return %1 : f32 } -// CHECK-LABEL: @tanh_vector -func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> { +// CHECK-LABEL: @vector +func @vector(%arg0: vector<8xf32>) -> vector<8xf32> { // CHECK-NOT: tanh %0 = math.tanh %arg0 : vector<8xf32> - return %0 : vector<8xf32> + // CHECK-NOT: log + %1 = math.log %0 : vector<8xf32> + return %1 : vector<8xf32> } diff --git a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp --- a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp +++ b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -24,7 +25,8 @@ : public PassWrapper { void runOnFunction() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } }; } // end anonymous namespace diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -7,12 +7,10 @@ // RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s - -func @main() { - // ------------------------------------------------------------------------ // - // Tanh. - // ------------------------------------------------------------------------ // - +// -------------------------------------------------------------------------- // +// Tanh. +// -------------------------------------------------------------------------- // +func @tanh() { // CHECK: 0.848284 %0 = constant 1.25 : f32 %1 = math.tanh %0 : f32 @@ -30,3 +28,51 @@ return } + +// -------------------------------------------------------------------------- // +// Log. +// -------------------------------------------------------------------------- // +func @log() { + // CHECK: 2.64704 + %0 = constant 14.112233 : f32 + %1 = math.log %0 : f32 + vector.print %1 : f32 + + // CHECK: -1.38629, -0.287682, 0, 0.223144 + %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32> + %3 = math.log %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: -2.30259, -1.60944, -1.20397, -0.916291, -0.693147, -0.510826, -0.356675, -0.223144 + %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32> + %5 = math.log %4 : vector<8xf32> + vector.print %5 : vector<8xf32> + + // CHECK: -inf + %zero = constant 0.0 : f32 + %log_zero = math.log %zero : f32 + vector.print %log_zero : f32 + + // CHECK: nan + %neg_one = constant -1.0 : f32 + %log_neg_one = math.log %neg_one : f32 + vector.print %log_neg_one : f32 + + // CHECK: inf + %inf = constant 0x7f800000 : f32 + %log_inf = math.log %inf : f32 + vector.print %log_inf : f32 + + // CHECK: -inf, nan, inf, 0.693147 + %special_vec = constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32> + %log_special_vec = math.log %special_vec : vector<4xf32> + vector.print %log_special_vec : vector<4xf32> + + return +} + +func @main() { + call @tanh(): () -> () + call @log(): () -> () + return +}