diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h @@ -0,0 +1,26 @@ +//===- MathToLibm.h - Utils to convert from the complex dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ +#define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +template +class OperationPass; + +/// Populate the given list with patterns that convert from Math to Libm calls. +void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit); + +/// Create a pass to convert Math operations to libm calls. +std::unique_ptr> createConvertMathToLibmPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -20,6 +20,7 @@ #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -228,6 +228,19 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// MathToLibm +//===----------------------------------------------------------------------===// + +def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> { + let summary = "Convert Math dialect to libm calls"; + let description = [{ + This pass converts supported Math ops to libm calls. + }]; + let constructor = "mlir::createConvertMathToLibmPass()"; + let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"]; +} + //===----------------------------------------------------------------------===// // OpenMPToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(LinalgToLLVM) add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) +add_subdirectory(MathToLibm) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) add_subdirectory(SCFToGPU) diff --git a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRMathToLibm + MathToLibm.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLibm + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRMath + MLIRStandardOpsTransforms + ) diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -0,0 +1,147 @@ +//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToLibm/MathToLibm.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +namespace { +// Pattern to convert vector operations to scalar operations. This is needed as +// libm calls require scalars. +template +struct VecOpToScalarOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; +}; +// Pattern to convert scalar math operations to calls to libm functions. +// Additionally the libm function signatures are declared. +template +struct ScalarOpToLibmCall : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, + StringRef doubleFunc, PatternBenefit benefit) + : OpRewritePattern(context, benefit), floatFunc(floatFunc), + doubleFunc(doubleFunc){}; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; + +private: + std::string floatFunc, doubleFunc; +}; +} // namespace + +template +LogicalResult +VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { + auto opType = op.getType(); + auto loc = op.getLoc(); + auto vecType = opType.template dyn_cast(); + + if (!vecType) + return failure(); + if (!vecType.hasRank()) + return failure(); + auto shape = vecType.getShape(); + // TODO: support multidimensional vectors + if (shape.size() != 1) + return failure(); + + Value result = rewriter.create( + loc, DenseElementsAttr::get( + vecType, FloatAttr::get(vecType.getElementType(), 0.0))); + for (auto i = 0; i < shape.front(); ++i) { + SmallVector operands; + for (auto input : op->getOperands()) + operands.push_back( + rewriter.create(loc, input, i)); + Value scalarOp = + rewriter.create(loc, vecType.getElementType(), operands); + result = rewriter.create(loc, scalarOp, result, i); + } + rewriter.replaceOp(op, {result}); + return success(); +} + +template +LogicalResult +ScalarOpToLibmCall::matchAndRewrite(Op op, + PatternRewriter &rewriter) const { + auto module = op->template getParentOfType(); + auto type = op.getType(); + // TODO: Support Float16 by upcasting to Float32 + if (!type.template isa()) + return failure(); + + auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; + auto opFunc = module.template lookupSymbol(name); + // Forward declare function if it hasn't already been + if (!opFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto opFunctionTy = FunctionType::get( + rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); + opFunc = + rewriter.create(rewriter.getUnknownLoc(), name, opFunctionTy); + opFunc.setPrivate(); + } + assert(opFunc.getType().template cast().getResults() == + op->getResultTypes()); + assert(opFunc.getType().template cast().getInputs() == + op->getOperandTypes()); + + rewriter.replaceOpWithNewOp(op, opFunc, op->getOperands()); + + return success(); +} + +void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add, VecOpToScalarOp, + VecOpToScalarOp>(patterns.getContext(), benefit); + patterns.add>(patterns.getContext(), + "atan2f", "atan2", benefit); + patterns.add>(patterns.getContext(), + "expm1f", "expm1", benefit); + patterns.add>(patterns.getContext(), "tanhf", + "tanh", benefit); +} + +namespace { +struct ConvertMathToLibmPass + : public ConvertMathToLibmBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToLibmPass::runOnOperation() { + auto module = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertMathToLibmPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir b/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s + +// CHECK-DAG: @expm1(f64) -> f64 +// CHECK-DAG: @expm1f(f32) -> f32 +// CHECK-DAG: @atan2(f64, f64) -> f64 +// CHECK-DAG: @atan2f(f32, f32) -> f32 +// CHECK-DAG: @tanh(f64) -> f64 +// CHECK-DAG: @tanhf(f32) -> f32 + +// CHECK-LABEL: func @tanh_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @tanhf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.tanh %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @tanh(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.tanh %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + + +// CHECK-LABEL: func @atan2_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32 + %float_result = math.atan2 %float, %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64 + %double_result = math.atan2 %double, %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + +// CHECK-LABEL: func @expm1_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func @expm1_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @expm1f(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.expm1 %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @expm1(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.expm1 %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + +func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { + %float_result = math.expm1 %float : vector<2xf32> + %double_result = math.expm1 %double : vector<2xf64> + return %float_result, %double_result : vector<2xf32>, vector<2xf64> +} +// CHECK-LABEL: func @expm1_vec_caller( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { +// CHECK: %[[CVF:.*]] = constant dense<0.000000e+00> : vector<2xf32> +// CHECK: %[[CVD:.*]] = constant dense<0.000000e+00> : vector<2xf64> +// CHECK: %[[C0:.*]] = constant 0 : i32 +// CHECK: %[[C1:.*]] = constant 1 : i32 +// CHECK: %[[IN0_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C0]] : i32] : vector<2xf32> +// CHECK: %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_8:.*]] = vector.insertelement %[[OUT0_F32]], %[[CVF]]{{\[}}%[[C0]] : i32] : vector<2xf32> +// CHECK: %[[IN1_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C1]] : i32] : vector<2xf32> +// CHECK: %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_11:.*]] = vector.insertelement %[[OUT1_F32]], %[[VAL_8]]{{\[}}%[[C1]] : i32] : vector<2xf32> +// CHECK: %[[IN0_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C0]] : i32] : vector<2xf64> +// CHECK: %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64 +// CHECK: %[[VAL_14:.*]] = vector.insertelement %[[OUT0_F64]], %[[CVD]]{{\[}}%[[C0]] : i32] : vector<2xf64> +// CHECK: %[[IN1_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C1]] : i32] : vector<2xf64> +// CHECK: %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64 +// CHECK: %[[VAL_17:.*]] = vector.insertelement %[[OUT1_F64]], %[[VAL_14]]{{\[}}%[[C1]] : i32] : vector<2xf64> +// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> +// CHECK: } +