diff --git a/mlir/include/mlir/Dialect/Math/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Math/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Math/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Math) +add_public_tablegen_target(MLIRMathTransformsIncGen) + +add_mlir_doc(Passes MathPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -9,7 +9,17 @@ #ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ +#include "mlir/Pass/Pass.h" + namespace mlir { +namespace math { +#define GEN_PASS_DECL +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +#define GEN_PASS_DECL_MATHUPLIFTTOFMA +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace math class RewritePatternSet; @@ -34,6 +44,8 @@ RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options = {}); +void populateUpliftToFMAPatterns(RewritePatternSet &patterns); + } // namespace mlir #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -0,0 +1,22 @@ +//===-- Passes.td - Math pass definition file --------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES +#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def MathUpliftToFMA : Pass<"math-uplift-to-fma"> { + let summary = "Uplift arith ops to math.fma."; + let description = [{ + Uplift sequence of addf and mulf ops to math.fma if fastmath flags allows it. + }]; + let dependentDialects = ["math::MathDialect"]; +} + +#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -25,6 +25,7 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" @@ -70,6 +71,7 @@ registerNVGPUPasses(); registerSparseTensorPasses(); LLVM::registerLLVMPasses(); + math::registerMathPasses(); memref::registerMemRefPasses(); registerSCFPasses(); registerShapePasses(); diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -2,10 +2,14 @@ AlgebraicSimplification.cpp ExpandPatterns.cpp PolynomialApproximation.cpp + UpliftToFMA.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms + DEPENDS + MLIRMathTransformsIncGen + LINK_LIBS PUBLIC MLIRArithDialect MLIRDialectUtils diff --git a/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp @@ -0,0 +1,79 @@ +//===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements uplifting from arith ops to math.fma. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::math { +#define GEN_PASS_DEF_MATHUPLIFTTOFMA +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + +using namespace mlir; + +template +static bool isValidForFMA(Op op) { + return static_cast(op.getFastmath() & arith::FastMathFlags::contract); +} + +namespace { + +struct UpliftFma final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter &rewriter) const override { + if (!isValidForFMA(op)) + return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma"); + + Value c; + arith::MulFOp ab; + if ((ab = op.getLhs().getDefiningOp())) { + c = op.getRhs(); + } else if ((ab = op.getRhs().getDefiningOp())) { + c = op.getLhs(); + } else { + return rewriter.notifyMatchFailure(op, "no mulf op"); + } + + if (!isValidForFMA(ab)) + return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma"); + + Value a = ab.getLhs(); + Value b = ab.getRhs(); + arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath(); + rewriter.replaceOpWithNewOp(op, a, b, c, fmf); + return success(); + } +}; + +struct MathUpliftToFMA final + : math::impl::MathUpliftToFMABase { + using MathUpliftToFMABase::MathUpliftToFMABase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateUpliftToFMAPatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::populateUpliftToFMAPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Math/uplift-to-fma.mlir b/mlir/test/Dialect/Math/uplift-to-fma.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Math/uplift-to-fma.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s --split-input-file --math-uplift-to-fma | FileCheck %s + +// No uplifting without fastmath flags. +// CHECK-LABEL: func @test +// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32) +// CHECK: %[[V1:.*]] = arith.mulf %[[ARG1]], %[[ARG2]] +// CHECK: %[[V2:.*]] = arith.addf %[[V1]], %[[ARG3]] +// CHECK: return %[[V2]] +func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 { + %1 = arith.mulf %arg1, %arg2 : f32 + %2 = arith.addf %1, %arg3 : f32 + return %2 : f32 +} + +// ----- + +// CHECK-LABEL: func @test +// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32) +// CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath : f32 +// CHECK: return %[[RES]] +func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 { + %1 = arith.mulf %arg1, %arg2 fastmath : f32 + %2 = arith.addf %1, %arg3 fastmath : f32 + return %2 : f32 +} + +// ----- + +// CHECK-LABEL: func @test +// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32) +// CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath : f32 +// CHECK: return %[[RES]] +func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 { + %1 = arith.mulf %arg1, %arg2 fastmath : f32 + %2 = arith.addf %arg3, %1 fastmath : f32 + return %2 : f32 +}