diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h @@ -0,0 +1,26 @@ +//===- MathToLLVM.h - Utils to convert from the complex dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H_ +#define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +/// Populate the given list with patterns that convert from Math to LLVM Libm +/// calls. +void populateMathToLibmConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Create a pass to convert Math operations to the LLVM Libm calls dialect. +std::unique_ptr> createConvertMathToLibmPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -20,6 +20,7 @@ #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -228,6 +228,20 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// MathToLibm +//===----------------------------------------------------------------------===// + +def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> { + let summary = "Convert Math dialect to LLVM based libm calls"; + let description = [{ + This pass converts supported Math ops to libm calls. This is only intended + to handle ops not supported by llvm intrinsics. + }]; + let constructor = "mlir::createConvertMathToLibmPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + //===----------------------------------------------------------------------===// // OpenMPToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(LinalgToLLVM) add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) +add_subdirectory(MathToLLVM) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) add_subdirectory(SCFToGPU) diff --git a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRMathToLLVM + MathToLibm.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRMath + MLIRLLVMIR + MLIRStandardOpsTransforms + MLIRStandardToLLVM + ) diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLibm.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLibm.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToLLVM/MathToLibm.cpp @@ -0,0 +1,145 @@ +//===- MathToLibm.cpp - conversion from Math to LLVM dialect libm calls ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" + +using namespace mlir; + +namespace { +// Pattern to convert vector operations to scalar operations. This is needed as +// libm calls require scalars. +template +struct VecOpToScalarOp : public OpRewritePattern { +public: + VecOpToScalarOp(MLIRContext *context) + : OpRewritePattern(context, 0){}; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; +}; +// Pattern to convert scalar math operations to calls to libm functions. +// Additionally the libm function signatures are declared. +template +struct ScalarOpToLibmCall : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ScalarOpToLibmCall(MLIRContext *context, std::string floatFunc, + std::string doubleFunc) + : OpRewritePattern(context, 0), floatFunc(floatFunc), + doubleFunc(doubleFunc){}; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; + +private: + std::string floatFunc, doubleFunc; +}; +} // namespace + +template +LogicalResult +VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { + LLVMTypeConverter type_converter(rewriter.getContext()); + auto opType = op.getType(); + auto loc = op.getLoc(); + auto vecType = opType.template dyn_cast(); + + if (!vecType) + return failure(); + if (!vecType.hasRank()) + return failure(); + auto shape = vecType.getShape(); + // TODO: support multidimensional vectors + if (shape.size() != 1) + return failure(); + + Value result = rewriter.create( + loc, vecType, DenseElementsAttr::get(vecType, 0.0)); + for (auto i = 0; i < shape.front(); ++i) { + SmallVector operands; + for (auto input : op->getOperands()) + operands.push_back( + rewriter.create(loc, input, i)); + Value scalarOp = + rewriter.create(loc, vecType.getElementType(), operands); + result = rewriter.create(loc, scalarOp, result, i); + } + rewriter.replaceOp(op, {result}); + return success(); +} + +template +LogicalResult +ScalarOpToLibmCall::matchAndRewrite(Op op, + PatternRewriter &rewriter) const { + LLVMTypeConverter type_converter(rewriter.getContext()); + auto module = op->template getParentOfType(); + auto opType = op.getType(); + std::string name; + if (!opType.template isa()) + return failure(); + + name = opType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; + auto opFunc = module.template lookupSymbol(name); + // Forward declare function if it hasn't already been + auto type = type_converter.convertType(opType); + if (!opFunc || opFunc.getType().getReturnType() != type) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + llvm::SmallVector operandTypes(op->getNumOperands(), type); + auto opFunctionTy = LLVM::LLVMFunctionType::get(type, operandTypes); + opFunc = rewriter.create(rewriter.getUnknownLoc(), name, + opFunctionTy); + } + rewriter.replaceOpWithNewOp(op, opFunc, op->getOperands()); + + return success(); +} + +void mlir::populateMathToLibmConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add, VecOpToScalarOp, + VecOpToScalarOp>(patterns.getContext()); + patterns.add>(patterns.getContext(), + "atan2f", "atan2"); + patterns.add>(patterns.getContext(), + "expm1f", "expm1"); + patterns.add>(patterns.getContext(), "tanhf", + "tanh"); +} + +namespace { +struct ConvertMathToLibmPass + : public ConvertMathToLibmBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToLibmPass::runOnOperation() { + auto module = getOperation(); + + // Convert to the LLVM IR dialect using the converter defined above. + RewritePatternSet patterns(&getContext()); + LLVMTypeConverter converter(&getContext()); + populateMathToLibmConversionPatterns(converter, patterns); + + LLVMConversionTarget target(getContext()); + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalDialect(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertMathToLibmPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir b/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -convert-math-to-libm | FileCheck %s + +// CHECK-DAG: @expm1(f64) -> f64 +// CHECK-DAG: @expm1f(f32) -> f32 +// CHECK-DAG: @atan2(f64, f64) -> f64 +// CHECK-DAG: @atan2f(f32, f32) -> f32 +// CHECK-DAG: @tanh(f64) -> f64 +// CHECK-DAG: @tanhf(f32) -> f32 + +// CHECK-LABEL: func @tanh_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = llvm.call @tanhf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.tanh %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = llvm.call @tanh(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.tanh %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + + +// CHECK-LABEL: func @atan2_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = llvm.call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32 + %float_result = math.atan2 %float, %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = llvm.call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64 + %double_result = math.atan2 %double, %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + +// CHECK-LABEL: func @expm1_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func @expm1_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = llvm.call @expm1f(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.expm1 %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = llvm.call @expm1(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.expm1 %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +}