diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1220,7 +1220,30 @@ } void genFIR(const Fortran::parser::AssociateConstruct &) { - TODO(toLocation(), "AssociateConstruct lowering"); + Fortran::lower::StatementContext stmtCtx; + Fortran::lower::pft::Evaluation &eval = getEval(); + for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) { + if (auto *stmt = e.getIf()) { + if (eval.lowerAsUnstructured()) + maybeStartBlock(e.block); + localSymbols.pushScope(); + for (const Fortran::parser::Association &assoc : + std::get>(stmt->t)) { + Fortran::semantics::Symbol &sym = + *std::get(assoc.t).symbol; + const Fortran::lower::SomeExpr &selector = + *sym.get().expr(); + localSymbols.addSymbol(sym, genAssociateSelector(selector, stmtCtx)); + } + } else if (e.getIf()) { + if (eval.lowerAsUnstructured()) + maybeStartBlock(e.block); + stmtCtx.finalize(); + localSymbols.popScope(); + } else { + genFIR(e); + } + } } void genFIR(const Fortran::parser::BlockConstruct &blockConstruct) { @@ -1571,10 +1594,6 @@ genFIRBranch(getEval().controlSuccessor->block); } - void genFIR(const Fortran::parser::AssociateStmt &) { - TODO(toLocation(), "AssociateStmt lowering"); - } - void genFIR(const Fortran::parser::CaseStmt &) { TODO(toLocation(), "CaseStmt lowering"); } @@ -1587,10 +1606,6 @@ TODO(toLocation(), "ElseStmt lowering"); } - void genFIR(const Fortran::parser::EndAssociateStmt &) { - TODO(toLocation(), "EndAssociateStmt lowering"); - } - void genFIR(const Fortran::parser::EndDoStmt &) { TODO(toLocation(), "EndDoStmt lowering"); } @@ -1604,7 +1619,9 @@ } // Nop statements - No code, or code is generated at the construct level. + void genFIR(const Fortran::parser::AssociateStmt &) {} // nop void genFIR(const Fortran::parser::ContinueStmt &) {} // nop + void genFIR(const Fortran::parser::EndAssociateStmt &) {} // nop void genFIR(const Fortran::parser::EndFunctionStmt &) {} // nop void genFIR(const Fortran::parser::EndIfStmt &) {} // nop void genFIR(const Fortran::parser::EndSubroutineStmt &) {} // nop diff --git a/flang/lib/Lower/Mangler.cpp b/flang/lib/Lower/Mangler.cpp --- a/flang/lib/Lower/Mangler.cpp +++ b/flang/lib/Lower/Mangler.cpp @@ -21,31 +21,36 @@ // recursively build the vector of module scopes static void moduleNames(const Fortran::semantics::Scope &scope, - llvm::SmallVector &result) { - if (scope.IsTopLevel()) { + llvm::SmallVector &result) { + if (scope.IsTopLevel()) return; - } moduleNames(scope.parent(), result); if (scope.kind() == Fortran::semantics::Scope::Kind::Module) - if (auto *symbol = scope.symbol()) + if (const Fortran::semantics::Symbol *symbol = scope.symbol()) result.emplace_back(toStringRef(symbol->name())); } -static llvm::SmallVector +static llvm::SmallVector moduleNames(const Fortran::semantics::Symbol &symbol) { - const auto &scope = symbol.owner(); - llvm::SmallVector result; + const Fortran::semantics::Scope &scope = symbol.owner(); + llvm::SmallVector result; moduleNames(scope, result); return result; } static llvm::Optional hostName(const Fortran::semantics::Symbol &symbol) { - const auto &scope = symbol.owner(); + const Fortran::semantics::Scope &scope = symbol.owner(); if (scope.kind() == Fortran::semantics::Scope::Kind::Subprogram) { assert(scope.symbol() && "subprogram scope must have a symbol"); - return {toStringRef(scope.symbol()->name())}; + return toStringRef(scope.symbol()->name()); } + if (scope.kind() == Fortran::semantics::Scope::Kind::MainProgram) + // Do not use the main program name, if any, because it may lead to name + // collision with procedures with the same name in other compilation units + // (technically illegal, but all compilers are able to compile and link + // properly these programs). + return llvm::StringRef(""); return {}; } diff --git a/flang/test/Lower/associate-construct.f90 b/flang/test/Lower/associate-construct.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/associate-construct.f90 @@ -0,0 +1,93 @@ +! RUN: bbc -emit-fir -o - %s | FileCheck %s + +! CHECK-LABEL: func @_QQmain +program p + ! CHECK-DAG: [[I:%[0-9]+]] = fir.alloca i32 {{{.*}}uniq_name = "_QFEi"} + ! CHECK-DAG: [[N:%[0-9]+]] = fir.alloca i32 {{{.*}}uniq_name = "_QFEn"} + ! CHECK: [[T:%[0-9]+]] = fir.address_of(@_QFEt) : !fir.ref> + integer :: n, foo, t(3) + ! CHECK: [[N]] + ! CHECK-COUNT-3: fir.coordinate_of [[T]] + n = 100; t(1) = 111; t(2) = 222; t(3) = 333 + ! CHECK: fir.load [[N]] + ! CHECK: addi {{.*}} %c5 + ! CHECK: fir.store %{{[0-9]*}} to [[B:%[0-9]+]] + ! CHECK: [[C:%[0-9]+]] = fir.coordinate_of [[T]] + ! CHECK: fir.call @_QPfoo + ! CHECK: fir.store %{{[0-9]*}} to [[D:%[0-9]+]] + associate (a => n, b => n+5, c => t(2), d => foo(7)) + ! CHECK: fir.load [[N]] + ! CHECK: addi %{{[0-9]*}}, %c1 + ! CHECK: fir.store %{{[0-9]*}} to [[N]] + a = a + 1 + ! CHECK: fir.load [[C]] + ! CHECK: addi %{{[0-9]*}}, %c1 + ! CHECK: fir.store %{{[0-9]*}} to [[C]] + c = c + 1 + ! CHECK: fir.load [[N]] + ! CHECK: addi %{{[0-9]*}}, %c1 + ! CHECK: fir.store %{{[0-9]*}} to [[N]] + n = n + 1 + ! CHECK: fir.load [[N]] + ! CHECK: fir.embox [[T]] + ! CHECK: fir.load [[N]] + ! CHECK: fir.load [[B]] + ! CHECK: fir.load [[C]] + ! CHECK: fir.load [[D]] + print*, n, t, a, b, c, d ! expect: 102 111 223 333 102 105 223 7 + end associate + + call nest + + associate (x=>i) + ! CHECK: [[IVAL:%[0-9]+]] = fir.load [[I]] : !fir.ref + ! CHECK: [[TWO:%.*]] = arith.constant 2 : i32 + ! CHECK: arith.cmpi eq, [[IVAL]], [[TWO]] : i32 + ! CHECK: ^bb + if (x==2) goto 9 + ! CHECK: [[IVAL:%[0-9]+]] = fir.load [[I]] : !fir.ref + ! CHECK: [[THREE:%.*]] = arith.constant 3 : i32 + ! CHECK: arith.cmpi eq, [[IVAL]], [[THREE]] : i32 + ! CHECK: ^bb + ! CHECK: fir.call @_FortranAStopStatementText + ! CHECK: fir.unreachable + ! CHECK: ^bb + if (x==3) stop 'Halt' + ! CHECK: fir.call @_FortranAioOutputAscii + print*, "ok" + 9 end associate + end + + ! CHECK-LABEL: func @_QPfoo + integer function foo(x) + integer x + integer, save :: i = 0 + i = i + x + foo = i + end function foo + + ! CHECK-LABEL: func @_QPnest( + subroutine nest + integer, parameter :: n = 10 + integer :: a(5), b(n) + associate (s => sequence(size(a))) + a = s + associate(t => sequence(n)) + b = t + ! CHECK: cond_br %{{.*}}, [[BB1:\^bb[0-9]]], [[BB2:\^bb[0-9]]] + ! CHECK: [[BB1]]: + ! CHECK: br [[BB3:\^bb[0-9]]] + ! CHECK: [[BB2]]: + if (a(1) > b(1)) goto 9 + end associate + a = a * a + end associate + ! CHECK: br [[BB3]] + ! CHECK: [[BB3]]: + 9 print *, sum(a), sum(b) ! expect: 55 55 + contains + function sequence(n) + integer sequence(n) + sequence = [(i,i=1,n)] + end function + end subroutine nest