Index: mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp =================================================================== --- mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -109,6 +109,15 @@ return success(); } + // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`. + if (isExponentValue(0.75)) { + Value pow_half = rewriter.create(op.getLoc(), x); + Value pow_quarter = rewriter.create(op.getLoc(), pow_half); + rewriter.replaceOpWithNewOp( + op, ValueRange{pow_half, pow_quarter}); + return success(); + } + return failure(); } Index: mlir/test/Dialect/Math/algebraic-simplification.mlir =================================================================== --- mlir/test/Dialect/Math/algebraic-simplification.mlir +++ mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -74,6 +74,22 @@ return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @pow_0_75 +func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0 + // CHECK: %[[SQRT2S:.*]] = math.sqrt %[[SQRT1S]] + // CHECK: %[[SCALAR:.*]] = arith.mulf %[[SQRT1S]], %[[SQRT2S]] + // CHECK: %[[SQRT1V:.*]] = math.sqrt %arg1 + // CHECK: %[[SQRT2V:.*]] = math.sqrt %[[SQRT1V]] + // CHECK: %[[VECTOR:.*]] = arith.mulf %[[SQRT1V]], %[[SQRT2V]] + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 0.75 : f32 + %v = arith.constant dense <0.75> : vector<4xf32> + %0 = math.powf %arg0, %c : f32 + %1 = math.powf %arg1, %v : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + // CHECK-LABEL: @ipowi_zero_exp( // CHECK-SAME: %[[ARG0:.+]]: i32 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>