diff --git a/flang/include/flang/Optimizer/Builder/Complex.h b/flang/include/flang/Optimizer/Builder/Complex.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Builder/Complex.h @@ -0,0 +1,88 @@ +//===-- Complex.h -- lowering of complex values -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H +#define FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H + +#include "flang/Optimizer/Builder/FIRBuilder.h" + +namespace fir::factory { + +/// Helper to facilitate lowering of COMPLEX manipulations in FIR. +class Complex { +public: + explicit Complex(FirOpBuilder &builder, mlir::Location loc) + : builder(builder), loc(loc) {} + Complex(const Complex &) = delete; + + // The values of part enum members are meaningful for + // InsertValueOp and ExtractValueOp so they are explicit. + enum class Part { Real = 0, Imag = 1 }; + + /// Get the Complex Type. Determine the type. Do not create MLIR operations. + mlir::Type getComplexPartType(mlir::Value cplx); + mlir::Type getComplexPartType(mlir::Type complexType); + + /// Complex operation creation. They create MLIR operations. + mlir::Value createComplex(fir::KindTy kind, mlir::Value real, + mlir::Value imag); + + /// Create a complex value. + mlir::Value createComplex(mlir::Type complexType, mlir::Value real, + mlir::Value imag); + + mlir::Value extractComplexPart(mlir::Value cplx, bool isImagPart) { + return isImagPart ? extract(cplx) : extract(cplx); + } + + /// Returns (Real, Imag) pair of \p cplx + std::pair extractParts(mlir::Value cplx) { + return {extract(cplx), extract(cplx)}; + } + + mlir::Value insertComplexPart(mlir::Value cplx, mlir::Value part, + bool isImagPart) { + return isImagPart ? insert(cplx, part) + : insert(cplx, part); + } + +protected: + template + mlir::Value extract(mlir::Value cplx) { + return builder.create( + loc, getComplexPartType(cplx), cplx, + builder.getArrayAttr({builder.getIntegerAttr( + builder.getIndexType(), static_cast(partId))})); + } + + template + mlir::Value insert(mlir::Value cplx, mlir::Value part) { + return builder.create( + loc, cplx.getType(), cplx, part, + builder.getArrayAttr({builder.getIntegerAttr( + builder.getIndexType(), static_cast(partId))})); + } + + template + mlir::Value createPartId() { + return builder.createIntegerConstant(loc, builder.getIndexType(), + static_cast(partId)); + } + +private: + FirOpBuilder &builder; + mlir::Location loc; +}; + +} // namespace fir::factory + +#endif // FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -57,6 +57,13 @@ /// Get a reference to the kind map. const fir::KindMapping &getKindMap() { return kindMap; } + /// The LHS and RHS are not always in agreement in terms of + /// type. In some cases, the disagreement is between COMPLEX and other scalar + /// types. In that case, the conversion must insert/extract out of a COMPLEX + /// value to have the proper semantics and be strongly typed. + mlir::Value convertWithSemantics(mlir::Location loc, mlir::Type toTy, + mlir::Value val); + /// Get the entry block of the current Function mlir::Block *getEntryBlock() { return &getFunction().front(); } diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -3,6 +3,7 @@ add_flang_library(FIRBuilder BoxValue.cpp Character.cpp + Complex.cpp DoLoopHelper.cpp FIRBuilder.cpp MutableBox.cpp diff --git a/flang/lib/Optimizer/Builder/Complex.cpp b/flang/lib/Optimizer/Builder/Complex.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Builder/Complex.cpp @@ -0,0 +1,36 @@ +//===-- Complex.cpp -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/Complex.h" + +//===----------------------------------------------------------------------===// +// Complex Factory implementation +//===----------------------------------------------------------------------===// + +mlir::Type fir::factory::Complex::getComplexPartType(mlir::Type complexType) { + return builder.getRealType(complexType.cast().getFKind()); +} + +mlir::Type fir::factory::Complex::getComplexPartType(mlir::Value cplx) { + return getComplexPartType(cplx.getType()); +} + +mlir::Value fir::factory::Complex::createComplex(fir::KindTy kind, + mlir::Value real, + mlir::Value imag) { + auto complexTy = fir::ComplexType::get(builder.getContext(), kind); + mlir::Value und = builder.create(loc, complexTy); + return insert(insert(und, real), imag); +} + +mlir::Value fir::factory::Complex::createComplex(mlir::Type cplxTy, + mlir::Value real, + mlir::Value imag) { + mlir::Value und = builder.create(loc, cplxTy); + return insert(insert(und, real), imag); +} diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -9,6 +9,7 @@ #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/Character.h" +#include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/MutableBox.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Support/FatalError.h" @@ -257,6 +258,33 @@ return glob; } +mlir::Value fir::FirOpBuilder::convertWithSemantics(mlir::Location loc, + mlir::Type toTy, + mlir::Value val) { + assert(toTy && "store location must be typed"); + auto fromTy = val.getType(); + if (fromTy == toTy) + return val; + fir::factory::Complex helper{*this, loc}; + if ((fir::isa_real(fromTy) || fir::isa_integer(fromTy)) && + fir::isa_complex(toTy)) { + // imaginary part is zero + auto eleTy = helper.getComplexPartType(toTy); + auto cast = createConvert(loc, eleTy, val); + llvm::APFloat zero{ + kindMap.getFloatSemantics(toTy.cast().getFKind()), 0}; + auto imag = createRealConstant(loc, eleTy, zero); + return helper.createComplex(toTy, cast, imag); + } + if (fir::isa_complex(fromTy) && + (fir::isa_integer(toTy) || fir::isa_real(toTy))) { + // drop the imaginary part + auto rp = helper.extractComplexPart(val, /*isImagPart=*/false); + return createConvert(loc, toTy, rp); + } + return createConvert(loc, toTy, val); +} + mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc, mlir::Type toTy, mlir::Value val) { if (val.getType() != toTy) { diff --git a/flang/unittests/Optimizer/Builder/ComplexTest.cpp b/flang/unittests/Optimizer/Builder/ComplexTest.cpp new file mode 100644 --- /dev/null +++ b/flang/unittests/Optimizer/Builder/ComplexTest.cpp @@ -0,0 +1,101 @@ +//===- ComplexExprTest.cpp -- ComplexExpr unit tests ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/Complex.h" +#include "gtest/gtest.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Support/InitFIR.h" +#include "flang/Optimizer/Support/KindMapping.h" + +struct ComplexTest : public testing::Test { +public: + void SetUp() override { + mlir::OpBuilder builder(&context); + auto loc = builder.getUnknownLoc(); + + // Set up a Module with a dummy function operation inside. + // Set the insertion point in the function entry block. + mlir::ModuleOp mod = builder.create(loc); + mlir::FuncOp func = mlir::FuncOp::create( + loc, "func1", builder.getFunctionType(llvm::None, llvm::None)); + auto *entryBlock = func.addEntryBlock(); + mod.push_back(mod); + builder.setInsertionPointToStart(entryBlock); + + fir::support::loadDialects(context); + kindMap = std::make_unique(&context); + firBuilder = std::make_unique(mod, *kindMap); + helper = std::make_unique(*firBuilder, loc); + + // Init commonly used types + realTy1 = mlir::FloatType::getF32(&context); + complexTy1 = fir::ComplexType::get(&context, 4); + integerTy1 = mlir::IntegerType::get(&context, 32); + + // Create commonly used reals + rOne = firBuilder->createRealConstant(loc, realTy1, 1u); + rTwo = firBuilder->createRealConstant(loc, realTy1, 2u); + rThree = firBuilder->createRealConstant(loc, realTy1, 3u); + rFour = firBuilder->createRealConstant(loc, realTy1, 4u); + } + + + mlir::MLIRContext context; + std::unique_ptr kindMap; + std::unique_ptr firBuilder; + std::unique_ptr helper; + + // Commonly used real/complex/integer types + mlir::FloatType realTy1; + fir::ComplexType complexTy1; + mlir::IntegerType integerTy1; + + // Commonly used real numbers + mlir::Value rOne; + mlir::Value rTwo; + mlir::Value rThree; + mlir::Value rFour; +}; + +TEST_F(ComplexTest, verifyTypes) { + mlir::Value cVal1 = helper->createComplex(complexTy1, rOne, rTwo); + mlir::Value cVal2 = helper->createComplex(4, rOne, rTwo); + EXPECT_TRUE(fir::isa_complex(cVal1.getType())); + EXPECT_TRUE(fir::isa_complex(cVal2.getType())); + EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal1))); + EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal2))); + + mlir::Value real1 = helper->extractComplexPart(cVal1, /*isImagPart=*/false); + mlir::Value imag1 = helper->extractComplexPart(cVal1, /*isImagPart=*/true); + mlir::Value real2 = helper->extractComplexPart(cVal2, /*isImagPart=*/false); + mlir::Value imag2 = helper->extractComplexPart(cVal2, /*isImagPart=*/true); + EXPECT_EQ(realTy1, real1.getType()); + EXPECT_EQ(realTy1, imag1.getType()); + EXPECT_EQ(realTy1, real2.getType()); + EXPECT_EQ(realTy1, imag2.getType()); + + mlir::Value cVal3 = + helper->insertComplexPart(cVal1, rThree, /*isImagPart=*/false); + mlir::Value cVal4 = + helper->insertComplexPart(cVal3, rFour, /*isImagPart=*/true); + EXPECT_TRUE(fir::isa_complex(cVal4.getType())); + EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal4))); +} + +TEST_F(ComplexTest, verifyConvertWithSemantics) { + auto loc = firBuilder->getUnknownLoc(); + rOne = firBuilder->createRealConstant(loc, realTy1, 1u); + // Convert real to complex + mlir::Value v1 = firBuilder->convertWithSemantics(loc, complexTy1, rOne); + EXPECT_TRUE(fir::isa_complex(v1.getType())); + + // Convert complex to integer + mlir::Value v2 = firBuilder->convertWithSemantics(loc, integerTy1, v1); + EXPECT_TRUE(v2.getType().isa()); + EXPECT_TRUE(mlir::dyn_cast(v2.getDefiningOp())); +} diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt --- a/flang/unittests/Optimizer/CMakeLists.txt +++ b/flang/unittests/Optimizer/CMakeLists.txt @@ -10,6 +10,7 @@ add_flang_unittest(FlangOptimizerTests Builder/CharacterTest.cpp + Builder/ComplexTest.cpp Builder/DoLoopHelperTest.cpp Builder/FIRBuilderTest.cpp FIRContextTest.cpp