diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -120,11 +120,6 @@ /// in these blocks. void forgetMapping(Region ®ion); - /// Returns the LLVM metadata corresponding to a symbol reference to an mlir - /// LLVM dialect access group operation. - llvm::MDNode *getAccessGroup(Operation *op, - SymbolRefAttr accessGroupRef) const; - /// Returns the LLVM metadata corresponding to a symbol reference to an mlir /// LLVM dialect alias scope operation llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const; diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h @@ -21,14 +21,28 @@ namespace LLVM { namespace detail { -/// A helper class that converts a LoopAnnotationAttr into a corresponding -/// llvm::MDNode. +/// A helper class that converts LoopAnnotationAttrs and AccessGroupMetadataOps +/// into a corresponding llvm::MDNodes. class LoopAnnotationTranslation { public: - LoopAnnotationTranslation(LLVM::ModuleTranslation &moduleTranslation) - : moduleTranslation(moduleTranslation) {} + LoopAnnotationTranslation(Operation *mlirModule, llvm::Module &llvmModule) + : mlirModule(mlirModule), llvmModule(llvmModule) {} - llvm::MDNode *translate(LoopAnnotationAttr attr, Operation *op); + llvm::MDNode *translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op); + + /// Traverses the global access group metadata operation in the `mlirModule` + /// and creates corresponding LLVM metadata nodes. + LogicalResult createAccessGroupMetadata(); + + /// Returns the LLVM metadata corresponding to a symbol reference to an mlir + /// LLVM dialect access group operation. + llvm::MDNode *getAccessGroup(Operation *op, + SymbolRefAttr accessGroupRef) const; + + /// Returns the LLVM metadata corresponding to a list of symbol reference to + /// an mlir LLVM dialect access group operation. Returns nullptr if + /// `accessGroupRefs` is null or empty. + llvm::MDNode *getAccessGroups(Operation *op, ArrayAttr accessGroupRefs) const; private: /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute. @@ -47,7 +61,13 @@ /// The metadata is attached to Latch block branches with this attribute. DenseMap loopMetadataMapping; - LLVM::ModuleTranslation &moduleTranslation; + /// Mapping from an access group metadata operation to its LLVM metadata. + /// This map is populated on module entry and is used to annotate loops (as + /// identified via their branches) and contained memory accesses. + DenseMap accessGroupMetadataMapping; + + Operation *mlirModule; + llvm::Module &llvmModule; }; } // namespace detail diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp @@ -15,12 +15,11 @@ namespace { /// Helper class that keeps the state of one attribute to metadata conversion. struct LoopAnnotationConversion { - LoopAnnotationConversion(LoopAnnotationAttr attr, - ModuleTranslation &moduleTranslation, Operation *op, - LoopAnnotationTranslation &loopAnnotationTranslation) - : attr(attr), moduleTranslation(moduleTranslation), op(op), - loopAnnotationTranslation(loopAnnotationTranslation), - ctx(moduleTranslation.getLLVMContext()) {} + LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op, + LoopAnnotationTranslation &loopAnnotationTranslation, + llvm::LLVMContext &ctx) + : attr(attr), op(op), + loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {} /// Converts this struct's loop annotation into a corresponding LLVMIR /// metadata representation. @@ -46,7 +45,6 @@ void convertLoopOptions(LoopUnswitchAttr options); LoopAnnotationAttr attr; - ModuleTranslation &moduleTranslation; Operation *op; LoopAnnotationTranslation &loopAnnotationTranslation; llvm::LLVMContext &ctx; @@ -95,7 +93,8 @@ if (!attr) return; - llvm::MDNode *node = loopAnnotationTranslation.translate(attr, op); + llvm::MDNode *node = + loopAnnotationTranslation.translateLoopAnnotation(attr, op); metadataNodes.push_back( llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node})); @@ -225,7 +224,7 @@ llvm::MDString::get(ctx, "llvm.loop.parallel_accesses")); for (SymbolRefAttr accessGroupRef : parallelAccessGroups) parallelAccess.push_back( - moduleTranslation.getAccessGroup(op, accessGroupRef)); + loopAnnotationTranslation.getAccessGroup(op, accessGroupRef)); metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess)); } @@ -236,7 +235,8 @@ return loopMD; } -llvm::MDNode *LoopAnnotationTranslation::translate(LoopAnnotationAttr attr, +llvm::MDNode * +LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op) { if (!attr) return nullptr; @@ -246,9 +246,47 @@ return loopMD; loopMD = - LoopAnnotationConversion(attr, moduleTranslation, op, *this).convert(); + LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext()) + .convert(); // Store a map from this Attribute to the LLVM metadata in case we // encounter it again. mapLoopMetadata(attr, loopMD); return loopMD; } + +LogicalResult LoopAnnotationTranslation::createAccessGroupMetadata() { + mlirModule->walk([&](LLVM::MetadataOp metadatas) { + metadatas.walk([&](LLVM::AccessGroupMetadataOp op) { + llvm::MDNode *accessGroup = + llvm::MDNode::getDistinct(llvmModule.getContext(), {}); + accessGroupMetadataMapping.insert({op, accessGroup}); + }); + }); + return success(); +} + +llvm::MDNode * +LoopAnnotationTranslation::getAccessGroup(Operation *op, + SymbolRefAttr accessGroupRef) const { + auto metadataName = accessGroupRef.getRootReference(); + auto accessGroupName = accessGroupRef.getLeafReference(); + auto metadataOp = SymbolTable::lookupNearestSymbolFrom( + op->getParentOp(), metadataName); + auto *accessGroupOp = + SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); + return accessGroupMetadataMapping.lookup(accessGroupOp); +} + +llvm::MDNode * +LoopAnnotationTranslation::getAccessGroups(Operation *op, + ArrayAttr accessGroupRefs) const { + if (!accessGroupRefs || accessGroupRefs.empty()) + return nullptr; + + SmallVector groupMDs; + for (SymbolRefAttr groupRef : accessGroupRefs.getAsRange()) + groupMDs.push_back(getAccessGroup(op, groupRef)); + if (groupMDs.size() == 1) + return llvm::cast(groupMDs.front()); + return llvm::MDNode::get(llvmModule.getContext(), groupMDs); +} diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -421,8 +421,8 @@ : mlirModule(module), llvmModule(std::move(llvmModule)), debugTranslation( std::make_unique(module, *this->llvmModule)), - loopAnnotationTranslation( - std::make_unique(*this)), + loopAnnotationTranslation(std::make_unique( + module, *this->llvmModule)), typeTranslator(this->llvmModule->getContext()), iface(module->getContext()) { assert(satisfiesLLVMModule(mlirModule) && @@ -449,7 +449,6 @@ branchMapping.erase(&op); if (isa(op)) globalsMapping.erase(&op); - accessGroupMetadataMapping.erase(&op); llvm::append_range( toProcess, llvm::map_range(op.getRegions(), [](Region &r) { return &r; })); @@ -994,47 +993,16 @@ return success(); } -llvm::MDNode * -ModuleTranslation::getAccessGroup(Operation *op, - SymbolRefAttr accessGroupRef) const { - auto metadataName = accessGroupRef.getRootReference(); - auto accessGroupName = accessGroupRef.getLeafReference(); - auto metadataOp = SymbolTable::lookupNearestSymbolFrom( - op->getParentOp(), metadataName); - auto *accessGroupOp = - SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); - return accessGroupMetadataMapping.lookup(accessGroupOp); -} - LogicalResult ModuleTranslation::createAccessGroupMetadata() { - mlirModule->walk([&](LLVM::MetadataOp metadatas) { - metadatas.walk([&](LLVM::AccessGroupMetadataOp op) { - llvm::LLVMContext &ctx = llvmModule->getContext(); - llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {}); - accessGroupMetadataMapping.insert({op, accessGroup}); - }); - }); - return success(); + return loopAnnotationTranslation->createAccessGroupMetadata(); } void ModuleTranslation::setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst) { auto populateGroupsMetadata = [&](ArrayAttr groupRefs) { - if (!groupRefs || groupRefs.empty()) - return; - - llvm::Module *module = inst->getModule(); - SmallVector groupMDs; - for (SymbolRefAttr groupRef : groupRefs.getAsRange()) - groupMDs.push_back(getAccessGroup(op, groupRef)); - - llvm::MDNode *node = nullptr; - if (groupMDs.size() == 1) - node = llvm::cast(groupMDs.front()); - else if (groupMDs.size() >= 2) - node = llvm::MDNode::get(module->getContext(), groupMDs); - - inst->setMetadata(llvm::LLVMContext::MD_access_group, node); + if (llvm::MDNode *node = + loopAnnotationTranslation->getAccessGroups(op, groupRefs)) + inst->setMetadata(llvm::LLVMContext::MD_access_group, node); }; auto groupRefs = @@ -1250,7 +1218,8 @@ [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); }); if (!attr) return; - llvm::MDNode *loopMD = loopAnnotationTranslation->translate(attr, op); + llvm::MDNode *loopMD = + loopAnnotationTranslation->translateLoopAnnotation(attr, op); inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD); }