diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -302,9 +302,6 @@ /// to the LLVMIR dialect TBAA operations corresponding to these /// nodes. DenseMap tbaaMapping; - /// Mapping between original LLVM access group metadata nodes and the symbol - /// references pointing to the imported MLIR access group operations. - DenseMap accessGroupMapping; /// The stateful type translator (contains named structs). LLVM::TypeFromLLVMIRTranslator typeTranslator; /// Stateful debug information importer. diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h --- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h @@ -21,13 +21,28 @@ namespace LLVM { namespace detail { -/// A helper class that converts a `llvm.loop` metadata node into a -/// corresponding LoopAnnotationAttr. +/// A helper class that converts `llvm.loop` metadata nodes into corresponding +/// LoopAnnotationAttrs and `llvm.access.group` nodes into +/// `AccessGroupMetadataOp`s. class LoopAnnotationImporter { public: - explicit LoopAnnotationImporter(ModuleImport &moduleImport) - : moduleImport(moduleImport) {} - LoopAnnotationAttr translate(const llvm::MDNode *node, Location loc); + explicit LoopAnnotationImporter(OpBuilder &builder) : builder(builder) {} + LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node, + Location loc); + + /// Converts all LLVM access groups starting from `node` to MLIR access group + /// operations and creates corresponding access group operation in + /// `metadataOp`. It stores a mapping from every nested access group nod to + /// the symbol pointing to the translated operation. Returns success if all + /// conversions succeed and failure otherwise. + LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc, + MetadataOp metadataOp); + + /// Returns the symbol references pointing to the access group operations that + /// map to the access group nodes starting from the access group metadata + /// `node`. Returns failure, if any of the symbol references cannot be found. + FailureOr> + lookupAccessGroupAttrs(const llvm::MDNode *node) const; private: /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute. @@ -42,8 +57,11 @@ "attempting to map loop options that was already mapped"); } - ModuleImport &moduleImport; + OpBuilder &builder; DenseMap loopMetadataMapping; + /// Mapping between original LLVM access group metadata nodes and the symbol + /// references pointing to the imported MLIR access group operations. + DenseMap accessGroupMapping; }; } // namespace detail diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp @@ -16,11 +16,9 @@ namespace { /// Helper class that keeps the state of one metadata to attribute conversion. struct LoopMetadataConversion { - LoopMetadataConversion(const llvm::MDNode *node, ModuleImport &moduleImport, - Location loc, + LoopMetadataConversion(const llvm::MDNode *node, Location loc, LoopAnnotationImporter &loopAnnotationImporter) - : node(node), moduleImport(moduleImport), loc(loc), - loopAnnotationImporter(loopAnnotationImporter), + : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter), ctx(loc->getContext()){}; /// Converts this structs loop metadata node into a LoopAnnotationAttr. LoopAnnotationAttr convert(); @@ -55,7 +53,6 @@ llvm::StringMap propertyMap; const llvm::MDNode *node; - ModuleImport &moduleImport; Location loc; LoopAnnotationImporter &loopAnnotationImporter; MLIRContext *ctx; @@ -233,7 +230,7 @@ if (*node == nullptr) return LoopAnnotationAttr(nullptr); - return loopAnnotationImporter.translate(*node, loc); + return loopAnnotationImporter.translateLoopAnnotation(*node, loc); } static bool isEmptyOrNull(const Attribute attr) { return !attr; } @@ -360,7 +357,7 @@ SmallVector refs; for (llvm::MDNode *node : *nodes) { FailureOr> accessGroups = - moduleImport.lookupAccessGroupAttrs(node); + loopAnnotationImporter.lookupAccessGroupAttrs(node); if (failed(accessGroups)) return emitWarning(loc) << "could not lookup access group"; llvm::append_range(refs, *accessGroups); @@ -398,8 +395,9 @@ parallelAccesses); } -LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node, - Location loc) { +LoopAnnotationAttr +LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node, + Location loc) { if (!node) return {}; @@ -409,9 +407,60 @@ if (it != loopMetadataMapping.end()) return it->getSecond(); - LoopAnnotationAttr attr = - LoopMetadataConversion(node, moduleImport, loc, *this).convert(); + LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert(); mapLoopMetadata(node, attr); return attr; } + +LogicalResult LoopAnnotationImporter::translateAccessGroup( + const llvm::MDNode *node, Location loc, MetadataOp metadataOp) { + SmallVector accessGroups; + if (!node->getNumOperands()) + accessGroups.push_back(node); + for (const llvm::MDOperand &operand : node->operands()) { + auto *childNode = dyn_cast(operand); + if (!childNode) + return emitWarning(loc) + << "expected access group operands to be metadata nodes"; + accessGroups.push_back(cast(operand.get())); + } + + // Convert all entries of the access group list to access group operations. + for (const llvm::MDNode *accessGroup : accessGroups) { + if (accessGroupMapping.count(accessGroup)) + continue; + // Verify the access group node is distinct and empty. + if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) + return emitWarning(loc) + << "expected an access group node to be empty and distinct"; + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(&metadataOp.getBody().back()); + auto groupOp = builder.create( + loc, llvm::formatv("group_{0}", accessGroupMapping.size()).str()); + // Add a mapping from the access group node to the symbol reference pointing + // to the newly created operation. + accessGroupMapping[accessGroup] = SymbolRefAttr::get( + builder.getContext(), metadataOp.getSymName(), + FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName())); + } + return success(); +} + +FailureOr> +LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const { + // An access group node is either a single access group or an access group + // list. + SmallVector accessGroups; + if (!node->getNumOperands()) + accessGroups.push_back(accessGroupMapping.lookup(node)); + for (const llvm::MDOperand &operand : node->operands()) { + auto *node = cast(operand.get()); + accessGroups.push_back(accessGroupMapping.lookup(node)); + } + // Exit if one of the access group node lookups failed. + if (llvm::is_contained(accessGroups, nullptr)) + return failure(); + return accessGroups; +} diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -243,7 +243,8 @@ iface(mlirModule->getContext()), typeTranslator(*mlirModule->getContext()), debugImporter(std::make_unique(mlirModule)), - loopAnnotationImporter(std::make_unique(*this)) { + loopAnnotationImporter( + std::make_unique(builder)) { builder.setInsertionPointToStart(mlirModule.getBody()); } @@ -500,35 +501,11 @@ LogicalResult ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) { - // An access group node is either access group or an access group list. Start - // by collecting all access groups to translate. - SmallVector accessGroups; - if (!node->getNumOperands()) - accessGroups.push_back(node); - for (const llvm::MDOperand &operand : node->operands()) - accessGroups.push_back(cast(operand.get())); - - // Convert all entries of the access group list to access group operations. - for (const llvm::MDNode *accessGroup : accessGroups) { - if (accessGroupMapping.count(accessGroup)) - continue; - // Verify the access group node is distinct and empty. - Location loc = mlirModule.getLoc(); - if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) - return emitError(loc) << "unsupported access group node: " - << diagMD(accessGroup, llvmModule.get()); - - MetadataOp metadataOp = getGlobalMetadataOp(); - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToEnd(&metadataOp.getBody().back()); - auto groupOp = builder.create( - loc, (Twine("group_") + Twine(accessGroupMapping.size())).str()); - // Add a mapping from the access group node to the symbol reference pointing - // to the newly created operation. - accessGroupMapping[accessGroup] = SymbolRefAttr::get( - builder.getContext(), metadataOp.getSymName(), - FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName())); - } + Location loc = mlirModule.getLoc(); + if (failed(loopAnnotationImporter->translateAccessGroup( + node, loc, getGlobalMetadataOp()))) + return emitError(loc) << "unsupported access group node: " + << diagMD(node, llvmModule.get()); return success(); } @@ -1575,25 +1552,13 @@ FailureOr> ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const { - // An access group node is either a single access group or an access group - // list. - SmallVector accessGroups; - if (!node->getNumOperands()) - accessGroups.push_back(accessGroupMapping.lookup(node)); - for (const llvm::MDOperand &operand : node->operands()) { - auto *node = cast(operand.get()); - accessGroups.push_back(accessGroupMapping.lookup(node)); - } - // Exit if one of the access group node lookups failed. - if (llvm::is_contained(accessGroups, nullptr)) - return failure(); - return accessGroups; + return loopAnnotationImporter->lookupAccessGroupAttrs(node); } LoopAnnotationAttr ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node, Location loc) const { - return loopAnnotationImporter->translate(node, loc); + return loopAnnotationImporter->translateLoopAnnotation(node, loc); } OwningOpRef diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -243,7 +243,8 @@ ; // ----- ; CHECK: import-failure.ll -; CHECK-SAME: error: unsupported access group node: !0 = !{} +; CHECK-SAME: warning: expected an access group node to be empty and distinct +; CHECK: error: unsupported access group node: !0 = !{} define void @access_group(ptr %arg1) { %1 = load i32, ptr %arg1, !llvm.access.group !0 ret void @@ -254,7 +255,8 @@ ; // ----- ; CHECK: import-failure.ll -; CHECK-SAME: error: unsupported access group node: !1 = distinct !{!"unsupported access group"} +; CHECK-SAME: warning: expected an access group node to be empty and distinct +; CHECK: error: unsupported access group node: !0 = !{!1} define void @access_group(ptr %arg1) { %1 = load i32, ptr %arg1, !llvm.access.group !0 ret void @@ -265,6 +267,18 @@ ; // ----- +; CHECK: import-failure.ll +; CHECK-SAME: warning: expected access group operands to be metadata nodes +; CHECK: error: unsupported access group node: !0 = !{i1 false} +define void @access_group(ptr %arg1) { + %1 = load i32, ptr %arg1, !llvm.access.group !0 + ret void +} + +!0 = !{i1 false} + +; // ----- + ; CHECK: import-failure.ll ; CHECK-SAME: warning: expected all loop properties to be either debug locations or metadata nodes ; CHECK: import-failure.ll