diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -206,7 +206,7 @@ OptionalParameter<"LoopUnswitchAttr">:$unswitch, OptionalParameter<"BoolAttr">:$mustProgress, OptionalParameter<"BoolAttr">:$isVectorized, - OptionalArrayRefParameter<"SymbolRefAttr">:$parallelAccesses + OptionalArrayRefParameter<"AccessGroupAttr">:$parallelAccesses ); let assemblyFormat = "`<` struct(params) `>`"; @@ -500,13 +500,39 @@ OptionalArrayRefParameter<"DITypeAttr">:$types ); let builders = [ - TypeBuilder<(ins "ArrayRef":$types), [{ + AttrBuilder<(ins "ArrayRef":$types), [{ return $_get($_ctxt, /*callingConvention=*/0, types); }]> ]; let assemblyFormat = "`<` struct(params) `>`"; } +//===----------------------------------------------------------------------===// +// AccessGroupAttr +//===----------------------------------------------------------------------===// + +def LLVM_AccessGroupAttr : LLVM_Attr<"AccessGroup", "access_group"> { + let parameters = (ins + "int64_t":$id, + "DistinctSequenceAttr":$elem_of + ); + let builders = [ + AttrBuilder<(ins "DistinctSequenceAttr":$sequence), [{ + return $_get($_ctxt, sequence.getNextID(), sequence); + }]> + ]; + let assemblyFormat = "`<` struct(params) `>`"; +} + +//===----------------------------------------------------------------------===// +// AccessGroupArrayAttr +//===----------------------------------------------------------------------===// + +def AccessGroupArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; +} + //===----------------------------------------------------------------------===// // MemoryEffectsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h @@ -60,6 +60,57 @@ static bool classof(Attribute attr); }; +namespace detail { +class DistinctSequenceAttrStorage; +} // namespace detail + +/// This class is a helper attribute to generate a sequence of unique +/// identifiers that can be used to model distinct metadata nodes. The +/// attribute has a scope that limits the validity of the generated sequence to +/// a function since generating unique identifiers at the module level could +/// lead to non determinism due to the parallel processing of functions. A +/// mutable state can be incremented to generate the next unique identifier. +/// +/// Example: +/// ``` +/// #distinct_sequence = #llvm.distinct_sequence +/// #access_group = #llvm.access_group +/// #access_group1 = #llvm.access_group +/// +/// llvm.func @foo(%arg0: !llvm.ptr) { +/// %0 = llvm.load %arg0 {access_groups = [#access_group, #access_group1]} +/// } +/// ``` +class DistinctSequenceAttr + : public Attribute::AttrBase { +public: + // Inherit Base constructors. + using Base::Base; + + /// Returns a distinct sequences attribute for the given scope. + static DistinctSequenceAttr get(SymbolRefAttr scope); + + /// Returns the keyword used when printing and parsing the attribute. + static constexpr StringLiteral getMnemonic() { return {"distinct_sequence"}; } + + /// Returns the symbol that limits the scope of the sequence. + SymbolRefAttr getScope() const; + + /// Returns the next identifier without incrementing the mutable state. + int64_t getState() const; + + /// Returns the next identifier and increments the mutable state. + int64_t getNextID(); + + /// Parses an instance of this attribute. + static Attribute parse(AsmParser &parser, Type type); + + /// Prints this attribute. + void print(AsmPrinter &os) const; +}; + // Inline the LLVM generated Linkage enum and utility. // This is only necessary to isolate the "enum generated code" from the // attribute definition itself. @@ -73,4 +124,12 @@ #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" +namespace mlir { +namespace LLVM { +/// Verifies the access groups attached to the given operation. +LogicalResult verifyAccessGroups(Operation *op, + ArrayRef accessGroups); +} // namespace LLVM +} // namespace mlir + #endif // MLIR_DIALECT_LLVMIR_LLVMATTRS_H_ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td @@ -15,7 +15,6 @@ let name = "llvm"; let cppNamespace = "::mlir::LLVM"; - let useDefaultAttributePrinterParser = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; let hasOperationAttrVerify = 1; @@ -76,6 +75,8 @@ Type parseType(DialectAsmParser &p) const override; void printType(Type, DialectAsmPrinter &p) const override; + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; + void printAttribute(Attribute attr, DialectAsmPrinter &os) const override; private: /// Verifies a parameter attribute attached to a parameter of type diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h @@ -19,8 +19,8 @@ namespace LLVM { namespace detail { -/// Verifies the access groups attribute of memory operations that implement the -/// access group interface. +/// Verifies the access group attributes of memory operations that implement +/// the access group interface. LogicalResult verifyAccessGroupOpInterface(Operation *op); /// Verifies the alias analysis attributes of memory operations that implement diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -53,7 +53,7 @@ An interface for memory operations that can carry access groups metadata. It provides setters and getters for the operation's access groups attribute. The default implementations of the interface methods expect the operation - to have an attribute of type ArrayAttr named access_groups. + to have an array attribute named access_groups. }]; let cppNamespace = "::mlir::LLVM"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -263,7 +263,7 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods], traits)>, LLVM_MemOpPatterns { - dag aliasAttrs = (ins OptionalAttr:$access_groups, + dag aliasAttrs = (ins OptionalAttr:$access_groups, OptionalAttr:$alias_scopes, OptionalAttr:$noalias_scopes, OptionalAttr:$tbaa); @@ -306,7 +306,7 @@ Results { dag aliasAttrs = !con( !if(!gt(requiresAccessGroup, 0), - (ins OptionalAttr:$access_groups), + (ins OptionalAttr:$access_groups), (ins )), !if(!gt(requiresAccessGroup, 0), (ins OptionalAttr:$alias_scopes, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1126,26 +1126,6 @@ let hasVerifier = 1; } -def LLVM_AccessGroupMetadataOp : LLVM_Op<"access_group", [ - HasParent<"MetadataOp">, Symbol -]> { - let arguments = (ins - SymbolNameAttr:$sym_name - ); - let summary = "LLVM dialect access group metadata."; - let description = [{ - Defines an access group metadata that can be attached to any instruction - that potentially accesses memory. The access group may be attached to a - memory accessing instruction via the `llvm.access.group` metadata and - a branch instruction in the loop latch block via the - `llvm.loop.parallel_accesses` metadata. - - See the following link for more details: - https://llvm.org/docs/LangRef.html#llvm-access-group-metadata - }]; - let assemblyFormat = "$sym_name attr-dict"; -} - def LLVM_TBAARootMetadataOp : LLVM_Op<"tbaa_root", [ HasParent<"MetadataOp">, Symbol ]> { diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -14,6 +14,7 @@ #ifndef MLIR_TARGET_LLVMIR_MODULEIMPORT_H #define MLIR_TARGET_LLVMIR_MODULEIMPORT_H +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Target/LLVMIR/Import.h" @@ -180,10 +181,10 @@ return tbaaMapping.lookup(node); } - /// Returns the symbol references pointing to the access group operations that - /// map to the access group nodes starting from the access group metadata - /// `node`. Returns failure, if any of the symbol references cannot be found. - FailureOr> + /// Returns the access group attributes that map to the access group nodes + /// starting from the access group metadata `node`. Returns failure if the + /// lookup fails for any of the access groups. + FailureOr> lookupAccessGroupAttrs(const llvm::MDNode *node) const; /// Returns the loop annotation attribute that corresponds to the given LLVM @@ -285,10 +286,13 @@ /// invocation of this function). LogicalResult processTBAAMetadata(const llvm::MDNode *node); /// Converts all LLVM access groups starting from `node` to MLIR access group - /// operations and stores a mapping from every nested access group node to the - /// symbol pointing to the translated operation. Returns success if all - /// conversions succeed and failure otherwise. - LogicalResult processAccessGroupMetadata(const llvm::MDNode *node); + /// operations and stores a mapping from every nested access group to the + /// translated attribute. Uses `distinctSequence` to generate the function + /// specific access group identifiers. Returns success if all conversions + /// succeed and failure otherwise. + LogicalResult + processAccessGroupMetadata(const llvm::MDNode *node, + DistinctSequenceAttr distinctSequence); /// Converts all LLVM alias scopes and domains starting from `node` to MLIR /// alias scope and domain operations and stores a mapping from every nested /// alias scope or alias domain node to the symbol pointing to the translated diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -280,10 +280,6 @@ LogicalResult convertGlobals(); LogicalResult convertOneFunction(LLVMFuncOp func); - /// Process access_group LLVM Metadata operations and create LLVM - /// metadata nodes. - LogicalResult createAccessGroupMetadata(); - /// Process alias.scope LLVM Metadata operations and create LLVM /// metadata nodes for them and their domains. LogicalResult createAliasScopeMetadata(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -27,16 +27,42 @@ #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" //===----------------------------------------------------------------------===// -// LLVMDialect registration +// LLVMDialect //===----------------------------------------------------------------------===// void LLVMDialect::registerAttributes() { + addAttributes(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" >(); } +Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + StringRef mnemonic; + Attribute attr; + OptionalParseResult result = + generatedAttributeParser(parser, &mnemonic, type, attr); + if (result.has_value()) + return attr; + + if (mnemonic == DistinctSequenceAttr::getMnemonic()) + return DistinctSequenceAttr::parse(parser, type); + + llvm_unreachable("unhandled LLVM attribute kind"); +} + +void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { + if (succeeded(generatedAttributePrinter(attr, os))) + return; + + if (auto composite = dyn_cast(attr)) + composite.print(os); + else + llvm_unreachable("unhandled LLVM attribute kind"); +} + //===----------------------------------------------------------------------===// // DINodeAttr //===----------------------------------------------------------------------===// @@ -100,3 +126,181 @@ return false; return true; } + +//===----------------------------------------------------------------------===// +// DistinctSequenceAttr +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace LLVM { +namespace detail { +/// Attribute storage class of the distinct sequence attribute that stores the +/// sequence scope and a mutable state that is used to get the next identifier. +class DistinctSequenceAttrStorage : public AttributeStorage { +public: + using KeyTy = SymbolRefAttr; + + DistinctSequenceAttrStorage(SymbolRefAttr scope) : scope(scope) {} + + static DistinctSequenceAttrStorage * + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + DistinctSequenceAttrStorage(key); + } + + /// Stores the next identifier value that matches the state variable to + /// `nextID` and post increments the state (incrementing is thread safe since + /// the storage uniquer acquires a mutex before calling the mutate method). + LogicalResult mutate(AttributeStorageAllocator &allocator, int64_t *nextID) { + *nextID = state++; + return success(); + } + + /// Sets the state to `state` after parsing an attribute instance (setting + /// the state is thread safe since the storage uniquer acquires a mutex before + /// calling the mutate method). + LogicalResult mutate(AttributeStorageAllocator &allocator, int64_t state) { + this->state = state; + return success(); + } + + /// Returns the scope of the sequence. + SymbolRefAttr getScope() const { return scope; } + + /// Returns the state of the sequence. + int64_t getState() const { return state; } + + /// Compares the non-mutable part of the attribute. + bool operator==(const KeyTy &other) const { return scope == other; } + +private: + SymbolRefAttr scope; + int64_t state = 0; +}; +} // namespace detail +} // namespace LLVM +} // namespace mlir + +DistinctSequenceAttr DistinctSequenceAttr::get(SymbolRefAttr scope) { + return Base::get(scope.getContext(), scope); +} + +SymbolRefAttr DistinctSequenceAttr::getScope() const { + return getImpl()->getScope(); +} + +int64_t DistinctSequenceAttr::getState() const { return getImpl()->getState(); } + +int64_t DistinctSequenceAttr::getNextID() { + int64_t nextID; + (void)Base::mutate(&nextID); + return nextID; +} + +Attribute DistinctSequenceAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess()) + return {}; + + // A helper function to parse a struct parameter. + auto parseParameter = + [&](StringRef name, StringRef type, bool &seen, + function_ref parseFn) -> ParseResult { + if (seen) { + return parser.emitError(parser.getCurrentLocation()) + << "struct has duplicate parameter '" << name << "'"; + } + if (failed(parseFn())) { + return parser.emitError(parser.getCurrentLocation()) + << "failed to parse DistinctSequenceAttr parameter '" << name + << "' which is to be a '" << type << "'"; + } + seen = true; + return success(); + }; + + std::pair scope = {nullptr, false}; + std::pair state = {0, false}; + do { + std::string keyword; + if (failed(parser.parseKeywordOrString(&keyword))) { + parser.emitError(parser.getCurrentLocation()) + << "expected a parameter name in struct"; + return {}; + } + if (parser.parseEqual()) + return {}; + + if (keyword == "scope") { + if (failed(parseParameter(keyword, "SymbolRefAttr", scope.second, [&]() { + return parser.parseAttribute(scope.first); + }))) + return {}; + } else if (keyword == "state") { + if (failed(parseParameter(keyword, "int64_t", state.second, [&]() { + return parser.parseInteger(state.first); + }))) + return {}; + } else { + parser.emitError(parser.getCurrentLocation()) + << "expected a parameter name in struct"; + return {}; + } + } while (succeeded(parser.parseOptionalComma())); + + if (!scope.second) { + parser.emitError(parser.getCurrentLocation()) + << "struct is missing required parameter 'scope'"; + return {}; + } + if (!state.second) { + parser.emitError(parser.getCurrentLocation()) + << "struct is missing required parameter 'state'"; + return {}; + } + + if (parser.parseGreater()) + return {}; + + DistinctSequenceAttr distinctSeqAttr = get(scope.first); + (void)distinctSeqAttr.mutate(state.first); + return distinctSeqAttr; +} + +void DistinctSequenceAttr::print(AsmPrinter &os) const { + os << DistinctSequenceAttr::getMnemonic() << "<"; + os << "scope = "; + os.printAttribute(getImpl()->getScope()); + os << ", state = " << getImpl()->getState(); + os << ">"; +} + +LogicalResult +mlir::LLVM::verifyAccessGroups(Operation *op, + ArrayRef accessGroups) { + // Search the top-level symbol table. + Operation *topLevelSymbolTable = SymbolTable::getNearestSymbolTable(op); + Operation *parentOp = topLevelSymbolTable->getParentOp(); + while (parentOp && parentOp->hasTrait()) { + topLevelSymbolTable = parentOp; + parentOp = parentOp->getParentOp(); + } + + // Verify the access group sequence scope matches the parent operation and the + // unique identifiers are smaller than the sequence state. + for (AccessGroupAttr group : accessGroups) { + DistinctSequenceAttr sequence = group.getElemOf(); + Operation *scopeOp = SymbolTable::lookupNearestSymbolFrom( + topLevelSymbolTable, sequence.getScope()); + if (scopeOp != op->getParentOp()) { + return op->emitOpError() + << "expected distinct sequence scope '" << sequence.getScope() + << "' to resolve to the parent operation"; + } + if (group.getId() >= sequence.getState()) { + return op->emitOpError() << "expected access group id '" << group.getId() + << "' to be lower than the sequence state '" + << sequence.getState() << "'"; + } + } + return success(); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -274,31 +274,12 @@ // LLVM::BrOp //===----------------------------------------------------------------------===// -/// Check if the `loopAttr` references correct symbols. -static LogicalResult verifyLoopAnnotationAttr(LoopAnnotationAttr loopAttr, - Operation *op) { +/// Verify the access group attributes referenced by the loop annotations. +static LogicalResult verifyLoopAnnotationAttr(Operation *op, + LoopAnnotationAttr loopAttr) { if (!loopAttr) return success(); - // If the `llvm.loop` attribute is present, enforce the following structure, - // which the module translation can assume. - ArrayRef parallelAccesses = loopAttr.getParallelAccesses(); - if (parallelAccesses.empty()) - return success(); - for (SymbolRefAttr accessGroupRef : parallelAccesses) { - StringAttr metadataName = accessGroupRef.getRootReference(); - auto metadataOp = SymbolTable::lookupNearestSymbolFrom( - op->getParentOp(), metadataName); - if (!metadataOp) - return op->emitOpError() << "expected '" << accessGroupRef - << "' to reference a metadata op"; - StringAttr accessGroupName = accessGroupRef.getLeafReference(); - Operation *accessGroupOp = - SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); - if (!accessGroupOp) - return op->emitOpError() << "expected '" << accessGroupRef - << "' to reference an access_group op"; - } - return success(); + return verifyAccessGroups(op, loopAttr.getParallelAccesses()); } SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { @@ -307,7 +288,7 @@ } LogicalResult BrOp::verify() { - return verifyLoopAnnotationAttr(getLoopAnnotationAttr(), *this); + return verifyLoopAnnotationAttr(*this, getLoopAnnotationAttr()); } //===----------------------------------------------------------------------===// @@ -320,10 +301,6 @@ : getFalseDestOperandsMutable()); } -LogicalResult CondBrOp::verify() { - return verifyLoopAnnotationAttr(getLoopAnnotationAttr(), *this); -} - void CondBrOp::build(OpBuilder &builder, OperationState &result, Value condition, Block *trueDest, ValueRange trueOperands, Block *falseDest, ValueRange falseOperands, @@ -338,6 +315,10 @@ /*loop_annotation=*/{}, trueDest, falseDest); } +LogicalResult CondBrOp::verify() { + return verifyLoopAnnotationAttr(*this, getLoopAnnotationAttr()); +} + //===----------------------------------------------------------------------===// // LLVM::SwitchOp //===----------------------------------------------------------------------===// @@ -2789,17 +2770,18 @@ AliasResult getAlias(Attribute attr, raw_ostream &os) const override { return TypeSwitch(attr) - .Case([&](auto attr) { - os << decltype(attr)::getMnemonic(); - return AliasResult::OverridableAlias; - }) + .Case( + [&](auto attr) { + os << decltype(attr)::getMnemonic(); + return AliasResult::OverridableAlias; + }) .Default([](Attribute) { return AliasResult::NoAlias; }); } }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" using namespace mlir; @@ -81,18 +82,6 @@ return verifySymbolRefs(op, name, symbolRefs, verifySymbolType); } -//===----------------------------------------------------------------------===// -// AccessGroupOpInterface -//===----------------------------------------------------------------------===// - -LogicalResult mlir::LLVM::detail::verifyAccessGroupOpInterface(Operation *op) { - auto iface = cast(op); - if (failed(verifySymbolRefsPointTo( - iface, "access groups", iface.getAccessGroupsOrNull()))) - return failure(); - return success(); -} - //===----------------------------------------------------------------------===// // AliasAnalysisOpInterface //===----------------------------------------------------------------------===// @@ -112,4 +101,16 @@ return success(); } +//===----------------------------------------------------------------------===// +// AccessGroupOpInterface +//===----------------------------------------------------------------------===// + +LogicalResult mlir::LLVM::detail::verifyAccessGroupOpInterface(Operation *op) { + auto iface = cast(op); + if (ArrayAttr groups = iface.getAccessGroupsOrNull()) + return verifyAccessGroups( + op, llvm::to_vector(groups.getAsRange())); + return success(); +} + #include "mlir/Dialect/LLVMIR/LLVMInterfaces.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -158,18 +158,18 @@ static LogicalResult setAccessGroupsAttr(const llvm::MDNode *node, Operation *op, LLVM::ModuleImport &moduleImport) { - FailureOr> accessGroups = + FailureOr> groups = moduleImport.lookupAccessGroupAttrs(node); - if (failed(accessGroups)) + if (failed(groups)) return failure(); auto iface = dyn_cast(op); if (!iface) return failure(); - iface.setAccessGroups(ArrayAttr::get( - iface.getContext(), - SmallVector{accessGroups->begin(), accessGroups->end()})); + iface.setAccessGroups( + ArrayAttr::get(iface->getContext(), + SmallVector{groups->begin(), groups->end()})); return success(); } diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h --- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h @@ -14,6 +14,7 @@ #ifndef MLIR_LIB_TARGET_LLVMIR_LOOPANNOTATIONIMPORTER_H_ #define MLIR_LIB_TARGET_LLVMIR_LOOPANNOTATIONIMPORTER_H_ +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Target/LLVMIR/ModuleImport.h" @@ -26,23 +27,23 @@ /// AccessGroupMetadataOps. class LoopAnnotationImporter { public: - explicit LoopAnnotationImporter(OpBuilder &builder) : builder(builder) {} + explicit LoopAnnotationImporter(MLIRContext *context) : context(context) {} LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node, Location loc); - /// Converts all LLVM access groups starting from node to MLIR access group - /// operations mested in the region of metadataOp. It stores a mapping from - /// every nested access group nod to the symbol pointing to the translated - /// operation. Returns success if all conversions succeed and failure - /// otherwise. + /// Converts all LLVM access groups starting from `node` to MLIR access group + /// attributes. Uses `distinctSequence` to generate the function specific + /// access group identifiers. It stores a mapping from every nested access + /// group node to the translated access group attribute. Returns success if + /// all conversions succeed and failure otherwise. LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc, - MetadataOp metadataOp); + DistinctSequenceAttr distinctSequence); - /// Returns the symbol references pointing to the access group operations that - /// map to the access group nodes starting from the access group metadata - /// node. Returns failure, if any of the symbol references cannot be found. - FailureOr> - lookupAccessGroupAttrs(const llvm::MDNode *node) const; + /// Returns the access group attributes that map to the access group nodes + /// starting from the access group metadata `node`. Returns failure if the + /// lookup fails for any of the access groups. + FailureOr> + lookupAccessGroupsAttr(const llvm::MDNode *node) const; private: /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute. @@ -57,11 +58,11 @@ "attempting to map loop options that was already mapped"); } - OpBuilder &builder; + MLIRContext *context; DenseMap loopMetadataMapping; - /// Mapping between original LLVM access group metadata nodes and the symbol - /// references pointing to the imported MLIR access group operations. - DenseMap accessGroupMapping; + /// Mapping between original LLVM access group metadata nodes and the + /// corresponding MLIR access group attribute. + DenseMap accessGroupMapping; }; } // namespace detail diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "LoopAnnotationImporter.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "llvm/IR/Constants.h" using namespace mlir; @@ -28,7 +29,7 @@ /// Helper function to get and erase a property. const llvm::MDNode *lookupAndEraseProperty(StringRef name); - /// Helper functions to lookup and convert MDNodes into a specifc attribute + /// Helper functions to lookup and convert MDNodes into a specific attribute /// kind. These functions return null-attributes if there is no node with the /// specified name, or failure, if the node is ill-formatted. FailureOr lookupUnitNode(StringRef name); @@ -52,7 +53,7 @@ FailureOr convertPipelineAttr(); FailureOr convertPeeledAttr(); FailureOr convertUnswitchAttr(); - FailureOr> convertParallelAccesses(); + FailureOr> convertParallelAccesses(); llvm::StringMap propertyMap; const llvm::MDNode *node; @@ -265,7 +266,7 @@ } /// Helper function that only creates and attribute of type T if all argument -/// conversion were successfull and at least one of them holds a non-null value. +/// conversion were successful and at least one of them holds a non-null value. template static T createIfNonNull(MLIRContext *ctx, const P &...args) { bool anyFailed = (failed(args) || ...); @@ -386,21 +387,21 @@ return createIfNonNull(ctx, partialDisable); } -FailureOr> +FailureOr> LoopMetadataConversion::convertParallelAccesses() { FailureOr> nodes = lookupMDNodes("llvm.loop.parallel_accesses"); if (failed(nodes)) return failure(); - SmallVector refs; + SmallVector accessGroups; for (llvm::MDNode *node : *nodes) { - FailureOr> accessGroups = - loopAnnotationImporter.lookupAccessGroupAttrs(node); - if (failed(accessGroups)) - return emitWarning(loc) << "could not lookup access group"; - llvm::append_range(refs, *accessGroups); + FailureOr> groups = + loopAnnotationImporter.lookupAccessGroupsAttr(node); + if (failed(groups)) + return emitWarning(loc) << "could not lookup access groups"; + llvm::append_range(accessGroups, *groups); } - return refs; + return accessGroups; } LoopAnnotationAttr LoopMetadataConversion::convert() { @@ -421,7 +422,7 @@ FailureOr mustProgress = lookupUnitNode("llvm.loop.mustprogress"); FailureOr isVectorized = lookupIntNodeAsBoolAttr("llvm.loop.isvectorized"); - FailureOr> parallelAccesses = + FailureOr> parallelAccesses = convertParallelAccesses(); // Drop the metadata if there are parts that cannot be imported. @@ -456,7 +457,8 @@ } LogicalResult LoopAnnotationImporter::translateAccessGroup( - const llvm::MDNode *node, Location loc, MetadataOp metadataOp) { + const llvm::MDNode *node, Location loc, + DistinctSequenceAttr distinctSequence) { SmallVector accessGroups; if (!node->getNumOperands()) accessGroups.push_back(node); @@ -469,31 +471,22 @@ // Convert all entries of the access group list to access group operations. for (const llvm::MDNode *accessGroup : accessGroups) { - if (accessGroupMapping.count(accessGroup)) + if (accessGroupMapping.contains(accessGroup)) continue; // Verify the access group node is distinct and empty. if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) return emitWarning(loc) << "expected an access group node to be empty and distinct"; - - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToEnd(&metadataOp.getBody().back()); - auto groupOp = builder.create( - loc, llvm::formatv("group_{0}", accessGroupMapping.size()).str()); - // Add a mapping from the access group node to the symbol reference pointing - // to the newly created operation. - accessGroupMapping[accessGroup] = SymbolRefAttr::get( - builder.getContext(), metadataOp.getSymName(), - FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName())); + accessGroupMapping[accessGroup] = + AccessGroupAttr::get(context, distinctSequence); } return success(); } -FailureOr> -LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const { - // An access group node is either a single access group or an access group - // list. - SmallVector accessGroups; +FailureOr> +LoopAnnotationImporter::lookupAccessGroupsAttr(const llvm::MDNode *node) const { + // An access group node is a single access group or an access group list. + SmallVector accessGroups; if (!node->getNumOperands()) accessGroups.push_back(accessGroupMapping.lookup(node)); for (const llvm::MDOperand &operand : node->operands()) { diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h @@ -25,23 +25,17 @@ /// into a corresponding llvm::MDNodes. class LoopAnnotationTranslation { public: - LoopAnnotationTranslation(Operation *mlirModule, llvm::Module &llvmModule) - : mlirModule(mlirModule), llvmModule(llvmModule) {} + LoopAnnotationTranslation(llvm::Module &llvmModule) + : llvmModule(llvmModule) {} llvm::MDNode *translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op); - /// Traverses the global access group metadata operation in the `mlirModule` - /// and creates corresponding LLVM metadata nodes. - LogicalResult createAccessGroupMetadata(); - - /// Returns the LLVM metadata corresponding to a symbol reference to an mlir - /// LLVM dialect access group operation. - llvm::MDNode *getAccessGroup(Operation *op, - SymbolRefAttr accessGroupRef) const; + /// Returns the LLVM metadata corresponding to an MLIR access group attribute. + llvm::MDNode *getAccessGroup(AccessGroupAttr group); /// Returns the LLVM metadata corresponding to the access group operations /// referenced by the AccessGroupOpInterface or null if there are none. - llvm::MDNode *getAccessGroups(AccessGroupOpInterface op) const; + llvm::MDNode *getAccessGroups(AccessGroupOpInterface op); private: /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute. @@ -60,12 +54,11 @@ /// The metadata is attached to Latch block branches with this attribute. DenseMap loopMetadataMapping; - /// Mapping from an access group metadata operation to its LLVM metadata. - /// This map is populated on module entry and is used to annotate loops (as - /// identified via their branches) and contained memory accesses. - DenseMap accessGroupMetadataMapping; + /// Mapping from an access group attributes to the corresponding LLVM metadata + /// nodes. This map is populated on module entry and is used to annotate loops + /// (as identified via their branches) and contained memory accesses. + DenseMap accessGroupMetadataMapping; - Operation *mlirModule; llvm::Module &llvmModule; }; diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "LoopAnnotationTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" using namespace mlir; using namespace mlir::LLVM; @@ -217,14 +218,13 @@ if (auto options = attr.getUnswitch()) convertLoopOptions(options); - ArrayRef parallelAccessGroups = attr.getParallelAccesses(); + ArrayRef parallelAccessGroups = attr.getParallelAccesses(); if (!parallelAccessGroups.empty()) { SmallVector parallelAccess; parallelAccess.push_back( llvm::MDString::get(ctx, "llvm.loop.parallel_accesses")); - for (SymbolRefAttr accessGroupRef : parallelAccessGroups) - parallelAccess.push_back( - loopAnnotationTranslation.getAccessGroup(op, accessGroupRef)); + for (AccessGroupAttr group : parallelAccessGroups) + parallelAccess.push_back(loopAnnotationTranslation.getAccessGroup(group)); metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess)); } @@ -254,38 +254,26 @@ return loopMD; } -LogicalResult LoopAnnotationTranslation::createAccessGroupMetadata() { - mlirModule->walk([&](LLVM::MetadataOp metadatas) { - metadatas.walk([&](LLVM::AccessGroupMetadataOp op) { - llvm::MDNode *accessGroup = - llvm::MDNode::getDistinct(llvmModule.getContext(), {}); - accessGroupMetadataMapping.insert({op, accessGroup}); - }); - }); - return success(); -} +llvm::MDNode *LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr group) { + auto it = accessGroupMetadataMapping.find(group); + if (it != accessGroupMetadataMapping.end()) + return it->getSecond(); -llvm::MDNode * -LoopAnnotationTranslation::getAccessGroup(Operation *op, - SymbolRefAttr accessGroupRef) const { - auto metadataName = accessGroupRef.getRootReference(); - auto accessGroupName = accessGroupRef.getLeafReference(); - auto metadataOp = SymbolTable::lookupNearestSymbolFrom( - op->getParentOp(), metadataName); - auto *accessGroupOp = - SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); - return accessGroupMetadataMapping.lookup(accessGroupOp); + llvm::MDNode *accessGroup = + llvm::MDNode::getDistinct(llvmModule.getContext(), {}); + accessGroupMetadataMapping.insert({group, accessGroup}); + return accessGroup; } llvm::MDNode * -LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) const { - ArrayAttr accessGroupRefs = op.getAccessGroupsOrNull(); - if (!accessGroupRefs || accessGroupRefs.empty()) +LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) { + ArrayAttr accessGroups = op.getAccessGroupsOrNull(); + if (!accessGroups || accessGroups.empty()) return nullptr; SmallVector groupMDs; - for (SymbolRefAttr groupRef : accessGroupRefs.getAsRange()) - groupMDs.push_back(getAccessGroup(op, groupRef)); + for (auto group : accessGroups.getAsRange()) + groupMDs.push_back(getAccessGroup(group)); if (groupMDs.size() == 1) return llvm::cast(groupMDs.front()); return llvm::MDNode::get(llvmModule.getContext(), groupMDs); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/ModuleImport.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Target/LLVMIR/Import.h" #include "AttrKindDetail.h" @@ -153,7 +154,7 @@ typeTranslator(*mlirModule->getContext()), debugImporter(std::make_unique(mlirModule)), loopAnnotationImporter( - std::make_unique(builder)) { + std::make_unique(mlirModule.getContext())) { builder.setInsertionPointToStart(mlirModule.getBody()); } @@ -408,11 +409,11 @@ return success(); } -LogicalResult -ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) { +LogicalResult ModuleImport::processAccessGroupMetadata( + const llvm::MDNode *node, DistinctSequenceAttr distinctSequence) { Location loc = mlirModule.getLoc(); - if (failed(loopAnnotationImporter->translateAccessGroup( - node, loc, getGlobalMetadataOp()))) + if (failed(loopAnnotationImporter->translateAccessGroup(node, loc, + distinctSequence))) return emitError(loc) << "unsupported access group node: " << diagMD(node, llvmModule.get()); return success(); @@ -515,11 +516,13 @@ OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToEnd(mlirModule.getBody()); for (const llvm::Function &func : llvmModule->functions()) { + auto distinctSequence = DistinctSequenceAttr::get( + SymbolRefAttr::get(StringAttr::get(context, func.getName()))); for (const llvm::Instruction &inst : llvm::instructions(func)) { // Convert access group metadata nodes. if (llvm::MDNode *node = inst.getMetadata(llvm::LLVMContext::MD_access_group)) - if (failed(processAccessGroupMetadata(node))) + if (failed(processAccessGroupMetadata(node, distinctSequence))) return failure(); // Convert alias analysis metadata nodes. @@ -1650,9 +1653,9 @@ return success(); } -FailureOr> +FailureOr> ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const { - return loopAnnotationImporter->lookupAccessGroupAttrs(node); + return loopAnnotationImporter->lookupAccessGroupsAttr(node); } LoopAnnotationAttr diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -442,8 +442,8 @@ : mlirModule(module), llvmModule(std::move(llvmModule)), debugTranslation( std::make_unique(module, *this->llvmModule)), - loopAnnotationTranslation(std::make_unique( - module, *this->llvmModule)), + loopAnnotationTranslation( + std::make_unique(*this->llvmModule)), typeTranslator(this->llvmModule->getContext()), iface(module->getContext()) { assert(satisfiesLLVMModule(mlirModule) && @@ -1017,10 +1017,6 @@ return success(); } -LogicalResult ModuleTranslation::createAccessGroupMetadata() { - return loopAnnotationTranslation->createAccessGroupMetadata(); -} - void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst) { if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op)) @@ -1326,8 +1322,6 @@ return nullptr; if (failed(translator.convertGlobals())) return nullptr; - if (failed(translator.createAccessGroupMetadata())) - return nullptr; if (failed(translator.createAliasScopeMetadata())) return nullptr; if (failed(translator.createTBAAMetadata())) diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -59,16 +59,14 @@ // ----- -llvm.metadata @metadata { - llvm.access_group @group - llvm.return -} +#distinct_sequence = #llvm.distinct_sequence +#access_group = #llvm.access_group func.func private @with_mem_attr(%ptr : !llvm.ptr) { %0 = llvm.mlir.constant(42 : i32) : i32 // Do not inline load/store operations that carry attributes requiring // handling while inlining, until this is supported by the inliner. - llvm.store %0, %ptr { access_groups = [@metadata::@group] }: i32, !llvm.ptr + llvm.store %0, %ptr { access_groups = [#access_group] }: i32, !llvm.ptr return } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -887,80 +887,80 @@ // ----- -module { - llvm.func @loopOptions() { - // expected-error@below {{expected '@func1' to reference a metadata op}} - llvm.br ^bb4 {loop_annotation = #llvm.loop_annotation} - ^bb4: - llvm.return - } - llvm.func @func1() { +llvm.func @loopOptions() { + // expected-error@below {{expected '<'}} + // expected-error@below {{failed to parse LoopAnnotationAttr parameter 'parallelAccesses' which is to be a `::llvm::ArrayRef`}} + llvm.br ^bb4 {loop_annotation = #llvm.loop_annotation} + ^bb4: llvm.return - } } // ----- -module { - llvm.func @loopOptions() { - // expected-error@below {{expected '@metadata' to reference an access_group op}} - llvm.br ^bb4 {loop_annotation = #llvm.loop_annotation} - ^bb4: - llvm.return - } - llvm.metadata @metadata { - } -} +// expected-error@below {{struct has duplicate parameter 'state'}} +#distinct_sequence = #llvm.distinct_sequence // ----- -module { - llvm.func @accessGroups(%arg0 : !llvm.ptr) { - // expected-error@below {{expected '@func1' to specify a fully qualified reference}} - %0 = llvm.load %arg0 { "access_groups" = [@func1] } : !llvm.ptr -> i32 - llvm.return - } - llvm.func @func1() { +// expected-error@below {{expected integer value}} +// expected-error@below {{failed to parse DistinctSequenceAttr parameter 'state' which is to be a 'int64_t'}} +#distinct_sequence = #llvm.distinct_sequence + +// ----- + +// expected-error@below {{expected a parameter name in struct}} +#distinct_sequence = #llvm.distinct_sequence + +// ----- + +// expected-error@below {{struct is missing required parameter 'scope'}} +#distinct_sequence = #llvm.distinct_sequence + +// ----- + +// expected-error@below {{struct is missing required parameter 'state'}} +#distinct_sequence = #llvm.distinct_sequence + +// ----- + +#distinct_sequence = #llvm.distinct_sequence +#access_group = #llvm.access_group + +llvm.func @parallel_accesses() { + // expected-error@below {{op expected distinct sequence scope '@bar' to resolve to the parent operation}} + llvm.br ^bb4 {loop_annotation = #llvm.loop_annotation} + ^bb4: llvm.return - } } // ----- -module { - llvm.func @accessGroups(%arg0 : i32, %arg1 : !llvm.ptr) { - // expected-error@below {{expected '@accessGroups::@group1' to reference a metadata op}} - llvm.store %arg0, %arg1 { "access_groups" = [@accessGroups::@group1] } : i32, !llvm.ptr - llvm.return - } - llvm.metadata @metadata { - } +#distinct_sequence = #llvm.distinct_sequence +#access_group = #llvm.access_group + +llvm.func @access_groups(%arg0 : !llvm.ptr) { + // expected-error@below {{op expected distinct sequence scope '@foo' to resolve to the parent operation}} + %0 = llvm.load %arg0 { "access_groups" = [#access_group] } : !llvm.ptr -> i32 + llvm.return } // ----- -module { - llvm.func @accessGroups(%arg0 : !llvm.ptr, %arg1 : f32) { - // expected-error@below {{expected '@metadata::@group1' to be a valid reference}} - %0 = llvm.atomicrmw fadd %arg0, %arg1 monotonic { "access_groups" = [@metadata::@group1] } : !llvm.ptr, f32 - llvm.return - } - llvm.metadata @metadata { - } +#distinct_sequence = #llvm.distinct_sequence +#access_group = #llvm.access_group + +llvm.func @access_groups(%arg0 : !llvm.ptr) { + // expected-error@below {{op expected access group id '1' to be lower than the sequence state '1'}} + %0 = llvm.load %arg0 { "access_groups" = [#access_group] } : !llvm.ptr -> i32 + llvm.return } // ----- -module { - llvm.func @accessGroups(%arg0 : !llvm.ptr, %arg1 : i32, %arg2 : i32) { - // expected-error@below {{expected '@metadata::@scope' to resolve to a llvm.access_group}} - %0 = llvm.cmpxchg %arg0, %arg1, %arg2 acq_rel monotonic { "access_groups" = [@metadata::@scope] } : !llvm.ptr, i32 - llvm.return - } - llvm.metadata @metadata { - llvm.alias_scope_domain @domain - llvm.alias_scope @scope { domain = @domain } - } +llvm.func @access_groups(%arg0 : !llvm.ptr) { + // expected-error@below {{attribute 'access_groups' failed to satisfy constraint: access group array attribute}} + %0 = llvm.load %arg0 { "access_groups" = [42 : i32] } : !llvm.ptr -> i32 + llvm.return } // ----- @@ -987,12 +987,12 @@ module { llvm.func @aliasScope(%arg0 : i32, %arg1 : !llvm.ptr) { - // expected-error@below {{expected '@metadata::@group' to resolve to a llvm.alias_scope}} - llvm.store %arg0, %arg1 { "alias_scopes" = [@metadata::@group] } : i32, !llvm.ptr + // expected-error@below {{expected '@metadata::@root' to resolve to a llvm.alias_scope}} + llvm.store %arg0, %arg1 { "alias_scopes" = [@metadata::@root] } : i32, !llvm.ptr llvm.return } llvm.metadata @metadata { - llvm.access_group @group + llvm.tbaa_root @root {id = "Simple C/C++ TBAA"} } } @@ -1000,12 +1000,12 @@ module { llvm.func @aliasScope(%arg0 : !llvm.ptr, %arg1 : f32) { - // expected-error@below {{expected '@metadata::@group' to resolve to a llvm.alias_scope}} - %0 = llvm.atomicrmw fadd %arg0, %arg1 monotonic { "noalias_scopes" = [@metadata::@group] } : !llvm.ptr, f32 + // expected-error@below {{expected '@metadata::@root' to resolve to a llvm.alias_scope}} + %0 = llvm.atomicrmw fadd %arg0, %arg1 monotonic { "noalias_scopes" = [@metadata::@root] } : !llvm.ptr, f32 llvm.return } llvm.metadata @metadata { - llvm.access_group @group + llvm.tbaa_root @root {id = "Simple C/C++ TBAA"} } } @@ -1013,9 +1013,9 @@ module { llvm.metadata @metadata { - llvm.access_group @group - // expected-error@below {{expected 'group' to reference a domain operation in the same region}} - llvm.alias_scope @scope { domain = @group } + llvm.tbaa_root @root {id = "Simple C/C++ TBAA"} + // expected-error@below {{expected 'root' to reference a domain operation in the same region}} + llvm.alias_scope @scope { domain = @root } } } diff --git a/mlir/test/Dialect/LLVMIR/loop-metadata.mlir b/mlir/test/Dialect/LLVMIR/loop-metadata.mlir --- a/mlir/test/Dialect/LLVMIR/loop-metadata.mlir +++ b/mlir/test/Dialect/LLVMIR/loop-metadata.mlir @@ -42,6 +42,14 @@ // CHECK-DAG: #[[UNSWITCH:.*]] = #llvm.loop_unswitch #unswitch = #llvm.loop_unswitch +// CHECK-DAG: #[[DISTINCT:.*]] = #llvm.distinct_sequence +#distinct_sequence = #llvm.distinct_sequence + +// CHECK-DAG: #[[GROUP0:.*]] = #llvm.access_group +#access_group0 = #llvm.access_group +// CHECK-DAG: #[[GROUP1:.*]] = #llvm.access_group +#access_group1 = #llvm.access_group + // CHECK: #[[LOOP_ANNOT:.*]] = #llvm.loop_annotation< // CHECK-DAG: disableNonforced = false // CHECK-DAG: mustProgress = true @@ -53,7 +61,7 @@ // CHECK-DAG: peeled = #[[PEELED]] // CHECK-DAG: unswitch = #[[UNSWITCH]] // CHECK-DAG: isVectorized = false -// CHECK-DAG: parallelAccesses = @metadata::@group1, @metadata::@group2> +// CHECK-DAG: parallelAccesses = #[[GROUP0]], #[[GROUP1]]> #loopMD = #llvm.loop_annotation + parallelAccesses = #access_group0, #access_group1> // CHECK: llvm.func @loop_annotation llvm.func @loop_annotation() { @@ -75,8 +83,3 @@ ^bb1: llvm.return } - -llvm.metadata @metadata { - llvm.access_group @group1 - llvm.access_group @group2 -} diff --git a/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir b/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir --- a/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir @@ -27,12 +27,12 @@ module { llvm.func @tbaa(%arg0: !llvm.ptr) { %0 = llvm.mlir.constant(1 : i8) : i8 - // expected-error@below {{expected '@metadata::@group1' to resolve to a llvm.tbaa_tag}} - llvm.store %0, %arg0 {tbaa = [@metadata::@group1]} : i8, !llvm.ptr + // expected-error@below {{expected '@metadata::@domain' to resolve to a llvm.tbaa_tag}} + llvm.store %0, %arg0 {tbaa = [@metadata::@domain]} : i8, !llvm.ptr llvm.return } llvm.metadata @metadata { - llvm.access_group @group1 + llvm.alias_scope_domain @domain } } diff --git a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll --- a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll @@ -1,20 +1,26 @@ ; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s -; CHECK: llvm.metadata @__llvm_global_metadata { -; CHECK: llvm.access_group @[[$GROUP0:.*]] -; CHECK: llvm.access_group @[[$GROUP1:.*]] -; CHECK: llvm.access_group @[[$GROUP2:.*]] -; CHECK: llvm.access_group @[[$GROUP3:.*]] -; CHECK: } - -; CHECK-LABEL: llvm.func @access_group -define void @access_group(ptr %arg1) { - ; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]] +; CHECK: #[[DISTINCT0:.*]] = #llvm.distinct_sequence +; CHECK: #[[DISTINCT1:.*]] = #llvm.distinct_sequence + +; CHECK: #[[$GROUP0:.*]] = #llvm.access_group +; CHECK: #[[$GROUP1:.*]] = #llvm.access_group +; CHECK: #[[$GROUP2:.*]] = #llvm.access_group +; CHECK: #[[$GROUP3:.*]] = #llvm.access_group + +; CHECK-LABEL: llvm.func @foo +define void @foo(ptr %arg1) { + ; CHECK: access_groups = [#[[$GROUP0]], #[[$GROUP1]]] %1 = load i32, ptr %arg1, !llvm.access.group !0 - ; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]] + ; CHECK: access_groups = [#[[$GROUP2]], #[[$GROUP0]]] %2 = load i32, ptr %arg1, !llvm.access.group !1 - ; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP3]]] - %3 = load i32, ptr %arg1, !llvm.access.group !2 + ret void +} + +; CHECK-LABEL: llvm.func @bar +define void @bar(ptr %arg1) { + ; CHECK: access_groups = [#[[$GROUP3]]] + %1 = load i32, ptr %arg1, !llvm.access.group !2 ret void } @@ -281,12 +287,11 @@ ; // ----- -; CHECK: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation - -; CHECK: llvm.metadata @__llvm_global_metadata { -; CHECK: llvm.access_group @[[GROUP0]] +; CHECK: #[[DISTINCT:.*]] = #llvm.distinct_sequence +; CHECK: #[[GROUP0:.*]] = #llvm.access_group +; CHECK: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation -; CHECK-LABEL: @parallel_accesses +; CHECK-LABEL: llvm.func @parallel_accesses define void @parallel_accesses(ptr %arg) { entry: %0 = load i32, ptr %arg, !llvm.access.group !0 @@ -302,13 +307,12 @@ ; // ----- -; CHECK: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation - -; CHECK: llvm.metadata @__llvm_global_metadata { -; CHECK: llvm.access_group @[[GROUP0]] -; CHECK: llvm.access_group @[[GROUP1]] +; CHECK: #[[DISTINCT:.*]] = #llvm.distinct_sequence +; CHECK: #[[GROUP0:.*]] = #llvm.access_group +; CHECK: #[[GROUP1:.*]] = #llvm.access_group +; CHECK: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation -; CHECK-LABEL: @multiple_parallel_accesses +; CHECK-LABEL: llvm.func @multiple_parallel_accesses define void @multiple_parallel_accesses(ptr %arg) { entry: %0 = load i32, ptr %arg, !llvm.access.group !0 diff --git a/mlir/test/Target/LLVMIR/loop-metadata.mlir b/mlir/test/Target/LLVMIR/loop-metadata.mlir --- a/mlir/test/Target/LLVMIR/loop-metadata.mlir +++ b/mlir/test/Target/LLVMIR/loop-metadata.mlir @@ -233,6 +233,10 @@ // ----- +#distinct_sequence = #llvm.distinct_sequence +#access_group0 = #llvm.access_group +#access_group1 = #llvm.access_group + // CHECK-LABEL: @loopOptions llvm.func @loopOptions(%arg1 : i32, %arg2 : i32) { %0 = llvm.mlir.constant(0 : i32) : i32 @@ -245,39 +249,34 @@ licm = , interleave = , unroll = , pipeline = , - parallelAccesses = @metadata::@group1, @metadata::@group2>} + parallelAccesses = #access_group0, #access_group1>} ^bb4: %3 = llvm.add %1, %arg2 : i32 // CHECK: = load i32, ptr %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]] - %5 = llvm.load %4 {access_groups = [@metadata::@group1, @metadata::@group2]} : !llvm.ptr -> i32 + %5 = llvm.load %4 {access_groups = [#access_group0, #access_group1]} : !llvm.ptr -> i32 // CHECK: store i32 %{{.*}}, ptr %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE]] - llvm.store %5, %4 {access_groups = [@metadata::@group1, @metadata::@group2]} : i32, !llvm.ptr + llvm.store %5, %4 {access_groups = [#access_group0, #access_group1]} : i32, !llvm.ptr // CHECK: = atomicrmw add ptr %{{.*}}, i32 %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE]] - %6 = llvm.atomicrmw add %4, %5 monotonic {access_groups = [@metadata::@group1, @metadata::@group2]} : !llvm.ptr, i32 + %6 = llvm.atomicrmw add %4, %5 monotonic {access_groups = [#access_group0, #access_group1]} : !llvm.ptr, i32 // CHECK: = cmpxchg ptr %{{.*}}, i32 %{{.*}}, i32 %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE]] - %7 = llvm.cmpxchg %4, %5, %6 acq_rel monotonic {access_groups = [@metadata::@group1, @metadata::@group2]} : !llvm.ptr, i32 + %7 = llvm.cmpxchg %4, %5, %6 acq_rel monotonic {access_groups = [#access_group0, #access_group1]} : !llvm.ptr, i32 %8 = llvm.mlir.constant(0 : i1) : i1 %9 = llvm.mlir.constant(42 : i8) : i8 // CHECK: llvm.memcpy{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE]] - "llvm.intr.memcpy"(%4, %4, %0, %8) {access_groups = [@metadata::@group1, @metadata::@group2]} : (!llvm.ptr, !llvm.ptr, i32, i1) -> () + "llvm.intr.memcpy"(%4, %4, %0, %8) {access_groups = [#access_group0, #access_group1]} : (!llvm.ptr, !llvm.ptr, i32, i1) -> () // CHECK: llvm.memset{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE]] - "llvm.intr.memset"(%4, %9, %0, %8) {access_groups = [@metadata::@group1, @metadata::@group2]} : (!llvm.ptr, i8, i32, i1) -> () + "llvm.intr.memset"(%4, %9, %0, %8) {access_groups = [#access_group0, #access_group1]} : (!llvm.ptr, i8, i32, i1) -> () // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]] llvm.br ^bb3(%3 : i32) {loop_annotation = #llvm.loop_annotation< licm = , interleave = , unroll = , pipeline = , - parallelAccesses = @metadata::@group1, @metadata::@group2>} + parallelAccesses = #access_group0, #access_group1>} ^bb5: llvm.return } -llvm.metadata @metadata { - llvm.access_group @group1 - llvm.access_group @group2 -} - // CHECK: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}} // CHECK-DAG: ![[PA_NODE:[0-9]+]] = !{!"llvm.loop.parallel_accesses", ![[GROUP_NODE1:[0-9]+]], ![[GROUP_NODE2:[0-9]+]]} // CHECK-DAG: ![[GROUP_NODE1:[0-9]+]] = distinct !{} diff --git a/mlir/test/mlir-tblgen/llvm-intrinsics.td b/mlir/test/mlir-tblgen/llvm-intrinsics.td --- a/mlir/test/mlir-tblgen/llvm-intrinsics.td +++ b/mlir/test/mlir-tblgen/llvm-intrinsics.td @@ -50,7 +50,7 @@ // It does not implement the alias analysis interface. // GROUPS: 0> // It has an access group attribute. -// GROUPS: OptionalAttr:$access_groups +// GROUPS: OptionalAttr:$access_groups //---------------------------------------------------------------------------// diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -217,7 +217,7 @@ llvm::SmallVector operands(intr.getNumOperands(), "LLVM_Type"); if (requiresAccessGroup) - operands.push_back("OptionalAttr:$access_groups"); + operands.push_back("OptionalAttr:$access_groups"); if (requiresAliasAnalysis) { operands.push_back("OptionalAttr:$alias_scopes"); operands.push_back("OptionalAttr:$noalias_scopes");