diff --git a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h @@ -0,0 +1,29 @@ +//===- ComplexToLLVM.h - Utils to convert from the complex dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ +#define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class MLIRContext; +class ModuleOp; +template +class OperationPass; + +/// Populate the given list with patterns that convert from Complex to LLVM. +void populateComplexToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Create a pass to convert Complex operations to the LLVMIR dialect. +std::unique_ptr> createConvertComplexToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -11,6 +11,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -88,6 +88,16 @@ let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// ComplexToLLVM +//===----------------------------------------------------------------------===// + +def ConvertComplexToLLVM : Pass<"convert-complex-to-llvm", "ModuleOp"> { + let summary = "Convert Complex dialect to LLVM dialect"; + let constructor = "mlir::createConvertComplexToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + //===----------------------------------------------------------------------===// // GPUCommon //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) add_subdirectory(AVX512) +add_subdirectory(Complex) add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) diff --git a/mlir/include/mlir/Dialect/Complex/CMakeLists.txt b/mlir/include/mlir/Dialect/Complex/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Complex/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(ComplexOps complex) +add_mlir_doc(ComplexOps -gen-dialect-doc ComplexOps Dialects/) diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -0,0 +1,32 @@ +//===- Complex.h - Complex dialect --------------------------------*- C++-*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ +#define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/VectorInterfaces.h" + +//===----------------------------------------------------------------------===// +// Complex Dialect +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Complex/IR/ComplexOpsDialect.h.inc" + +//===----------------------------------------------------------------------===// +// Complex Dialect Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Complex/IR/ComplexOps.h.inc" + +#endif // MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td @@ -0,0 +1,23 @@ +//===- ComplexBase.td - Base definitions for complex dialect -*- tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef COMPLEX_BASE +#define COMPLEX_BASE + +include "mlir/IR/OpBase.td" + +def Complex_Dialect : Dialect { + let name = "complex"; + let cppNamespace = "::mlir::complex"; + let description = [{ + The complex dialect is intended to hold complex numbers creation and + arithmetic ops. + }]; +} + +#endif // COMPLEX_BASE diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -0,0 +1,153 @@ +//===- ComplexOps.td - Complex op definitions ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef COMPLEX_OPS +#define COMPLEX_OPS + +include "mlir/Dialect/Complex/IR/ComplexBase.td" +include "mlir/Interfaces/VectorInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class Complex_Op traits = []> + : Op + +// Base class for standard arithmetic operations on complex numbers with a +// floating-point element type. These operations take two operands and return +// one result, all of which must be complex numbers of the same type. +class ComplexArithmeticOp traits = []> : + Complex_Op, + ElementwiseMappable])> { + let arguments = (ins Complex:$lhs, Complex:$rhs); + let results = (outs Complex:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; + let verifier = ?; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : ComplexArithmeticOp<"add"> { + let summary = "complex addition"; + let description = [{ + The `add` operation takes two complex numbers and returns their sum. + + Example: + + ```mlir + %a = add %b, %c : complex + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// CreateOp +//===----------------------------------------------------------------------===// + +def CreateOp : Complex_Op<"create", + [NoSideEffect, + AllTypesMatch<["real", "imaginary"]>, + TypesMatchWith<"complex element type matches real operand type", + "complex", "real", + "$_self.cast().getElementType()">, + TypesMatchWith<"complex element type matches imaginary operand type", + "complex", "imaginary", + "$_self.cast().getElementType()">]> { + + let summary = "complex number creation operation"; + let description = [{ + The `complex.complex` operation creates a complex number from two + floating-point operands, the real and the imaginary part. + + Example: + + ```mlir + %a = create_complex %b, %c : complex + ``` + }]; + + let arguments = (ins AnyFloat:$real, AnyFloat:$imaginary); + let results = (outs Complex:$complex); + + let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)"; +} + +//===----------------------------------------------------------------------===// +// ImOp +//===----------------------------------------------------------------------===// + +def ImOp : Complex_Op<"im", + [NoSideEffect, + TypesMatchWith<"complex element type matches result type", + "complex", "imaginary", + "$_self.cast().getElementType()">]> { + let summary = "extracts the imaginary part of a complex number"; + let description = [{ + The `im` op takes a single complex number and extracts the imaginary part. + + Example: + + ```mlir + %a = im %b : complex + ``` + }]; + + let arguments = (ins Complex:$complex); + let results = (outs AnyFloat:$imaginary); + + let assemblyFormat = "$complex attr-dict `:` type($complex)"; +} + +//===----------------------------------------------------------------------===// +// ReOp +//===----------------------------------------------------------------------===// + +def ReOp : Complex_Op<"re", + [NoSideEffect, + TypesMatchWith<"complex element type matches result type", + "complex", "real", + "$_self.cast().getElementType()">]> { + let summary = "extracts the real part of a complex number"; + let description = [{ + The `re` op takes a single complex number and extracts the real part. + + Example: + + ```mlir + %a = re %b : complex + ``` + }]; + + let arguments = (ins Complex:$complex); + let results = (outs AnyFloat:$real); + + let assemblyFormat = "$complex attr-dict `:` type($complex)"; +} + + +//===----------------------------------------------------------------------===// +// SubOp +//===----------------------------------------------------------------------===// + +def SubOp : ComplexArithmeticOp<"sub"> { + let summary = "complex subtraction"; + let description = [{ + The `sub` operation takes two complex numbers and returns their difference. + + Example: + + ```mlir + %a = sub %b, %c : complex + ``` + }]; +} + +#endif // COMPLEX_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -19,6 +19,7 @@ #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" @@ -52,6 +53,7 @@ arm_neon::ArmNeonDialect, async::AsyncDialect, avx512::AVX512Dialect, + complex::ComplexDialect, gpu::GPUDialect, LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect, diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(ArmNeonToLLVM) add_subdirectory(AsyncToLLVM) add_subdirectory(AVX512ToLLVM) +add_subdirectory(ComplexToLLVM) add_subdirectory(GPUCommon) add_subdirectory(GPUToNVVM) add_subdirectory(GPUToROCDL) diff --git a/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRComplexToLLVM + ComplexToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRComplex + MLIRLLVMIR + MLIRStandardOpsTransforms + MLIRStandardToLLVM + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -0,0 +1,193 @@ +//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { + +struct CreateOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::CreateOp complexOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::CreateOp::Adaptor transformed(operands); + + // Pack real and imaginary part in a complex number struct. + auto loc = complexOp.getLoc(); + auto structType = typeConverter->convertType(complexOp.getType()); + auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); + complexStruct.setReal(rewriter, loc, transformed.real()); + complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); + + rewriter.replaceOp(complexOp, {complexStruct}); + return success(); + } +}; + +struct ReOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::ReOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::ReOp::Adaptor transformed(operands); + + // Extract real part from the complex number struct. + ComplexStructBuilder complexStruct(transformed.complex()); + Value real = complexStruct.real(rewriter, op.getLoc()); + rewriter.replaceOp(op, real); + + return success(); + } +}; + +struct ImOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::ImOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::ImOp::Adaptor transformed(operands); + + // Extract imaginary part from the complex number struct. + ComplexStructBuilder complexStruct(transformed.complex()); + Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); + rewriter.replaceOp(op, imaginary); + + return success(); + } +}; + +struct BinaryComplexOperands { + std::complex lhs; + std::complex rhs; +}; + +template +BinaryComplexOperands +unpackBinaryComplexOperands(OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op.getLoc(); + typename OpTy::Adaptor transformed(operands); + + // Extract real and imaginary values from operands. + BinaryComplexOperands unpacked; + ComplexStructBuilder lhs(transformed.lhs()); + unpacked.lhs.real(lhs.real(rewriter, loc)); + unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); + ComplexStructBuilder rhs(transformed.rhs()); + unpacked.rhs.real(rhs.real(rewriter, loc)); + unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); + + return unpacked; +} + +struct AddOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::AddOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + BinaryComplexOperands arg = + unpackBinaryComplexOperands(op, operands, rewriter); + + // Initialize complex number struct for result. + auto structType = typeConverter->convertType(op.getType()); + auto result = ComplexStructBuilder::undef(rewriter, loc, structType); + + // Emit IR to add complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + Value real = + rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); + Value imag = + rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); + result.setReal(rewriter, loc, real); + result.setImaginary(rewriter, loc, imag); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +struct SubOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::SubOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + BinaryComplexOperands arg = + unpackBinaryComplexOperands(op, operands, rewriter); + + // Initialize complex number struct for result. + auto structType = typeConverter->convertType(op.getType()); + auto result = ComplexStructBuilder::undef(rewriter, loc, structType); + + // Emit IR to substract complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + Value real = + rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); + Value imag = + rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); + result.setReal(rewriter, loc, real); + result.setImaginary(rewriter, loc, imag); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; +} // namespace + +void mlir::populateComplexToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // clang-format off + patterns.insert< + AddOpConversion, + CreateOpConversion, + ImOpConversion, + ReOpConversion, + SubOpConversion + >(converter); + // clang-format on +} + +namespace { +struct ConvertComplexToLLVMPass + : public ConvertComplexToLLVMBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertComplexToLLVMPass::runOnOperation() { + auto module = getOperation(); + + // Convert to the LLVM IR dialect using the converter defined above. + OwningRewritePatternList patterns; + LLVMTypeConverter converter(&getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + populateComplexToLLVMConversionPatterns(converter, patterns); + + LLVMConversionTarget target(getContext()); + target.addLegalOp(); + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> +mlir::createConvertComplexToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -19,6 +19,10 @@ template void registerDialect(DialectRegistry ®istry); +namespace complex { +class ComplexDialect; +} // end namespace complex + namespace gpu { class GPUDialect; class GPUModuleOp; diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(ArmSVE) add_subdirectory(Async) add_subdirectory(AVX512) +add_subdirectory(Complex) add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) diff --git a/mlir/lib/Dialect/Complex/CMakeLists.txt b/mlir/lib/Dialect/Complex/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Complex/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRComplex + ComplexOps.cpp + ComplexDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Complex + + DEPENDS + MLIRComplexOpsIncGen + + LINK_LIBS PUBLIC + MLIRDialect + MLIRIR + ) diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp @@ -0,0 +1,16 @@ +//===- ComplexDialect.cpp - MLIR Complex Dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Complex/IR/Complex.h" + +void mlir::complex::ComplexDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -0,0 +1,19 @@ +//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Complex/IR/Complex.h" + +using namespace mlir; +using namespace mlir::complex; + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm | FileCheck %s + +// CHECK-LABEL: llvm.func @complex_numbers() +// CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32 +// CHECK-NEXT: %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : f32 +// CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)> +// CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)> +// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)> +// CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)> +// CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)> +// CHECK-NEXT: llvm.return +func @complex_numbers() { + %real0 = constant 1.2 : f32 + %imag0 = constant 3.4 : f32 + %cplx2 = complex.create %real0, %imag0 : complex + %real1 = complex.re%cplx2 : complex + %imag1 = complex.im %cplx2 : complex + return +} + +// CHECK-LABEL: llvm.func @complex_addition() +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : f64 +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : f64 +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> +func @complex_addition() { + %a_re = constant 1.2 : f64 + %a_im = constant 3.4 : f64 + %a = complex.create %a_re, %a_im : complex + %b_re = constant 5.6 : f64 + %b_im = constant 7.8 : f64 + %b = complex.create %b_re, %b_im : complex + %c = complex.add %a, %b : complex + return +} + +// CHECK-LABEL: llvm.func @complex_substraction() +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : f64 +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : f64 +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> +func @complex_substraction() { + %a_re = constant 1.2 : f64 + %a_im = constant 3.4 : f64 + %a = complex.create %a_re, %a_im : complex + %b_re = constant 5.6 : f64 + %b_im = constant 7.8 : f64 + %b = complex.create %b_re, %b_im : complex + %c = complex.sub %a, %b : complex + return +} diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + + +// CHECK-LABEL: func @ops( +// CHECK-SAME: [[F:%.*]]: f32) { +func @ops(%f: f32) { + // CHECK: [[C:%.*]] = complex.create [[F]], [[F]] : complex + %complex = complex.create %f, %f : complex + + // CHECK: complex.re [[C]] : complex + %real = complex.re %complex : complex + + // CHECK: complex.im [[C]] : complex + %imag = complex.im %complex : complex + + // CHECK: complex.add [[C]], [[C]] : complex + %sum = complex.add %complex, %complex : complex + + // CHECK: complex.sub [[C]], [[C]] : complex + %diff = complex.sub %complex, %complex : complex + return +} + diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -6,6 +6,7 @@ // CHECK-NEXT: arm_sve // CHECK-NEXT: async // CHECK-NEXT: avx512 +// CHECK-NEXT: complex // CHECK-NEXT: gpu // CHECK-NEXT: linalg // CHECK-NEXT: llvm