diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -17,7 +17,13 @@ void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); -void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns); +struct MathPolynomialApproximationOptions { + bool avx2 = false; +}; + +void populateMathPolynomialApproximationPatterns( + RewritePatternSet &patterns, + const MathPolynomialApproximationOptions &options = {}); } // namespace mlir diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -13,4 +13,5 @@ MLIRPass MLIRStandard MLIRTransforms + MLIRX86Vector ) diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/Bufferize.h" @@ -773,13 +774,86 @@ return success(); } +//----------------------------------------------------------------------------// +// Rsqrt approximation. +//----------------------------------------------------------------------------// + +namespace { +struct RsqrtApproximation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::RsqrtOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, + PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + // Only support already-vectorized rsqrt's. + if (!width.hasValue() || *width != 8) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + + Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); + Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); + Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); + Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); + + Value neg_half = builder.create(op.operand(), cstNegHalf); + + // Select only the inverse sqrt of positive normals (denormals are + // flushed to zero). + Value lt_min_mask = builder.create( + arith::CmpFPredicate::OLT, op.operand(), cstMinNormPos); + Value inf_mask = builder.create(arith::CmpFPredicate::OEQ, + op.operand(), cstPosInf); + Value not_normal_finite_mask = + builder.create(lt_min_mask, inf_mask); + + // Compute an approximate result. + // Value sqrt = builder.create(op.operand()); + // Value cstOne = bcast(f32Cst(builder, 1.0)); + // auto fmf = mlir::LLVM::FMFAttr::get(op.getContext(), + // LLVM::FastmathFlags::arcp); Value y_approx = + // builder.create(cstOne, sqrt, fmf); + Value y_approx = builder.create(op.operand()); + + // Do a single step of Newton-Raphson iteration to improve the approximation. + // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). + // It is essential to evaluate the inner term like this because forming + // y_n^2 may over- or underflow. + Value inner = builder.create(neg_half, y_approx); + Value fma = builder.create(y_approx, inner, cstOnePointFive); + Value y_newton = builder.create(y_approx, fma); + + // Select the result of the Newton-Raphson step for positive normal arguments. + // For other arguments, choose the output of the intrinsic. This will + // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if + // x is zero or a positive denormalized float (equivalent to flushing positive + // denormalized inputs to zero). + Value res = + builder.create(not_normal_finite_mask, y_approx, y_newton); + rewriter.replaceOp(op, res); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, + const MathPolynomialApproximationOptions &options) { patterns.add, SinAndCosApproximation>( patterns.getContext()); + if (options.avx2) + patterns.add(patterns.getContext()); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation-avx2.mlir b/mlir/test/Dialect/Math/polynomial-approximation-avx2.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Math/polynomial-approximation-avx2.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s -test-math-polynomial-approximation=avx2 | FileCheck %s + +// We only approximate rsqrt for vectors. +// CHECK-LABEL: func @rsqrt_scalar +// CHECK: math.rsqrt +func @rsqrt_scalar(%arg0: f32) -> f32 { + %0 = math.rsqrt %arg0 : f32 + return %0 : f32 +} + +// CHECK-LABEL: func @rsqrt_vector( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<-5.000000e-01> : vector<8xf32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<1.17549435E-38> : vector<8xf32> +// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_0]], %[[VAL_3]] : vector<8xf32> +// CHECK: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32> +// CHECK: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32> +// CHECK: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1> +// CHECK: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32> +// CHECK: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32> +// CHECK: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32> +// CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32> +// CHECK: %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32> +// CHECK: return %[[VAL_13]] : vector<8xf32> +// CHECK: } +func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> { + %0 = math.rsqrt %arg0 : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -300,3 +300,19 @@ %0 = math.tanh %arg0 : vector<8xf32> return %0 : vector<8xf32> } + +// We only approximate rsqrt for vectors and when the AVX2 option is enabled. +// CHECK-LABEL: func @rsqrt_scalar +// CHECK: math.rsqrt +func @rsqrt_scalar(%arg0: f32) -> f32 { + %0 = math.rsqrt %arg0 : f32 + return %0 : f32 +} + +// We only approximate rsqrt for vectors and when the AVX2 option is enabled. +// CHECK-LABEL: func @rsqrt_vector +// CHECK: math.rsqrt +func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> { + %0 = math.rsqrt %arg0 : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt --- a/mlir/test/lib/Dialect/Math/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt @@ -11,4 +11,5 @@ MLIRPass MLIRTransformUtils MLIRVector + MLIRX86Vector ) diff --git a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp --- a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp +++ b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -23,10 +24,16 @@ namespace { struct TestMathPolynomialApproximationPass : public PassWrapper { + TestMathPolynomialApproximationPass() = default; + TestMathPolynomialApproximationPass( + const TestMathPolynomialApproximationPass &pass) {} + void runOnFunction() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + if (avx2) + registry.insert(); } StringRef getArgument() const final { return "test-math-polynomial-approximation"; @@ -34,12 +41,20 @@ StringRef getDescription() const final { return "Test math polynomial approximations"; } + + Option avx2{ + *this, "avx2", + llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the " + "X86Vector dialect"), + llvm::cl::init(false)}; }; } // end anonymous namespace void TestMathPolynomialApproximationPass::runOnFunction() { RewritePatternSet patterns(&getContext()); - populateMathPolynomialApproximationPatterns(patterns); + MathPolynomialApproximationOptions approx_options; + approx_options.avx2 = avx2; + populateMathPolynomialApproximationPatterns(patterns, approx_options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg b/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg @@ -0,0 +1,5 @@ +import sys + +# X86Vector tests must be enabled via build flag. +if not config.mlir_run_x86vector_tests: + config.unsupported = True \ No newline at end of file diff --git a/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir b/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -test-math-polynomial-approximation="avx2" \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-vector-to-llvm="enable-x86vector" \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +// -------------------------------------------------------------------------- // +// rsqrt. +// -------------------------------------------------------------------------- // + +func @rsqrt() { + // Sanity-check that the scalar rsqrt still works OK. + // CHECK: inf + %0 = arith.constant 0.0 : f32 + %rsqrt_0 = math.rsqrt %0 : f32 + vector.print %rsqrt_0 : f32 + // CHECK: 0.707107 + %two = arith.constant 2.0: f32 + %rsqrt_two = math.rsqrt %two : f32 + vector.print %rsqrt_two : f32 + + // Check that the vectorized approximation is reasonably accurate. + // CHECK: 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107 + %vec8 = arith.constant dense<2.0> : vector<8xf32> + %rsqrt_vec8 = math.rsqrt %vec8 : vector<8xf32> + vector.print %rsqrt_vec8 : vector<8xf32> + + return +} + +func @main() { + call @rsqrt(): () -> () + return +} \ No newline at end of file diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7057,6 +7057,7 @@ ":Support", ":Transforms", ":VectorOps", + ":X86Vector", "//llvm:Support", ], ) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -406,6 +406,7 @@ "//mlir:Pass", "//mlir:TransformUtils", "//mlir:VectorOps", + "//mlir:X86Vector", ], )