diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLLVMIR MLIRMath MLIRPass MLIRStandard diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -21,30 +23,69 @@ using namespace mlir; using namespace mlir::vector; -static bool isValidFloatType(Type type) { - if (auto vectorType = type.dyn_cast()) - return vectorType.getElementType().isa(); - return type.isa(); +// Returns a vector length if the type is a f32 vector type, 1 if the type is +// a f32 scalar, and None otherwise. +static Optional isValidF32Type(Type type) { + if (type.isF32()) + return 1; + + if (auto vectorType = type.dyn_cast()) { + auto elementType = vectorType.getElementType(); + if (vectorType.getRank() != 1 || !elementType.isF32()) + return llvm::None; + return vectorType.getDimSize(0); + } + + return llvm::None; } //----------------------------------------------------------------------------// -// A PatternRewriter wrapper that provides concise API for building expansions -// for operations on float scalars or vectors. +// A PatternRewriter wrapper that provides concise API for building function +// approximations for scalars and vectors. //----------------------------------------------------------------------------// namespace { -class FloatApproximationBuilder { +// A helper class that allows to build function approximation IR that works +// on scalars and on vectors (e.g. all created constants automatically +// broadcasted to the vector size). +// +// All type arguments to functions are expected to be scalar types (f32, i32, +// etc...) and are automatically converted to vectors. +class ApproximationsBuilder { public: - FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter); + ApproximationsBuilder(int vectorSize, Location loc, + PatternRewriter &rewriter); + + // Create constants from values. + Value f32Cst(float value) const; + Value i32Cst(int32_t value) const; - Value constant(double value) const; + // Create constants from bit patterns. + Value f32FromBits(uint32_t bits); Value abs(Value a) const; Value min(Value a, Value b) const; Value max(Value a, Value b) const; + Value add(Value a, Value b) const; + Value sub(Value a, Value b) const; Value mul(Value a, Value b) const; Value div(Value a, Value b) const; + // Casts from signed integer to float data type. + Value castSiToFp(Type type, Value arg) const; + + // Bitwise operations that do rely on the LLVM dialect. Arguments must be of + // integer type. + Value bitwiseOr(Value a, Value b) const; + Value bitwiseAnd(Value a, Value b) const; + + // Returns the `arg` shifted to the right a specified number of bits with zero + // fill. + Value logicalShiftRight(Value arg, int n) const; + + // Bitcasting operation that do rely on the LLVM dialect. + Value bitcast(Type type, Value arg) const; + // Fused multiple-add operation: a * b + c. Value madd(Value a, Value b, Value c) const; @@ -54,66 +95,153 @@ // Selects values from `a` or `b` based on the `predicate`. Value select(Value predicate, Value a, Value b) const; + // Decomposes given floating point value `arg` into a normalized fraction and + // an integral power of two (see std::frexp). Returned values have float type. + std::pair frexp(Value arg); + private: + // Converts scalar type/value into vector type/value. + Type vectorize(Type type) const; + Value vectorize(Value value) const; + + // Returns vector element type or `type` itself if it is a scalar type. + Type elementType(Type type) const; + + bool isVectorized() const { return vectorSize > 1; } + + int vectorSize; + MLIRContext *ctx; Location loc; PatternRewriter &rewriter; - VectorType vectorType; // can be null for scalar type - FloatType elementType; }; } // namespace -FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type, - PatternRewriter &rewriter) - : loc(loc), rewriter(rewriter) { - vectorType = type.dyn_cast(); +ApproximationsBuilder::ApproximationsBuilder(int vectorSize, Location loc, + PatternRewriter &rewriter) + : vectorSize(vectorSize), ctx(rewriter.getContext()), loc(loc), + rewriter(rewriter) { + assert(vectorSize >= 0); +} + +Type ApproximationsBuilder::vectorize(Type type) const { + assert(!type.isa() && "must be scalar type"); + if (isVectorized()) + return VectorType::get({vectorSize}, type); + return type; +} + +Value ApproximationsBuilder::vectorize(Value value) const { + assert(!value.getType().isa() && "must be scalar value"); + if (isVectorized()) + return rewriter.create(loc, vectorize(value.getType()), value); + return value; +} + +Type ApproximationsBuilder::elementType(Type type) const { + if (auto vectorType = type.dyn_cast()) + return vectorType.getElementType(); + return type; +} - if (vectorType) - elementType = vectorType.getElementType().cast(); - else - elementType = type.cast(); +Value ApproximationsBuilder::f32Cst(float value) const { + auto attr = rewriter.getF32FloatAttr(value); + return vectorize(rewriter.create(loc, attr)); } -Value FloatApproximationBuilder::constant(double value) const { - auto attr = rewriter.getFloatAttr(elementType, value); - Value scalar = rewriter.create(loc, attr); +Value ApproximationsBuilder::i32Cst(int32_t value) const { + auto attr = rewriter.getI32IntegerAttr(value); + return vectorize(rewriter.create(loc, attr)); +} - if (vectorType) - return rewriter.create(loc, vectorType, scalar); - return scalar; +Value ApproximationsBuilder::f32FromBits(uint32_t bits) { + return bitcast(rewriter.getF32Type(), i32Cst(static_cast(bits))); } -Value FloatApproximationBuilder::abs(Value a) const { +Value ApproximationsBuilder::abs(Value a) const { return rewriter.create(loc, a); } -Value FloatApproximationBuilder::min(Value a, Value b) const { +Value ApproximationsBuilder::castSiToFp(Type type, Value arg) const { + return rewriter.create(loc, vectorize(type), arg); +} + +Value ApproximationsBuilder::bitwiseOr(Value a, Value b) const { + return rewriter.create(loc, a, b); +} + +Value ApproximationsBuilder::bitwiseAnd(Value a, Value b) const { + return rewriter.create(loc, a, b); +} + +Value ApproximationsBuilder::logicalShiftRight(Value arg, int n) const { + return rewriter.create(loc, arg, i32Cst(n)); +} + +Value ApproximationsBuilder::bitcast(Type type, Value arg) const { + return rewriter.create(loc, vectorize(type), arg); +} + +Value ApproximationsBuilder::min(Value a, Value b) const { return select(cmp(CmpFPredicate::OLT, a, b), a, b); } -Value FloatApproximationBuilder::max(Value a, Value b) const { +Value ApproximationsBuilder::max(Value a, Value b) const { return select(cmp(CmpFPredicate::OGT, a, b), a, b); } -Value FloatApproximationBuilder::mul(Value a, Value b) const { +Value ApproximationsBuilder::mul(Value a, Value b) const { return rewriter.create(loc, a, b); } -Value FloatApproximationBuilder::div(Value a, Value b) const { +Value ApproximationsBuilder::add(Value a, Value b) const { + return rewriter.create(loc, a, b); +} + +Value ApproximationsBuilder::sub(Value a, Value b) const { + return rewriter.create(loc, a, b); +} + +Value ApproximationsBuilder::div(Value a, Value b) const { return rewriter.create(loc, a, b); } -Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const { +Value ApproximationsBuilder::madd(Value a, Value b, Value c) const { return rewriter.create(loc, a, b, c); } -Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a, - Value b) const { +Value ApproximationsBuilder::cmp(CmpFPredicate predicate, Value a, + Value b) const { return rewriter.create(loc, predicate, a, b); } -Value FloatApproximationBuilder::select(Value predicate, Value a, - Value b) const { +Value ApproximationsBuilder::select(Value predicate, Value a, Value b) const { return rewriter.create(loc, predicate, a, b); } +std::pair ApproximationsBuilder::frexp(Value arg) { + assert(elementType(arg.getType()).isF32() && "argument must be f32 type"); + + auto i32 = rewriter.getIntegerType(32); + auto f32 = rewriter.getF32Type(); + + Value cst126f = f32Cst(126.0f); + Value cstHalf = f32Cst(0.5f); + Value cstInvMantMask = f32FromBits(~0x7f800000u); + + // Cast to i32 for bitwise operations. + Value i32Half = bitcast(i32, cstHalf); + Value i32InvMantMask = bitcast(i32, cstInvMantMask); + Value i32Arg = bitcast(i32, arg); + + // Compute normalized fraction. + Value tmp0 = bitwiseOr(bitwiseAnd(i32Arg, i32InvMantMask), i32Half); + Value normalized_fraction = bitcast(f32, tmp0); + + // Compute exponent. + Value tmp1 = castSiToFp(f32, logicalShiftRight(bitcast(i32, abs(arg)), 23)); + Value exponent = sub(tmp1, cst126f); + + return {normalized_fraction, exponent}; +} + //----------------------------------------------------------------------------// // TanhOp approximation. //----------------------------------------------------------------------------// @@ -131,35 +259,37 @@ LogicalResult TanhApproximation::matchAndRewrite(math::TanhOp op, PatternRewriter &rewriter) const { - if (!isValidFloatType(op.operand().getType())) + Value operand = op.operand(); + + auto vectorSize = isValidF32Type(operand.getType()); + if (!vectorSize.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); - Value operand = op.operand(); - FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter); + ApproximationsBuilder builder(*vectorSize, op->getLoc(), rewriter); // Clamp operand into [plusClamp, minusClamp] range. - Value plusClamp = builder.constant(7.90531110763549805); - Value minusClamp = builder.constant(-7.9053111076354980); + Value plusClamp = builder.f32Cst(7.90531110763549805f); + Value minusClamp = builder.f32Cst(-7.9053111076354980f); Value x = builder.max(builder.min(operand, plusClamp), minusClamp); // Mask for tiny values that are approximated with `operand`. - Value tiny = builder.constant(0.0004f); + Value tiny = builder.f32Cst(0.0004f); Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny); // The monomial coefficients of the numerator polynomial (odd). - Value alpha1 = builder.constant(4.89352455891786e-03); - Value alpha3 = builder.constant(6.37261928875436e-04); - Value alpha5 = builder.constant(1.48572235717979e-05); - Value alpha7 = builder.constant(5.12229709037114e-08); - Value alpha9 = builder.constant(-8.60467152213735e-11); - Value alpha11 = builder.constant(2.00018790482477e-13); - Value alpha13 = builder.constant(-2.76076847742355e-16); + Value alpha1 = builder.f32Cst(4.89352455891786e-03f); + Value alpha3 = builder.f32Cst(6.37261928875436e-04f); + Value alpha5 = builder.f32Cst(1.48572235717979e-05f); + Value alpha7 = builder.f32Cst(5.12229709037114e-08f); + Value alpha9 = builder.f32Cst(-8.60467152213735e-11f); + Value alpha11 = builder.f32Cst(2.00018790482477e-13f); + Value alpha13 = builder.f32Cst(-2.76076847742355e-16f); // The monomial coefficients of the denominator polynomial (even). - Value beta0 = builder.constant(4.89352518554385e-03); - Value beta2 = builder.constant(2.26843463243900e-03); - Value beta4 = builder.constant(1.18534705686654e-04); - Value beta6 = builder.constant(1.19825839466702e-06); + Value beta0 = builder.f32Cst(4.89352518554385e-03f); + Value beta2 = builder.f32Cst(2.26843463243900e-03f); + Value beta4 = builder.f32Cst(1.18534705686654e-04f); + Value beta6 = builder.f32Cst(1.19825839466702e-06f); // Since the polynomials are odd/even, we need x^2. Value x2 = builder.mul(x, x); @@ -186,9 +316,126 @@ return success(); } +//----------------------------------------------------------------------------// +// LogOp approximation. +//----------------------------------------------------------------------------// + +namespace { + +// This approximations comes from the Julien Pommier's SSE math library. +// Link: http://gruntthepeon.free.fr/ssemath +struct LogApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::LogOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +#define LN2_VALUE \ + 0.693147180559945309417232121458176568075500134360255254120680009493393621L + +LogicalResult +LogApproximation::matchAndRewrite(math::LogOp op, + PatternRewriter &rewriter) const { + Value operand = op.operand(); + + auto vectorSize = isValidF32Type(operand.getType()); + if (!vectorSize.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ApproximationsBuilder builder(*vectorSize, op->getLoc(), rewriter); + + Value cstZero = builder.f32Cst(0.0f); + Value cstOne = builder.f32Cst(1.0f); + Value cstNegHalf = builder.f32Cst(-0.5f); + + // The smallest non denormalized float number. + Value cstMinNormPos = builder.f32FromBits(0x00800000u); + Value cstMinusInf = builder.f32FromBits(0xff800000u); + Value cstPosInf = builder.f32FromBits(0x7f800000u); + Value cstNan = builder.f32FromBits(0x7fc00000); + + // Polynomial coefficients. + Value cstCephesSQRTHF = builder.f32Cst(0.707106781186547524f); + Value cstCephesLogP0 = builder.f32Cst(7.0376836292E-2f); + Value cstCephesLogP1 = builder.f32Cst(-1.1514610310E-1f); + Value cstCephesLogP2 = builder.f32Cst(1.1676998740E-1f); + Value cstCephesLogP3 = builder.f32Cst(-1.2420140846E-1f); + Value cstCephesLogP4 = builder.f32Cst(+1.4249322787E-1f); + Value cstCephesLogP5 = builder.f32Cst(-1.6668057665E-1f); + Value cstCephesLogP6 = builder.f32Cst(+2.0000714765E-1f); + Value cstCephesLogP7 = builder.f32Cst(-2.4999993993E-1f); + Value cstCephesLogP8 = builder.f32Cst(+3.3333331174E-1f); + + Value x = op.operand(); + + // Truncate input values to the minimum positive normal. + x = builder.max(x, cstMinNormPos); + + // Extract significant in the range [0.5,1) and exponent. + std::pair pair = builder.frexp(x); + x = pair.first; + Value e = pair.second; + + // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift + // by -1.0. The values are then centered around 0, which improves the + // stability of the polynomial evaluation: + // + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + Value mask = builder.cmp(CmpFPredicate::OLT, x, cstCephesSQRTHF); + Value tmp = builder.select(mask, x, cstZero); + + x = builder.sub(x, cstOne); + e = builder.sub(e, builder.select(mask, cstOne, cstZero)); + x = builder.add(x, tmp); + + Value x2 = builder.mul(x, x); + Value x3 = builder.mul(x2, x); + + // Evaluate the polynomial approximant of degree 8 in three parts. + Value y0, y1, y2; + y0 = builder.madd(cstCephesLogP0, x, cstCephesLogP1); + y1 = builder.madd(cstCephesLogP3, x, cstCephesLogP4); + y2 = builder.madd(cstCephesLogP6, x, cstCephesLogP7); + y0 = builder.madd(y0, x, cstCephesLogP2); + y1 = builder.madd(y1, x, cstCephesLogP5); + y2 = builder.madd(y2, x, cstCephesLogP8); + y0 = builder.madd(y0, x3, y1); + y0 = builder.madd(y0, x3, y2); + y0 = builder.mul(y0, x3); + + y0 = builder.madd(cstNegHalf, x2, y0); + x = builder.add(x, y0); + + Value cstLn2 = builder.f32Cst(static_cast(LN2_VALUE)); + x = builder.madd(e, cstLn2, x); + + Value invalidMask = builder.cmp(CmpFPredicate::ULT, op.operand(), cstZero); + Value zeroMask = builder.cmp(CmpFPredicate::OEQ, op.operand(), cstZero); + Value posInfMask = builder.cmp(CmpFPredicate::OEQ, op.operand(), cstPosInf); + + // Filter out invalid values: + // • x == 0 -> -INF + // • x < 0 -> NAN + // • x == +INF -> +INF + Value aproximation = + builder.select(zeroMask, cstMinusInf, + builder.select(invalidMask, cstNan, + builder.select(posInfMask, cstPosInf, x))); + + rewriter.replaceOp(op, aproximation); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -1,15 +1,22 @@ // RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s -// CHECK-LABEL: @tanh_scalar -func @tanh_scalar(%arg0: f32) -> f32 { +// Check that all math functions lowered to approximations built from +// standard operations (add, mul, fma, shift, etc...). + +// CHECK-LABEL: @scalar +func @scalar(%arg0: f32) -> f32 { // CHECK-NOT: tanh %0 = math.tanh %arg0 : f32 - return %0 : f32 + // CHECK-NOT: log + %1 = math.log %0 : f32 + return %1 : f32 } -// CHECK-LABEL: @tanh_vector -func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> { +// CHECK-LABEL: @vector +func @vector(%arg0: vector<8xf32>) -> vector<8xf32> { // CHECK-NOT: tanh %0 = math.tanh %arg0 : vector<8xf32> - return %0 : vector<8xf32> + // CHECK-NOT: log + %1 = math.log %0 : vector<8xf32> + return %1 : vector<8xf32> } diff --git a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp --- a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp +++ b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -24,7 +25,8 @@ : public PassWrapper { void runOnFunction() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } }; } // end anonymous namespace diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -7,12 +7,10 @@ // RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s - -func @main() { - // ------------------------------------------------------------------------ // - // Tanh. - // ------------------------------------------------------------------------ // - +// ------------------------------------------------------------------------ // +// Tanh. +// ------------------------------------------------------------------------ // +func @tanh() { // CHECK: 0.848284 %0 = constant 1.25 : f32 %1 = math.tanh %0 : f32 @@ -30,3 +28,51 @@ return } + +// ------------------------------------------------------------------------ // +// Log. +// ------------------------------------------------------------------------ // +func @log() { + // CHECK: 2.64704 + %0 = constant 14.112233 : f32 + %1 = math.log %0 : f32 + vector.print %1 : f32 + + // CHECK: -1.38629, -0.287682, 0, 0.223144 + %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32> + %3 = math.log %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: -2.30259, -1.60944, -1.20397, -0.916291, -0.693147, -0.510826, -0.356675, -0.223144 + %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32> + %5 = math.log %4 : vector<8xf32> + vector.print %5 : vector<8xf32> + + // CHECK: -inf + %zero = constant 0.0 : f32 + %log_zero = math.log %zero : f32 + vector.print %log_zero : f32 + + // CHECK: nan + %neg_one = constant -1.0 : f32 + %log_neg_one = math.log %neg_one : f32 + vector.print %log_neg_one : f32 + + // CHECK: inf + %inf = constant 0x7f800000 : f32 + %log_inf = math.log %inf : f32 + vector.print %log_inf : f32 + + // CHECK: -inf, nan, inf, 0.693147 + %special_vec = constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32> + %log_special_vec = math.log %special_vec : vector<4xf32> + vector.print %log_special_vec : vector<4xf32> + + return +} + +func @main() { + call @tanh(): () -> () + call @log(): () -> () + return +}