diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -11,6 +11,9 @@ // //===----------------------------------------------------------------------===// +#include +#include + #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" @@ -20,68 +23,35 @@ #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" -#include -#include using namespace mlir; using namespace mlir::math; using namespace mlir::vector; -using TypePredicate = llvm::function_ref; - -// Returns vector shape if the element type is matching the predicate (scalars -// that do match the predicate have shape equal to `{1}`). -static Optional> vectorShape(Type type, - TypePredicate pred) { - // If the type matches the predicate then its shape is `{1}`. - if (pred(type)) - return SmallVector{1}; - - // Otherwise check if the type is a vector type. - auto vectorType = type.dyn_cast(); - if (vectorType && pred(vectorType.getElementType())) { - return llvm::to_vector<2>(vectorType.getShape()); - } - - return llvm::None; -} - -// Returns vector shape of the type. If the type is a scalar returns `1`. -static SmallVector vectorShape(Type type) { - auto vectorType = type.dyn_cast(); - return vectorType ? llvm::to_vector<2>(vectorType.getShape()) - : SmallVector{1}; -} - -// Returns vector element type. If the type is a scalar returns the argument. -LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) { +// Returns vector shape if the type is a vector. Returns an empty shape if it is +// not a vector. +static ArrayRef vectorShape(Type type) { auto vectorType = type.dyn_cast(); - return vectorType ? vectorType.getElementType() : type; + return vectorType ? vectorType.getShape() : ArrayRef(); } -LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); } - -LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) { - return type.isInteger(32); +static ArrayRef vectorShape(Value value) { + return vectorShape(value.getType()); } //----------------------------------------------------------------------------// // Broadcast scalar types and values into vector types and values. //----------------------------------------------------------------------------// -// Returns true if shape != {1}. -static bool isNonScalarShape(ArrayRef shape) { - return shape.size() > 1 || shape[0] > 1; -} - // Broadcasts scalar type into vector type (iff shape is non-scalar). static Type broadcast(Type type, ArrayRef shape) { assert(!type.isa() && "must be scalar type"); - return isNonScalarShape(shape) ? VectorType::get(shape, type) : type; + return !shape.empty() ? VectorType::get(shape, type) : type; } // Broadcasts scalar value into vector (iff shape is non-scalar). @@ -89,8 +59,7 @@ ArrayRef shape) { assert(!value.getType().isa() && "must be scalar value"); auto type = broadcast(value.getType(), shape); - return isNonScalarShape(shape) ? builder.create(type, value) - : value; + return !shape.empty() ? builder.create(type, value) : value; } //----------------------------------------------------------------------------// @@ -228,9 +197,8 @@ // an integral power of two (see std::frexp). Returned values have float type. static std::pair frexp(ImplicitLocOpBuilder &builder, Value arg, bool is_positive = false) { - assert(isF32(elementType(arg.getType())) && "argument must be f32 type"); - - auto shape = vectorShape(arg.getType()); + assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); + ArrayRef shape = vectorShape(arg); auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); @@ -269,9 +237,8 @@ // Computes exp2 for an i32 argument. static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { - assert(isI32(elementType(arg.getType())) && "argument must be i32 type"); - - auto shape = vectorShape(arg.getType()); + assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type"); + ArrayRef shape = vectorShape(arg); auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); @@ -294,12 +261,15 @@ namespace { Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, llvm::ArrayRef coeffs, Value x) { - auto shape = vectorShape(x.getType(), isF32); - if (coeffs.size() == 0) { - return broadcast(builder, f32Cst(builder, 0.0f), *shape); - } else if (coeffs.size() == 1) { + assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type"); + ArrayRef shape = vectorShape(x); + + if (coeffs.empty()) + return broadcast(builder, f32Cst(builder, 0.0f), shape); + + if (coeffs.size() == 1) return coeffs[0]; - } + Value res = builder.create(x, coeffs[coeffs.size() - 1], coeffs[coeffs.size() - 2]); for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { @@ -326,13 +296,14 @@ LogicalResult TanhApproximation::matchAndRewrite(math::TanhOp op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.operand().getType(), isF32); - if (!shape.hasValue()) + if (!getElementTypeOrSelf(op.operand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); + ArrayRef shape = vectorShape(op.operand()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); + return broadcast(builder, value, shape); }; // Clamp operand into [plusClamp, minusClamp] range. @@ -413,13 +384,14 @@ LogicalResult LogApproximationBase::logMatchAndRewrite(Op op, PatternRewriter &rewriter, bool base2) const { - auto shape = vectorShape(op.operand().getType(), isF32); - if (!shape.hasValue()) + if (!getElementTypeOrSelf(op.operand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); + ArrayRef shape = vectorShape(op.operand()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); + return broadcast(builder, value, shape); }; Value cstZero = bcast(f32Cst(builder, 0.0f)); @@ -559,13 +531,14 @@ LogicalResult Log1pApproximation::matchAndRewrite(math::Log1pOp op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.operand().getType(), isF32); - if (!shape.hasValue()) + if (!getElementTypeOrSelf(op.operand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); + ArrayRef shape = vectorShape(op.operand()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); + return broadcast(builder, value, shape); }; // Approximate log(1+x) using the following, due to W. Kahan: @@ -605,13 +578,14 @@ LogicalResult ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.operand().getType(), isF32); - if (!shape.hasValue()) + if (!getElementTypeOrSelf(op.operand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); + ArrayRef shape = vectorShape(op.operand()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); + return broadcast(builder, value, shape); }; const int intervalsCount = 3; @@ -728,15 +702,17 @@ LogicalResult ExpApproximation::matchAndRewrite(math::ExpOp op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.operand().getType(), isF32); - if (!shape.hasValue()) + if (!getElementTypeOrSelf(op.operand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ArrayRef shape = vectorShape(op.operand()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); // TODO: Consider a common pattern rewriter with all methods below to // write the approximations. auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); + return broadcast(builder, value, shape); }; auto fmla = [&](Value a, Value b, Value c) { return builder.create(a, b, c); @@ -779,7 +755,7 @@ Value expY = fmla(q1, y2, q0); expY = fmla(q2, y4, expY); - auto i32Vec = broadcast(builder.getI32Type(), *shape); + auto i32Vec = broadcast(builder.getI32Type(), shape); // exp2(k) Value k = builder.create(kF32, i32Vec); @@ -848,13 +824,14 @@ LogicalResult ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.operand().getType(), isF32); - if (!shape.hasValue()) + if (!getElementTypeOrSelf(op.operand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); + ArrayRef shape = vectorShape(op.operand()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); + return broadcast(builder, value, shape); }; // expm1(x) = exp(x) - 1 = u - 1. @@ -915,13 +892,15 @@ static_assert( llvm::is_one_of::value, "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); - auto shape = vectorShape(op.operand().getType(), isF32); - if (!shape.hasValue()) + + if (!getElementTypeOrSelf(op.operand()).isF32()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); + ArrayRef shape = vectorShape(op.operand()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); + return broadcast(builder, value, shape); }; auto mul = [&](Value a, Value b) -> Value { return builder.create(a, b); @@ -931,7 +910,7 @@ }; auto floor = [&](Value a) { return builder.create(a); }; - auto i32Vec = broadcast(builder.getI32Type(), *shape); + auto i32Vec = broadcast(builder.getI32Type(), shape); auto fPToSingedInteger = [&](Value a) -> Value { return builder.create(a, i32Vec); }; @@ -1037,14 +1016,18 @@ LogicalResult RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.operand().getType(), isF32); + if (!getElementTypeOrSelf(op.operand()).isF32()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ArrayRef shape = vectorShape(op.operand()); + // Only support already-vectorized rsqrt's. - if (!shape.hasValue() || shape->back() % 8 != 0) + if (shape.empty() || shape.back() % 8 != 0) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); + return broadcast(builder, value, shape); }; Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));