diff --git a/flang/include/flang/Optimizer/Builder/BoxValue.h b/flang/include/flang/Optimizer/Builder/BoxValue.h --- a/flang/include/flang/Optimizer/Builder/BoxValue.h +++ b/flang/include/flang/Optimizer/Builder/BoxValue.h @@ -517,6 +517,14 @@ [](const auto &box) -> unsigned { return box.rank(); }); } + bool isPolymorphic() const { + return match([](const fir::PolymorphicValue &box) -> bool { return true; }, + [](const fir::ArrayBoxValue &box) -> bool { + return box.getTdesc() ? true : false; + }, + [](const auto &box) -> bool { return false; }); + } + /// Is this an assumed size array ? bool isAssumedSize() const; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2096,14 +2096,148 @@ } void genFIR(const Fortran::parser::SelectTypeConstruct &selectTypeConstruct) { - setCurrentPositionAt(selectTypeConstruct); - TODO(toLocation(), "SelectTypeConstruct implementation"); - } - void genFIR(const Fortran::parser::SelectTypeStmt &) { - TODO(toLocation(), "SelectTypeStmt implementation"); - } - void genFIR(const Fortran::parser::TypeGuardStmt &) { - TODO(toLocation(), "TypeGuardStmt implementation"); + mlir::Location loc = toLocation(); + mlir::MLIRContext *context = builder->getContext(); + Fortran::lower::StatementContext stmtCtx; + fir::ExtendedValue selector; + llvm::SmallVector attrList; + llvm::SmallVector blockList; + unsigned typeGuardIdx = 0; + bool hasLocalScope = false; + + for (Fortran::lower::pft::Evaluation &eval : + getEval().getNestedEvaluations()) { + if (auto *selectTypeStmt = + eval.getIf()) { + // Retrieve the selector + const auto &s = std::get(selectTypeStmt->t); + if (const auto *v = std::get_if(&s.u)) + selector = genExprBox(loc, *Fortran::semantics::GetExpr(*v), stmtCtx); + else + fir::emitFatalError( + loc, "selector with expr not expected in select type statement"); + + // Going through the controlSuccessor first to create the + // fir.select_type operation. + mlir::Block *defaultBlock = nullptr; + for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e; + e = e->controlSuccessor) { + const auto &typeGuardStmt = + e->getIf(); + const auto &guard = + std::get(typeGuardStmt->t); + assert(e->block && "missing TypeGuardStmt block"); + // CLASS DEFAULT + if (std::holds_alternative(guard.u)) { + defaultBlock = e->block; + continue; + } + + blockList.push_back(e->block); + if (const auto *typeSpec = + std::get_if(&guard.u)) { + // TYPE IS + mlir::Type ty; + if (std::holds_alternative( + typeSpec->u)) { + const Fortran::semantics::IntrinsicTypeSpec *intrinsic = + typeSpec->declTypeSpec->AsIntrinsic(); + int kind = + Fortran::evaluate::ToInt64(intrinsic->kind()).value_or(kind); + llvm::SmallVector params; + if (intrinsic->category() == + Fortran::common::TypeCategory::Character || + intrinsic->category() == + Fortran::common::TypeCategory::Derived) + TODO(loc, "typeSpec with length parameters"); + ty = genType(intrinsic->category(), kind, params); + } else { + const Fortran::semantics::DerivedTypeSpec *derived = + typeSpec->declTypeSpec->AsDerived(); + ty = genType(*derived); + } + attrList.push_back(fir::ExactTypeAttr::get(ty)); + } else if (const auto *derived = + std::get_if( + &guard.u)) { + // CLASS IS + assert(derived->derivedTypeSpec && "derived type spec is null"); + mlir::Type ty = genType(*(derived->derivedTypeSpec)); + attrList.push_back(fir::SubclassAttr::get(ty)); + } + } + if (defaultBlock) { + attrList.push_back(mlir::UnitAttr::get(context)); + blockList.push_back(defaultBlock); + } + builder->create(loc, fir::getBase(selector), + attrList, blockList); + } else if (auto *typeGuardStmt = + eval.getIf()) { + // Map the type guard local symbol for the selector to a more precise + // typed entity in the TypeGuardStmt when necessary. + const auto &guard = + std::get(typeGuardStmt->t); + if (hasLocalScope) + localSymbols.popScope(); + localSymbols.pushScope(); + hasLocalScope = true; + assert(attrList.size() >= typeGuardIdx && + "TypeGuard attribute missing"); + mlir::Attribute typeGuardAttr = attrList[typeGuardIdx]; + mlir::Block *typeGuardBlock = blockList[typeGuardIdx]; + const Fortran::semantics::Scope &guardScope = + bridge.getSemanticsContext().FindScope(eval.position); + mlir::OpBuilder::InsertPoint crtInsPt = builder->saveInsertionPoint(); + builder->setInsertionPointToStart(typeGuardBlock); + + auto addAssocEntitySymbol = [&](fir::ExtendedValue exv) { + for (auto &symbol : guardScope.GetSymbols()) { + if (symbol->GetUltimate() + .detailsIf()) { + localSymbols.addSymbol(symbol, exv); + break; + } + } + }; + + if (std::holds_alternative(guard.u)) { + // CLASS DEFAULT + addAssocEntitySymbol(selector); + } else if (const auto *typeSpec = + std::get_if(&guard.u)) { + // TYPE IS + fir::ExactTypeAttr attr = + typeGuardAttr.dyn_cast(); + mlir::Value exactValue; + if (std::holds_alternative( + typeSpec->u)) { + exactValue = builder->create( + loc, fir::ReferenceType::get(attr.getType()), + fir::getBase(selector)); + } else if (std::holds_alternative( + typeSpec->u)) { + exactValue = builder->create( + loc, fir::BoxType::get(attr.getType()), fir::getBase(selector)); + } + addAssocEntitySymbol(exactValue); + } else if (std::holds_alternative( + guard.u)) { + // CLASS IS + fir::SubclassAttr attr = typeGuardAttr.dyn_cast(); + mlir::Value derived = builder->create( + loc, fir::ClassType::get(attr.getType()), fir::getBase(selector)); + addAssocEntitySymbol(derived); + } + builder->restoreInsertionPoint(crtInsPt); + ++typeGuardIdx; + } else if (eval.getIf()) { + if (hasLocalScope) + localSymbols.popScope(); + stmtCtx.finalize(); + } + genFIR(eval); + } } //===--------------------------------------------------------------------===// @@ -2751,6 +2885,8 @@ void genFIR(const Fortran::parser::IfThenStmt &) {} // nop void genFIR(const Fortran::parser::NonLabelDoStmt &) {} // nop void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop + void genFIR(const Fortran::parser::SelectTypeStmt &) {} // nop + void genFIR(const Fortran::parser::TypeGuardStmt &) {} // nop void genFIR(const Fortran::parser::NamelistStmt &) { TODO(toLocation(), "NamelistStmt lowering"); diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -4162,7 +4162,7 @@ mlir::Value convertElementForUpdate(mlir::Location loc, mlir::Type eleTy, mlir::Value origVal) { if (auto origEleTy = fir::dyn_cast_ptrEleTy(origVal.getType())) - if (origEleTy.isa()) { + if (origEleTy.isa()) { // If origVal is a box variable, load it so it is in the value domain. origVal = builder.create(loc, origVal); } @@ -7645,8 +7645,8 @@ } fir::ExtendedValue addr = Fortran::lower::createSomeExtendedAddress( loc, converter, expr, symMap, stmtCtx); - fir::ExtendedValue result = - fir::BoxValue(converter.getFirOpBuilder().createBox(loc, addr)); + fir::ExtendedValue result = fir::BoxValue( + converter.getFirOpBuilder().createBox(loc, addr, addr.isPolymorphic())); if (isParentComponent(expr)) result = updateBoxForParentComponent(converter, result, expr); return result; diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp --- a/flang/lib/Lower/SymbolMap.cpp +++ b/flang/lib/Lower/SymbolMap.cpp @@ -26,6 +26,7 @@ [&](const fir::CharArrayBoxValue &v) { makeSym(sym, v, force); }, [&](const fir::BoxValue &v) { makeSym(sym, v, force); }, [&](const fir::MutableBoxValue &v) { makeSym(sym, v, force); }, + [&](const fir::PolymorphicValue &v) { makeSym(sym, v, force); }, [](auto) { llvm::report_fatal_error("value not added to symbol table"); }); diff --git a/flang/test/Lower/select-type.f90 b/flang/test/Lower/select-type.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/select-type.f90 @@ -0,0 +1,176 @@ +! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s + +module select_type_lower_test + type p1 + integer :: a + integer :: b + end type + + type, extends(p1) :: p2 + integer :: c + end type + + type, extends(p1) :: p3(k) + integer, kind :: k + real(k) :: r + end type + +contains + + function get_class() + class(p1), pointer :: get_class + end function + + subroutine select_type1(a) + class(p1), intent(in) :: a + + select type (a) + type is (p1) + print*, 'type is p1' + class is (p1) + print*, 'class is p1' + class is (p2) + print*, 'class is p2', a%c + class default + print*,'default' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMselect_type_lower_testPselect_type1( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class> {fir.bindc_name = "a"}) + +! CHECK: fir.select_type %[[ARG0]] : !fir.class> +! CHECK-SAME: [#fir.type_is>, ^[[TYPE_IS_BLK:.*]], #fir.class_is>, ^[[CLASS_IS_P1_BLK:.*]], #fir.class_is>, ^[[CLASS_IS_P2_BLK:.*]], unit, ^[[DEFAULT_BLOCK:.*]]] +! CHECK: ^[[TYPE_IS_BLK]] +! CHECK: ^[[CLASS_IS_P1_BLK]] +! CHECK: ^[[CLASS_IS_P2_BLK]] +! CHECK: %[[P2:.*]] = fir.convert %[[ARG0:.*]] : (!fir.class>) -> !fir.class> +! CHECK: %[[FIELD:.*]] = fir.field_index c, !fir.type<_QMselect_type_lower_testTp2{a:i32,b:i32,c:i32}> +! CHECK: %{{.*}} = fir.coordinate_of %[[P2]], %[[FIELD]] : (!fir.class>, !fir.field) -> !fir.ref +! CHECK: ^[[DEFAULT_BLOCK]] + + subroutine select_type2() + select type (a => get_class()) + type is (p1) + print*, 'type is p1' + class is (p1) + print*, 'class is p1' + class default + print*,'default' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMselect_type_lower_testPselect_type2() +! CHECK: %[[RESULT:.*]] = fir.alloca !fir.class>> {bindc_name = ".result"} +! CHECK: %[[FCTCALL:.*]] = fir.call @_QMselect_type_lower_testPget_class() : () -> !fir.class>> +! CHECK: fir.save_result %[[FCTCALL]] to %[[RESULT]] : !fir.class>>, !fir.ref>>> +! CHECK: %[[SELECTOR:.*]] = fir.load %[[RESULT]] : !fir.ref>>> +! CHECK: fir.select_type %[[SELECTOR]] : !fir.class>> +! CHECK-SAME: [#fir.type_is>, ^[[TYPE_IS_BLK:.*]], #fir.class_is>, ^[[CLASS_IS_BLK:.*]], unit, ^[[DEFAULT_BLK:.*]]] +! CHECK: ^[[TYPE_IS_BLK]] +! CHECK: ^[[CLASS_IS_BLK]] +! CHECK: ^[[DEFAULT_BLK]] + + subroutine select_type3(a) + class(p1), pointer, intent(in) :: a(:) + + select type (x => a(1)) + type is (p1) + print*, 'type is p1' + class is (p1) + print*, 'class is p1' + class default + print*,'default' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMselect_type_lower_testPselect_type3( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref>>>> {fir.bindc_name = "a"}) +! CHECK: %[[ARG0_LOAD:.*]] = fir.load %[[ARG0]] : !fir.ref>>>> +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[ARG0_LOAD]], %{{.*}} : (!fir.class>>>, i64) -> !fir.ref> +! CHECK: %[[TDESC:.*]] = fir.box_tdesc %[[ARG0_LOAD]] : (!fir.class>>>) -> !fir.tdesc +! CHECK: %[[SELECTOR:.*]] = fir.embox %[[COORD]] tdesc %[[TDESC]] : (!fir.ref>, !fir.tdesc) -> !fir.class> +! CHECK: fir.select_type %[[SELECTOR]] : !fir.class> +! CHECK-SAME: [#fir.type_is>, ^[[TYPE_IS_BLK:.*]], #fir.class_is>, ^[[CLASS_IS_BLK:.*]], unit, ^[[DEFAULT_BLK:.*]]] +! CHECK: ^[[TYPE_IS_BLK]] +! CHECK: ^[[CLASS_IS_BLK]] +! CHECK: ^[[DEFAULT_BLK]] + + subroutine select_type4(a) + class(p1), intent(in) :: a + select type(a) + type is(p3(8)) + print*, 'type is p3(8)' + type is(p3(4)) + print*, 'type is p3(4)' + class is (p1) + print*, 'class is p1' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMselect_type_lower_testPselect_type4( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class> {fir.bindc_name = "a"}) +! CHECK: fir.select_type %[[ARG0]] : !fir.class> +! CHECK-SAME: [#fir.type_is>, ^[[P3_8:.*]], #fir.type_is>, ^[[P3_4:.*]], #fir.class_is>, ^[[P1:.*]]] +! CHECK: ^[[P3_8]] +! CHECK: ^[[P3_4]] +! CHECK: ^[[P1]] + + subroutine select_type5(a) + class(*), intent(in) :: a + + select type (x => a) + type is (integer(1)) + print*, 'type is integer(1)' + type is (integer(4)) + print*, 'type is integer(4)' + type is (real(4)) + print*, 'type is real' + type is (logical) + print*, 'type is logical' + class default + print*,'default' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMselect_type_lower_testPselect_type5( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class {fir.bindc_name = "a"}) +! CHECK: fir.select_type %[[ARG0]] : !fir.class +! CHECK-SAME: [#fir.type_is, ^[[I8_BLK:.*]], #fir.type_is, ^[[I32_BLK:.*]], #fir.type_is, ^[[F32_BLK:.*]], #fir.type_is>, ^[[LOG_BLK:.*]], unit, ^[[DEFAULT:.*]]] +! CHECK: ^[[I8_BLK]] +! CHECK: ^[[I32_BLK]] +! CHECK: ^[[F32_BLK]] +! CHECK: ^[[LOG_BLK]] +! CHECK: ^[[DEFAULT_BLOCK]] + + + subroutine select_type6(a) + class(*), intent(out) :: a + + select type(a) + type is (integer) + a = 100 + type is (real) + a = 2.0 + class default + stop 'error' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMselect_type_lower_testPselect_type6( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class {fir.bindc_name = "a"}) + +! CHECK: fir.select_type %[[ARG0]] : !fir.class [#fir.type_is, ^[[INT_BLK:.*]], #fir.type_is, ^[[REAL_BLK:.*]], unit, ^[[DEFAULT_BLK:.*]]] +! CHECK: ^[[INT_BLK]] +! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] : (!fir.class) -> !fir.ref +! CHECK: %[[C100:.*]] = arith.constant 100 : i32 +! CHECK: fir.store %[[C100]] to %[[BOX_ADDR]] : !fir.ref + +! CHECK: ^[[REAL_BLK]]: // pred: ^bb0 +! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] : (!fir.class) -> !fir.ref +! CHECK: %[[C2:.*]] = arith.constant 2.000000e+00 : f32 +! CHECK: fir.store %[[C2]] to %[[BOX_ADDR]] : !fir.ref + +end module + +