diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2096,14 +2096,155 @@ } void genFIR(const Fortran::parser::SelectTypeConstruct &selectTypeConstruct) { - setCurrentPositionAt(selectTypeConstruct); - TODO(toLocation(), "SelectTypeConstruct implementation"); - } - void genFIR(const Fortran::parser::SelectTypeStmt &) { - TODO(toLocation(), "SelectTypeStmt implementation"); - } - void genFIR(const Fortran::parser::TypeGuardStmt &) { - TODO(toLocation(), "TypeGuardStmt implementation"); + mlir::Location loc = toLocation(); + mlir::MLIRContext *context = builder->getContext(); + Fortran::lower::StatementContext stmtCtx; + mlir::Value selector; + llvm::SmallVector attrList; + llvm::SmallVector blockList; + unsigned typeGuardIdx = 0; + bool hasLocalScope = false; + + for (Fortran::lower::pft::Evaluation &eval : + getEval().getNestedEvaluations()) { + if (auto *selectTypeStmt = + eval.getIf()) { + // Retrieve the selector + const auto &s = std::get(selectTypeStmt->t); + if (const auto *v = std::get_if(&s.u)) { + fir::ExtendedValue exv = + genExprAddr(*Fortran::semantics::GetExpr(*v), stmtCtx, &loc); + if (auto polyBox = exv.getBoxOf()) { + auto elementType = + fir::dyn_cast_ptrEleTy(fir::getBase(exv).getType()); + mlir::Type classTy = fir::ClassType::get(elementType); + llvm::SmallVector lenParams{}; + selector = builder->create( + loc, classTy, fir::getBase(exv), mlir::Value{}, mlir::Value{}, + lenParams, polyBox->getTdesc()); + } else { + selector = fir::getBase(exv); + } + } else { + fir::emitFatalError( + loc, "selector with expr not expected in select type statement"); + } + + // Going through the controlSuccessor first to create the + // fir.select_type operation. + mlir::Block *defaultBlock = nullptr; + for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e; + e = e->controlSuccessor) { + const auto &typeGuardStmt = + e->getIf(); + const auto &guard = + std::get(typeGuardStmt->t); + assert(e->block && "missing TypeGuardStmt block"); + // CLASS DEFAULT + if (std::holds_alternative(guard.u)) { + defaultBlock = e->block; + continue; + } + + blockList.push_back(e->block); + if (const auto *typeSpec = + std::get_if(&guard.u)) { + // TYPE IS + mlir::Type ty; + if (std::holds_alternative( + typeSpec->u)) { + const Fortran::semantics::IntrinsicTypeSpec *intrinsic = + typeSpec->declTypeSpec->AsIntrinsic(); + int kind = + Fortran::evaluate::ToInt64(intrinsic->kind()).value_or(kind); + llvm::SmallVector params; + if (intrinsic->category() == + Fortran::common::TypeCategory::Character || + intrinsic->category() == + Fortran::common::TypeCategory::Derived) + TODO(loc, "typeSpec with length parameters"); + ty = genType(intrinsic->category(), kind, params); + } else { + const Fortran::semantics::DerivedTypeSpec *derived = + typeSpec->declTypeSpec->AsDerived(); + ty = genType(*derived); + } + attrList.push_back(fir::ExactTypeAttr::get(ty)); + } else if (const auto *derived = + std::get_if( + &guard.u)) { + // CLASS IS + assert(derived->derivedTypeSpec && "derived type spec is null"); + mlir::Type ty = genType(*(derived->derivedTypeSpec)); + attrList.push_back(fir::SubclassAttr::get(ty)); + } + } + if (defaultBlock) { + attrList.push_back(mlir::UnitAttr::get(context)); + blockList.push_back(defaultBlock); + } + builder->create(loc, selector, attrList, blockList); + } else if (auto *typeGuardStmt = + eval.getIf()) { + // Map the type guard local symbol for the selector to a more precise + // typed entity in the TypeGuardStmt when necessary. + const auto &guard = + std::get(typeGuardStmt->t); + if (hasLocalScope) + localSymbols.popScope(); + localSymbols.pushScope(); + hasLocalScope = true; + assert(attrList.size() >= typeGuardIdx && + "TypeGuard attribute missing"); + mlir::Attribute typeGuardAttr = attrList[typeGuardIdx]; + mlir::Block *typeGuardBlock = blockList[typeGuardIdx]; + const Fortran::semantics::Scope &guardScope = + bridge.getSemanticsContext().FindScope(eval.position); + mlir::OpBuilder::InsertPoint crtInsPt = builder->saveInsertionPoint(); + builder->setInsertionPointToStart(typeGuardBlock); + + auto addAssocEntitySymbol = [&](mlir::Value val) { + for (auto &symbol : guardScope.GetSymbols()) { + if (symbol->GetUltimate() + .detailsIf()) { + localSymbols.addSymbol(symbol, val); + } + } + }; + + if (std::holds_alternative(guard.u)) { + // CLASS DEFAULT + addAssocEntitySymbol(fir::getBase(selector)); + } else if (std::holds_alternative(guard.u)) { + // TYPE IS + fir::ExactTypeAttr attr = + typeGuardAttr.dyn_cast(); + const auto &typeSpec = std::get(guard.u); + mlir::Value exactValue; + if (std::holds_alternative( + typeSpec.u)) { + exactValue = builder->create( + loc, fir::ReferenceType::get(attr.getType()), selector); + } else if (std::holds_alternative( + typeSpec.u)) { + exactValue = builder->create( + loc, fir::BoxType::get(attr.getType()), selector); + } + addAssocEntitySymbol(exactValue); + } else if (std::holds_alternative( + guard.u)) { + // CLASS IS + addAssocEntitySymbol(fir::getBase(selector)); + } + builder->restoreInsertionPoint(crtInsPt); + ++typeGuardIdx; + } else if (eval.getIf()) { + if (hasLocalScope) + localSymbols.popScope(); + stmtCtx.finalize(); + } + genFIR(eval); + } } //===--------------------------------------------------------------------===// @@ -2751,6 +2892,8 @@ void genFIR(const Fortran::parser::IfThenStmt &) {} // nop void genFIR(const Fortran::parser::NonLabelDoStmt &) {} // nop void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop + void genFIR(const Fortran::parser::SelectTypeStmt &) {} // nop + void genFIR(const Fortran::parser::TypeGuardStmt &) {} // nop void genFIR(const Fortran::parser::NamelistStmt &) { TODO(toLocation(), "NamelistStmt lowering"); diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -4162,7 +4162,7 @@ mlir::Value convertElementForUpdate(mlir::Location loc, mlir::Type eleTy, mlir::Value origVal) { if (auto origEleTy = fir::dyn_cast_ptrEleTy(origVal.getType())) - if (origEleTy.isa()) { + if (origEleTy.isa()) { // If origVal is a box variable, load it so it is in the value domain. origVal = builder.create(loc, origVal); } diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp --- a/flang/lib/Lower/SymbolMap.cpp +++ b/flang/lib/Lower/SymbolMap.cpp @@ -26,6 +26,7 @@ [&](const fir::CharArrayBoxValue &v) { makeSym(sym, v, force); }, [&](const fir::BoxValue &v) { makeSym(sym, v, force); }, [&](const fir::MutableBoxValue &v) { makeSym(sym, v, force); }, + [&](const fir::PolymorphicValue &v) { makeSym(sym, v, force); }, [](auto) { llvm::report_fatal_error("value not added to symbol table"); }); diff --git a/flang/test/Lower/select-type.f90 b/flang/test/Lower/select-type.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/select-type.f90 @@ -0,0 +1,170 @@ +! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s + +module m1 + type p1 + integer :: a + integer :: b + end type + + type, extends(p1) :: p2 + integer :: c + end type + + type, extends(p1) :: p3(k) + integer, kind :: k + real(k) :: r + end type + +contains + + function get_class() + class(p1), pointer :: get_class + end function + + subroutine select_type1(a) + class(p1), intent(in) :: a + + select type (a) + type is (p1) + print*, 'type is p1' + class is (p1) + print*, 'class is p1' + class default + print*,'default' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMm1Pselect_type1( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class> {fir.bindc_name = "a"}) + +! CHECK: fir.select_type %[[ARG0]] : !fir.class> +! CHECK-SAME: [#fir.type_is>, ^[[TYPE_IS_BLOCK:.*]], #fir.class_is>, ^[[CLASS_IS_BLOCK:.*]], unit, ^[[DEFAULT_BLOCK:.*]]] +! CHECK: ^[[TYPE_IS_BLOCK]] +! CHECK: ^[[CLASS_IS_BLOCK]] +! CHECK: ^[[DEFAULT_BLOCK]] + + subroutine select_type2() + select type (a => get_class()) + type is (p1) + print*, 'type is p1' + class is (p1) + print*, 'class is p1' + class default + print*,'default' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMm1Pselect_type2() +! CHECK: %[[RESULT:.*]] = fir.alloca !fir.class>> {bindc_name = ".result"} +! CHECK: %[[FCTCALL:.*]] = fir.call @_QMm1Pget_class() : () -> !fir.class>> +! CHECK: fir.save_result %[[FCTCALL]] to %[[RESULT]] : !fir.class>>, !fir.ref>>> +! CHECK: %[[SELECTOR:.*]] = fir.load %[[RESULT]] : !fir.ref>>> +! CHECK: fir.select_type %[[SELECTOR]] : !fir.class>> +! CHECK-SAME: [#fir.type_is>, ^bb1, #fir.class_is>, ^bb2, unit, ^bb3] +! CHECK: ^[[TYPE_IS_BLOCK]] +! CHECK: ^[[CLASS_IS_BLOCK]] +! CHECK: ^[[DEFAULT_BLOCK]] + + subroutine select_type3(a) + class(p1), pointer, intent(in) :: a(:) + + select type (x => a(1)) + type is (p1) + print*, 'type is p1' + class is (p1) + print*, 'class is p1' + class default + print*,'default' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMm1Pselect_type3( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref>>>> {fir.bindc_name = "a"}) +! CHECK: %[[ARG0_LOAD:.*]] = fir.load %[[ARG0]] : !fir.ref>>>> +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[ARG0_LOAD]], %{{.*}} : (!fir.class>>>, i64) -> !fir.ref> +! CHECK: %[[TDESC:.*]] = fir.box_tdesc %[[ARG0_LOAD]] : (!fir.class>>>) -> !fir.tdesc +! CHECK: %[[SELECTOR:.*]] = fir.embox %[[COORD]] tdesc %[[TDESC]] : (!fir.ref>, !fir.tdesc) -> !fir.class> +! CHECK: fir.select_type %[[SELECTOR]] : !fir.class> +! CHECK-SAME: [#fir.type_is>, ^bb1, #fir.class_is>, ^bb2, unit, ^bb3] +! CHECK: ^[[TYPE_IS_BLOCK]] +! CHECK: ^[[CLASS_IS_BLOCK]] +! CHECK: ^[[DEFAULT_BLOCK]] + + subroutine select_type4(a) + class(p1), intent(in) :: a + select type(a) + type is(p3(8)) + print*, 'type is p3(8)' + type is(p3(4)) + print*, 'type is p3(4)' + class is (p1) + print*, 'class is p1' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMm1Pselect_type4( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class> {fir.bindc_name = "a"}) +! CHECK: fir.select_type %[[ARG0]] : !fir.class> +! CHECK-SAME: [#fir.type_is>, ^[[P3_8:.*]], #fir.type_is>, ^[[P3_4:.*]], #fir.class_is>, ^[[P1:.*]]] +! CHECK: ^[[P3_8]] +! CHECK: ^[[P3_4]] +! CHECK: ^[[P1]] + + subroutine select_type5(a) + class(*), intent(in) :: a + + select type (x => a) + type is (integer(1)) + print*, 'type is integer(1)' + type is (integer(4)) + print*, 'type is integer(4)' + type is (real(4)) + print*, 'type is real' + type is (logical) + print*, 'type is logical' + class default + print*,'default' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMm1Pselect_type5( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class {fir.bindc_name = "a"}) +! CHECK: fir.select_type %[[ARG0]] : !fir.class +! CHECK-SAME: [#fir.type_is, ^[[I8_BLK:.*]], #fir.type_is, ^[[I32_BLK:.*]], #fir.type_is, ^[[F32_BLK:.*]], #fir.type_is>, ^[[LOG_BLK:.*]], unit, ^[[DEFAULT:.*]]] +! CHECK: ^[[I8_BLK]] +! CHECK: ^[[I32_BLK]] +! CHECK: ^[[F32_BLK]] +! CHECK: ^[[LOG_BLK]] +! CHECK: ^[[DEFAULT_BLOCK]] + + + subroutine select_type6(a) + class(*), intent(out) :: a + + select type(a) + type is (integer) + a = 100 + type is (real) + a = 2.0 + class default + stop 'error' + end select + end subroutine + +! CHECK-LABEL: func.func @_QMm1Pselect_type6( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class {fir.bindc_name = "a"}) + +! CHECK: fir.select_type %[[ARG0]] : !fir.class [#fir.type_is, ^[[INT_BLK:.*]], #fir.type_is, ^[[REAL_BLK:.*]], unit, ^[[DEFAULT_BLK:.*]]] +! CHECK: ^[[INT_BLK]] +! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] : (!fir.class) -> !fir.ref +! CHECK: %[[C100:.*]] = arith.constant 100 : i32 +! CHECK: fir.store %[[C100]] to %[[BOX_ADDR]] : !fir.ref + +! CHECK: ^[[REAL_BLK]]: // pred: ^bb0 +! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] : (!fir.class) -> !fir.ref +! CHECK: %[[C2:.*]] = arith.constant 2.000000e+00 : f32 +! CHECK: fir.store %[[C2]] to %[[BOX_ADDR]] : !fir.ref + +end module + +