Index: mlir/include/mlir/Dialect/Math/IR/MathOps.td =================================================================== --- mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -770,4 +770,49 @@ }]; } +//===----------------------------------------------------------------------===// +// FPowIOp +//===----------------------------------------------------------------------===// + +def Math_FPowIOp : Math_Op<"fpowi", + [SameOperandsAndResultShape, AllTypesMatch<["lhs", "result"]>]> { + let summary = "floating point raised to the signed integer power"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.fpowi` ssa-use `,` ssa-use `:` type + ``` + + The `fpowi` operation takes a `base` operand of floating point type + (i.e. scalar, tensor or vector) and a `power` operand of integer type + (also scalar, tensor or vector) and returns one result of the same type + as `base`. The result is `base` raised to the power of `power`. + The operation is elementwise for non-scalars, e.g.: + ```mlir + %v = math.fpowi %base, %power : + (vector<2xf32>, vector<2xi32) -> vector<2xf32> + ``` + The result is a vector of: + ``` + [, ] + ``` + + Example: + + ```mlir + // Scalar exponentiation. + %a = math.fpowi %base, %power : (f64, i32) -> f64 + ``` + }]; + + let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs); + let results = (outs FloatLike:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + + // TODO: add a constant folder using pow[f] for cases, when + // the power argument is exactly representable in floating + // point type of the base. +} + #endif // MATH_OPS Index: mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp =================================================================== --- mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -113,27 +113,31 @@ } //----------------------------------------------------------------------------// -// IPowIOp strength reduction. +// FPowIOp/IPowIOp strength reduction. //----------------------------------------------------------------------------// namespace { -struct IPowIStrengthReduction : public OpRewritePattern { +template +struct PowIStrengthReduction : public OpRewritePattern { + unsigned exponentThreshold; public: - IPowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3, - PatternBenefit benefit = 1, - ArrayRef generatedNames = {}) - : OpRewritePattern(context, benefit, generatedNames), + PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3, + PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) + : OpRewritePattern(context, benefit, generatedNames), exponentThreshold(exponentThreshold) {} - LogicalResult matchAndRewrite(math::IPowIOp op, + + LogicalResult matchAndRewrite(PowIOpTy op, PatternRewriter &rewriter) const final; }; } // namespace +template LogicalResult -IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op, - PatternRewriter &rewriter) const { +PowIStrengthReduction::matchAndRewrite( + PowIOpTy op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value base = op.getLhs(); @@ -153,16 +157,23 @@ return failure(); // Maybe broadcasts scalar value into vector type compatible with `op`. - auto bcast = [&](Value value) -> Value { - if (auto vec = op.getType().dyn_cast()) + auto bcast = [&loc, &op, &rewriter](Value value) -> Value { + if (auto vec = op.getType().template dyn_cast()) return rewriter.create(loc, vec, value); return value; }; + Value one; + Type opType = getElementTypeOrSelf(op.getType()); + if constexpr (std::is_same_v) + one = rewriter.create( + loc, rewriter.getFloatAttr(opType, 1.0)); + else + one = rewriter.create( + loc, rewriter.getIntegerAttr(opType, 1)); + + // Replace `[fi]powi(x, 0)` with `1`. if (exponentValue == 0) { - // Replace `ipowi(x, 0)` with `1`. - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1)); rewriter.replaceOp(op, bcast(one)); return success(); } @@ -178,25 +189,22 @@ return failure(); // Inverse the base for negative exponent, i.e. for - // `ipowi(x, negative_exponent)` set `x` to `1 / x`. - if (exponentIsNegative) { - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1)); - base = rewriter.create(loc, bcast(one), base); - } + // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. + if (exponentIsNegative) + base = rewriter.create(loc, bcast(one), base); Value result = base; // Transform to naive sequence of multiplications: // * For positive exponent case replace: - // `ipowi(x, positive_exponent)` + // `[fi]powi(x, positive_exponent)` // with: // x * x * x * ... // * For negative exponent case replace: - // `ipowi(x, negative_exponent)` + // `[fi]powi(x, negative_exponent)` // with: // (1 / x) * (1 / x) * (1 / x) * ... for (unsigned i = 1; i < exponentValue; ++i) - result = rewriter.create(loc, result, base); + result = rewriter.create(loc, result, base); rewriter.replaceOp(op, result); return success(); @@ -206,6 +214,9 @@ void mlir::populateMathAlgebraicSimplificationPatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add, + PowIStrengthReduction>( + patterns.getContext()); } Index: mlir/test/Dialect/Math/algebraic-simplification.mlir =================================================================== --- mlir/test/Dialect/Math/algebraic-simplification.mlir +++ mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -163,3 +163,93 @@ %3 = math.ipowi %arg1, %vm1 : vector<4xi32> return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32> } + +// CHECK-LABEL: @fpowi_zero_exp( +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK-SAME: -> (f32, vector<4xf32>) { +func.func @fpowi_zero_exp(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: return %[[CST_S]], %[[CST_V]] + %c = arith.constant 0 : i32 + %v = arith.constant dense <0> : vector<4xi32> + %0 = math.fpowi %arg0, %c : f32, i32 + %1 = math.fpowi %arg1, %v : vector<4xf32>, vector<4xi32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @fpowi_exp_one( +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) { +func.func @fpowi_exp_one(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %[[ARG0]] + // CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %[[ARG1]] + // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]] + %c1 = arith.constant 1 : i32 + %v1 = arith.constant dense <1> : vector<4xi32> + %0 = math.fpowi %arg0, %c1 : f32, i32 + %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32> + %cm1 = arith.constant -1 : i32 + %vm1 = arith.constant dense <-1> : vector<4xi32> + %2 = math.fpowi %arg0, %cm1 : f32, i32 + %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32> + return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32> +} + +// CHECK-LABEL: @fpowi_exp_two( +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) { +func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] + // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]] + // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]] + // CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]] + // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]] + // CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]] + %c1 = arith.constant 2 : i32 + %v1 = arith.constant dense <2> : vector<4xi32> + %0 = math.fpowi %arg0, %c1 : f32, i32 + %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32> + %cm1 = arith.constant -2 : i32 + %vm1 = arith.constant dense <-2> : vector<4xi32> + %2 = math.fpowi %arg0, %cm1 : f32, i32 + %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32> + return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32> +} + +// CHECK-LABEL: @fpowi_exp_three( +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK-SAME: -> (f32, vector<4xf32>, f32, vector<4xf32>) { +func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32>, f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: %[[SMUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] + // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]] + // CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]] + // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]] + // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]] + // CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]] + // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]] + // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]] + // CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]] + // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]] + %c1 = arith.constant 3 : i32 + %v1 = arith.constant dense <3> : vector<4xi32> + %0 = math.fpowi %arg0, %c1 : f32, i32 + %1 = math.fpowi %arg1, %v1 : vector<4xf32>, vector<4xi32> + %cm1 = arith.constant -3 : i32 + %vm1 = arith.constant dense <-3> : vector<4xi32> + %2 = math.fpowi %arg0, %cm1 : f32, i32 + %3 = math.fpowi %arg1, %vm1 : vector<4xf32>, vector<4xi32> + return %0, %1, %2, %3 : f32, vector<4xf32>, f32, vector<4xf32> +} Index: mlir/test/Dialect/Math/ops.mlir =================================================================== --- mlir/test/Dialect/Math/ops.mlir +++ mlir/test/Dialect/Math/ops.mlir @@ -158,6 +158,20 @@ return } +// CHECK-LABEL: func @fpowi( +// CHECK-SAME: %[[SB:.*]]: f32, %[[SP:.*]]: i32, +// CHECK-SAME: %[[VB:.*]]: vector<4xf64>, %[[VP:.*]]: vector<4xi16>, +// CHECK-SAME: %[[TB:.*]]: tensor<4x3x?xf16>, %[[TP:.*]]: tensor<4x3x?xi64>) { +func.func @fpowi(%b: f32, %p: i32, %vb: vector<4xf64>, %vp: vector<4xi16>, %tb: tensor<4x3x?xf16>, %tp: tensor<4x3x?xi64>) { +// CHECK: {{.*}} = math.fpowi %[[SB]], %[[SP]] : f32, i32 + %0 = math.fpowi %b, %p : f32, i32 +// CHECK: {{.*}} = math.fpowi %[[VB]], %[[VP]] : vector<4xf64>, vector<4xi16> + %1 = math.fpowi %vb, %vp : vector<4xf64>, vector<4xi16> +// CHECK: {{.*}} = math.fpowi %[[TB]], %[[TP]] : tensor<4x3x?xf16>, tensor<4x3x?xi64> + %2 = math.fpowi %tb, %tp : tensor<4x3x?xf16>, tensor<4x3x?xi64> + return +} + // CHECK-LABEL: func @rsqrt( // CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) func.func @rsqrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {