diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -0,0 +1,90 @@ +//===-- HLFIRTools.h -- HLFIR tools -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H +#define FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H + +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Dialect/FortranVariableInterface.h" +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" + +namespace fir { +class FirOpBuilder; +} + +namespace hlfir { + +/// Is this an SSA value type for the value of a Fortran expression? +inline bool isFortranValueType(mlir::Type type) { + return type.isa() || fir::isa_trivial(type); +} + +/// Is this the value of a Fortran expression in an SSA value form? +inline bool isFortranValue(mlir::Value value) { + return isFortranValueType(value.getType()); +} + +/// Is this a Fortran variable? +/// Note that by "variable", it must be understood that the mlir::Value is +/// a memory value of a storage that can be reason about as a Fortran object +/// (its bounds, shape, and type parameters, if any, are retrievable). +/// This does not imply that the mlir::Value points to a variable from the +/// original source or can be legally defined: temporaries created to store +/// expression values are considered to be variables, and so are PARAMETERs +/// global constant address. +inline bool isFortranVariable(mlir::Value value) { + return value.getDefiningOp(); +} + +/// Is this a Fortran variable or expression value? +inline bool isFortranEntity(mlir::Value value) { + return isFortranValue(value) || isFortranVariable(value); +} + +/// Wrapper over an mlir::Value that can be viewed as a Fortran entity. +/// This provides some Fortran specific helpers as well as a guarantee +/// in the compiler source that a certain mlir::Value must be a Fortran +/// entity. +class FortranEntity : public mlir::Value { +public: + explicit FortranEntity(mlir::Value value) : mlir::Value(value) { + assert(isFortranEntity(value) && + "must be a value representing a Fortran value or variable"); + } + FortranEntity(fir::FortranVariableOpInterface variable) + : mlir::Value(variable.getBase()) {} + bool isValue() const { return isFortranValue(*this); } + bool isVariable() const { return !isValue(); } + fir::FortranVariableOpInterface getIfVariable() const { + return this->getDefiningOp(); + } + mlir::Value getBase() const { return *this; } +}; + +/// Functions to translate hlfir::FortranEntity to fir::ExtendedValue. +/// For Fortran arrays, character, and derived type values, this require +/// allocating a storage since these can only be represented in memory in FIR. +/// In that case, a cleanup function is provided to generate the finalization +/// code after the end of the fir::ExtendedValue use. +using CleanupFunction = std::function; +std::pair> +translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder, + FortranEntity entity); + +/// Function to translate FortranVariableOpInterface to fir::ExtendedValue. +/// It does not generate any IR, and is a simple packaging operation. +fir::ExtendedValue +translateToExtendedValue(fir::FortranVariableOpInterface fortranVariable); + +} // namespace hlfir + +#endif // FORTRAN_OPTIMIZER_BUILDER_BOXVALUE_H diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -6,6 +6,7 @@ Complex.cpp DoLoopHelper.cpp FIRBuilder.cpp + HLFIRTools.cpp LowLevelIntrinsics.cpp MutableBox.cpp Runtime/Assign.cpp diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -0,0 +1,100 @@ +//===-- HLFIRTools.cpp +//------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Tools to manipulate HLFIR variable and expressions +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/HLFIRTools.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/Todo.h" + +// Return explicit extents. If the base is a fir.box, this won't read it to +// return the extents and will instead return an empty vector. +static llvm::SmallVector +getExplicitExtents(fir::FortranVariableOpInterface var) { + llvm::SmallVector result; + if (mlir::Value shape = var.getShape()) { + auto *shapeOp = shape.getDefiningOp(); + if (auto s = mlir::dyn_cast_or_null(shapeOp)) { + auto e = s.getExtents(); + result.append(e.begin(), e.end()); + } else if (auto s = mlir::dyn_cast_or_null(shapeOp)) { + auto e = s.getExtents(); + result.append(e.begin(), e.end()); + } else if (mlir::dyn_cast_or_null(shapeOp)) { + return {}; + } else { + TODO(var->getLoc(), "read fir.shape to get extents"); + } + } + return result; +} + +// Return explicit lower bounds. For pointers and allocatables, this will not +// read the lower bounds and instead return an empty vector. +static llvm::SmallVector +getExplicitLbounds(fir::FortranVariableOpInterface var) { + llvm::SmallVector result; + if (mlir::Value shape = var.getShape()) { + auto *shapeOp = shape.getDefiningOp(); + if (auto s = mlir::dyn_cast_or_null(shapeOp)) { + return {}; + } else if (auto s = mlir::dyn_cast_or_null(shapeOp)) { + auto e = s.getOrigins(); + result.append(e.begin(), e.end()); + } else if (auto s = mlir::dyn_cast_or_null(shapeOp)) { + auto e = s.getOrigins(); + result.append(e.begin(), e.end()); + } else { + TODO(var->getLoc(), "read fir.shape to get lower bounds"); + } + } + return result; +} + +static llvm::SmallVector +getExplicitTypeParams(fir::FortranVariableOpInterface var) { + llvm::SmallVector res; + mlir::OperandRange range = var.getExplicitTypeParams(); + res.append(range.begin(), range.end()); + return res; +} + +std::pair> +hlfir::translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &, + hlfir::FortranEntity entity) { + if (auto variable = entity.getIfVariable()) + return {hlfir::translateToExtendedValue(variable), {}}; + if (entity.getType().isa()) + TODO(loc, "hlfir.expr to fir::ExtendedValue"); // use hlfir.associate + return {{static_cast(entity)}, {}}; +} + +fir::ExtendedValue +hlfir::translateToExtendedValue(fir::FortranVariableOpInterface variable) { + if (variable.isPointer() || variable.isAllocatable()) + TODO(variable->getLoc(), "pointer or allocatable " + "FortranVariableOpInterface to extendedValue"); + if (variable.getBase().getType().isa()) + return fir::BoxValue(variable.getBase(), getExplicitLbounds(variable), + getExplicitTypeParams(variable), + getExplicitExtents(variable)); + if (variable.isCharacter()) { + if (variable.isArray()) + return fir::CharArrayBoxValue( + variable.getBase(), variable.getExplicitCharLen(), + getExplicitExtents(variable), getExplicitLbounds(variable)); + return fir::CharBoxValue(variable.getBase(), variable.getExplicitCharLen()); + } + if (variable.isArray()) + return fir::ArrayBoxValue(variable.getBase(), getExplicitExtents(variable), + getExplicitLbounds(variable)); + return variable.getBase(); +} diff --git a/flang/unittests/Optimizer/Builder/HLFIRToolsTest.cpp b/flang/unittests/Optimizer/Builder/HLFIRToolsTest.cpp new file mode 100644 --- /dev/null +++ b/flang/unittests/Optimizer/Builder/HLFIRToolsTest.cpp @@ -0,0 +1,218 @@ +//===- HLFIRTools.cpp -- HLFIR tools unit tests ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/HLFIRTools.h" +#include "gtest/gtest.h" +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Support/InitFIR.h" +#include "flang/Optimizer/Support/KindMapping.h" + +struct HLFIRToolsTest : public testing::Test { +public: + void SetUp() override { + fir::support::loadDialects(context); + + llvm::ArrayRef defs; + fir::KindMapping kindMap(&context, defs); + mlir::OpBuilder builder(&context); + auto loc = builder.getUnknownLoc(); + + // Set up a Module with a dummy function operation inside. + // Set the insertion point in the function entry block. + mlir::ModuleOp mod = builder.create(loc); + mlir::func::FuncOp func = mlir::func::FuncOp::create( + loc, "func1", builder.getFunctionType(llvm::None, llvm::None)); + auto *entryBlock = func.addEntryBlock(); + mod.push_back(mod); + builder.setInsertionPointToStart(entryBlock); + + firBuilder = std::make_unique(mod, kindMap); + } + + mlir::Value createDeclare(fir::ExtendedValue exv) { + mlir::Value addr = fir::getBase(exv); + mlir::Location loc = getLoc(); + mlir::Value shape; + if (exv.rank() > 0) + shape = firBuilder->createShape(loc, exv); + llvm::SmallVector typeParams; + exv.match( + [&](const fir::CharBoxValue &x) { + typeParams.emplace_back(x.getLen()); + }, + [&](const fir::CharArrayBoxValue &x) { + typeParams.emplace_back(x.getLen()); + }, + [&](const fir::BoxValue &x) { + typeParams.append(x.getExplicitParameters().begin(), + x.getExplicitParameters().end()); + }, + [&](const fir::MutableBoxValue &x) { + typeParams.append( + x.nonDeferredLenParams().begin(), x.nonDeferredLenParams().end()); + }, + [](const auto &) {}); + auto name = + mlir::StringAttr::get(&context, "x" + std::to_string(varCounter++)); + return firBuilder->create(loc, addr.getType(), addr, shape, + typeParams, name, + /*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); + } + + mlir::Value createConstant(std::int64_t cst) { + mlir::Type indexType = firBuilder->getIndexType(); + return firBuilder->create( + getLoc(), indexType, firBuilder->getIntegerAttr(indexType, cst)); + } + + mlir::Location getLoc() { return firBuilder->getUnknownLoc(); } + fir::FirOpBuilder &getBuilder() { return *firBuilder; } + + int varCounter = 0; + mlir::MLIRContext context; + std::unique_ptr firBuilder; +}; + +TEST_F(HLFIRToolsTest, testScalarRoundTrip) { + auto &builder = getBuilder(); + mlir::Location loc = getLoc(); + mlir::Type f32Type = mlir::FloatType::getF32(&context); + mlir::Type scalarf32Type = builder.getRefType(f32Type); + mlir::Value scalarf32Addr = builder.create(loc, scalarf32Type); + fir::ExtendedValue scalarf32{scalarf32Addr}; + hlfir::FortranEntity scalarf32Entity(createDeclare(scalarf32)); + auto [scalarf32Result, cleanup] = + hlfir::translateToExtendedValue(loc, builder, scalarf32Entity); + auto *unboxed = scalarf32Result.getUnboxed(); + EXPECT_FALSE(cleanup.has_value()); + ASSERT_NE(unboxed, nullptr); + EXPECT_TRUE(*unboxed == scalarf32Entity.getBase()); + EXPECT_TRUE(scalarf32Entity.isVariable()); + EXPECT_FALSE(scalarf32Entity.isValue()); +} + +TEST_F(HLFIRToolsTest, testArrayRoundTrip) { + auto &builder = getBuilder(); + mlir::Location loc = getLoc(); + llvm::SmallVector extents{ + createConstant(20), createConstant(30)}; + llvm::SmallVector lbounds{ + createConstant(-1), createConstant(-2)}; + + mlir::Type f32Type = mlir::FloatType::getF32(&context); + mlir::Type seqf32Type = builder.getVarLenSeqTy(f32Type, 2); + mlir::Type arrayf32Type = builder.getRefType(seqf32Type); + mlir::Value arrayf32Addr = builder.create(loc, arrayf32Type); + fir::ArrayBoxValue arrayf32{arrayf32Addr, extents, lbounds}; + hlfir::FortranEntity arrayf32Entity(createDeclare(arrayf32)); + auto [arrayf32Result, cleanup] = + hlfir::translateToExtendedValue(loc, builder, arrayf32Entity); + auto *res = arrayf32Result.getBoxOf(); + EXPECT_FALSE(cleanup.has_value()); + ASSERT_NE(res, nullptr); + // gtest has a terrible time printing mlir::Value in case of failing + // EXPECT_EQ(mlir::Value, mlir::Value). So use EXPECT_TRUE instead. + EXPECT_TRUE(fir::getBase(*res) == arrayf32Entity.getBase()); + ASSERT_EQ(res->getExtents().size(), arrayf32.getExtents().size()); + for (unsigned i = 0; i < arrayf32.getExtents().size(); ++i) + EXPECT_TRUE(res->getExtents()[i] == arrayf32.getExtents()[i]); + ASSERT_EQ(res->getLBounds().size(), arrayf32.getLBounds().size()); + for (unsigned i = 0; i < arrayf32.getLBounds().size(); ++i) + EXPECT_TRUE(res->getLBounds()[i] == arrayf32.getLBounds()[i]); + EXPECT_TRUE(arrayf32Entity.isVariable()); + EXPECT_FALSE(arrayf32Entity.isValue()); +} + +TEST_F(HLFIRToolsTest, testScalarCharRoundTrip) { + auto &builder = getBuilder(); + mlir::Location loc = getLoc(); + mlir::Value len = createConstant(42); + mlir::Type charType = fir::CharacterType::getUnknownLen(&context, 1); + mlir::Type scalarCharType = builder.getRefType(charType); + mlir::Value scalarCharAddr = + builder.create(loc, scalarCharType); + fir::CharBoxValue scalarChar{scalarCharAddr, len}; + hlfir::FortranEntity scalarCharEntity(createDeclare(scalarChar)); + auto [scalarCharResult, cleanup] = + hlfir::translateToExtendedValue(loc, builder, scalarCharEntity); + auto *res = scalarCharResult.getBoxOf(); + EXPECT_FALSE(cleanup.has_value()); + ASSERT_NE(res, nullptr); + EXPECT_TRUE(fir::getBase(*res) == scalarCharEntity.getBase()); + EXPECT_TRUE(res->getLen() == scalarChar.getLen()); + EXPECT_TRUE(scalarCharEntity.isVariable()); + EXPECT_FALSE(scalarCharEntity.isValue()); +} + +TEST_F(HLFIRToolsTest, testArrayCharRoundTrip) { + auto &builder = getBuilder(); + mlir::Location loc = getLoc(); + llvm::SmallVector extents{ + createConstant(20), createConstant(30)}; + llvm::SmallVector lbounds{ + createConstant(-1), createConstant(-2)}; + mlir::Value len = createConstant(42); + mlir::Type charType = fir::CharacterType::getUnknownLen(&context, 1); + mlir::Type seqCharType = builder.getVarLenSeqTy(charType, 2); + mlir::Type arrayCharType = builder.getRefType(seqCharType); + mlir::Value arrayCharAddr = builder.create(loc, arrayCharType); + fir::CharArrayBoxValue arrayChar{arrayCharAddr, len, extents, lbounds}; + hlfir::FortranEntity arrayCharEntity(createDeclare(arrayChar)); + auto [arrayCharResult, cleanup] = + hlfir::translateToExtendedValue(loc, builder, arrayCharEntity); + auto *res = arrayCharResult.getBoxOf(); + EXPECT_FALSE(cleanup.has_value()); + ASSERT_NE(res, nullptr); + // gtest has a terrible time printing mlir::Value in case of failing + // EXPECT_EQ(mlir::Value, mlir::Value). So use EXPECT_TRUE instead. + EXPECT_TRUE(fir::getBase(*res) == arrayCharEntity.getBase()); + EXPECT_TRUE(res->getLen() == arrayChar.getLen()); + ASSERT_EQ(res->getExtents().size(), arrayChar.getExtents().size()); + for (unsigned i = 0; i < arrayChar.getExtents().size(); ++i) + EXPECT_TRUE(res->getExtents()[i] == arrayChar.getExtents()[i]); + ASSERT_EQ(res->getLBounds().size(), arrayChar.getLBounds().size()); + for (unsigned i = 0; i < arrayChar.getLBounds().size(); ++i) + EXPECT_TRUE(res->getLBounds()[i] == arrayChar.getLBounds()[i]); + EXPECT_TRUE(arrayCharEntity.isVariable()); + EXPECT_FALSE(arrayCharEntity.isValue()); +} + +TEST_F(HLFIRToolsTest, testArrayCharBoxRoundTrip) { + auto &builder = getBuilder(); + mlir::Location loc = getLoc(); + llvm::SmallVector lbounds{ + createConstant(-1), createConstant(-2)}; + mlir::Value len = createConstant(42); + mlir::Type charType = fir::CharacterType::getUnknownLen(&context, 1); + mlir::Type seqCharType = builder.getVarLenSeqTy(charType, 2); + mlir::Type arrayCharBoxType = fir::BoxType::get(seqCharType); + mlir::Value arrayCharAddr = + builder.create(loc, arrayCharBoxType); + llvm::SmallVector explicitTypeParams{len}; + fir::BoxValue arrayChar{arrayCharAddr, lbounds, explicitTypeParams}; + hlfir::FortranEntity arrayCharEntity(createDeclare(arrayChar)); + auto [arrayCharResult, cleanup] = + hlfir::translateToExtendedValue(loc, builder, arrayCharEntity); + auto *res = arrayCharResult.getBoxOf(); + EXPECT_FALSE(cleanup.has_value()); + ASSERT_NE(res, nullptr); + // gtest has a terrible time printing mlir::Value in case of failing + // EXPECT_EQ(mlir::Value, mlir::Value). So use EXPECT_TRUE instead. + EXPECT_TRUE(fir::getBase(*res) == arrayCharEntity.getBase()); + ASSERT_EQ(res->getExplicitParameters().size(), + arrayChar.getExplicitParameters().size()); + for (unsigned i = 0; i < arrayChar.getExplicitParameters().size(); ++i) + EXPECT_TRUE(res->getExplicitParameters()[i] == + arrayChar.getExplicitParameters()[i]); + ASSERT_EQ(res->getLBounds().size(), arrayChar.getLBounds().size()); + for (unsigned i = 0; i < arrayChar.getLBounds().size(); ++i) + EXPECT_TRUE(res->getLBounds()[i] == arrayChar.getLBounds()[i]); + EXPECT_TRUE(arrayCharEntity.isVariable()); + EXPECT_FALSE(arrayCharEntity.isValue()); +} diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt --- a/flang/unittests/Optimizer/CMakeLists.txt +++ b/flang/unittests/Optimizer/CMakeLists.txt @@ -14,6 +14,7 @@ Builder/ComplexTest.cpp Builder/DoLoopHelperTest.cpp Builder/FIRBuilderTest.cpp + Builder/HLFIRToolsTest.cpp Builder/Runtime/AssignTest.cpp Builder/Runtime/CommandTest.cpp Builder/Runtime/CharacterTest.cpp