Index: include/clang/AST/OpenMPClause.h =================================================================== --- include/clang/AST/OpenMPClause.h +++ include/clang/AST/OpenMPClause.h @@ -2774,6 +2774,495 @@ } }; +/// \brief Struct that defines common infrastructure to handle mappable +/// expressions used in OpenMP clauses. +class OMPClauseMappableExprCommon { +public: + // \brief Class that represents a component of a mappable expression. E.g. + // for an expression S.a, the first component is a declaration reference + // expression associated with 'S' and the second is a member expression + // associated with the field declaration 'a'. If the expression is an array + // subscript it may not have any associated declaration. In that case the + // associated declaration is set to nullptr. + class MappableComponent { + // \brief Expression associated with the component. + Expr *AssociatedExpression = nullptr; + // \brief Declaration associated with the declaration. If the component does + // not have a declaration (e.g. array subscripts or section), this is set to + // nullptr. + ValueDecl *AssociatedDeclaration = nullptr; + + public: + explicit MappableComponent() {} + explicit MappableComponent(Expr *AssociatedExpression, + ValueDecl *AssociatedDeclaration) + : AssociatedExpression(AssociatedExpression), + AssociatedDeclaration( + AssociatedDeclaration + ? cast(AssociatedDeclaration->getCanonicalDecl()) + : nullptr) {} + + Expr *getAssociatedExpression() const { return AssociatedExpression; } + ValueDecl *getAssociatedDeclaration() const { + return AssociatedDeclaration; + } + }; + + // \brief List of components of an expression. This first one is the whole + // expression and the last one is the base expression. + typedef SmallVector MappableExprComponentList; + typedef ArrayRef MappableExprComponentListRef; + + // \brief List of all component lists associated to the same base declaration. + // E.g. if both 'S.a' and 'S.b' are a mappable expressions, each will have + // their component list but the same base declaration 'S'. + typedef SmallVector MappableExprComponentLists; + typedef ArrayRef MappableExprComponentListsRef; + +protected: + // \brief Return the total number of elements in a list of component lists. + static unsigned + getComponentsTotalNumber(MappableExprComponentListsRef ComponentLists); + + // \brief Return the total number of elements in a list of declarations. All + // declarations are expected to be canonical. + static unsigned + getUniqueDeclarationsTotalNumber(ArrayRef Declarations); +}; + +/// \brief This represents clauses with a list of expressions that are mappable. +/// Examples of these clauses are 'map' in +/// '#pragma omp target [enter|exit] [data]...' directives, and 'to' and 'from +/// in '#pragma omp target update...' directives. +template +class OMPMappableExprListClause : public OMPVarListClause, + public OMPClauseMappableExprCommon { + friend class OMPClauseReader; + + /// \brief Number of unique declarations in this clause. + unsigned NumUniqueDeclarations; + + /// \brief Number of component lists in this clause. + unsigned NumComponentLists; + + /// \brief Total number of components in this clause. + unsigned NumComponents; + +protected: + /// \brief Get the unique declarations that are in the trailing objects of the + /// class. + MutableArrayRef getUniqueDeclsRef() { + return MutableArrayRef( + static_cast(this)->template getTrailingObjects(), + NumUniqueDeclarations); + } + + /// \brief Get the unique declarations that are in the trailing objects of the + /// class. + ArrayRef getUniqueDeclsRef() const { + return ArrayRef( + static_cast(this) + ->template getTrailingObjects(), + NumUniqueDeclarations); + } + + /// \brief Set the unique declarations that are in the trailing objects of the + /// class. + void setUniqueDecls(ArrayRef UDs) { + assert(UDs.size() == NumUniqueDeclarations && + "Unexpected amount of unique declarations."); + std::copy(UDs.begin(), UDs.end(), getUniqueDeclsRef().begin()); + } + + /// \brief Get the number of lists per declaration that are in the trailing + /// objects of the class. + MutableArrayRef getDeclNumListsRef() { + return MutableArrayRef( + static_cast(this)->template getTrailingObjects(), + NumUniqueDeclarations); + } + + /// \brief Get the number of lists per declaration that are in the trailing + /// objects of the class. + ArrayRef getDeclNumListsRef() const { + return ArrayRef( + static_cast(this)->template getTrailingObjects(), + NumUniqueDeclarations); + } + + /// \brief Set the number of lists per declaration that are in the trailing + /// objects of the class. + void setDeclNumLists(ArrayRef DNLs) { + assert(DNLs.size() == NumUniqueDeclarations && + "Unexpected amount of list numbers."); + std::copy(DNLs.begin(), DNLs.end(), getDeclNumListsRef().begin()); + } + + /// \brief Get the cumulative component lists sizes that are in the trailing + /// objects of the class. They are appended after the number of lists. + MutableArrayRef getComponentListSizesRef() { + return MutableArrayRef( + static_cast(this)->template getTrailingObjects() + + NumUniqueDeclarations, + NumComponentLists); + } + + /// \brief Get the cumulative component lists sizes that are in the trailing + /// objects of the class. They are appended after the number of lists. + ArrayRef getComponentListSizesRef() const { + return ArrayRef( + static_cast(this)->template getTrailingObjects() + + NumUniqueDeclarations, + NumComponentLists); + } + + /// \brief Set the cumulative component lists sizes that are in the trailing + /// objects of the class. + void setComponentListSizes(ArrayRef CLSs) { + assert(CLSs.size() == NumComponentLists && + "Unexpected amount of component lists."); + std::copy(CLSs.begin(), CLSs.end(), getComponentListSizesRef().begin()); + } + + /// \brief Get the components that are in the trailing objects of the class. + MutableArrayRef getComponentsRef() { + return MutableArrayRef( + static_cast(this) + ->template getTrailingObjects(), + NumComponents); + } + + /// \brief Get the components that are in the trailing objects of the class. + ArrayRef getComponentsRef() const { + return ArrayRef( + static_cast(this) + ->template getTrailingObjects(), + NumComponents); + } + + /// \brief Set the components that are in the trailing objects of the class. + /// This requires the list sizes so that it can also fill the original + /// expressions, which are the first component of each list. + void setComponents(ArrayRef Components, + ArrayRef CLSs) { + assert(Components.size() == NumComponents && + "Unexpected amount of component lists."); + assert(CLSs.size() == NumComponentLists && + "Unexpected amount of list sizes."); + std::copy(Components.begin(), Components.end(), getComponentsRef().begin()); + } + + /// \brief Fill the clause information from the list of declarations and + /// associated component lists. + void setClauseInfo(ArrayRef Declarations, + MappableExprComponentListsRef ComponentLists) { + // Perform some checks to make sure the data sizes are consistent with the + // information available when the clause was created. + assert(getUniqueDeclarationsTotalNumber(Declarations) == + NumUniqueDeclarations && + "Unexpected number of mappable expression info entries!"); + assert(getComponentsTotalNumber(ComponentLists) == NumComponents && + "Unexpected total number of components!"); + assert(Declarations.size() == ComponentLists.size() && + "Declaration and component lists size is not consistent!"); + assert(Declarations.size() == NumComponentLists && + "Unexpected declaration and component lists size!"); + + // Organize the components by declaration and retrieve the original + // expression. Original expressions are always the first component of the + // mappable component list. + llvm::DenseMap> + ComponentListMap; + { + auto CI = ComponentLists.begin(); + for (auto DI = Declarations.begin(), DE = Declarations.end(); DI != DE; + ++DI, ++CI) { + assert(!CI->empty() && "Invalid component list!"); + ComponentListMap[*DI].push_back(*CI); + } + } + + // Iterators of the target storage. + auto UniqueDeclarations = getUniqueDeclsRef(); + auto UDI = UniqueDeclarations.begin(); + + auto DeclNumLists = getDeclNumListsRef(); + auto DNLI = DeclNumLists.begin(); + + auto ComponentListSizes = getComponentListSizesRef(); + auto CLSI = ComponentListSizes.begin(); + + auto Components = getComponentsRef(); + auto CI = Components.begin(); + + // Variable to compute the accumulation of the number of components. + unsigned PrevSize = 0u; + + // Scan all the declarations and associated component lists. + for (auto &M : ComponentListMap) { + // The declaration. + auto *D = M.first; + // The component lists. + auto CL = M.second; + + // Initialize the entry. + *UDI = D; + ++UDI; + + *DNLI = CL.size(); + ++DNLI; + + // Obtain the cumulative sizes and concatenate all the components in the + // reserved storage. + for (auto C : CL) { + // Accumulate with the previous size. + PrevSize += C.size(); + + // Save the size. + *CLSI = PrevSize; + ++CLSI; + + // Append components after the current components iterator. + CI = std::copy(C.begin(), C.end(), CI); + } + } + } + + /// \brief Build a clause for \a NumUniqueDeclarations declarations, \a + /// NumComponentLists total component lists, and \a NumComponents total + /// components. + /// + /// \param K Kind of the clause. + /// \param StartLoc Starting location of the clause (the clause keyword). + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + /// \param NumVars Number of expressions listed in the clause. + /// \param NumUniqueDeclarations Number of unique base declarations in this + /// clause. + /// \param NumComponentLists Number of component lists in this clause - one + /// list for each expression in the clause. + /// \param NumComponents Total number of expression components in the clause. + /// + OMPMappableExprListClause(OpenMPClauseKind K, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc, + unsigned NumVars, unsigned NumUniqueDeclarations, + unsigned NumComponentLists, unsigned NumComponents) + : OMPVarListClause(K, StartLoc, LParenLoc, EndLoc, NumVars), + NumUniqueDeclarations(NumUniqueDeclarations), + NumComponentLists(NumComponentLists), NumComponents(NumComponents) {} + +public: + /// \brief Return the number of unique base declarations in this clause. + unsigned getUniqueDeclarationsNum() const { return NumUniqueDeclarations; } + /// \brief Return the number of lists derived from the clause expressions. + unsigned getTotalComponentListNum() const { return NumComponentLists; } + /// \brief Return the total number of components in all lists derived from the + /// clause. + unsigned getTotalComponentsNum() const { return NumComponents; } + + /// \brief Iterator that browse the components by lists. It also allows + /// browsing components of a single declaration. + class const_component_lists_iterator + : public llvm::iterator_adaptor_base< + const_component_lists_iterator, + MappableExprComponentListRef::const_iterator, + std::forward_iterator_tag, MappableComponent, ptrdiff_t, + MappableComponent, MappableComponent> { + // The declaration the iterator currently refers to. + ArrayRef::iterator DeclCur; + + // The list number associated with the current declaration. + ArrayRef::iterator NumListsCur; + + // Remaining lists for the current declaration. + unsigned RemainingLists; + + // The cumulative size of the previous list, or zero if there is no previous + // list. + unsigned PrevListSize; + + // The cumulative sizes of the current list - it will delimit the remaining + // range of interest. + ArrayRef::const_iterator ListSizeCur; + ArrayRef::const_iterator ListSizeEnd; + + // Iterator to the end of the components storage. + MappableExprComponentListRef::const_iterator End; + + public: + /// \brief Construct an iterator that scans all lists. + explicit const_component_lists_iterator( + ArrayRef UniqueDecls, ArrayRef DeclsListNum, + ArrayRef CumulativeListSizes, + MappableExprComponentListRef Components) + : const_component_lists_iterator::iterator_adaptor_base( + Components.begin()), + DeclCur(UniqueDecls.begin()), NumListsCur(DeclsListNum.begin()), + RemainingLists(0u), PrevListSize(0u), + ListSizeCur(CumulativeListSizes.begin()), + ListSizeEnd(CumulativeListSizes.end()), End(Components.end()) { + assert(UniqueDecls.size() == DeclsListNum.size() && + "Inconsistent number of declarations and list sizes!"); + if (!DeclsListNum.empty()) + RemainingLists = *NumListsCur; + } + + /// \brief Construct an iterator that scan lists for a given declaration \a + /// Declaration. + explicit const_component_lists_iterator( + const ValueDecl *Declaration, ArrayRef UniqueDecls, + ArrayRef DeclsListNum, ArrayRef CumulativeListSizes, + MappableExprComponentListRef Components) + : const_component_lists_iterator(UniqueDecls, DeclsListNum, + CumulativeListSizes, Components) { + + // Look for the desired declaration. While we are looking for it, we + // update the state so that we know the component where a given list + // starts. + for (; DeclCur != UniqueDecls.end(); ++DeclCur, ++NumListsCur) { + if (*DeclCur == Declaration) + break; + + assert(*NumListsCur > 0 && "No lists associated with declaration??"); + + // Skip the lists associated with the current declaration, but save the + // last list size that was skipped. + std::advance(ListSizeCur, *NumListsCur - 1); + PrevListSize = *ListSizeCur; + ++ListSizeCur; + } + + // If we didn't find any declaration, advance the iterator to after the + // last component and set remaining lists to zero. + if (ListSizeCur == CumulativeListSizes.end()) { + this->I = End; + RemainingLists = 0u; + return; + } + + // Set the remaining lists with the total number of lists of the current + // declaration. + RemainingLists = *NumListsCur; + + // Adjust the list size end iterator to the end of the relevant range. + ListSizeEnd = ListSizeCur; + std::advance(ListSizeEnd, RemainingLists); + + // Given that the list sizes are cumulative, the index of the component + // that start the list is the size of the previous list. + std::advance(this->I, PrevListSize); + } + + // Return the array with the current list. The sizes are cumulative, so the + // array size is the difference between the current size and previous one. + std::pair + operator*() const { + assert(ListSizeCur != ListSizeEnd && "Invalid iterator!"); + return std::make_pair( + *DeclCur, + MappableExprComponentListRef(&*this->I, *ListSizeCur - PrevListSize)); + } + std::pair + operator->() const { + return **this; + } + + // Skip the components of the current list. + const_component_lists_iterator &operator++() { + assert(ListSizeCur != ListSizeEnd && RemainingLists && + "Invalid iterator!"); + + // If we don't have more lists just skip all the components. Otherwise, + // advance the iterator by the number of components in the current list. + if (std::next(ListSizeCur) == ListSizeEnd) { + this->I = End; + RemainingLists = 0; + } else { + std::advance(this->I, *ListSizeCur - PrevListSize); + PrevListSize = *ListSizeCur; + + // We are done with a declaration, move to the next one. + if (!(--RemainingLists)) { + ++DeclCur; + ++NumListsCur; + RemainingLists = *NumListsCur; + assert(RemainingLists && "No lists in the following declaration??"); + } + } + + ++ListSizeCur; + return *this; + } + }; + + typedef llvm::iterator_range + const_component_lists_range; + + /// \brief Iterators for all component lists. + const_component_lists_iterator component_lists_begin() const { + return const_component_lists_iterator( + getUniqueDeclsRef(), getDeclNumListsRef(), getComponentListSizesRef(), + getComponentsRef()); + } + const_component_lists_iterator component_lists_end() const { + return const_component_lists_iterator( + ArrayRef(), ArrayRef(), ArrayRef(), + MappableExprComponentListRef(getComponentsRef().end(), + getComponentsRef().end())); + } + const_component_lists_range component_lists() const { + return {component_lists_begin(), component_lists_end()}; + } + + /// \brief Iterators for component lists associated with the provided + /// declaration. + const_component_lists_iterator + decl_component_lists_begin(const ValueDecl *VD) const { + return const_component_lists_iterator( + VD, getUniqueDeclsRef(), getDeclNumListsRef(), + getComponentListSizesRef(), getComponentsRef()); + } + const_component_lists_iterator decl_component_lists_end() const { + return component_lists_end(); + } + const_component_lists_range decl_component_lists(const ValueDecl *VD) const { + return {decl_component_lists_begin(VD), decl_component_lists_end()}; + } + + /// Iterators to access all the declarations, number of lists, list sizes, and + /// components. + typedef ArrayRef::iterator const_all_decls_iterator; + typedef llvm::iterator_range const_all_decls_range; + const_all_decls_range all_decls() const { + auto A = getUniqueDeclsRef(); + return const_all_decls_range(A.begin(), A.end()); + } + + typedef ArrayRef::iterator const_all_num_lists_iterator; + typedef llvm::iterator_range + const_all_num_lists_range; + const_all_num_lists_range all_num_lists() const { + auto A = getDeclNumListsRef(); + return const_all_num_lists_range(A.begin(), A.end()); + } + + typedef ArrayRef::iterator const_all_lists_sizes_iterator; + typedef llvm::iterator_range + const_all_lists_sizes_range; + const_all_lists_sizes_range all_lists_sizes() const { + auto A = getComponentListSizesRef(); + return const_all_lists_sizes_range(A.begin(), A.end()); + } + + typedef ArrayRef::iterator const_all_components_iterator; + typedef llvm::iterator_range + const_all_components_range; + const_all_components_range all_components() const { + auto A = getComponentsRef(); + return const_all_components_range(A.begin(), A.end()); + } +}; + /// \brief This represents clause 'map' in the '#pragma omp ...' /// directives. /// @@ -2783,12 +3272,27 @@ /// In this example directive '#pragma omp target' has clause 'map' /// with the variables 'a' and 'b'. /// -class OMPMapClause final : public OMPVarListClause, - private llvm::TrailingObjects { +class OMPMapClause final : public OMPMappableExprListClause, + private llvm::TrailingObjects< + OMPMapClause, Expr *, ValueDecl *, unsigned, + OMPClauseMappableExprCommon::MappableComponent> { friend TrailingObjects; friend OMPVarListClause; + friend OMPMappableExprListClause; friend class OMPClauseReader; + /// Define the sizes of each trailing object array except the last one. This + /// is required for TrailingObjects to work properly. + size_t numTrailingObjects(OverloadToken) const { + return varlist_size(); + } + size_t numTrailingObjects(OverloadToken) const { + return getUniqueDeclarationsNum(); + } + size_t numTrailingObjects(OverloadToken) const { + return getUniqueDeclarationsNum() + getTotalComponentListNum(); + } + /// \brief Map type modifier for the 'map' clause. OpenMPMapClauseKind MapTypeModifier; /// \brief Map type for the 'map' clause. @@ -2821,7 +3325,9 @@ /// \brief Set colon location. void setColonLoc(SourceLocation Loc) { ColonLoc = Loc; } - /// \brief Build clause with number of variables \a N. + /// \brief Build a clause for \a NumVars listed expressions, \a + /// NumUniqueDeclarations declarations, \a NumComponentLists total component + /// lists, and \a NumComponents total expression components. /// /// \param MapTypeModifier Map type modifier. /// \param MapType Map type. @@ -2829,25 +3335,37 @@ /// \param MapLoc Location of the map type. /// \param StartLoc Starting location of the clause. /// \param EndLoc Ending location of the clause. - /// \param N Number of the variables in the clause. + /// \param NumVars Number of expressions listed in this clause. + /// \param NumUniqueDeclarations Number of unique base declarations in this + /// clause. + /// \param NumComponentLists Number of component lists in this clause. + /// \param NumComponents Total number of expression components in the clause. /// explicit OMPMapClause(OpenMPMapClauseKind MapTypeModifier, OpenMPMapClauseKind MapType, bool MapTypeIsImplicit, SourceLocation MapLoc, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc, - unsigned N) - : OMPVarListClause(OMPC_map, StartLoc, LParenLoc, EndLoc, - N), + unsigned NumVars, unsigned NumUniqueDeclarations, + unsigned NumComponentLists, unsigned NumComponents) + : OMPMappableExprListClause(OMPC_map, StartLoc, LParenLoc, EndLoc, + NumVars, NumUniqueDeclarations, + NumComponentLists, NumComponents), MapTypeModifier(MapTypeModifier), MapType(MapType), MapTypeIsImplicit(MapTypeIsImplicit), MapLoc(MapLoc) {} /// \brief Build an empty clause. /// - /// \param N Number of variables. - /// - explicit OMPMapClause(unsigned N) - : OMPVarListClause(OMPC_map, SourceLocation(), - SourceLocation(), SourceLocation(), N), + /// \param NumVars Number of expressions listed in this clause. + /// \param NumUniqueDeclarations Number of unique base declarations in this + /// clause. + /// \param NumComponentLists Number of component lists in this clause. + /// \param NumComponents Total number of expression components in the clause. + /// + explicit OMPMapClause(unsigned NumVars, unsigned NumUniqueDeclarations, + unsigned NumComponentLists, unsigned NumComponents) + : OMPMappableExprListClause( + OMPC_map, SourceLocation(), SourceLocation(), SourceLocation(), + NumVars, NumUniqueDeclarations, NumComponentLists, NumComponents), MapTypeModifier(OMPC_MAP_unknown), MapType(OMPC_MAP_unknown), MapTypeIsImplicit(false), MapLoc() {} @@ -2857,7 +3375,9 @@ /// \param C AST context. /// \param StartLoc Starting location of the clause. /// \param EndLoc Ending location of the clause. - /// \param VL List of references to the variables. + /// \param Vars The original expression used in the clause. + /// \param Declarations Declarations used in the clause. + /// \param ComponentLists Component lists used in the clause. /// \param TypeModifier Map type modifier. /// \param Type Map type. /// \param TypeIsImplicit Map type is inferred implicitly. @@ -2865,16 +3385,28 @@ /// static OMPMapClause *Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc, - ArrayRef VL, + ArrayRef Vars, + ArrayRef Declarations, + MappableExprComponentListsRef ComponentLists, OpenMPMapClauseKind TypeModifier, OpenMPMapClauseKind Type, bool TypeIsImplicit, SourceLocation TypeLoc); - /// \brief Creates an empty clause with the place for \a N variables. + /// \brief Creates an empty clause with the place for for \a NumVars original + /// expressions, \a NumUniqueDeclarations declarations, \NumComponentLists + /// lists, and \a NumComponents expression components. /// /// \param C AST context. - /// \param N The number of variables. - /// - static OMPMapClause *CreateEmpty(const ASTContext &C, unsigned N); + /// \param NumVars Number of expressions listed in the clause. + /// \param NumUniqueDeclarations Number of unique base declarations in this + /// clause. + /// \param NumComponentLists Number of unique base declarations in this + /// clause. + /// \param NumComponents Total number of expression components in the clause. + /// + static OMPMapClause *CreateEmpty(const ASTContext &C, unsigned NumVars, + unsigned NumUniqueDeclarations, + unsigned NumComponentLists, + unsigned NumComponents); /// \brief Fetches mapping kind for the clause. OpenMPMapClauseKind getMapType() const LLVM_READONLY { return MapType; } Index: lib/AST/OpenMPClause.cpp =================================================================== --- lib/AST/OpenMPClause.cpp +++ lib/AST/OpenMPClause.cpp @@ -530,25 +530,78 @@ return new (Mem) OMPDependClause(N); } -OMPMapClause *OMPMapClause::Create(const ASTContext &C, SourceLocation StartLoc, - SourceLocation LParenLoc, - SourceLocation EndLoc, ArrayRef VL, - OpenMPMapClauseKind TypeModifier, - OpenMPMapClauseKind Type, - bool TypeIsImplicit, - SourceLocation TypeLoc) { - void *Mem = C.Allocate(totalSizeToAlloc(VL.size())); - OMPMapClause *Clause = - new (Mem) OMPMapClause(TypeModifier, Type, TypeIsImplicit, TypeLoc, - StartLoc, LParenLoc, EndLoc, VL.size()); - Clause->setVarRefs(VL); +unsigned OMPClauseMappableExprCommon::getComponentsTotalNumber( + MappableExprComponentListsRef ComponentLists) { + unsigned TotalNum = 0u; + for (auto &C : ComponentLists) + TotalNum += C.size(); + return TotalNum; +} + +unsigned OMPClauseMappableExprCommon::getUniqueDeclarationsTotalNumber( + ArrayRef Declarations) { + unsigned TotalNum = 0u; + llvm::SmallPtrSet Cache; + for (auto *D : Declarations) { + const ValueDecl *VD = D ? cast(D->getCanonicalDecl()) : nullptr; + if (Cache.count(VD)) + continue; + ++TotalNum; + Cache.insert(VD); + } + return TotalNum; +} + +OMPMapClause * +OMPMapClause::Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc, + ArrayRef Vars, ArrayRef Declarations, + MappableExprComponentListsRef ComponentLists, + OpenMPMapClauseKind TypeModifier, OpenMPMapClauseKind Type, + bool TypeIsImplicit, SourceLocation TypeLoc) { + + unsigned NumVars = Vars.size(); + unsigned NumUniqueDeclarations = + getUniqueDeclarationsTotalNumber(Declarations); + unsigned NumComponentLists = ComponentLists.size(); + unsigned NumComponents = getComponentsTotalNumber(ComponentLists); + + // We need to allocate: + // NumVars x Expr* - we have an original list expression for each clause list + // entry. + // NumUniqueDeclarations x ValueDecl* - unique base declarations associated + // with each component list. + // (NumUniqueDeclarations + NumComponentLists) x unsigned - we specify the + // number of lists for each unique declaration and the size of each component + // list. + // NumComponents x MappableComponent - the total of all the components in all + // the lists. + void *Mem = C.Allocate( + totalSizeToAlloc( + NumVars, NumUniqueDeclarations, + NumUniqueDeclarations + NumComponentLists, NumComponents)); + OMPMapClause *Clause = new (Mem) OMPMapClause( + TypeModifier, Type, TypeIsImplicit, TypeLoc, StartLoc, LParenLoc, EndLoc, + NumVars, NumUniqueDeclarations, NumComponentLists, NumComponents); + + Clause->setVarRefs(Vars); + Clause->setClauseInfo(Declarations, ComponentLists); Clause->setMapTypeModifier(TypeModifier); Clause->setMapType(Type); Clause->setMapLoc(TypeLoc); return Clause; } -OMPMapClause *OMPMapClause::CreateEmpty(const ASTContext &C, unsigned N) { - void *Mem = C.Allocate(totalSizeToAlloc(N)); - return new (Mem) OMPMapClause(N); +OMPMapClause *OMPMapClause::CreateEmpty(const ASTContext &C, unsigned NumVars, + unsigned NumUniqueDeclarations, + unsigned NumComponentLists, + unsigned NumComponents) { + void *Mem = C.Allocate( + totalSizeToAlloc( + NumVars, NumUniqueDeclarations, + NumUniqueDeclarations + NumComponentLists, NumComponents)); + return new (Mem) OMPMapClause(NumVars, NumUniqueDeclarations, + NumComponentLists, NumComponents); } Index: lib/Sema/SemaOpenMP.cpp =================================================================== --- lib/Sema/SemaOpenMP.cpp +++ lib/Sema/SemaOpenMP.cpp @@ -81,8 +81,6 @@ }; private: - typedef SmallVector MapInfo; - struct DSAInfo { OpenMPClauseKind Attributes; Expr *RefExpr; @@ -92,14 +90,16 @@ typedef llvm::DenseMap AlignedMapTy; typedef std::pair LCDeclInfo; typedef llvm::DenseMap LoopControlVariablesMapTy; - typedef llvm::DenseMap MappedDeclsTy; + typedef llvm::DenseMap< + ValueDecl *, OMPClauseMappableExprCommon::MappableExprComponentLists> + MappedExprComponentsTy; typedef llvm::StringMap> CriticalsWithHintsTy; struct SharingMapTy { DeclSAMapTy SharingMap; AlignedMapTy AlignedMap; - MappedDeclsTy MappedDecls; + MappedExprComponentsTy MappedExprComponents; LoopControlVariablesMapTy LCVMap; DefaultDataSharingAttributes DefaultAttr; SourceLocation DefaultAttrLoc; @@ -340,11 +340,12 @@ Scope *getCurScope() { return Stack.back().CurScope; } SourceLocation getConstructLoc() { return Stack.back().ConstructLoc; } - // Do the check specified in MapInfoCheck and return true if any issue is - // found. - template - bool checkMapInfoForVar(ValueDecl *VD, bool CurrentRegionOnly, - MapInfoCheck Check) { + // Do the check specified in \a Check to all component lists and return true + // if any issue is found. + bool checkMappableExprComponentListsForDecl( + ValueDecl *VD, bool CurrentRegionOnly, + const llvm::function_ref &Check) { auto SI = Stack.rbegin(); auto SE = Stack.rend(); @@ -358,21 +359,26 @@ } for (; SI != SE; ++SI) { - auto MI = SI->MappedDecls.find(VD); - if (MI != SI->MappedDecls.end()) { - for (Expr *E : MI->second) { - if (Check(E)) + auto MI = SI->MappedExprComponents.find(VD); + if (MI != SI->MappedExprComponents.end()) + for (auto &L : MI->second) + if (Check(L)) return true; - } - } } return false; } - void addExprToVarMapInfo(ValueDecl *VD, Expr *E) { - if (Stack.size() > 1) { - Stack.back().MappedDecls[VD].push_back(E); - } + // Create a new mappable expression component list associated with a given + // declaration and initialize it with the provided list of components. + void addMappableExpressionComponents( + ValueDecl *VD, + OMPClauseMappableExprCommon::MappableExprComponentListRef Components) { + assert(Stack.size() > 1 && + "Not expecting to retrieve components from a empty stack!"); + auto &MEC = Stack.back().MappedExprComponents[VD]; + // Create new entry and append the new components there. + MEC.resize(MEC.size() + 1); + MEC.back().append(Components.begin(), Components.end()); } }; bool isParallelOrTaskRegion(OpenMPDirectiveKind DKind) { @@ -7554,8 +7560,10 @@ // A list item cannot appear in both a map clause and a data-sharing // attribute clause on the same construct if (DSAStack->getCurrentDirective() == OMPD_target) { - if(DSAStack->checkMapInfoForVar(VD, /* CurrentRegionOnly = */ true, - [&](Expr *RE) -> bool {return true;})) { + if (DSAStack->checkMappableExprComponentListsForDecl( + VD, /* CurrentRegionOnly = */ true, + [&](OMPClauseMappableExprCommon::MappableExprComponentListRef) + -> bool { return true; })) { Diag(ELoc, diag::err_omp_variable_in_map_and_dsa) << getOpenMPClauseName(OMPC_private) << getOpenMPDirectiveName(DSAStack->getCurrentDirective()); @@ -7799,8 +7807,10 @@ // A list item cannot appear in both a map clause and a data-sharing // attribute clause on the same construct if (CurrDir == OMPD_target) { - if(DSAStack->checkMapInfoForVar(VD, /* CurrentRegionOnly = */ true, - [&](Expr *RE) -> bool {return true;})) { + if (DSAStack->checkMappableExprComponentListsForDecl( + VD, /* CurrentRegionOnly = */ true, + [&](OMPClauseMappableExprCommon::MappableExprComponentListRef) + -> bool { return true; })) { Diag(ELoc, diag::err_omp_variable_in_map_and_dsa) << getOpenMPClauseName(OMPC_firstprivate) << getOpenMPDirectiveName(DSAStack->getCurrentDirective()); @@ -9706,8 +9716,11 @@ // Return the expression of the base of the map clause or null if it cannot // be determined and do all the necessary checks to see if the expression is -// valid as a standalone map clause expression. -static Expr *CheckMapClauseExpressionBase(Sema &SemaRef, Expr *E) { +// valid as a standalone map clause expression. In the process, record all the +// components of the expression. +static Expr *CheckMapClauseExpressionBase( + Sema &SemaRef, Expr *E, + OMPClauseMappableExprCommon::MappableExprComponentList &CurComponents) { SourceLocation ELoc = E->getExprLoc(); SourceRange ERange = E->getSourceRange(); @@ -9765,6 +9778,10 @@ // section before that. AllowUnitySizeArraySection = false; AllowWholeSizeArraySection = false; + + // Record the component. + CurComponents.push_back(OMPClauseMappableExprCommon::MappableComponent( + CurE, CurE->getDecl())); continue; } @@ -9819,6 +9836,10 @@ // AllowUnitySizeArraySection = false; AllowWholeSizeArraySection = false; + + // Record the component. + CurComponents.push_back( + OMPClauseMappableExprCommon::MappableComponent(CurE, FD)); continue; } @@ -9837,6 +9858,10 @@ if (CheckArrayExpressionDoesNotReferToWholeSize(SemaRef, CurE, E->getType())) AllowWholeSizeArraySection = false; + + // Record the component - we don't have any declaration associated. + CurComponents.push_back( + OMPClauseMappableExprCommon::MappableComponent(CurE, nullptr)); continue; } @@ -9882,6 +9907,10 @@ << CurE->getSourceRange(); break; } + + // Record the component - we don't have any declaration associated. + CurComponents.push_back( + OMPClauseMappableExprCommon::MappableComponent(CurE, nullptr)); continue; } @@ -9897,57 +9926,11 @@ // Return true if expression E associated with value VD has conflicts with other // map information. -static bool CheckMapConflicts(Sema &SemaRef, DSAStackTy *DSAS, ValueDecl *VD, - Expr *E, bool CurrentRegionOnly) { +static bool CheckMapConflicts( + Sema &SemaRef, DSAStackTy *DSAS, ValueDecl *VD, Expr *E, + bool CurrentRegionOnly, + OMPClauseMappableExprCommon::MappableExprComponentListRef CurComponents) { assert(VD && E); - - // Types used to organize the components of a valid map clause. - typedef std::pair MapExpressionComponent; - typedef SmallVector MapExpressionComponents; - - // Helper to extract the components in the map clause expression E and store - // them into MEC. This assumes that E is a valid map clause expression, i.e. - // it has already passed the single clause checks. - auto ExtractMapExpressionComponents = [](Expr *TE, - MapExpressionComponents &MEC) { - while (true) { - TE = TE->IgnoreParenImpCasts(); - - if (auto *CurE = dyn_cast(TE)) { - MEC.push_back( - MapExpressionComponent(CurE, cast(CurE->getDecl()))); - break; - } - - if (auto *CurE = dyn_cast(TE)) { - auto *BaseE = CurE->getBase()->IgnoreParenImpCasts(); - - MEC.push_back(MapExpressionComponent( - CurE, cast(CurE->getMemberDecl()))); - if (isa(BaseE)) - break; - - TE = BaseE; - continue; - } - - if (auto *CurE = dyn_cast(TE)) { - MEC.push_back(MapExpressionComponent(CurE, nullptr)); - TE = CurE->getBase()->IgnoreParenImpCasts(); - continue; - } - - if (auto *CurE = dyn_cast(TE)) { - MEC.push_back(MapExpressionComponent(CurE, nullptr)); - TE = CurE->getBase()->IgnoreParenImpCasts(); - continue; - } - - llvm_unreachable( - "Expecting only valid map clause expressions at this point!"); - } - }; - SourceLocation ELoc = E->getExprLoc(); SourceRange ERange = E->getSourceRange(); @@ -9955,26 +9938,27 @@ // the expression under test with the components of the expressions that are // already in the stack. - MapExpressionComponents CurComponents; - ExtractMapExpressionComponents(E, CurComponents); - assert(!CurComponents.empty() && "Map clause expression with no components!"); - assert(CurComponents.back().second == VD && + assert(CurComponents.back().getAssociatedDeclaration() == VD && "Map clause expression with unexpected base!"); // Variables to help detecting enclosing problems in data environment nests. bool IsEnclosedByDataEnvironmentExpr = false; - Expr *EnclosingExpr = nullptr; + const Expr *EnclosingExpr = nullptr; + + bool FoundError = DSAS->checkMappableExprComponentListsForDecl( + VD, CurrentRegionOnly, + [&](OMPClauseMappableExprCommon::MappableExprComponentListRef + StackComponents) -> bool { - bool FoundError = - DSAS->checkMapInfoForVar(VD, CurrentRegionOnly, [&](Expr *RE) -> bool { - MapExpressionComponents StackComponents; - ExtractMapExpressionComponents(RE, StackComponents); assert(!StackComponents.empty() && "Map clause expression with no components!"); - assert(StackComponents.back().second == VD && + assert(StackComponents.back().getAssociatedDeclaration() == VD && "Map clause expression with unexpected base!"); + // The whole expression in the stack. + auto *RE = StackComponents.front().getAssociatedExpression(); + // Expressions must start from the same base. Here we detect at which // point both expressions diverge from each other and see if we can // detect if the memory referred to both expressions is contiguous and @@ -9988,25 +9972,27 @@ // OpenMP 4.5 [2.15.5.1, map Clause, Restrictions, p.3] // At most one list item can be an array item derived from a given // variable in map clauses of the same construct. - if (CurrentRegionOnly && (isa(CI->first) || - isa(CI->first)) && - (isa(SI->first) || - isa(SI->first))) { - SemaRef.Diag(CI->first->getExprLoc(), + if (CurrentRegionOnly && + (isa(CI->getAssociatedExpression()) || + isa(CI->getAssociatedExpression())) && + (isa(SI->getAssociatedExpression()) || + isa(SI->getAssociatedExpression()))) { + SemaRef.Diag(CI->getAssociatedExpression()->getExprLoc(), diag::err_omp_multiple_array_items_in_map_clause) - << CI->first->getSourceRange(); - ; - SemaRef.Diag(SI->first->getExprLoc(), diag::note_used_here) - << SI->first->getSourceRange(); + << CI->getAssociatedExpression()->getSourceRange(); + SemaRef.Diag(SI->getAssociatedExpression()->getExprLoc(), + diag::note_used_here) + << SI->getAssociatedExpression()->getSourceRange(); return true; } // Do both expressions have the same kind? - if (CI->first->getStmtClass() != SI->first->getStmtClass()) + if (CI->getAssociatedExpression()->getStmtClass() != + SI->getAssociatedExpression()->getStmtClass()) break; // Are we dealing with different variables/fields? - if (CI->second != SI->second) + if (CI->getAssociatedDeclaration() != SI->getAssociatedDeclaration()) break; } @@ -10030,14 +10016,15 @@ } } - QualType DerivedType = std::prev(CI)->first->getType(); - SourceLocation DerivedLoc = std::prev(CI)->first->getExprLoc(); + QualType DerivedType = + std::prev(CI)->getAssociatedDeclaration()->getType(); + SourceLocation DerivedLoc = + std::prev(CI)->getAssociatedExpression()->getExprLoc(); // OpenMP 4.5 [2.15.5.1, map Clause, Restrictions, C++, p.1] // If the type of a list item is a reference to a type T then the type // will be considered to be T for all purposes of this clause. - if (DerivedType->isReferenceType()) - DerivedType = DerivedType->getPointeeType(); + DerivedType = DerivedType.getNonReferenceType(); // OpenMP 4.5 [2.15.5.1, map Clause, Restrictions, C/C++, p.1] // A variable for which the type is pointer and an array section @@ -10079,7 +10066,7 @@ } // The current expression uses the same base as other expression in the - // data environment but does not contain it completelly. + // data environment but does not contain it completely. if (!CurrentRegionOnly && SI != SE) EnclosingExpr = RE; @@ -10102,7 +10089,7 @@ // If a list item is an element of a structure, and a different element of // the structure has a corresponding list item in the device data environment // prior to a task encountering the construct associated with the map clause, - // then the list item must also have a correspnding list item in the device + // then the list item must also have a corresponding list item in the device // data environment prior to the task encountering the construct. // if (EnclosingExpr && !IsEnclosedByDataEnvironmentExpr) { @@ -10125,6 +10112,17 @@ SourceLocation LParenLoc, SourceLocation EndLoc) { SmallVector Vars; + // Keep track of the mappable components and base declarations in this clause. + // Each entry in the list is going to have a list of components associated. We + // record each set of the components so that we can build the clause later on. + // In the end we should have the same amount of declarations and component + // lists. + OMPClauseMappableExprCommon::MappableExprComponentLists ClauseComponents; + SmallVector ClauseBaseDeclarations; + + ClauseComponents.reserve(VarList.size()); + ClauseBaseDeclarations.reserve(VarList.size()); + for (auto &RE : VarList) { assert(RE && "Null expr in omp map"); if (isa(RE)) { @@ -10153,25 +10151,29 @@ continue; } - // Obtain the array or member expression bases if required. - auto *BE = CheckMapClauseExpressionBase(*this, SimpleExpr); + OMPClauseMappableExprCommon::MappableExprComponentList CurComponents; + ValueDecl *CurDeclaration = nullptr; + + // Obtain the array or member expression bases if required. Also, fill the + // components array with all the components identified in the process. + auto *BE = CheckMapClauseExpressionBase(*this, SimpleExpr, CurComponents); if (!BE) continue; - // If the base is a reference to a variable, we rely on that variable for - // the following checks. If it is a 'this' expression we rely on the field. - ValueDecl *D = nullptr; - if (auto *DRE = dyn_cast(BE)) { - D = DRE->getDecl(); - } else { - auto *ME = cast(BE); - assert(isa(ME->getBase()) && "Unexpected expression!"); - D = ME->getMemberDecl(); - } - assert(D && "Null decl on map clause."); + assert(!CurComponents.empty() && + "Invalid mappable expression information."); - auto *VD = dyn_cast(D); - auto *FD = dyn_cast(D); + // For the following checks, we rely on the base declaration which is + // expected to be associated with the last component. The declaration is + // expected to be a variable or a field (if 'this' is being mapped). + CurDeclaration = CurComponents.back().getAssociatedDeclaration(); + assert(CurDeclaration && "Null decl on map clause."); + assert( + CurDeclaration->isCanonicalDecl() && + "Expecting components to have associated only canonical declarations."); + + auto *VD = dyn_cast(CurDeclaration); + auto *FD = dyn_cast(CurDeclaration); assert((VD || FD) && "Only variables or fields are expected here!"); (void)FD; @@ -10196,19 +10198,17 @@ // Check conflicts with other map clause expressions. We check the conflicts // with the current construct separately from the enclosing data // environment, because the restrictions are different. - if (CheckMapConflicts(*this, DSAStack, D, SimpleExpr, - /*CurrentRegionOnly=*/true)) + if (CheckMapConflicts(*this, DSAStack, CurDeclaration, SimpleExpr, + /*CurrentRegionOnly=*/true, CurComponents)) break; - if (CheckMapConflicts(*this, DSAStack, D, SimpleExpr, - /*CurrentRegionOnly=*/false)) + if (CheckMapConflicts(*this, DSAStack, CurDeclaration, SimpleExpr, + /*CurrentRegionOnly=*/false, CurComponents)) break; // OpenMP 4.5 [2.15.5.1, map Clause, Restrictions, C++, p.1] // If the type of a list item is a reference to a type T then the type will // be considered to be T for all purposes of this clause. - QualType Type = D->getType(); - if (Type->isReferenceType()) - Type = Type->getPointeeType(); + QualType Type = CurDeclaration->getType().getNonReferenceType(); // OpenMP 4.5 [2.15.5.1, map Clause, Restrictions, p.9] // A list item must have a mappable type. @@ -10254,20 +10254,32 @@ Diag(ELoc, diag::err_omp_variable_in_map_and_dsa) << getOpenMPClauseName(DVar.CKind) << getOpenMPDirectiveName(DSAStack->getCurrentDirective()); - ReportOriginalDSA(*this, DSAStack, D, DVar); + ReportOriginalDSA(*this, DSAStack, CurDeclaration, DVar); continue; } } + // Save the current expression. Vars.push_back(RE); - DSAStack->addExprToVarMapInfo(D, RE); + + // Store the components in the stack so that they can be used to check + // against other clauses later on. + DSAStack->addMappableExpressionComponents(CurDeclaration, CurComponents); + + // Save the components and declaration to create the clause. For purposes of + // the clause creation, any component list that has has base 'this' uses + // null has + ClauseComponents.resize(ClauseComponents.size() + 1); + ClauseComponents.back().append(CurComponents.begin(), CurComponents.end()); + ClauseBaseDeclarations.push_back(isa(BE) ? nullptr + : CurDeclaration); } // We need to produce a map clause even if we don't have variables so that // other diagnostics related with non-existing map clauses are accurate. - return OMPMapClause::Create(Context, StartLoc, LParenLoc, EndLoc, Vars, - MapTypeModifier, MapType, IsMapTypeImplicit, - MapLoc); + return OMPMapClause::Create( + Context, StartLoc, LParenLoc, EndLoc, Vars, ClauseBaseDeclarations, + ClauseComponents, MapTypeModifier, MapType, IsMapTypeImplicit, MapLoc); } QualType Sema::ActOnOpenMPDeclareReductionType(SourceLocation TyLoc, Index: lib/Serialization/ASTReaderStmt.cpp =================================================================== --- lib/Serialization/ASTReaderStmt.cpp +++ lib/Serialization/ASTReaderStmt.cpp @@ -1861,9 +1861,15 @@ case OMPC_device: C = new (Context) OMPDeviceClause(); break; - case OMPC_map: - C = OMPMapClause::CreateEmpty(Context, Record[Idx++]); + case OMPC_map: { + unsigned NumVars = Record[Idx++]; + unsigned NumDeclarations = Record[Idx++]; + unsigned NumLists = Record[Idx++]; + unsigned NumComponents = Record[Idx++]; + C = OMPMapClause::CreateEmpty(Context, NumVars, NumDeclarations, NumLists, + NumComponents); break; + } case OMPC_num_teams: C = new (Context) OMPNumTeamsClause(); break; @@ -2225,12 +2231,45 @@ C->setMapLoc(Reader->ReadSourceLocation(Record, Idx)); C->setColonLoc(Reader->ReadSourceLocation(Record, Idx)); auto NumVars = C->varlist_size(); + auto UniqueDecls = C->getUniqueDeclarationsNum(); + auto TotalLists = C->getTotalComponentListNum(); + auto TotalComponents = C->getTotalComponentsNum(); + SmallVector Vars; Vars.reserve(NumVars); - for (unsigned i = 0; i != NumVars; ++i) { + for (unsigned i = 0; i != NumVars; ++i) Vars.push_back(Reader->Reader.ReadSubExpr()); - } C->setVarRefs(Vars); + + SmallVector Decls; + Decls.reserve(UniqueDecls); + for (unsigned i = 0; i < UniqueDecls; ++i) + Decls.push_back( + Reader->Reader.ReadDeclAs(Reader->F, Record, Idx)); + C->setUniqueDecls(Decls); + + SmallVector ListsPerDecl; + ListsPerDecl.reserve(UniqueDecls); + for (unsigned i = 0; i < UniqueDecls; ++i) + ListsPerDecl.push_back(Record[Idx++]); + C->setDeclNumLists(ListsPerDecl); + + SmallVector ListSizes; + ListSizes.reserve(TotalLists); + for (unsigned i = 0; i < TotalLists; ++i) + ListSizes.push_back(Record[Idx++]); + C->setComponentListSizes(ListSizes); + + SmallVector Components; + Components.reserve(TotalComponents); + for (unsigned i = 0; i < TotalComponents; ++i) { + Expr *AssociatedExpr = Reader->Reader.ReadSubExpr(); + ValueDecl *AssociatedDecl = + Reader->Reader.ReadDeclAs(Reader->F, Record, Idx); + Components.push_back(OMPClauseMappableExprCommon::MappableComponent( + AssociatedExpr, AssociatedDecl)); + } + C->setComponents(Components, ListSizes); } void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) { Index: lib/Serialization/ASTWriterStmt.cpp =================================================================== --- lib/Serialization/ASTWriterStmt.cpp +++ lib/Serialization/ASTWriterStmt.cpp @@ -2021,13 +2021,26 @@ void OMPClauseWriter::VisitOMPMapClause(OMPMapClause *C) { Record.push_back(C->varlist_size()); + Record.push_back(C->getUniqueDeclarationsNum()); + Record.push_back(C->getTotalComponentListNum()); + Record.push_back(C->getTotalComponentsNum()); Record.AddSourceLocation(C->getLParenLoc()); Record.push_back(C->getMapTypeModifier()); Record.push_back(C->getMapType()); Record.AddSourceLocation(C->getMapLoc()); Record.AddSourceLocation(C->getColonLoc()); - for (auto *VE : C->varlists()) - Record.AddStmt(VE); + for (auto *E : C->varlists()) + Record.AddStmt(E); + for (auto *D : C->all_decls()) + Record.AddDeclRef(D); + for (auto N : C->all_num_lists()) + Record.push_back(N); + for (auto N : C->all_lists_sizes()) + Record.push_back(N); + for (auto &M : C->all_components()) { + Record.AddStmt(M.getAssociatedExpression()); + Record.AddDeclRef(M.getAssociatedDeclaration()); + } } void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {