diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -15,6 +15,8 @@ void populateExpandTanhPattern(RewritePatternSet &patterns); +void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); + void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns); } // namespace mlir diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -0,0 +1,112 @@ +//===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrites based on the basic rules of algebra +// (Commutativity, associativity, etc...) and strength reductions for math +// operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include + +using namespace mlir; + +//----------------------------------------------------------------------------// +// PowFOp strength reduction. +//----------------------------------------------------------------------------// + +namespace { +struct PowFStrengthReduction : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::PowFOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +PowFStrengthReduction::matchAndRewrite(math::PowFOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value x = op.lhs(); + + FloatAttr scalarExponent; + DenseFPElementsAttr vectorExponent; + + bool isScalar = matchPattern(op.rhs(), m_Constant(&scalarExponent)); + bool isVector = matchPattern(op.rhs(), m_Constant(&vectorExponent)); + + // Returns true if exponent is a constant equal to `value`. + auto isExponentValue = [&](double value) -> bool { + if (isScalar) + return scalarExponent.getValue().isExactlyValue(value); + + if (isVector && vectorExponent.isSplat()) + return vectorExponent.getSplatValue() + .getValue() + .isExactlyValue(value); + + return false; + }; + + // Maybe broadcasts scalar value into vector type compatible with `op`. + auto bcast = [&](Value value) -> Value { + if (auto vec = op.getType().dyn_cast()) + return rewriter.create(op.getLoc(), vec, value); + return value; + }; + + // Replace `pow(x, 1.0)` with `x`. + if (isExponentValue(1.0)) { + rewriter.replaceOp(op, x); + return success(); + } + + // Replace `pow(x, 2.0)` with `x * x`. + if (isExponentValue(2.0)) { + rewriter.replaceOpWithNewOp(op, ValueRange({x, x})); + return success(); + } + + // Replace `pow(x, 2.0)` with `x * x * x`. + if (isExponentValue(3.0)) { + Value square = rewriter.create(op.getLoc(), ValueRange({x, x})); + rewriter.replaceOpWithNewOp(op, ValueRange({x, square})); + return success(); + } + + // Replace `pow(x, -1.0)` with `1.0 / x`. + if (isExponentValue(-1.0)) { + Value one = rewriter.create( + loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); + rewriter.replaceOpWithNewOp(op, ValueRange({bcast(one), x})); + return success(); + } + + // Replace `pow(x, -2.0)` with `sqrt(x)`. + if (isExponentValue(-1.0)) { + rewriter.replaceOpWithNewOp(op, x); + return success(); + } + + return failure(); +} + +//----------------------------------------------------------------------------// + +void mlir::populateMathAlgebraicSimplificationPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRMathTransforms + AlgebraicSimplification.cpp ExpandTanh.cpp PolynomialApproximation.cpp diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s --dump-input=always + +// CHECK-LABEL: @pow_noop +func @pow_noop(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: return %arg0, %arg1 + %c = constant 1.0 : f32 + %v = constant dense <1.0> : vector<4xf32> + %0 = math.powf %arg0, %c : f32 + %1 = math.powf %arg1, %v : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @pow_square +func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SCALAR:.*]] = mulf %arg0, %arg0 + // CHECK: %[[VECTOR:.*]] = mulf %arg1, %arg1 + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = constant 2.0 : f32 + %v = constant dense <2.0> : vector<4xf32> + %0 = math.powf %arg0, %c : f32 + %1 = math.powf %arg1, %v : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @pow_cube +func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[TMP_S:.*]] = mulf %arg0, %arg0 + // CHECK: %[[SCALAR:.*]] = mulf %arg0, %[[TMP_S]] + // CHECK: %[[TMP_V:.*]] = mulf %arg1, %arg1 + // CHECK: %[[VECTOR:.*]] = mulf %arg1, %[[TMP_V]] + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = constant 3.0 : f32 + %v = constant dense <3.0> : vector<4xf32> + %0 = math.powf %arg0, %c : f32 + %1 = math.powf %arg1, %v : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @pow_recip +func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[CST_S:.*]] = constant 1.0{{.*}} : f32 + // CHECK: %[[CST_V:.*]] = constant dense<1.0{{.*}}> : vector<4xf32> + // CHECK: %[[SCALAR:.*]] = divf %[[CST_S]], %arg0 + // CHECK: %[[VECTOR:.*]] = divf %[[CST_V]], %arg1 + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = constant -1.0 : f32 + %v = constant dense <-1.0> : vector<4xf32> + %0 = math.powf %arg0, %c : f32 + %1 = math.powf %arg1, %v : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt --- a/mlir/test/lib/Dialect/Math/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMathTestPasses + TestAlgebraicSimplification.cpp TestExpandTanh.cpp TestPolynomialApproximation.cpp diff --git a/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp @@ -0,0 +1,50 @@ +//===- TestAlgebraicSimplification.cpp - Test algebraic simplification ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for algebraic simplification patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestMathAlgebraicSimplificationPass + : public PassWrapper { + void runOnFunction() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { + return "test-math-algebraic-simplification"; + } + StringRef getDescription() const final { + return "Test math algebraic simplification"; + } +}; +} // end anonymous namespace + +void TestMathAlgebraicSimplificationPass::runOnFunction() { + RewritePatternSet patterns(&getContext()); + populateMathAlgebraicSimplificationPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +namespace mlir { +namespace test { +void registerTestMathAlgebraicSimplificationPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -92,6 +92,7 @@ void registerTestLoopFusion(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); +void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); @@ -173,6 +174,7 @@ test::registerTestLoopFusion(); test::registerTestLoopMappingPass(); test::registerTestLoopUnrollingPass(); + test::registerTestMathAlgebraicSimplificationPass(); test::registerTestMathPolynomialApproximationPass(); test::registerTestMemRefDependenceCheck(); test::registerTestMemRefStrideCalculation();