diff --git a/mlir/include/mlir/Conversion/ComplexToLibm/ComplexToLibm.h b/mlir/include/mlir/Conversion/ComplexToLibm/ComplexToLibm.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ComplexToLibm/ComplexToLibm.h @@ -0,0 +1,27 @@ +//===- ComplexToLibm.h - Utils to convert from the complex dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_ +#define MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +template +class OperationPass; + +/// Populate the given list with patterns that convert from Complex to Libm +/// calls. +void populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit); + +/// Create a pass to convert Complex operations to libm calls. +std::unique_ptr> createConvertComplexToLibmPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -17,6 +17,7 @@ #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -196,6 +196,21 @@ let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// ComplexToLibm +//===----------------------------------------------------------------------===// + +def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> { + let summary = "Convert Complex dialect to libm calls"; + let description = [{ + This pass converts supported Complex ops to libm calls. + }]; + let constructor = "mlir::createConvertComplexToLibmPass()"; + let dependentDialects = [ + "func::FuncDialect", + ]; +} + //===----------------------------------------------------------------------===// // ComplexToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -346,6 +346,24 @@ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)"; } +//===----------------------------------------------------------------------===// +// PowOp +//===----------------------------------------------------------------------===// + +def PowOp : ComplexArithmeticOp<"pow"> { + let summary = "complex power function"; + let description = [{ + The `sqrt` operation takes a complex number raises it to the given complex + exponent. + + Example: + + ```mlir + %a = complex.pow %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// @@ -409,6 +427,25 @@ let results = (outs Complex:$result); } +//===----------------------------------------------------------------------===// +// SqrtOp +//===----------------------------------------------------------------------===// + +def SqrtOp : ComplexUnaryOp<"sqrt", [SameOperandsAndResultType]> { + let summary = "complex square root"; + let description = [{ + The `sqrt` operation takes a complex number and returns its square root. + + Example: + + ```mlir + %a = complex.sqrt %b : complex + ``` + }]; + + let results = (outs Complex:$result); +} + //===----------------------------------------------------------------------===// // SubOp //===----------------------------------------------------------------------===// @@ -426,4 +463,24 @@ }]; } +//===----------------------------------------------------------------------===// +// TanhOp +//===----------------------------------------------------------------------===// + +def TanhOp : ComplexUnaryOp<"tanh", [SameOperandsAndResultType]> { + let summary = "complex hyperbolic tangent"; + let description = [{ + The `tanh` operation takes a complex number and returns its hyperbolic + tangent. + + Example: + + ```mlir + %a = complex.tanh %b : complex + ``` + }]; + + let results = (outs Complex:$result); +} + #endif // COMPLEX_OPS diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(AsyncToLLVM) add_subdirectory(BufferizationToMemRef) add_subdirectory(ComplexToLLVM) +add_subdirectory(ComplexToLibm) add_subdirectory(ComplexToStandard) add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSPIRV) diff --git a/mlir/lib/Conversion/ComplexToLibm/CMakeLists.txt b/mlir/lib/Conversion/ComplexToLibm/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToLibm/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRComplexToLibm + ComplexToLibm.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToLibm + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRFunc + MLIRComplex + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -0,0 +1,101 @@ +//===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +namespace { +// Pattern to convert scalar complex operations to calls to libm functions. +// Additionally the libm function signatures are declared. +template +struct ScalarOpToLibmCall : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, + StringRef doubleFunc, PatternBenefit benefit) + : OpRewritePattern(context, benefit), floatFunc(floatFunc), + doubleFunc(doubleFunc){}; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; + +private: + std::string floatFunc, doubleFunc; +}; +} // namespace + +template +LogicalResult +ScalarOpToLibmCall::matchAndRewrite(Op op, + PatternRewriter &rewriter) const { + auto module = SymbolTable::getNearestSymbolTable(op); + auto type = op.getType().template cast(); + Type elementType = type.getElementType(); + if (!elementType.isa()) + return failure(); + + auto name = + elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; + auto opFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, name)); + // Forward declare function if it hasn't already been + if (!opFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + auto opFunctionTy = FunctionType::get( + rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); + opFunc = rewriter.create(rewriter.getUnknownLoc(), name, + opFunctionTy); + opFunc.setPrivate(); + } + assert(isa(SymbolTable::lookupSymbolIn(module, name))); + + rewriter.replaceOpWithNewOp(op, name, type, op->getOperands()); + + return success(); +} + +void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add>(patterns.getContext(), + "cpowf", "cpow", benefit); + patterns.add>(patterns.getContext(), + "csqrtf", "csqrt", benefit); + patterns.add>(patterns.getContext(), + "ctanhf", "ctanh", benefit); +} + +namespace { +struct ConvertComplexToLibmPass + : public ConvertComplexToLibmBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertComplexToLibmPass::runOnOperation() { + auto module = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> +mlir::createConvertComplexToLibmPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -convert-complex-to-libm -canonicalize | FileCheck %s + +// CHECK-DAG: @cpowf(complex, complex) -> complex +// CHECK-DAG: @cpow(complex, complex) -> complex +// CHECK-DAG: @csqrtf(complex) -> complex +// CHECK-DAG: @csqrt(complex) -> complex +// CHECK-DAG: @ctanhf(complex) -> complex +// CHECK-DAG: @ctanh(complex) -> complex + +// CHECK-LABEL: func @cpow_caller +// CHECK-SAME: %[[FLOAT:.*]]: complex +// CHECK-SAME: %[[DOUBLE:.*]]: complex +func.func @cpow_caller(%float: complex, %double: complex) -> (complex, complex) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cpowf(%[[FLOAT]], %[[FLOAT]]) : (complex, complex) -> complex + %float_result = complex.pow %float, %float : complex + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cpow(%[[DOUBLE]], %[[DOUBLE]]) : (complex, complex) -> complex + %double_result = complex.pow %double, %double : complex + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : complex, complex +} + +// CHECK-LABEL: func @csqrt_caller +// CHECK-SAME: %[[FLOAT:.*]]: complex +// CHECK-SAME: %[[DOUBLE:.*]]: complex +func.func @csqrt_caller(%float: complex, %double: complex) -> (complex, complex) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @csqrtf(%[[FLOAT]]) : (complex) -> complex + %float_result = complex.sqrt %float : complex + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @csqrt(%[[DOUBLE]]) : (complex) -> complex + %double_result = complex.sqrt %double : complex + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : complex, complex +} + +// CHECK-LABEL: func @ctanh_caller +// CHECK-SAME: %[[FLOAT:.*]]: complex +// CHECK-SAME: %[[DOUBLE:.*]]: complex +func.func @ctanh_caller(%float: complex, %double: complex) -> (complex, complex) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @ctanhf(%[[FLOAT]]) : (complex) -> complex + %float_result = complex.tanh %float : complex + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @ctanh(%[[DOUBLE]]) : (complex) -> complex + %double_result = complex.tanh %double : complex + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : complex, complex +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2456,6 +2456,7 @@ ":AsyncToLLVM", ":BufferizationToMemRef", ":ComplexToLLVM", + ":ComplexToLibm", ":ComplexToStandard", ":ControlFlowToLLVM", ":ControlFlowToSPIRV", @@ -4707,7 +4708,6 @@ ], ) - cc_library( name = "TensorToSPIRV", srcs = glob([ @@ -6188,6 +6188,7 @@ ":BufferizationTransforms", ":ComplexDialect", ":ComplexToLLVM", + ":ComplexToLibm", ":ControlFlowOps", ":ConversionPasses", ":DLTIDialect", @@ -8060,6 +8061,30 @@ ], ) +cc_library( + name = "ComplexToLibm", + srcs = glob([ + "lib/Conversion/ComplexToLibm/*.cpp", + "lib/Conversion/ComplexToLibm/*.h", + ]) + [":ConversionPassDetail"], + hdrs = glob([ + "include/mlir/Conversion/ComplexToLibm/*.h", + ]), + includes = ["include"], + deps = [ + ":ComplexDialect", + ":ConversionPassIncGen", + ":DialectUtils", + ":FuncDialect", + ":IR", + ":Pass", + ":Support", + ":Transforms", + "//llvm:Core", + "//llvm:Support", + ], +) + cc_library( name = "ComplexToStandard", srcs = glob([