Index: flang/include/flang/Optimizer/Support/KindMapping.h =================================================================== --- flang/include/flang/Optimizer/Support/KindMapping.h +++ flang/include/flang/Optimizer/Support/KindMapping.h @@ -5,6 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #ifndef OPTIMIZER_SUPPORT_KINDMAPPING_H #define OPTIMIZER_SUPPORT_KINDMAPPING_H @@ -36,7 +40,8 @@ /// 'c' : COMPLEX (encoding value) /// /// kind-value is either an unsigned integer (for 'i', 'l', and 'a') or one of -/// 'Half', 'Float', 'Double', 'X86_FP80', or 'FP128' (for 'r' and 'c'). +/// 'Half', 'BFloat', 'Float', 'Double', 'X86_FP80', or 'FP128' (for 'r' and +/// 'c'). /// /// If LLVM adds support for new floating-point types, the final list should be /// extended. @@ -47,8 +52,10 @@ using LLVMTypeID = llvm::Type::TypeID; using MatchResult = mlir::ParseResult; - explicit KindMapping(mlir::MLIRContext *context); - explicit KindMapping(mlir::MLIRContext *context, llvm::StringRef map); + explicit KindMapping(mlir::MLIRContext *context, + llvm::ArrayRef defs = llvm::None); + explicit KindMapping(mlir::MLIRContext *context, llvm::StringRef map, + llvm::ArrayRef defs = llvm::None); /// Get the size in bits of !fir.char Bitsize getCharacterBitsize(KindTy kind) const; @@ -73,13 +80,26 @@ /// Get the float semantics of !fir.real const llvm::fltSemantics &getFloatSemantics(KindTy kind) const; + //===--------------------------------------------------------------------===// + // Default kinds of intrinsic types + //===--------------------------------------------------------------------===// + + KindTy defaultCharacterKind() const; + KindTy defaultComplexKind() const; + KindTy defaultDoubleKind() const; + KindTy defaultIntegerKind() const; + KindTy defaultLogicalKind() const; + KindTy defaultRealKind() const; + private: MatchResult badMapString(const llvm::Twine &ptr); MatchResult parse(llvm::StringRef kindMap); + mlir::LogicalResult setDefaultKinds(llvm::ArrayRef defs); mlir::MLIRContext *context; llvm::DenseMap, Bitsize> intMap; llvm::DenseMap, LLVMTypeID> floatMap; + llvm::DenseMap defaultMap; }; } // namespace fir Index: flang/lib/Optimizer/Support/KindMapping.cpp =================================================================== --- flang/lib/Optimizer/Support/KindMapping.cpp +++ flang/lib/Optimizer/Support/KindMapping.cpp @@ -5,6 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #include "flang/Optimizer/Support/KindMapping.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -32,12 +36,14 @@ } /// Floating-point types default to the kind value being the size of the value -/// in bytes. The default is to translate kinds of 2, 4, 8, 10, and 16 to a +/// in bytes. The default is to translate kinds of 2, 3, 4, 8, 10, and 16 to a /// valid llvm::Type::TypeID value. Otherwise, the default is FloatTyID. static LLVMTypeID defaultRealKind(KindTy kind) { switch (kind) { case 2: return LLVMTypeID::HalfTyID; + case 3: + return LLVMTypeID::BFloatTyID; case 4: return LLVMTypeID::FloatTyID; case 8: @@ -81,6 +87,8 @@ switch (doLookup(defaultRealKind, map, kind)) { case LLVMTypeID::HalfTyID: return llvm::APFloat::IEEEhalf(); + case LLVMTypeID::BFloatTyID: + return llvm::APFloat::BFloat(); case LLVMTypeID::FloatTyID: return llvm::APFloat::IEEEsingle(); case LLVMTypeID::DoubleTyID: @@ -148,6 +156,10 @@ result = LLVMTypeID::HalfTyID; return mlir::success(); } + if (mlir::succeeded(matchString(ptr, "BFloat"))) { + result = LLVMTypeID::BFloatTyID; + return mlir::success(); + } if (mlir::succeeded(matchString(ptr, "Float"))) { result = LLVMTypeID::FloatTyID; return mlir::success(); @@ -171,16 +183,23 @@ return mlir::failure(); } -fir::KindMapping::KindMapping(mlir::MLIRContext *context, llvm::StringRef map) +fir::KindMapping::KindMapping(mlir::MLIRContext *context, llvm::StringRef map, + llvm::ArrayRef defs) : context{context} { + if (mlir::failed(setDefaultKinds(defs))) { + mlir::emitError(mlir::UnknownLoc::get(context), "bad default kinds"); + return; + } if (mlir::failed(parse(map))) { + mlir::emitError(mlir::UnknownLoc::get(context), "could not parse kind map"); intMap.clear(); floatMap.clear(); } } -fir::KindMapping::KindMapping(mlir::MLIRContext *context) - : KindMapping{context, clKindMapping} {} +fir::KindMapping::KindMapping(mlir::MLIRContext *context, + llvm::ArrayRef defs) + : KindMapping{context, clKindMapping, defs} {} MatchResult fir::KindMapping::badMapString(const llvm::Twine &ptr) { auto unknown = mlir::UnknownLoc::get(context); @@ -248,3 +267,65 @@ fir::KindMapping::getFloatSemantics(KindTy kind) const { return getFloatSemanticsOfKind<'r'>(kind, floatMap); } + +mlir::LogicalResult +fir::KindMapping::setDefaultKinds(llvm::ArrayRef defs) { + if (defs.size() == 0) { + // generic front-end defaults + const KindTy genericKind = 4; + defaultMap.insert({'a', 1}); + defaultMap.insert({'c', genericKind}); + defaultMap.insert({'d', 2 * genericKind}); + defaultMap.insert({'i', genericKind}); + defaultMap.insert({'l', genericKind}); + defaultMap.insert({'r', genericKind}); + return mlir::success(); + } + if (defs.size() != 6) + return mlir::failure(); + + // defaults determined after command-line processing + defaultMap.insert({'a', defs[0]}); + defaultMap.insert({'c', defs[1]}); + defaultMap.insert({'d', defs[2]}); + defaultMap.insert({'i', defs[3]}); + defaultMap.insert({'l', defs[4]}); + defaultMap.insert({'r', defs[5]}); + return mlir::success(); +} + +KindTy fir::KindMapping::defaultCharacterKind() const { + auto iter = defaultMap.find('a'); + assert(iter != defaultMap.end()); + return iter->second; +} + +KindTy fir::KindMapping::defaultComplexKind() const { + auto iter = defaultMap.find('c'); + assert(iter != defaultMap.end()); + return iter->second; +} + +KindTy fir::KindMapping::defaultDoubleKind() const { + auto iter = defaultMap.find('d'); + assert(iter != defaultMap.end()); + return iter->second; +} + +KindTy fir::KindMapping::defaultIntegerKind() const { + auto iter = defaultMap.find('i'); + assert(iter != defaultMap.end()); + return iter->second; +} + +KindTy fir::KindMapping::defaultLogicalKind() const { + auto iter = defaultMap.find('l'); + assert(iter != defaultMap.end()); + return iter->second; +} + +KindTy fir::KindMapping::defaultRealKind() const { + auto iter = defaultMap.find('r'); + assert(iter != defaultMap.end()); + return iter->second; +} Index: flang/unittests/Optimizer/CMakeLists.txt =================================================================== --- flang/unittests/Optimizer/CMakeLists.txt +++ flang/unittests/Optimizer/CMakeLists.txt @@ -7,6 +7,7 @@ add_flang_unittest(FlangOptimizerTests InternalNamesTest.cpp + KindMappingTest.cpp ) target_link_libraries(FlangOptimizerTests PRIVATE Index: flang/unittests/Optimizer/KindMappingTest.cpp =================================================================== --- /dev/null +++ flang/unittests/Optimizer/KindMappingTest.cpp @@ -0,0 +1,161 @@ +//===- KindMappingTest.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Support/KindMapping.h" +#include "gtest/gtest.h" +#include + +using namespace fir; +namespace llvm { +struct fltSemantics; +} // namespace llvm + +namespace mlir { +class MLIRContext; +} // namespace mlir + +using Bitsize = fir::KindMapping::Bitsize; +using LLVMTypeID = fir::KindMapping::LLVMTypeID; + +struct DefaultStringTests : public testing::Test { +public: + void SetUp() { defaultString = new KindMapping(context); } + void TearDown() { delete defaultString; } + + KindMapping *defaultString{}; + mlir::MLIRContext *context{}; +}; + +struct commandLineStringTests : public testing::Test { +public: + void SetUp() { + commandLineString = new KindMapping(context, + "i10:80,l3:24,a1:8,r54:Double,c20:X86_FP80,r11:PPC_FP128," + "r12:FP128,r13:X86_FP80,r14:Double,r15:Float,r16:Half,r23:BFloat"); + clStringConflict = + new KindMapping(context, "i10:80,i10:40,r54:Double,r54:X86_FP80"); + } + void TearDown() { + delete commandLineString; + delete clStringConflict; + } + + KindMapping *commandLineString{}; + KindMapping *clStringConflict{}; + mlir::MLIRContext *context{}; +}; + +TEST_F(DefaultStringTests, getIntegerBitsizeTest) { + EXPECT_EQ(defaultString->getIntegerBitsize(10), 80u); + EXPECT_EQ(defaultString->getIntegerBitsize(0), 0u); +} + +TEST_F(DefaultStringTests, getCharacterBitsizeTest) { + EXPECT_EQ(defaultString->getCharacterBitsize(10), 80u); + EXPECT_EQ(defaultString->getCharacterBitsize(0), 0u); +} + +TEST_F(DefaultStringTests, getLogicalBitsizeTest) { + EXPECT_EQ(defaultString->getLogicalBitsize(10), 80u); + // Unsigned values are expected + std::string actual = std::to_string(defaultString->getLogicalBitsize(-10)); + std::string expect = "-80"; + EXPECT_NE(actual, expect); +} + +TEST_F(DefaultStringTests, getRealTypeIDTest) { + EXPECT_EQ(defaultString->getRealTypeID(2), LLVMTypeID::HalfTyID); + EXPECT_EQ(defaultString->getRealTypeID(3), LLVMTypeID::BFloatTyID); + EXPECT_EQ(defaultString->getRealTypeID(4), LLVMTypeID::FloatTyID); + EXPECT_EQ(defaultString->getRealTypeID(8), LLVMTypeID::DoubleTyID); + EXPECT_EQ(defaultString->getRealTypeID(10), LLVMTypeID::X86_FP80TyID); + EXPECT_EQ(defaultString->getRealTypeID(16), LLVMTypeID::FP128TyID); + // Default cases + EXPECT_EQ(defaultString->getRealTypeID(-1), LLVMTypeID::FloatTyID); + EXPECT_EQ(defaultString->getRealTypeID(1), LLVMTypeID::FloatTyID); +} + +TEST_F(DefaultStringTests, getComplexTypeIDTest) { + EXPECT_EQ(defaultString->getComplexTypeID(2), LLVMTypeID::HalfTyID); + EXPECT_EQ(defaultString->getComplexTypeID(3), LLVMTypeID::BFloatTyID); + EXPECT_EQ(defaultString->getComplexTypeID(4), LLVMTypeID::FloatTyID); + EXPECT_EQ(defaultString->getComplexTypeID(8), LLVMTypeID::DoubleTyID); + EXPECT_EQ(defaultString->getComplexTypeID(10), LLVMTypeID::X86_FP80TyID); + EXPECT_EQ(defaultString->getComplexTypeID(16), LLVMTypeID::FP128TyID); + // Default cases + EXPECT_EQ(defaultString->getComplexTypeID(-1), LLVMTypeID::FloatTyID); + EXPECT_EQ(defaultString->getComplexTypeID(1), LLVMTypeID::FloatTyID); +} + +TEST_F(DefaultStringTests, getFloatSemanticsTest) { + EXPECT_EQ(&defaultString->getFloatSemantics(2), &llvm::APFloat::IEEEhalf()); + EXPECT_EQ(&defaultString->getFloatSemantics(3), &llvm::APFloat::BFloat()); + EXPECT_EQ(&defaultString->getFloatSemantics(4), &llvm::APFloat::IEEEsingle()); + EXPECT_EQ(&defaultString->getFloatSemantics(8), &llvm::APFloat::IEEEdouble()); + EXPECT_EQ(&defaultString->getFloatSemantics(10), + &llvm::APFloat::x87DoubleExtended()); + EXPECT_EQ(&defaultString->getFloatSemantics(16), &llvm::APFloat::IEEEquad()); + + // Default cases + EXPECT_EQ( + &defaultString->getFloatSemantics(-1), &llvm::APFloat::IEEEsingle()); + EXPECT_EQ(&defaultString->getFloatSemantics(1), &llvm::APFloat::IEEEsingle()); +} + +TEST_F(commandLineStringTests, getIntegerBitsizeTest) { + // KEY is present in map. + EXPECT_EQ(commandLineString->getIntegerBitsize(10), 80u); + EXPECT_EQ(commandLineString->getCharacterBitsize(1), 8u); + EXPECT_EQ(commandLineString->getLogicalBitsize(3), 24u); + EXPECT_EQ(commandLineString->getComplexTypeID(20), LLVMTypeID::X86_FP80TyID); + EXPECT_EQ(commandLineString->getRealTypeID(54), LLVMTypeID::DoubleTyID); + EXPECT_EQ(commandLineString->getRealTypeID(11), LLVMTypeID::PPC_FP128TyID); + EXPECT_EQ(&commandLineString->getFloatSemantics(11), + &llvm::APFloat::PPCDoubleDouble()); + EXPECT_EQ( + &commandLineString->getFloatSemantics(12), &llvm::APFloat::IEEEquad()); + EXPECT_EQ(&commandLineString->getFloatSemantics(13), + &llvm::APFloat::x87DoubleExtended()); + EXPECT_EQ( + &commandLineString->getFloatSemantics(14), &llvm::APFloat::IEEEdouble()); + EXPECT_EQ( + &commandLineString->getFloatSemantics(15), &llvm::APFloat::IEEEsingle()); + EXPECT_EQ( + &commandLineString->getFloatSemantics(16), &llvm::APFloat::IEEEhalf()); + EXPECT_EQ( + &commandLineString->getFloatSemantics(23), &llvm::APFloat::BFloat()); + + // Converts to default case + EXPECT_EQ( + &commandLineString->getFloatSemantics(20), &llvm::APFloat::IEEEsingle()); + + // KEY is absent from map, Default values are expected. + EXPECT_EQ(commandLineString->getIntegerBitsize(9), 72u); + EXPECT_EQ(commandLineString->getCharacterBitsize(9), 72u); + EXPECT_EQ(commandLineString->getLogicalBitsize(9), 72u); + EXPECT_EQ(commandLineString->getComplexTypeID(9), LLVMTypeID::FloatTyID); + EXPECT_EQ(commandLineString->getRealTypeID(9), LLVMTypeID::FloatTyID); + + // KEY repeats in map. + EXPECT_NE(clStringConflict->getIntegerBitsize(10), 80u); + EXPECT_NE(clStringConflict->getRealTypeID(10), LLVMTypeID::DoubleTyID); +} + +TEST(KindMappingDeathTests, mapTest) { + mlir::MLIRContext *context{}; + // Catch parsing errors + ASSERT_DEATH(new KindMapping(context, "r10:Double,r20:Doubl"), ""); + ASSERT_DEATH(new KindMapping(context, "10:Double"), ""); + ASSERT_DEATH(new KindMapping(context, "rr:Double"), ""); + ASSERT_DEATH(new KindMapping(context, "rr:"), ""); + ASSERT_DEATH(new KindMapping(context, "rr:Double MoreContent"), ""); + // length of 'size' > 10 + ASSERT_DEATH(new KindMapping(context, "i11111111111:10"), ""); +} + +// main() from gtest_main