diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -265,8 +265,21 @@ bool isUnlimitedPolymorphicType(mlir::Type ty) { if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) ty = refTy; - if (auto clTy = ty.dyn_cast()) - return clTy.getEleTy().isa(); + if (auto clTy = ty.dyn_cast()) { + if (clTy.getEleTy().isa()) + return true; + mlir::Type innerType = + llvm::TypeSwitch(clTy.getEleTy()) + .Case( + [](auto ty) { + mlir::Type eleTy = ty.getEleTy(); + if (auto seqTy = eleTy.dyn_cast()) + return seqTy.getEleTy(); + return eleTy; + }) + .Default([](mlir::Type) { return mlir::Type{}; }); + return innerType.isa(); + } return false; } diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt --- a/flang/unittests/Optimizer/CMakeLists.txt +++ b/flang/unittests/Optimizer/CMakeLists.txt @@ -23,6 +23,7 @@ Builder/Runtime/StopTest.cpp Builder/Runtime/TransformationalTest.cpp FIRContextTest.cpp + FIRTypesTest.cpp InternalNamesTest.cpp KindMappingTest.cpp RTBuilder.cpp diff --git a/flang/unittests/Optimizer/FIRTypesTest.cpp b/flang/unittests/Optimizer/FIRTypesTest.cpp new file mode 100644 --- /dev/null +++ b/flang/unittests/Optimizer/FIRTypesTest.cpp @@ -0,0 +1,61 @@ +//===- FIRTypesTest.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gtest/gtest.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/InitFIR.h" + +struct FIRTypesTest : public testing::Test { +public: + void SetUp() { fir::support::loadDialects(context); } + + mlir::MLIRContext context; +}; + +// Test fir::isUnlimitedPolymorphicType from flang/Optimizer/Dialect/FIRType.h. +TEST_F(FIRTypesTest, isUnlimitedPolymorphicTypeTest) { + mlir::Type noneTy = mlir::NoneType::get(&context); + + // CLASS(*) + mlir::Type ty = fir::ClassType::get(noneTy); + EXPECT_TRUE(fir::isUnlimitedPolymorphicType(ty)); + EXPECT_TRUE(fir::isUnlimitedPolymorphicType(fir::ReferenceType::get(ty))); + + // CLASS(*), DIMENSION(10) + ty = fir::ClassType::get(fir::SequenceType::get({10}, noneTy)); + EXPECT_TRUE(fir::isUnlimitedPolymorphicType(ty)); + + // CLASS(*), DIMENSION(:) + ty = fir::ClassType::get( + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, noneTy)); + EXPECT_TRUE(fir::isUnlimitedPolymorphicType(ty)); + + // CLASS(*), ALLOCATABLE + ty = fir::ClassType::get(fir::HeapType::get(noneTy)); + EXPECT_TRUE(fir::isUnlimitedPolymorphicType(ty)); + + mlir::Type seqNoneTy = + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, noneTy); + // CLASS(*), ALLOCATABLE, DIMENSION(:) + ty = fir::ClassType::get(fir::HeapType::get(seqNoneTy)); + EXPECT_TRUE(fir::isUnlimitedPolymorphicType(ty)); + + // CLASS(*), POINTER + ty = fir::ClassType::get(fir::PointerType::get(noneTy)); + EXPECT_TRUE(fir::isUnlimitedPolymorphicType(ty)); + + // CLASS(*), POINTER, DIMENSIONS(:) + ty = fir::ClassType::get(fir::PointerType::get(seqNoneTy)); + EXPECT_TRUE(fir::isUnlimitedPolymorphicType(ty)); + + // false tests + EXPECT_FALSE(fir::isUnlimitedPolymorphicType(noneTy)); + EXPECT_FALSE(fir::isUnlimitedPolymorphicType(fir::BoxType::get(noneTy))); + EXPECT_FALSE(fir::isUnlimitedPolymorphicType(fir::BoxType::get(seqNoneTy))); + EXPECT_FALSE(fir::isUnlimitedPolymorphicType(seqNoneTy)); +}