Index: flang/include/flang/Lower/SymbolMap.h =================================================================== --- flang/include/flang/Lower/SymbolMap.h +++ flang/include/flang/Lower/SymbolMap.h @@ -316,10 +316,10 @@ } private: - /// Add `symbol` to the current map and bind a `box`. + /// Bind `box` to `symRef` in the symbol map. void makeSym(semantics::SymbolRef symRef, const SymbolBox &box, bool force = false) { - const auto *sym = &symRef.get().GetUltimate(); + auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate(); if (force) symbolMapStack.back().erase(sym); assert(box && "cannot add an undefined symbol box"); Index: flang/include/flang/Semantics/symbol.h =================================================================== --- flang/include/flang/Semantics/symbol.h +++ flang/include/flang/Semantics/symbol.h @@ -707,6 +707,9 @@ }, details_); } + bool HasLocalLocality() const { + return test(Flag::LocalityLocal) || test(Flag::LocalityLocalInit); + } bool operator==(const Symbol &that) const { return this == &that; } bool operator!=(const Symbol &that) const { return !(*this == that); } Index: flang/lib/Lower/Bridge.cpp =================================================================== --- flang/lib/Lower/Bridge.cpp +++ flang/lib/Lower/Bridge.cpp @@ -95,6 +95,11 @@ return fir::unwrapRefType(loopVariable.getType()); } + bool hasLocalitySpecs() const { + return !localSymList.empty() || !localInitSymList.empty() || + !sharedSymList.empty(); + } + // Data members common to both structured and unstructured loops. const Fortran::semantics::Symbol &loopVariableSym; const Fortran::lower::SomeExpr *lowerExpr; @@ -102,6 +107,7 @@ const Fortran::lower::SomeExpr *stepExpr; const Fortran::lower::SomeExpr *maskExpr = nullptr; bool isUnordered; // do concurrent, forall + llvm::SmallVector localSymList; llvm::SmallVector localInitSymList; llvm::SmallVector sharedSymList; mlir::Value loopVariable = nullptr; @@ -1514,6 +1520,10 @@ info.maskExpr = Fortran::semantics::GetExpr( std::get>(header.t)); for (const Fortran::parser::LocalitySpec &x : localityList) { + if (const auto *localList = + std::get_if(&x.u)) + for (const Fortran::parser::Name &x : localList->v) + info.localSymList.push_back(x.symbol); if (const auto *localInitList = std::get_if(&x.u)) for (const Fortran::parser::Name &x : localInitList->v) @@ -1522,12 +1532,38 @@ std::get_if(&x.u)) for (const Fortran::parser::Name &x : sharedList->v) info.sharedSymList.push_back(x.symbol); - if (std::get_if(&x.u)) - TODO(toLocation(), "do concurrent locality specs not implemented"); } return incrementLoopNestInfo; } + /// Create DO CONCURRENT construct symbol bindings and generate LOCAL_INIT + /// assignments. + void handleLocalitySpecs(const IncrementLoopInfo &info) { + Fortran::semantics::SemanticsContext &semanticsContext = + bridge.getSemanticsContext(); + for (const Fortran::semantics::Symbol *sym : info.localSymList) + createHostAssociateVarClone(*sym); + for (const Fortran::semantics::Symbol *sym : info.localInitSymList) { + createHostAssociateVarClone(*sym); + const auto *hostDetails = + sym->detailsIf(); + assert(hostDetails && "missing locality spec host symbol"); + const Fortran::semantics::Symbol *hostSym = &hostDetails->symbol(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; + Fortran::evaluate::Assignment assign{ + ea.Designate(Fortran::evaluate::DataRef{*sym}).value(), + ea.Designate(Fortran::evaluate::DataRef{*hostSym}).value()}; + if (Fortran::semantics::IsPointer(*sym)) + assign.u = Fortran::evaluate::Assignment::BoundsSpec{}; + genAssignment(assign); + } + for (const Fortran::semantics::Symbol *sym : info.sharedSymList) { + const auto *hostDetails = + sym->detailsIf(); + copySymbolBinding(hostDetails->symbol(), *sym); + } + } + /// Generate FIR for a DO construct. There are six variants: /// - unstructured infinite and while loops /// - structured and unstructured increment loops @@ -1656,25 +1692,6 @@ return builder->createRealConstant(loc, controlType, 1u); return builder->createIntegerConstant(loc, controlType, 1); // step }; - auto handleLocalitySpec = [&](IncrementLoopInfo &info) { - // Generate Local Init Assignments - for (const Fortran::semantics::Symbol *sym : info.localInitSymList) { - const auto *hostDetails = - sym->detailsIf(); - assert(hostDetails && "missing local_init variable host variable"); - const Fortran::semantics::Symbol &hostSym = hostDetails->symbol(); - (void)hostSym; - TODO(loc, "do concurrent locality specs not implemented"); - } - // Handle shared locality spec - for (const Fortran::semantics::Symbol *sym : info.sharedSymList) { - const auto *hostDetails = - sym->detailsIf(); - assert(hostDetails && "missing shared variable host variable"); - const Fortran::semantics::Symbol &hostSym = hostDetails->symbol(); - copySymbolBinding(hostSym, *sym); - } - }; for (IncrementLoopInfo &info : incrementLoopNestInfo) { info.loopVariable = genLoopVariableAddress(loc, info.loopVariableSym, info.isUnordered); @@ -1714,7 +1731,8 @@ /*withElseRegion=*/false); builder->setInsertionPointToStart(&ifOp.getThenRegion().front()); } - handleLocalitySpec(info); + if (info.hasLocalitySpecs()) + handleLocalitySpecs(info); continue; } @@ -1771,10 +1789,10 @@ if (&info != &incrementLoopNestInfo.back()) // not innermost startBlock(info.bodyBlock); // preheader block of enclosed dimension } - if (!info.localInitSymList.empty()) { + if (info.hasLocalitySpecs()) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); builder->setInsertionPointToStart(info.bodyBlock); - handleLocalitySpec(info); + handleLocalitySpecs(info); builder->restoreInsertionPoint(insertPt); } } Index: flang/lib/Lower/ConvertExpr.cpp =================================================================== --- flang/lib/Lower/ConvertExpr.cpp +++ flang/lib/Lower/ConvertExpr.cpp @@ -590,13 +590,15 @@ // associations. template const Fortran::semantics::Symbol &getFirstSym(const A &obj) { - return obj.GetFirstSymbol().GetUltimate(); + const Fortran::semantics::Symbol &sym = obj.GetFirstSymbol(); + return sym.HasLocalLocality() ? sym : sym.GetUltimate(); } // Helper to get the ultimate last symbol. template const Fortran::semantics::Symbol &getLastSym(const A &obj) { - return obj.GetLastSymbol().GetUltimate(); + const Fortran::semantics::Symbol &sym = obj.GetLastSymbol(); + return sym.HasLocalLocality() ? sym : sym.GetUltimate(); } // Return true if TRANSPOSE should be lowered without a runtime call. Index: flang/lib/Lower/SymbolMap.cpp =================================================================== --- flang/lib/Lower/SymbolMap.cpp +++ flang/lib/Lower/SymbolMap.cpp @@ -35,10 +35,10 @@ Fortran::lower::SymbolBox Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) { - Fortran::semantics::SymbolRef sym = symRef.get().GetUltimate(); + auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate(); for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend(); jmap != jend; ++jmap) { - auto iter = jmap->find(&*sym); + auto iter = jmap->find(sym); if (iter != jmap->end()) return iter->second; } @@ -47,8 +47,9 @@ Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol( Fortran::semantics::SymbolRef symRef) { + auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate(); auto &map = symbolMapStack.back(); - auto iter = map.find(&symRef.get().GetUltimate()); + auto iter = map.find(sym); if (iter != map.end()) return iter->second; return SymbolBox::None{}; @@ -59,14 +60,14 @@ /// host-association in OpenMP code. Fortran::lower::SymbolBox Fortran::lower::SymMap::lookupOneLevelUpSymbol( Fortran::semantics::SymbolRef symRef) { - Fortran::semantics::SymbolRef sym = symRef.get().GetUltimate(); + auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate(); auto jmap = symbolMapStack.rbegin(); auto jend = symbolMapStack.rend(); if (jmap == jend) return SymbolBox::None{}; // Skip one level in symbol map stack. for (++jmap; jmap != jend; ++jmap) { - auto iter = jmap->find(&*sym); + auto iter = jmap->find(sym); if (iter != jmap->end()) return iter->second; } Index: flang/test/Lower/loops.f90 =================================================================== --- flang/test/Lower/loops.f90 +++ flang/test/Lower/loops.f90 @@ -91,6 +91,71 @@ print '(" F:",X,F3.1,A,I2)', x, ' -', xsum end subroutine loop_test +! CHECK-LABEL: c.func @_QPlis +subroutine lis(n) + ! CHECK-DAG: fir.alloca i32 {bindc_name = "m"} + ! CHECK-DAG: fir.alloca i32 {bindc_name = "j"} + ! CHECK-DAG: fir.alloca i32 {bindc_name = "i"} + ! CHECK-DAG: fir.alloca i8 {bindc_name = "i"} + ! CHECK-DAG: fir.alloca i32 {bindc_name = "j", uniq_name = "_QFlisEj"} + ! CHECK-DAG: fir.alloca i32 {bindc_name = "k", uniq_name = "_QFlisEk"} + ! CHECK-DAG: fir.alloca !fir.box>> {bindc_name = "p", uniq_name = "_QFlisEp"} + ! CHECK-DAG: fir.alloca !fir.array, %{{.*}}, %{{.*}}, %{{.*}} {bindc_name = "a", fir.target, uniq_name = "_QFlisEa"} + ! CHECK-DAG: fir.alloca !fir.array, %{{.*}}, %{{.*}} {bindc_name = "r", uniq_name = "_QFlisEr"} + ! CHECK-DAG: fir.alloca !fir.array, %{{.*}}, %{{.*}} {bindc_name = "s", uniq_name = "_QFlisEs"} + ! CHECK-DAG: fir.alloca !fir.array, %{{.*}}, %{{.*}} {bindc_name = "t", uniq_name = "_QFlisEt"} + integer, target :: a(n,n,n) ! operand via p + integer :: r(n,n) ! result, unspecified locality + integer :: s(n,n) ! shared locality + integer :: t(n,n) ! local locality + integer, pointer :: p(:,:,:) ! local_init locality + + p => a + ! CHECK: fir.do_loop %arg1 = %c0{{.*}} to %{{.*}} step %c1{{.*}} unordered iter_args(%arg2 = %{{.*}}) -> (!fir.array) { + ! CHECK: fir.do_loop %arg3 = %c0{{.*}} to %{{.*}} step %c1{{.*}} unordered iter_args(%arg4 = %arg2) -> (!fir.array) { + ! CHECK: } + ! CHECK: } + r = 0 + + ! CHECK: fir.do_loop %arg1 = %{{.*}} to %{{.*}} step %{{.*}} unordered { + ! CHECK: fir.do_loop %arg2 = %{{.*}} to %{{.*}} step %c1{{.*}} iter_args(%arg3 = %{{.*}}) -> (index, i32) { + ! CHECK: } + ! CHECK: } + do concurrent (integer(kind=1)::i=n:1:-1) + do j = 1,n + a(i,j,:) = 2*(i+j) + s(i,j) = -i-j + enddo + enddo + + ! CHECK: fir.do_loop %arg1 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered { + ! CHECK: fir.do_loop %arg2 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered { + ! CHECK: fir.if %{{.*}} { + ! CHECK: %[[V_95:[0-9]+]] = fir.alloca !fir.array, %{{.*}}, %{{.*}} {bindc_name = "t", pinned, uniq_name = "_QFlisEt"} + ! CHECK: %[[V_96:[0-9]+]] = fir.alloca !fir.box>> {bindc_name = "p", pinned, uniq_name = "_QFlisEp"} + ! CHECK: fir.store %{{.*}} to %[[V_96]] : !fir.ref>>> + ! CHECK: fir.do_loop %arg3 = %{{.*}} to %{{.*}} step %c1{{.*}} iter_args(%arg4 = %{{.*}}) -> (index, i32) { + ! CHECK: fir.do_loop %arg5 = %{{.*}} to %{{.*}} step %c1{{.*}} unordered { + ! CHECK: fir.load %[[V_96]] : !fir.ref>>> + ! CHECK: fir.convert %[[V_95]] : (!fir.ref>) -> !fir.ref> + ! CHECK: } + ! CHECK: } + ! CHECK: fir.convert %[[V_95]] : (!fir.ref>) -> !fir.ref> + ! CHECK: } + ! CHECK: } + ! CHECK: } + do concurrent (i=1:n,j=1:n,i.ne.j) local(t) local_init(p) shared(s) + do k=1,n + do concurrent (m=1:n) + t(k,m) = p(k,m,k) + enddo + enddo + r(i,j) = t(i,j) + s(i,j) + enddo + + print*, sum(r) ! n=6 -> 210 +end + ! CHECK-LABEL: print_nothing subroutine print_nothing(k1, k2) if (k1 > 0) then @@ -105,5 +170,6 @@ end call loop_test + call lis(6) call print_nothing(2, 2) end