diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -17,7 +17,14 @@ void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); -void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns); +struct MathPolynomialApproximationOptions { + // Enables the use of AVX2 intrinsics in some of the approximations. + bool enableAvx2 = false; +}; + +void populateMathPolynomialApproximationPatterns( + RewritePatternSet &patterns, + const MathPolynomialApproximationOptions &options = {}); } // namespace mlir diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -13,4 +13,5 @@ MLIRPass MLIRStandard MLIRTransforms + MLIRX86Vector ) diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/Bufferize.h" @@ -778,13 +779,79 @@ return success(); } +//----------------------------------------------------------------------------// +// Rsqrt approximation. +//----------------------------------------------------------------------------// + +namespace { +struct RsqrtApproximation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::RsqrtOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, + PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + // Only support already-vectorized rsqrt's. + if (!width.hasValue() || *width != 8) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + + Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); + Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); + Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); + Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); + + Value negHalf = builder.create(op.operand(), cstNegHalf); + + // Select only the inverse sqrt of positive normals (denormals are + // flushed to zero). + Value ltMinMask = builder.create(arith::CmpFPredicate::OLT, + op.operand(), cstMinNormPos); + Value infMask = builder.create(arith::CmpFPredicate::OEQ, + op.operand(), cstPosInf); + Value notNormalFiniteMask = builder.create(ltMinMask, infMask); + + // Compute an approximate result. + Value yApprox = builder.create(op.operand()); + + // Do a single step of Newton-Raphson iteration to improve the approximation. + // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). + // It is essential to evaluate the inner term like this because forming + // y_n^2 may over- or underflow. + Value inner = builder.create(negHalf, yApprox); + Value fma = builder.create(yApprox, inner, cstOnePointFive); + Value yNewton = builder.create(yApprox, fma); + + // Select the result of the Newton-Raphson step for positive normal arguments. + // For other arguments, choose the output of the intrinsic. This will + // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if + // x is zero or a positive denormalized float (equivalent to flushing positive + // denormalized inputs to zero). + Value res = builder.create(notNormalFiniteMask, yApprox, yNewton); + rewriter.replaceOp(op, res); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, + const MathPolynomialApproximationOptions &options) { patterns.add, SinAndCosApproximation>( patterns.getContext()); + if (options.enableAvx2) + patterns.add(patterns.getContext()); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -1,4 +1,6 @@ // RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s +// RUN: mlir-opt %s -test-math-polynomial-approximation=enable-avx2 \ +// RUN: | FileCheck --check-prefix=AVX2 %s // Check that all math functions lowered to approximations built from // standard operations (add, mul, fma, shift, etc...). @@ -300,3 +302,37 @@ %0 = math.tanh %arg0 : vector<8xf32> return %0 : vector<8xf32> } + +// We only approximate rsqrt for vectors and when the AVX2 option is enabled. +// CHECK-LABEL: func @rsqrt_scalar +// AVX2-LABEL: func @rsqrt_scalar +// CHECK: math.rsqrt +// AVX2: math.rsqrt +func @rsqrt_scalar(%arg0: f32) -> f32 { + %0 = math.rsqrt %arg0 : f32 + return %0 : f32 +} + +// CHECK-LABEL: func @rsqrt_vector +// CHECK: math.rsqrt +// AVX2-LABEL: func @rsqrt_vector( +// AVX2-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { +// AVX2: %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32> +// AVX2: %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32> +// AVX2: %[[VAL_3:.*]] = arith.constant dense<-5.000000e-01> : vector<8xf32> +// AVX2: %[[VAL_4:.*]] = arith.constant dense<1.17549435E-38> : vector<8xf32> +// AVX2: %[[VAL_5:.*]] = arith.mulf %[[VAL_0]], %[[VAL_3]] : vector<8xf32> +// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32> +// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32> +// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1> +// AVX2: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32> +// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32> +// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32> +// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32> +// AVX2: %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32> +// AVX2: return %[[VAL_13]] : vector<8xf32> +// AVX2: } +func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> { + %0 = math.rsqrt %arg0 : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt --- a/mlir/test/lib/Dialect/Math/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt @@ -11,4 +11,5 @@ MLIRPass MLIRTransformUtils MLIRVector + MLIRX86Vector ) diff --git a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp --- a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp +++ b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -23,10 +24,16 @@ namespace { struct TestMathPolynomialApproximationPass : public PassWrapper { + TestMathPolynomialApproximationPass() = default; + TestMathPolynomialApproximationPass( + const TestMathPolynomialApproximationPass &pass) {} + void runOnFunction() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + if (enableAvx2) + registry.insert(); } StringRef getArgument() const final { return "test-math-polynomial-approximation"; @@ -34,12 +41,20 @@ StringRef getDescription() const final { return "Test math polynomial approximations"; } + + Option enableAvx2{ + *this, "enable-avx2", + llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the " + "X86Vector dialect"), + llvm::cl::init(false)}; }; } // end anonymous namespace void TestMathPolynomialApproximationPass::runOnFunction() { RewritePatternSet patterns(&getContext()); - populateMathPolynomialApproximationPatterns(patterns); + MathPolynomialApproximationOptions approx_options; + approx_options.enableAvx2 = enableAvx2; + populateMathPolynomialApproximationPatterns(patterns, approx_options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg b/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg @@ -0,0 +1,5 @@ +import sys + +# X86Vector tests must be enabled via build flag. +if not config.mlir_run_x86vector_tests: + config.unsupported = True diff --git a/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir b/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -test-math-polynomial-approximation="enable-avx2" \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-vector-to-llvm="enable-x86vector" \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +// -------------------------------------------------------------------------- // +// rsqrt. +// -------------------------------------------------------------------------- // + +func @rsqrt() { + // Sanity-check that the scalar rsqrt still works OK. + // CHECK: inf + %0 = arith.constant 0.0 : f32 + %rsqrt_0 = math.rsqrt %0 : f32 + vector.print %rsqrt_0 : f32 + // CHECK: 0.707107 + %two = arith.constant 2.0: f32 + %rsqrt_two = math.rsqrt %two : f32 + vector.print %rsqrt_two : f32 + + // Check that the vectorized approximation is reasonably accurate. + // CHECK: 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107 + %vec8 = arith.constant dense<2.0> : vector<8xf32> + %rsqrt_vec8 = math.rsqrt %vec8 : vector<8xf32> + vector.print %rsqrt_vec8 : vector<8xf32> + + return +} + +func @main() { + call @rsqrt(): () -> () + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7057,6 +7057,7 @@ ":Support", ":Transforms", ":VectorOps", + ":X86Vector", "//llvm:Support", ], ) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -406,6 +406,7 @@ "//mlir:Pass", "//mlir:TransformUtils", "//mlir:VectorOps", + "//mlir:X86Vector", ], )