diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -112,9 +112,100 @@ return failure(); } +//----------------------------------------------------------------------------// +// IPowIOp strength reduction. +//----------------------------------------------------------------------------// + +namespace { +struct IPowIStrengthReduction : public OpRewritePattern { + unsigned exponentThreshold; + +public: + IPowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3, + PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) + : OpRewritePattern(context, benefit, generatedNames), + exponentThreshold(exponentThreshold) {} + LogicalResult matchAndRewrite(math::IPowIOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value base = op.getLhs(); + + IntegerAttr scalarExponent; + DenseIntElementsAttr vectorExponent; + + bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent)); + bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent)); + + // Simplify cases with known exponent value. + int64_t exponentValue = 0; + if (isScalar) + exponentValue = scalarExponent.getInt(); + else if (isVector && vectorExponent.isSplat()) + exponentValue = vectorExponent.getSplatValue().getInt(); + else + return failure(); + + // Maybe broadcasts scalar value into vector type compatible with `op`. + auto bcast = [&](Value value) -> Value { + if (auto vec = op.getType().dyn_cast()) + return rewriter.create(loc, vec, value); + return value; + }; + + if (exponentValue == 0) { + // Replace `ipowi(x, 0)` with `1`. + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1)); + rewriter.replaceOp(op, bcast(one)); + return success(); + } + + bool exponentIsNegative = false; + if (exponentValue < 0) { + exponentIsNegative = true; + exponentValue *= -1; + } + + // Bail out if `abs(exponent)` exceeds the threshold. + if (exponentValue > exponentThreshold) + return failure(); + + // Inverse the base for negative exponent, i.e. for + // `ipowi(x, negative_exponent)` set `x` to `1 / x`. + if (exponentIsNegative) { + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1)); + base = rewriter.create(loc, bcast(one), base); + } + + Value result = base; + // Transform to naive sequence of multiplications: + // * For positive exponent case replace: + // `ipowi(x, positive_exponent)` + // with: + // x * x * x * ... + // * For negative exponent case replace: + // `ipowi(x, negative_exponent)` + // with: + // (1 / x) * (1 / x) * (1 / x) * ... + for (unsigned i = 1; i < exponentValue; ++i) + result = rewriter.create(loc, result, base); + + rewriter.replaceOp(op, result); + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathAlgebraicSimplificationPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir --- a/mlir/test/Dialect/Math/algebraic-simplification.mlir +++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -73,3 +73,93 @@ %1 = math.powf %arg1, %v : vector<4xf32> return %0, %1 : f32, vector<4xf32> } + +// CHECK-LABEL: @ipowi_zero_exp( +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> +// CHECK-SAME: -> (i32, vector<4xi32>) { +func.func @ipowi_zero_exp(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // CHECK: return %[[CST_S]], %[[CST_V]] + %c = arith.constant 0 : i32 + %v = arith.constant dense <0> : vector<4xi32> + %0 = math.ipowi %arg0, %c : i32 + %1 = math.ipowi %arg1, %v : vector<4xi32> + return %0, %1 : i32, vector<4xi32> +} + +// CHECK-LABEL: @ipowi_exp_one( +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> +// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { +func.func @ipowi_exp_one(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // CHECK: %[[SCALAR:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] + // CHECK: %[[VECTOR:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] + // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]] + %c1 = arith.constant 1 : i32 + %v1 = arith.constant dense <1> : vector<4xi32> + %0 = math.ipowi %arg0, %c1 : i32 + %1 = math.ipowi %arg1, %v1 : vector<4xi32> + %cm1 = arith.constant -1 : i32 + %vm1 = arith.constant dense <-1> : vector<4xi32> + %2 = math.ipowi %arg0, %cm1 : i32 + %3 = math.ipowi %arg1, %vm1 : vector<4xi32> + return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32> +} + +// CHECK-LABEL: @ipowi_exp_two( +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> +// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { +func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]] + // CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]] + // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] + // CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]] + // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] + // CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]] + %c1 = arith.constant 2 : i32 + %v1 = arith.constant dense <2> : vector<4xi32> + %0 = math.ipowi %arg0, %c1 : i32 + %1 = math.ipowi %arg1, %v1 : vector<4xi32> + %cm1 = arith.constant -2 : i32 + %vm1 = arith.constant dense <-2> : vector<4xi32> + %2 = math.ipowi %arg0, %cm1 : i32 + %3 = math.ipowi %arg1, %vm1 : vector<4xi32> + return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32> +} + +// CHECK-LABEL: @ipowi_exp_three( +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> +// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { +func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // CHECK: %[[SMUL0:.*]] = arith.muli %[[ARG0]], %[[ARG0]] + // CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]] + // CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]] + // CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]] + // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] + // CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]] + // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]] + // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] + // CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]] + // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]] + %c1 = arith.constant 3 : i32 + %v1 = arith.constant dense <3> : vector<4xi32> + %0 = math.ipowi %arg0, %c1 : i32 + %1 = math.ipowi %arg1, %v1 : vector<4xi32> + %cm1 = arith.constant -3 : i32 + %vm1 = arith.constant dense <-3> : vector<4xi32> + %2 = math.ipowi %arg0, %cm1 : i32 + %3 = math.ipowi %arg1, %vm1 : vector<4xi32> + return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32> +}