diff --git a/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h b/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h @@ -0,0 +1,29 @@ +//===- ComplexToStandard.h - Utils to convert from the complex dialect ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_ +#define MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_ + +#include + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class FuncOp; +class RewritePatternSet; +template +class OperationPass; + +/// Populate the given list with patterns that convert from Complex to Standard. +void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns); + +/// Create a pass to convert Complex operations to the Standard dialect. +std::unique_ptr> createConvertComplexToStandardPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -12,6 +12,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -99,6 +99,20 @@ let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// ComplexToStandard +//===----------------------------------------------------------------------===// + +def ConvertComplexToStandard : FunctionPass<"convert-complex-to-standard"> { + let summary = "Convert Complex dialect to standard dialect"; + let constructor = "mlir::createConvertComplexToStandardPass()"; + let dependentDialects = [ + "complex::ComplexDialect", + "math::MathDialect", + "StandardOpsDialect" + ]; +} + //===----------------------------------------------------------------------===// // GPUCommon //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AsyncToLLVM) add_subdirectory(ComplexToLLVM) +add_subdirectory(ComplexToStandard) add_subdirectory(GPUCommon) add_subdirectory(GPUToNVVM) add_subdirectory(GPUToROCDL) diff --git a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRComplexToStandard + ComplexToStandard.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToStandard + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRComplex + MLIRIR + MLIRMath + MLIRStandard + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -0,0 +1,77 @@ +//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" + +#include + +#include "../PassDetail.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct AbsOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::AbsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::AbsOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + auto type = op.getType(); + + Value real = + rewriter.create(loc, type, transformed.complex()); + Value imag = + rewriter.create(loc, type, transformed.complex()); + Value realSqr = rewriter.create(loc, real, real); + Value imagSqr = rewriter.create(loc, imag, imag); + Value sqNorm = rewriter.create(loc, realSqr, imagSqr); + + rewriter.replaceOpWithNewOp(op, sqNorm); + return success(); + } +}; +} // namespace + +void mlir::populateComplexToStandardConversionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +namespace { +struct ConvertComplexToStandardPass + : public ConvertComplexToStandardBase { + void runOnFunction() override; +}; + +void ConvertComplexToStandardPass::runOnFunction() { + auto function = getFunction(); + + // Convert to the Standard dialect using the converter defined above. + RewritePatternSet patterns(&getContext()); + populateComplexToStandardConversionPatterns(patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(function, target, std::move(patterns)))) + signalPassFailure(); +} +} // namespace + +std::unique_ptr> +mlir::createConvertComplexToStandardPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -36,6 +36,10 @@ class NVVMDialect; } // end namespace NVVM +namespace math { +class MathDialect; +} // end namespace math + namespace memref { class MemRefDialect; } // end namespace memref diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -convert-complex-to-standard | FileCheck %s + +// CHECK-LABEL: func @complex_abs +// CHECK-SAME: %[[ARG:.*]]: complex +func @complex_abs(%arg: complex) -> f32 { + %abs = complex.abs %arg: complex + return %abs : f32 +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-DAG: %[[REAL_SQ:.*]] = mulf %[[REAL]], %[[REAL]] : f32 +// CHECK-DAG: %[[IMAG_SQ:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_NORM:.*]] = addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: return %[[NORM]] : f32 + diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -convert-complex-to-standard -convert-complex-to-llvm -convert-std-to-llvm | FileCheck %s + +// CHECK-LABEL: llvm.func @complex_abs +// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]]) +func @complex_abs(%arg: complex) -> f32 { + %abs = complex.abs %arg: complex + return %abs : f32 +} +// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]] +// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]] +// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32 +// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32 +// CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32 +// CHECK: llvm.return %[[NORM]] : f32 +