diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -126,10 +126,12 @@ llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const; // Sets LLVM metadata for memory operations that are in a parallel loop. - void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst); + void setAccessGroupsMetadata(AccessGroupOpInterface op, + llvm::Instruction *inst); // Sets LLVM metadata for memory operations that have alias scope information. - void setAliasScopeMetadata(Operation *op, llvm::Instruction *inst); + void setAliasScopeMetadata(AliasAnalysisOpInterface op, + llvm::Instruction *inst); /// Sets LLVM TBAA metadata for memory operations that have TBAA attributes. void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst); diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h @@ -39,10 +39,9 @@ llvm::MDNode *getAccessGroup(Operation *op, SymbolRefAttr accessGroupRef) const; - /// Returns the LLVM metadata corresponding to a list of symbol reference to - /// an mlir LLVM dialect access group operation. Returns nullptr if - /// `accessGroupRefs` is null or empty. - llvm::MDNode *getAccessGroups(Operation *op, ArrayAttr accessGroupRefs) const; + /// Returns the LLVM metadata corresponding to the access group operations + /// referenced by the AccessGroupOpInterface or null if there are none. + llvm::MDNode *getAccessGroups(AccessGroupOpInterface op) const; private: /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute. diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp @@ -278,8 +278,8 @@ } llvm::MDNode * -LoopAnnotationTranslation::getAccessGroups(Operation *op, - ArrayAttr accessGroupRefs) const { +LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) const { + ArrayAttr accessGroupRefs = op.getAccessGroupsOrNull(); if (!accessGroupRefs || accessGroupRefs.empty()) return nullptr; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -998,17 +998,10 @@ return loopAnnotationTranslation->createAccessGroupMetadata(); } -void ModuleTranslation::setAccessGroupsMetadata(Operation *op, +void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst) { - auto populateGroupsMetadata = [&](ArrayAttr groupRefs) { - if (llvm::MDNode *node = - loopAnnotationTranslation->getAccessGroups(op, groupRefs)) - inst->setMetadata(llvm::LLVMContext::MD_access_group, node); - }; - - auto groupRefs = - op->getAttrOfType(LLVMDialect::getAccessGroupsAttrName()); - populateGroupsMetadata(groupRefs); + if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op)) + inst->setMetadata(llvm::LLVMContext::MD_access_group, node); } LogicalResult ModuleTranslation::createAliasScopeMetadata() { @@ -1060,7 +1053,7 @@ return aliasScopeMetadataMapping.lookup(aliasScopeOp); } -void ModuleTranslation::setAliasScopeMetadata(Operation *op, +void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst) { auto populateScopeMetadata = [&](ArrayAttr scopeRefs, unsigned kind) { if (!scopeRefs || scopeRefs.empty()) @@ -1073,13 +1066,10 @@ inst->setMetadata(kind, node); }; - auto aliasScopeRefs = - op->getAttrOfType(LLVMDialect::getAliasScopesAttrName()); - populateScopeMetadata(aliasScopeRefs, llvm::LLVMContext::MD_alias_scope); - - auto noaliasScopeRefs = - op->getAttrOfType(LLVMDialect::getNoAliasScopesAttrName()); - populateScopeMetadata(noaliasScopeRefs, llvm::LLVMContext::MD_noalias); + populateScopeMetadata(op.getAliasScopesOrNull(), + llvm::LLVMContext::MD_alias_scope); + populateScopeMetadata(op.getNoAliasScopesOrNull(), + llvm::LLVMContext::MD_noalias); } llvm::MDNode *ModuleTranslation::getTBAANode(Operation *op,