Index: mlir/include/mlir/Dialect/Math/IR/MathOps.td =================================================================== --- mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -719,4 +719,51 @@ }]; } +//===----------------------------------------------------------------------===// +// FPowSIOp +//===----------------------------------------------------------------------===// + +def Math_FPowSIOp : Math_Op<"fpowsi", + [SameOperandsAndResultShape, + TypesMatchWith<"result type matches type of lhs", + "lhs", "result", "$_self">]> { + let summary = "floating point raised to the signed integer power"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.fpowsi` ssa-use `,` ssa-use `:` type + ``` + + The `fpowsi` operation takes a `base` operand of floating point type + (i.e. scalar, tensor or vector) and a `power` operand of integer type + (also scalar, tensor or vector) and returns one result of the same type + as `base`. The result is `base` raised to the power of `power`. + The operation is elementwise for non-scalars, e.g.: + ```mlir + %v = math.fpowsi %base, %power : + (vector<2xf32>, vector<2xi32) -> vector<2xf32> + ``` + The result is a vector of: + ``` + [, ] + ``` + + Example: + + ```mlir + // Scalar exponentiation. + %a = math.fpowsi %base, %power : (f64, i32) -> f64 + ``` + }]; + + let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs); + let results = (outs FloatLike:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, $result)"; + + // TODO: add a constant folder using pow[f] for cases, when + // the power argument is exactly representable in floating + // point type of the base. +} + #endif // MATH_OPS Index: mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp =================================================================== --- mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -113,29 +113,38 @@ } //----------------------------------------------------------------------------// -// IPowSIOp strength reduction. +// FPowSIOp/IPowSIOp strength reduction. //----------------------------------------------------------------------------// namespace { -struct IPowSIStrengthReduction : public OpRewritePattern { +template +struct PowSIStrengthReduction : public OpRewritePattern { + static_assert(std::is_same::value || + std::is_same::value, + "Only FPowSIOp and IPowSIOp are supported."); + constexpr static bool baseIsFP = std::is_same::value; + using DivOpTy = + typename std::conditional::type; + using MulOpTy = + typename std::conditional::type; + unsigned exponentThreshold; public: - using OpRewritePattern::OpRewritePattern; - - IPowSIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3, - PatternBenefit benefit = 1, - ArrayRef generatedNames = {}) - : OpRewritePattern(context, benefit, generatedNames), + PowSIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3, + PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) + : OpRewritePattern(context, benefit, generatedNames), exponentThreshold(exponentThreshold) {} - LogicalResult matchAndRewrite(math::IPowSIOp op, + + LogicalResult matchAndRewrite(PowSITy op, PatternRewriter &rewriter) const final; }; } // namespace -LogicalResult -IPowSIStrengthReduction::matchAndRewrite(math::IPowSIOp op, - PatternRewriter &rewriter) const { +template +LogicalResult PowSIStrengthReduction::matchAndRewrite( + PowSITy op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value base = op.getLhs(); @@ -155,14 +164,20 @@ // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&](Value value) -> Value { - if (auto vec = op.getType().dyn_cast()) + if (auto vec = op.getType().template dyn_cast()) return rewriter.create(loc, vec, value); return value; }; - if (exponentValue == 0) { - Value one = rewriter.create( + Value one; + if (baseIsFP) + one = rewriter.create( + loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); + else + one = rewriter.create( loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1)); + + if (exponentValue == 0) { rewriter.replaceOp(op, bcast(one)); return success(); } @@ -177,16 +192,13 @@ return failure(); // Inverse the base for negative exponent. - if (exponentIsNegative) { - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1)); - base = rewriter.create(loc, bcast(one), base); - } + if (exponentIsNegative) + base = rewriter.create(loc, bcast(one), base); Value result = base; // Transform to naive sequence of multiplications. for (unsigned i = 1; i < exponentValue; ++i) - result = rewriter.create(loc, result, base); + result = rewriter.create(loc, result, base); rewriter.replaceOp(op, result); return success(); @@ -197,5 +209,6 @@ void mlir::populateMathAlgebraicSimplificationPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); } Index: mlir/test/Dialect/Math/algebraic-simplification.mlir =================================================================== --- mlir/test/Dialect/Math/algebraic-simplification.mlir +++ mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -163,3 +163,93 @@ %3 = math.ipowsi %arg1, %vm1 : vector<4xi32> return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32> } + +// CHECK-LABEL: @fpowi_zero_exp( +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK-SAME: -> (f32, vector<4xf32>) { +func.func @fpowi_zero_exp(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: return %[[CST_S]], %[[CST_V]] + %c = arith.constant 0 : i32 + %v = arith.constant dense <0> : vector<4xi32> + %0 = math.fpowsi %arg0, %c : (f32, i32) -> f32 + %1 = math.fpowsi %arg1, %v : (vector<4xf32>, vector<4xi32>) -> vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @fpowi_exp_one( +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) { +func.func @fpowi_exp_one(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %[[ARG0]] + // CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %[[ARG1]] + // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]] + %c1 = arith.constant 1 : i32 + %v1 = arith.constant dense <1> : vector<4xi32> + %0 = math.fpowsi %arg0, %c1 : (f32, i32) -> f32 + %1 = math.fpowsi %arg1, %v1 : (vector<4xf32>, vector<4xi32>) -> vector<4xf32> + %cm1 = arith.constant -1 : i32 + %vm1 = arith.constant dense <-1> : vector<4xi32> + %2 = math.fpowsi %arg0, %cm1 : (f32, i32) -> f32 + %3 = math.fpowsi %arg1, %vm1 : (vector<4xf32>, vector<4xi32>) -> vector<4xf32> + return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32> +} + +// CHECK-LABEL: @fpowi_exp_two( +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) { +func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] + // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]] + // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]] + // CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]] + // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]] + // CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]] + %c1 = arith.constant 2 : i32 + %v1 = arith.constant dense <2> : vector<4xi32> + %0 = math.fpowsi %arg0, %c1 : (f32, i32) -> f32 + %1 = math.fpowsi %arg1, %v1 : (vector<4xf32>, vector<4xi32>) -> vector<4xf32> + %cm1 = arith.constant -2 : i32 + %vm1 = arith.constant dense <-2> : vector<4xi32> + %2 = math.fpowsi %arg0, %cm1 : (f32, i32) -> f32 + %3 = math.fpowsi %arg1, %vm1 : (vector<4xf32>, vector<4xi32>) -> vector<4xf32> + return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32> +} + +// CHECK-LABEL: @fpowi_exp_three( +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) { +func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: %[[SMUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] + // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]] + // CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]] + // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]] + // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]] + // CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]] + // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]] + // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]] + // CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]] + // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]] + %c1 = arith.constant 3 : i32 + %v1 = arith.constant dense <3> : vector<4xi32> + %0 = math.fpowsi %arg0, %c1 : (f32, i32) -> f32 + %1 = math.fpowsi %arg1, %v1 : (vector<4xf32>, vector<4xi32>) -> vector<4xf32> + %cm1 = arith.constant -3 : i32 + %vm1 = arith.constant dense <-3> : vector<4xi32> + %2 = math.fpowsi %arg0, %cm1 : (f32, i32) -> f32 + %3 = math.fpowsi %arg1, %vm1 : (vector<4xf32>, vector<4xi32>) -> vector<4xf32> + return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32> +} Index: mlir/test/Dialect/Math/ops.mlir =================================================================== --- mlir/test/Dialect/Math/ops.mlir +++ mlir/test/Dialect/Math/ops.mlir @@ -158,6 +158,20 @@ return } +// CHECK-LABEL: func @fpowsi( +// CHECK-SAME: %[[SB:.*]]: f32, %[[SP:.*]]: i32, +// CHECK-SAME: %[[VB:.*]]: vector<4xf64>, %[[VP:.*]]: vector<4xi16>, +// CHECK-SAME: %[[TB:.*]]: tensor<4x3x?xf16>, %[[TP:.*]]: tensor<4x3x?xi64>) { +func.func @fpowsi(%b: f32, %p: i32, %vb: vector<4xf64>, %vp: vector<4xi16>, %tb: tensor<4x3x?xf16>, %tp: tensor<4x3x?xi64>) { +// CHECK: {{.*}} = math.fpowsi %[[SB]], %[[SP]] : (f32, i32) -> f32 + %0 = math.fpowsi %b, %p : (f32, i32) -> f32 +// CHECK: {{.*}} = math.fpowsi %[[VB]], %[[VP]] : (vector<4xf64>, vector<4xi16>) -> vector<4xf64> + %1 = math.fpowsi %vb, %vp : (vector<4xf64>, vector<4xi16>) -> vector<4xf64> +// CHECK: {{.*}} = math.fpowsi %[[TB]], %[[TP]] : (tensor<4x3x?xf16>, tensor<4x3x?xi64>) -> tensor<4x3x?xf16> + %2 = math.fpowsi %tb, %tp : (tensor<4x3x?xf16>, tensor<4x3x?xi64>) -> tensor<4x3x?xf16> + return +} + // CHECK-LABEL: func @rsqrt( // CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) func.func @rsqrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {