diff --git a/flang/include/flang/Optimizer/Builder/BoxValue.h b/flang/include/flang/Optimizer/Builder/BoxValue.h --- a/flang/include/flang/Optimizer/Builder/BoxValue.h +++ b/flang/include/flang/Optimizer/Builder/BoxValue.h @@ -467,6 +467,30 @@ [](const fir::UnboxedValue &box) { return box ? true : false; }, [](const auto &) { return false; }); } + +/// Returns the base type of \p exv. This is the type of \p exv +/// without any memory or box type. The sequence type, if any, is kept. +inline mlir::Type getBaseTypeOf(const ExtendedValue &exv) { + return exv.match( + [](const fir::MutableBoxValue &box) { return box.getBaseTy(); }, + [](const fir::BoxValue &box) { return box.getBaseTy(); }, + [&](const auto &) { + return fir::unwrapRefType(fir::getBase(exv).getType()); + }); +} + +/// Return the scalar type of \p exv type. This removes all +/// reference, box, or sequence type from \p exv base. +inline mlir::Type getElementTypeOf(const ExtendedValue &exv) { + return fir::unwrapSequenceType(getBaseTypeOf(exv)); +} + +/// Is the extended value `exv` a derived type with length parameters ? +inline bool isDerivedWithLengthParameters(const ExtendedValue &exv) { + auto record = getElementTypeOf(exv).dyn_cast(); + return record && record.getNumLenParams() != 0; +} + } // namespace fir #endif // FORTRAN_OPTIMIZER_BUILDER_BOXVALUE_H diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -416,6 +416,11 @@ /// flang/include/flang/Runtime/ragged.h. mlir::TupleType getRaggedArrayHeaderType(fir::FirOpBuilder &builder); +/// Create the zero value of a given the numerical or logical \p type (`false` +/// for logical types). +mlir::Value createZeroValue(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type); + } // namespace fir::factory #endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -604,3 +604,22 @@ auto shTy = fir::HeapType::get(extTy); return mlir::TupleType::get(builder.getContext(), {i64Ty, buffTy, shTy}); } + +mlir::Value fir::factory::createZeroValue(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Type type) { + mlir::Type i1 = builder.getIntegerType(1); + if (type.isa() || type == i1) + return builder.createConvert(loc, type, builder.createBool(loc, false)); + if (fir::isa_integer(type)) + return builder.createIntegerConstant(loc, type, 0); + if (fir::isa_real(type)) + return builder.createRealZeroConstant(loc, type); + if (fir::isa_complex(type)) { + fir::factory::Complex complexHelper(builder, loc); + mlir::Type partType = complexHelper.getComplexPartType(type); + mlir::Value zeroPart = builder.createRealZeroConstant(loc, partType); + return complexHelper.createComplex(type, zeroPart, zeroPart); + } + fir::emitFatalError(loc, "internal: trying to generate zero value of non " + "numeric or logical type"); +} diff --git a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp --- a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp +++ b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp @@ -414,3 +414,114 @@ auto readExtents = fir::factory::getExtents(builder, loc, ex); EXPECT_EQ(2u, readExtents.size()); } + +TEST_F(FIRBuilderTest, createZeroValue) { + auto builder = getBuilder(); + auto loc = builder.getUnknownLoc(); + + mlir::Type i64Ty = mlir::IntegerType::get(builder.getContext(), 64); + mlir::Value zeroInt = fir::factory::createZeroValue(builder, loc, i64Ty); + EXPECT_TRUE(zeroInt.getType() == i64Ty); + auto cst = + mlir::dyn_cast_or_null(zeroInt.getDefiningOp()); + EXPECT_TRUE(cst); + auto intAttr = cst.getValue().dyn_cast(); + EXPECT_TRUE(intAttr && intAttr.getInt() == 0); + + mlir::Type f32Ty = mlir::FloatType::getF32(builder.getContext()); + mlir::Value zeroFloat = fir::factory::createZeroValue(builder, loc, f32Ty); + EXPECT_TRUE(zeroFloat.getType() == f32Ty); + auto cst2 = mlir::dyn_cast_or_null( + zeroFloat.getDefiningOp()); + EXPECT_TRUE(cst2); + auto floatAttr = cst2.getValue().dyn_cast(); + EXPECT_TRUE(floatAttr && floatAttr.getValueAsDouble() == 0.); + + mlir::Type boolTy = mlir::IntegerType::get(builder.getContext(), 1); + mlir::Value flaseBool = fir::factory::createZeroValue(builder, loc, boolTy); + EXPECT_TRUE(flaseBool.getType() == boolTy); + auto cst3 = mlir::dyn_cast_or_null( + flaseBool.getDefiningOp()); + EXPECT_TRUE(cst3); + auto intAttr2 = cst.getValue().dyn_cast(); + EXPECT_TRUE(intAttr2 && intAttr2.getInt() == 0); +} + +TEST_F(FIRBuilderTest, getBaseTypeOf) { + auto builder = getBuilder(); + auto loc = builder.getUnknownLoc(); + + auto makeExv = [&](mlir::Type elementType, mlir::Type arrayType) + -> std::tuple, + llvm::SmallVector> { + auto ptrTyArray = fir::PointerType::get(arrayType); + auto ptrTyScalar = fir::PointerType::get(elementType); + auto ptrBoxTyArray = fir::BoxType::get(ptrTyArray); + auto ptrBoxTyScalar = fir::BoxType::get(ptrTyScalar); + auto boxRefTyArray = fir::ReferenceType::get(ptrBoxTyArray); + auto boxRefTyScalar = fir::ReferenceType::get(ptrBoxTyScalar); + auto boxTyArray = fir::BoxType::get(arrayType); + auto boxTyScalar = fir::BoxType::get(elementType); + + auto ptrValArray = builder.create(loc, ptrTyArray); + auto ptrValScalar = builder.create(loc, ptrTyScalar); + auto boxRefValArray = builder.create(loc, boxRefTyArray); + auto boxRefValScalar = builder.create(loc, boxRefTyScalar); + auto boxValArray = builder.create(loc, boxTyArray); + auto boxValScalar = builder.create(loc, boxTyScalar); + + llvm::SmallVector scalars; + scalars.emplace_back(fir::UnboxedValue(ptrValScalar)); + scalars.emplace_back(fir::BoxValue(boxValScalar)); + scalars.emplace_back( + fir::MutableBoxValue(boxRefValScalar, mlir::ValueRange(), {})); + + llvm::SmallVector arrays; + auto extent = builder.create(loc, builder.getIndexType()); + llvm::SmallVector extents( + arrayType.dyn_cast().getDimension(), + extent.getResult()); + arrays.emplace_back(fir::ArrayBoxValue(ptrValArray, extents)); + arrays.emplace_back(fir::BoxValue(boxValArray)); + arrays.emplace_back( + fir::MutableBoxValue(boxRefValArray, mlir::ValueRange(), {})); + return {scalars, arrays}; + }; + + auto f32Ty = mlir::FloatType::getF32(builder.getContext()); + mlir::Type f32SeqTy = builder.getVarLenSeqTy(f32Ty); + auto [f32Scalars, f32Arrays] = makeExv(f32Ty, f32SeqTy); + for (const auto &scalar : f32Scalars) { + EXPECT_EQ(fir::getBaseTypeOf(scalar), f32Ty); + EXPECT_EQ(fir::getElementTypeOf(scalar), f32Ty); + EXPECT_FALSE(fir::isDerivedWithLengthParameters(scalar)); + } + for (const auto &array : f32Arrays) { + EXPECT_EQ(fir::getBaseTypeOf(array), f32SeqTy); + EXPECT_EQ(fir::getElementTypeOf(array), f32Ty); + EXPECT_FALSE(fir::isDerivedWithLengthParameters(array)); + } + + auto derivedWithLengthTy = + fir::RecordType::get(builder.getContext(), "derived_test"); + + llvm::SmallVector> parameters; + llvm::SmallVector> components; + parameters.emplace_back("p1", builder.getI64Type()); + components.emplace_back("c1", f32Ty); + derivedWithLengthTy.finalize(parameters, components); + mlir::Type derivedWithLengthSeqTy = + builder.getVarLenSeqTy(derivedWithLengthTy); + auto [derivedWithLengthScalars, derivedWithLengthArrays] = + makeExv(derivedWithLengthTy, derivedWithLengthSeqTy); + for (const auto &scalar : derivedWithLengthScalars) { + EXPECT_EQ(fir::getBaseTypeOf(scalar), derivedWithLengthTy); + EXPECT_EQ(fir::getElementTypeOf(scalar), derivedWithLengthTy); + EXPECT_TRUE(fir::isDerivedWithLengthParameters(scalar)); + } + for (const auto &array : derivedWithLengthArrays) { + EXPECT_EQ(fir::getBaseTypeOf(array), derivedWithLengthSeqTy); + EXPECT_EQ(fir::getElementTypeOf(array), derivedWithLengthTy); + EXPECT_TRUE(fir::isDerivedWithLengthParameters(array)); + } +}