diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -149,17 +149,27 @@ /// implement the fastmath interface. void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const; - /// Converts LLVM metadata to corresponding MLIR representation, - /// e.g. metadata nodes referenced via !tbaa are converted to - /// TBAA operations hosted inside a MetadataOp. + /// Converts all LLVM metadata nodes that translate to operations nested in a + /// global metadata operation, such as alias analysis or access group + /// metadata, and builds a map from the metadata nodes to the symbols pointing + /// to the converted operations. Returns success if all conversions succeed + /// and failure otherwise. + // Note: All metadata is nested inside a single global metadata operation to + // minimize the number of symbols that pollute the global namespace. LogicalResult convertMetadata(); - /// Returns SymbolRefAttr representing TBAA metadata `node` - /// in `tbaaMapping`. - SymbolRefAttr lookupTBAAAttr(const llvm::MDNode *node) { + /// Returns the MLIR symbol reference mapped to the given LLVM TBAA + /// metadata `node`. + SymbolRefAttr lookupTBAAAttr(const llvm::MDNode *node) const { return tbaaMapping.lookup(node); } + /// Returns the MLIR symbol reference mapped to the given LLVM access + /// group metadata `node`. + SymbolRefAttr lookupAccessGroupAttr(const llvm::MDNode *node) const { + return accessGroupMapping.lookup(node); + } + private: /// Clears the block and value mapping before processing a new region. void clearBlockAndValueMapping() { @@ -237,25 +247,21 @@ /// them fails. All operations are inserted at the start of the current /// function entry block. FailureOr convertConstantExpr(llvm::Constant *constant); - /// Returns symbol name to be used for MetadataOp containing - /// TBAA metadata operations. It must not conflict with the user - /// name space. - StringRef getTBAAMetadataOpName() const { return "__tbaa"; } - /// Returns a terminated MetadataOp into which TBAA metadata - /// operations can be placed. The MetadataOp is created - /// on the first invocation of this function. - MetadataOp getTBAAMetadataOp(); + /// Returns a global metadata operation that serves as a container for LLVM + /// metadata that converts to MLIR operations. Creates the global metadata + /// operation on the first invocation. + MetadataOp getGlobalMetadataOp(); /// Performs conversion of LLVM TBAA metadata starting from /// `node`. On exit from this function all nodes reachable /// from `node` are converted, and tbaaMapping map is updated /// (unless all dependencies have been converted by a previous /// invocation of this function). LogicalResult processTBAAMetadata(const llvm::MDNode *node); - /// Returns unique string name of a symbol that may be used - /// for a TBAA metadata operation. The name will contain - /// the provided `basename` and will be uniqued via - /// tbaaNodeCounter (see below). - std::string getNewTBAANodeName(StringRef basename); + /// Converts all LLVM access groups starting from `node` to MLIR access group + /// operations and stores a mapping from every nested access group node to the + /// symbol pointing to the translated operation. Returns success if all + /// conversions succeed and failure otherwise. + LogicalResult processAccessGroupMetadata(const llvm::MDNode *node); /// Builder pointing at where the next instruction should be generated. OpBuilder builder; @@ -265,6 +271,8 @@ Operation *constantInsertionOp = nullptr; /// Operation to insert the next global after. Operation *globalInsertionOp = nullptr; + /// Operation to insert metadata operations into. + MetadataOp globalMetadataOp = nullptr; /// The current context. MLIRContext *context; /// The MLIR module being created. @@ -284,20 +292,17 @@ /// operations for all operations that return no result. All operations that /// return a result have a valueMapping entry instead. DenseMap noResultOpMapping; - /// The stateful type translator (contains named structs). - LLVM::TypeFromLLVMIRTranslator typeTranslator; - /// Stateful debug information importer. - std::unique_ptr debugImporter; - /// A terminated MetadataOp where TBAA metadata operations - /// can be inserted. - MetadataOp tbaaMetadataOp{}; /// Mapping between LLVM TBAA metadata nodes and symbol references /// to the LLVMIR dialect TBAA operations corresponding to these /// nodes. DenseMap tbaaMapping; - /// A counter to be used as a unique suffix for symbols - /// defined by TBAA operations. - unsigned tbaaNodeCounter = 0; + /// Mapping between original LLVM access group metadata nodes and the symbol + /// references pointing to the imported MLIR access group operations. + DenseMap accessGroupMapping; + /// The stateful type translator (contains named structs). + LLVM::TypeFromLLVMIRTranslator typeTranslator; + /// Stateful debug information importer. + std::unique_ptr debugImporter; }; } // namespace LLVM diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -71,17 +71,17 @@ /// dialect attributes. static ArrayRef getSupportedMetadataImpl() { static const SmallVector convertibleMetadata = { - llvm::LLVMContext::MD_prof, // profiling metadata - llvm::LLVMContext::MD_tbaa}; + llvm::LLVMContext::MD_prof, llvm::LLVMContext::MD_tbaa, + llvm::LLVMContext::MD_access_group}; return convertibleMetadata; } -/// Attaches the given profiling metadata to the imported operation if a -/// conversion to an MLIR profiling attribute exists and succeeds. Returns -/// failure otherwise. -static LogicalResult setProfilingAttrs(OpBuilder &builder, llvm::MDNode *node, - Operation *op, - LLVM::ModuleImport &moduleImport) { +/// Converts the given profiling metadata `node` to an MLIR profiling attribute +/// and attaches it to the imported operation if the translation succeeds. +/// Returns failure otherwise. +static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, + Operation *op, + LLVM::ModuleImport &moduleImport) { // Return success for empty metadata nodes since there is nothing to import. if (!node->getNumOperands()) return success(); @@ -127,11 +127,11 @@ .Default([](auto) { return failure(); }); } -/// Attaches the given TBAA metadata `node` to the imported operation. -/// Returns success, if the metadata has been converted and the attachment -/// succeeds, failure - otherwise. -static LogicalResult setTBAAAttrs(const llvm::MDNode *node, Operation *op, - LLVM::ModuleImport &moduleImport) { +/// Searches the symbol reference pointing to the metadata operation that +/// maps to the given TBAA metadata `node` and attaches it to the imported +/// operation if the lookup succeeds. Returns failure otherwise. +static LogicalResult setTBAAAttr(const llvm::MDNode *node, Operation *op, + LLVM::ModuleImport &moduleImport) { SymbolRefAttr tbaaTagSym = moduleImport.lookupTBAAAttr(node); if (!tbaaTagSym) return failure(); @@ -141,6 +141,28 @@ return success(); } +/// Searches the symbol references pointing to the access group operations that +/// map to the access group nodes starting from the access group metadata +/// `node`, and attaches all of them to the imported operation if the lookups +/// succeed. Returns failure otherwise. +static LogicalResult setAccessGroupAttr(const llvm::MDNode *node, Operation *op, + LLVM::ModuleImport &moduleImport) { + // An access group node is either access group or an access group list. + SmallVector accessGroups; + if (!node->getNumOperands()) + accessGroups.push_back(moduleImport.lookupAccessGroupAttr(node)); + for (const llvm::MDOperand &operand : node->operands()) { + auto *node = cast(operand.get()); + accessGroups.push_back(moduleImport.lookupAccessGroupAttr(node)); + } + // Exit if one of the access group node lookups failed. + if (llvm::is_contained(accessGroups, nullptr)) + return failure(); + + op->setAttr(LLVMDialect::getAccessGroupsAttrName(), + ArrayAttr::get(op->getContext(), accessGroups)); + return success(); +} namespace { /// Implementation of the dialect interface that converts operations belonging @@ -164,9 +186,11 @@ LLVM::ModuleImport &moduleImport) const final { // Call metadata specific handlers. if (kind == llvm::LLVMContext::MD_prof) - return setProfilingAttrs(builder, node, op, moduleImport); + return setProfilingAttr(builder, node, op, moduleImport); if (kind == llvm::LLVMContext::MD_tbaa) - return setTBAAAttrs(node, op, moduleImport); + return setTBAAAttr(node, op, moduleImport); + if (kind == llvm::LLVMContext::MD_access_group) + return setAccessGroupAttr(node, op, moduleImport); // A handler for a supported metadata kind is missing. llvm_unreachable("unknown metadata type"); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -71,6 +71,12 @@ return "llvm.global_dtors"; } +/// Returns the symbol name for the module-level metadata operation. It must not +/// conflict with the user namespace. +static constexpr StringRef getGlobalMetadataOpName() { + return "__llvm_global_metadata"; +} + /// Returns a supported MLIR floating point type of the given bit width or null /// if the bit width is not supported. static FloatType getDLFloatType(MLIRContext &ctx, int32_t bitwidth) { @@ -359,23 +365,14 @@ builder.setInsertionPointToStart(mlirModule.getBody()); } -MetadataOp ModuleImport::getTBAAMetadataOp() { - if (tbaaMetadataOp) - return tbaaMetadataOp; +MetadataOp ModuleImport::getGlobalMetadataOp() { + if (globalMetadataOp) + return globalMetadataOp; OpBuilder::InsertionGuard guard(builder); - Location loc = mlirModule.getLoc(); - builder.setInsertionPointToEnd(mlirModule.getBody()); - tbaaMetadataOp = builder.create(loc, getTBAAMetadataOpName()); - - return tbaaMetadataOp; -} - -std::string ModuleImport::getNewTBAANodeName(StringRef basename) { - return (Twine("tbaa_") + Twine(basename) + Twine('_') + - Twine(tbaaNodeCounter++)) - .str(); + return globalMetadataOp = builder.create( + mlirModule.getLoc(), getGlobalMetadataOpName()); } LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) { @@ -534,10 +531,18 @@ return true; }; + // Helper to compute a unique symbol name that includes the given `baseName`. + // Uses the size of the mapping to unique the symbol name. + auto getUniqueSymbolName = [&](StringRef baseName) { + return (Twine("tbaa_") + Twine(baseName) + Twine('_') + + Twine(tbaaMapping.size())) + .str(); + }; + // Insert new operations at the end of the MetadataOp. OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToEnd(&getTBAAMetadataOp().getBody().back()); - StringAttr metadataOpName = SymbolTable::getSymbolName(getTBAAMetadataOp()); + builder.setInsertionPointToEnd(&getGlobalMetadataOp().getBody().back()); + StringAttr metadataOpName = SymbolTable::getSymbolName(getGlobalMetadataOp()); // On the first walk, create SymbolRefAttr's and map them // to nodes in `nodesToConvert`. @@ -550,7 +555,7 @@ // The root nodes do not have operands, so we can create // the TBAARootMetadataOp on the first walk. auto rootNode = builder.create( - loc, getNewTBAANodeName("root"), identity.value()); + loc, getUniqueSymbolName("root"), identity.value()); tbaaMapping.try_emplace(current, FlatSymbolRefAttr::get(rootNode)); continue; } @@ -559,7 +564,7 @@ return failure(); tbaaMapping.try_emplace( current, FlatSymbolRefAttr::get(builder.getContext(), - getNewTBAANodeName("type_desc"))); + getUniqueSymbolName("type_desc"))); continue; } if (std::optional isValid = isTagNode(current)) { @@ -571,7 +576,7 @@ current, SymbolRefAttr::get( builder.getContext(), metadataOpName, FlatSymbolRefAttr::get(builder.getContext(), - getNewTBAANodeName("tag")))); + getUniqueSymbolName("tag")))); continue; } return emitError(loc) << "unsupported TBAA node format: " @@ -611,21 +616,62 @@ return success(); } +LogicalResult +ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) { + // An access group node is either access group or an access group list. Start + // by collecting all access groups to translate. + SmallVector accessGroups; + if (!node->getNumOperands()) + accessGroups.push_back(node); + for (const llvm::MDOperand &operand : node->operands()) + accessGroups.push_back(cast(operand.get())); + + // Convert all entries of the access group list to access group operations. + for (const llvm::MDNode *accessGroup : accessGroups) { + if (accessGroupMapping.count(accessGroup)) + continue; + // Verify the access group node is distinct and empty. + Location loc = mlirModule.getLoc(); + if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) + return emitError(loc) << "unsupported access group node: " + << diagMD(accessGroup, llvmModule.get()); + + MetadataOp metadataOp = getGlobalMetadataOp(); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(&metadataOp.getBody().back()); + auto groupOp = builder.create( + loc, (Twine("group_") + Twine(accessGroupMapping.size())).str()); + // Add a mapping from the access group node to the symbol reference pointing + // to the newly created operation. + accessGroupMapping[accessGroup] = SymbolRefAttr::get( + builder.getContext(), metadataOp.getSymName(), + FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName())); + } + return success(); +} + LogicalResult ModuleImport::convertMetadata() { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToEnd(mlirModule.getBody()); - for (const llvm::Function &func : llvmModule->functions()) + for (const llvm::Function &func : llvmModule->functions()) { for (const llvm::Instruction &inst : llvm::instructions(func)) { - llvm::AAMDNodes nodes = inst.getAAMetadata(); - if (!nodes) - continue; + // Convert access group metadata nodes. + if (llvm::MDNode *node = + inst.getMetadata(llvm::LLVMContext::MD_access_group)) + if (failed(processAccessGroupMetadata(node))) + return failure(); - if (const llvm::MDNode *tbaaMD = nodes.TBAA) - if (failed(processTBAAMetadata(tbaaMD))) + // Convert alias analysis metadata nodes. + llvm::AAMDNodes aliasAnalysisNodes = inst.getAAMetadata(); + if (!aliasAnalysisNodes) + continue; + if (aliasAnalysisNodes.TBAA) + if (failed(processTBAAMetadata(aliasAnalysisNodes.TBAA))) return failure(); - // TODO: only TBAA metadata is currently supported. - } + // TODO: Support noalias and scope metadata nodes. + } + } return success(); } diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -239,3 +239,26 @@ !3 = !{!4, i64 4, !"int"} !4 = !{!5, i64 1, !"omnipotent char"} !5 = !{!"Simple C++ TBAA"} + +; // ----- + +; CHECK: import-failure.ll +; CHECK-SAME: error: unsupported access group node: !0 = !{} +define void @access_group(ptr %arg1) { + %1 = load i32, ptr %arg1, !llvm.access.group !0 + ret void +} + +!0 = !{} + +; // ----- + +; CHECK: import-failure.ll +; CHECK-SAME: error: unsupported access group node: !1 = distinct !{!"unsupported access group"} +define void @access_group(ptr %arg1) { + %1 = load i32, ptr %arg1, !llvm.access.group !0 + ret void +} + +!0 = !{!1} +!1 = distinct !{!"unsupported access group"} diff --git a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll @@ -0,0 +1,27 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK: llvm.metadata @__llvm_global_metadata { +; CHECK: llvm.access_group @[[$GROUP0:.*]] +; CHECK: llvm.access_group @[[$GROUP1:.*]] +; CHECK: llvm.access_group @[[$GROUP2:.*]] +; CHECK: llvm.access_group @[[$GROUP3:.*]] +; CHECK: } + +; CHECK-LABEL: llvm.func @access_group +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +define void @access_group(ptr %arg1) { + ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]]} + %1 = load i32, ptr %arg1, !llvm.access.group !0 + ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]]} + %2 = load i32, ptr %arg1, !llvm.access.group !1 + ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP3]]]} + %3 = load i32, ptr %arg1, !llvm.access.group !2 + ret void +} + +!0 = !{!3, !4} +!1 = !{!5, !3} +!2 = distinct !{} +!3 = distinct !{} +!4 = distinct !{} +!5 = distinct !{} diff --git a/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll rename from mlir/test/Target/LLVMIR/Import/profiling-metadata.ll rename to mlir/test/Target/LLVMIR/Import/metadata-profiling.ll diff --git a/mlir/test/Target/LLVMIR/Import/tbaa.ll b/mlir/test/Target/LLVMIR/Import/metadata-tbaa.ll rename from mlir/test/Target/LLVMIR/Import/tbaa.ll rename to mlir/test/Target/LLVMIR/Import/metadata-tbaa.ll --- a/mlir/test/Target/LLVMIR/Import/tbaa.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-tbaa.ll @@ -2,7 +2,7 @@ // ----- -; CHECK-LABEL: llvm.metadata @__tbaa { +; CHECK-LABEL: llvm.metadata @__llvm_global_metadata { ; CHECK-NEXT: llvm.tbaa_root @[[R0:tbaa_root_[0-9]+]] {id = "Simple C/C++ TBAA"} ; CHECK-NEXT: llvm.tbaa_tag @[[T0:tbaa_tag_[0-9]+]] {access_type = @[[R0]], base_type = @[[R0]], offset = 0 : i64} ; CHECK-NEXT: llvm.tbaa_root @[[R1:tbaa_root_[0-9]+]] {id = "Other language TBAA"} @@ -10,10 +10,10 @@ ; CHECK-NEXT: } ; CHECK: llvm.func @tbaa1 ; CHECK: llvm.store %{{.*}}, %{{.*}} { -; CHECK-SAME: tbaa = [@__tbaa::@[[T0]]] +; CHECK-SAME: tbaa = [@__llvm_global_metadata::@[[T0]]] ; CHECK-SAME: } : i8, !llvm.ptr ; CHECK: llvm.store %{{.*}}, %{{.*}} { -; CHECK-SAME: tbaa = [@__tbaa::@[[T1]]] +; CHECK-SAME: tbaa = [@__llvm_global_metadata::@[[T1]]] ; CHECK-SAME: } : i8, !llvm.ptr define dso_local void @tbaa1(ptr %0, ptr %1) { store i8 1, ptr %0, align 4, !tbaa !1 @@ -28,7 +28,7 @@ // ----- -; CHECK-LABEL: llvm.metadata @__tbaa { +; CHECK-LABEL: llvm.metadata @__llvm_global_metadata { ; CHECK-NEXT: llvm.tbaa_root @[[R0:tbaa_root_[0-9]+]] {id = "Simple C/C++ TBAA"} ; CHECK-NEXT: llvm.tbaa_tag @[[T0:tbaa_tag_[0-9]+]] {access_type = @[[D1:tbaa_type_desc_[0-9]+]], base_type = @[[D2:tbaa_type_desc_[0-9]+]], offset = 8 : i64} ; CHECK-NEXT: llvm.tbaa_type_desc @[[D1]] {id = "long long", members = {<@[[D0:tbaa_type_desc_[0-9]+]], 0>}} @@ -40,10 +40,10 @@ ; CHECK-NEXT: } ; CHECK: llvm.func @tbaa2 ; CHECK: llvm.load %{{.*}} { -; CHECK-SAME: tbaa = [@__tbaa::@[[T0]]] +; CHECK-SAME: tbaa = [@__llvm_global_metadata::@[[T0]]] ; CHECK-SAME: } : !llvm.ptr -> i64 ; CHECK: llvm.store %{{.*}}, %{{.*}} { -; CHECK-SAME: tbaa = [@__tbaa::@[[T1]]] +; CHECK-SAME: tbaa = [@__llvm_global_metadata::@[[T1]]] ; CHECK-SAME: } : i32, !llvm.ptr %struct.agg2_t = type { i64, i64 } %struct.agg1_t = type { i32, i32 }