diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -32,27 +32,28 @@ using TypePredicate = llvm::function_ref; -// Returns vector width if the element type is matching the predicate (scalars -// that do match the predicate have width equal to `1`). -static Optional vectorWidth(Type type, TypePredicate pred) { - // If the type matches the predicate then its width is `1`. +// Returns vector shape if the element type is matching the predicate (scalars +// that do match the predicate have shape equal to `{1}`). +static Optional> vectorShape(Type type, + TypePredicate pred) { + // If the type matches the predicate then its shape is `{1}`. if (pred(type)) - return 1; + return SmallVector{1}; // Otherwise check if the type is a vector type. auto vectorType = type.dyn_cast(); if (vectorType && pred(vectorType.getElementType())) { - assert(vectorType.getRank() == 1 && "only 1d vectors are supported"); - return vectorType.getDimSize(0); + return llvm::to_vector<2>(vectorType.getShape()); } return llvm::None; } -// Returns vector width of the type. If the type is a scalar returns `1`. -static int vectorWidth(Type type) { +// Returns vector shape of the type. If the type is a scalar returns `1`. +static SmallVector vectorShape(Type type) { auto vectorType = type.dyn_cast(); - return vectorType ? vectorType.getDimSize(0) : 1; + return vectorType ? llvm::to_vector<2>(vectorType.getShape()) + : SmallVector{1}; } // Returns vector element type. If the type is a scalar returns the argument. @@ -71,17 +72,24 @@ // Broadcast scalar types and values into vector types and values. //----------------------------------------------------------------------------// -// Broadcasts scalar type into vector type (iff width is greater then 1). -static Type broadcast(Type type, int width) { +// Returns true if shape != {1}. +static bool isNonScalarShape(ArrayRef shape) { + return shape.size() > 1 || shape[0] > 1; +} + +// Broadcasts scalar type into vector type (iff shape is non-scalar). +static Type broadcast(Type type, ArrayRef shape) { assert(!type.isa() && "must be scalar type"); - return width > 1 ? VectorType::get({width}, type) : type; + return isNonScalarShape(shape) ? VectorType::get(shape, type) : type; } -// Broadcasts scalar value into vector (iff width is greater then 1). -static Value broadcast(ImplicitLocOpBuilder &builder, Value value, int width) { +// Broadcasts scalar value into vector (iff shape is non-scalar). +static Value broadcast(ImplicitLocOpBuilder &builder, Value value, + ArrayRef shape) { assert(!value.getType().isa() && "must be scalar value"); - auto type = broadcast(value.getType(), width); - return width > 1 ? builder.create(type, value) : value; + auto type = broadcast(value.getType(), shape); + return isNonScalarShape(shape) ? builder.create(type, value) + : value; } //----------------------------------------------------------------------------// @@ -126,15 +134,15 @@ bool is_positive = false) { assert(isF32(elementType(arg.getType())) && "argument must be f32 type"); - int width = vectorWidth(arg.getType()); + auto shape = vectorShape(arg.getType()); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, width); + return broadcast(builder, value, shape); }; auto i32 = builder.getIntegerType(32); - auto i32Vec = broadcast(i32, width); - auto f32Vec = broadcast(builder.getF32Type(), width); + auto i32Vec = broadcast(i32, shape); + auto f32Vec = broadcast(builder.getF32Type(), shape); Value cst126f = f32Cst(builder, 126.0f); Value cstHalf = f32Cst(builder, 0.5f); @@ -167,13 +175,13 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { assert(isI32(elementType(arg.getType())) && "argument must be i32 type"); - int width = vectorWidth(arg.getType()); + auto shape = vectorShape(arg.getType()); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, width); + return broadcast(builder, value, shape); }; - auto f32Vec = broadcast(builder.getF32Type(), width); + auto f32Vec = broadcast(builder.getF32Type(), shape); // The exponent of f32 located at 23-bit. auto exponetBitLocation = bcast(i32Cst(builder, 23)); // Set the exponent bias to zero. @@ -222,13 +230,13 @@ LogicalResult TanhApproximation::matchAndRewrite(math::TanhOp op, PatternRewriter &rewriter) const { - auto width = vectorWidth(op.operand().getType(), isF32); - if (!width.hasValue()) + auto shape = vectorShape(op.operand().getType(), isF32); + if (!shape.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *width); + return broadcast(builder, value, *shape); }; // Clamp operand into [plusClamp, minusClamp] range. @@ -309,13 +317,13 @@ LogicalResult LogApproximationBase::logMatchAndRewrite(Op op, PatternRewriter &rewriter, bool base2) const { - auto width = vectorWidth(op.operand().getType(), isF32); - if (!width.hasValue()) + auto shape = vectorShape(op.operand().getType(), isF32); + if (!shape.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *width); + return broadcast(builder, value, *shape); }; Value cstZero = bcast(f32Cst(builder, 0.0f)); @@ -455,13 +463,13 @@ LogicalResult Log1pApproximation::matchAndRewrite(math::Log1pOp op, PatternRewriter &rewriter) const { - auto width = vectorWidth(op.operand().getType(), isF32); - if (!width.hasValue()) + auto shape = vectorShape(op.operand().getType(), isF32); + if (!shape.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *width); + return broadcast(builder, value, *shape); }; // Approximate log(1+x) using the following, due to W. Kahan: @@ -624,15 +632,15 @@ LogicalResult ExpApproximation::matchAndRewrite(math::ExpOp op, PatternRewriter &rewriter) const { - auto width = vectorWidth(op.operand().getType(), isF32); - if (!width.hasValue()) + auto shape = vectorShape(op.operand().getType(), isF32); + if (!shape.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); // TODO: Consider a common pattern rewriter with all methods below to // write the approximations. auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *width); + return broadcast(builder, value, *shape); }; auto fmla = [&](Value a, Value b, Value c) { return builder.create(a, b, c); @@ -675,7 +683,7 @@ Value expY = fmla(q1, y2, q0); expY = fmla(q2, y4, expY); - auto i32Vec = broadcast(builder.getI32Type(), *width); + auto i32Vec = broadcast(builder.getI32Type(), *shape); // exp2(k) Value k = builder.create(kF32, i32Vec); @@ -744,13 +752,13 @@ LogicalResult ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, PatternRewriter &rewriter) const { - auto width = vectorWidth(op.operand().getType(), isF32); - if (!width.hasValue()) + auto shape = vectorShape(op.operand().getType(), isF32); + if (!shape.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *width); + return broadcast(builder, value, *shape); }; // expm1(x) = exp(x) - 1 = u - 1. @@ -811,13 +819,13 @@ static_assert( llvm::is_one_of::value, "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); - auto width = vectorWidth(op.operand().getType(), isF32); - if (!width.hasValue()) + auto shape = vectorShape(op.operand().getType(), isF32); + if (!shape.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *width); + return broadcast(builder, value, *shape); }; auto mul = [&](Value a, Value b) -> Value { return builder.create(a, b); @@ -827,7 +835,7 @@ }; auto floor = [&](Value a) { return builder.create(a); }; - auto i32Vec = broadcast(builder.getI32Type(), *width); + auto i32Vec = broadcast(builder.getI32Type(), *shape); auto fPToSingedInteger = [&](Value a) -> Value { return builder.create(a, i32Vec); }; @@ -933,14 +941,14 @@ LogicalResult RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, PatternRewriter &rewriter) const { - auto width = vectorWidth(op.operand().getType(), isF32); + auto shape = vectorShape(op.operand().getType(), isF32); // Only support already-vectorized rsqrt's. - if (!width.hasValue() || *width != 8) + if (!shape.hasValue() || (*shape)[0] != 8) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *width); + return broadcast(builder, value, *shape); }; Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -183,8 +183,8 @@ } // CHECK-LABEL: func @expm1_vector( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8xf32> +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32> // CHECK-NOT: exp // CHECK-COUNT-4: select // CHECK-NOT: log @@ -192,11 +192,11 @@ // CHECK-NOT: expm1 // CHECK-COUNT-3: select // CHECK: %[[VAL_115:.*]] = select -// CHECK: return %[[VAL_115]] : vector<8xf32> +// CHECK: return %[[VAL_115]] : vector<8x8xf32> // CHECK: } -func @expm1_vector(%arg0: vector<8xf32>) -> vector<8xf32> { - %0 = math.expm1 %arg0 : vector<8xf32> - return %0 : vector<8xf32> +func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { + %0 = math.expm1 %arg0 : vector<8x8xf32> + return %0 : vector<8x8xf32> } // CHECK-LABEL: func @log_scalar(