diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -32,6 +32,8 @@ static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; } static StringRef getAlignAttrName() { return "llvm.align"; } static StringRef getNoAliasAttrName() { return "llvm.noalias"; } + static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; } + static StringRef getAliasScopesAttrName() { return "alias_scopes"; } static StringRef getLoopAttrName() { return "llvm.loop"; } static StringRef getParallelAccessAttrName() { return "parallel_access"; } static StringRef getLoopOptionsAttrName() { return "options"; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -286,6 +286,10 @@ code setAccessGroupsMetadataCode = [{ moduleTranslation.setAccessGroupsMetadata(op, inst); }]; + + code setAliasScopeMetadataCode = [{ + moduleTranslation.setAliasScopeMetadata(op, inst); + }]; } // Memory-related operations. @@ -329,13 +333,19 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { let arguments = (ins LLVM_PointerTo:$addr, OptionalAttr:$access_groups, + OptionalAttr:$alias_scopes, + OptionalAttr:$noalias_scopes, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); let results = (outs LLVM_LoadableType:$res); string llvmBuilder = [{ auto *inst = builder.CreateLoad( $addr->getType()->getPointerElementType(), $addr, $volatile_); - }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode # [{ + }] # setAlignmentCode + # setNonTemporalMetadataCode + # setAccessGroupsMetadataCode + # setAliasScopeMetadataCode + # [{ $res = inst; }]; let builders = [ @@ -357,11 +367,16 @@ let arguments = (ins LLVM_LoadableType:$value, LLVM_PointerTo:$addr, OptionalAttr:$access_groups, + OptionalAttr:$alias_scopes, + OptionalAttr:$noalias_scopes, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); string llvmBuilder = [{ auto *inst = builder.CreateStore($value, $addr, $volatile_); - }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode; + }] # setAlignmentCode + # setNonTemporalMetadataCode + # setAccessGroupsMetadataCode + # setAliasScopeMetadataCode; let builders = [ OpBuilder<(ins "Value":$value, "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, @@ -876,8 +891,7 @@ ); let summary = "LLVM dialect metadata."; let description = [{ - llvm.metadata op defines one or more metadata nodes. Currently the - llvm.access_group metadata op is supported. + llvm.metadata op defines one or more metadata nodes. Example: llvm.metadata @metadata { @@ -890,6 +904,66 @@ let assemblyFormat = "$sym_name attr-dict-with-keyword $body"; } +def LLVM_AliasScopeDomainMetadataOp : LLVM_Op<"alias_scope_domain", [ + HasParent<"MetadataOp">, Symbol +]> { + let arguments = (ins + SymbolNameAttr:$sym_name, + OptionalAttr:$description + ); + let summary = "LLVM dialect alias.scope domain metadata."; + let description = [{ + Defines a domain that may be associated with an alias scope. + + See the following link for more details: + https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata + }]; + let assemblyFormat = "$sym_name attr-dict"; +} + +def LLVM_AliasScopeMetadataOp : LLVM_Op<"alias_scope", [ + HasParent<"MetadataOp">, Symbol +]> { + let arguments = (ins + SymbolNameAttr:$sym_name, + FlatSymbolRefAttr:$domain, + OptionalAttr:$description + ); + let summary = "LLVM dialect alias.scope metadata."; + let description = [{ + Defines an alias scope that can be attached to a memory-accessing operation. + Such scopes can be used in combination with `noalias` metadata to indicate + that sets of memory-affecting operations in one scope do not alias with + memory-affecting operations in another scope. + + Example: + module { + llvm.func @foo(%ptr1 : !llvm.ptr) { + %c0 = llvm.mlir.constant(0 : i32) : i32 + %c4 = llvm.mlir.constant(4 : i32) : i32 + %1 = llvm.ptrtoint %ptr1 : !llvm.ptr to i32 + %2 = llvm.add %1, %c1 : i32 + %ptr2 = llvm.inttoptr %2 : i32 to !llvm.ptr + llvm.store %c0, %ptr1 { alias_scopes = [@metadata::@scope1], llvm.noalias = [@metadata::@scope2] } : !llvm.ptr + llvm.store %c4, %ptr2 { alias_scopes = [@metadata::@scope2], llvm.noalias = [@metadata::@scope1] } : !llvm.ptr + llvm.return + } + + llvm.metadata @metadata { + llvm.alias_scope_domain @unused_domain + llvm.alias_scope_domain @domain { description = "Optional domain description"} + llvm.alias_scope @scope1 { domain = @domain } + llvm.alias_scope @scope2 { domain = @domain, description = "Optional scope description" } + llvm.return + } + } + + See the following link for more details: + https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata + }]; + let assemblyFormat = "$sym_name attr-dict"; +} + def LLVM_AccessGroupMetadataOp : LLVM_Op<"access_group", [ HasParent<"MetadataOp">, Symbol ]> { diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -115,6 +115,11 @@ llvm::MDNode *getAccessGroup(Operation &opInst, SymbolRefAttr accessGroupRef) const; + /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM + /// dialect alias scope operation + llvm::MDNode *getAliasScope(Operation &opInst, + SymbolRefAttr aliasScopeRef) const; + /// Returns the LLVM metadata corresponding to a llvm loop's codegen /// options attribute. llvm::MDNode *lookupLoopOptionsMetadata(Attribute options) const { @@ -131,6 +136,9 @@ // Sets LLVM metadata for memory operations that are in a parallel loop. void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst); + // Sets LLVM metadata for memory operations that have alias scope information. + void setAliasScopeMetadata(Operation *op, llvm::Instruction *inst); + /// Converts the type from MLIR LLVM dialect to LLVM. llvm::Type *convertType(Type type); @@ -268,6 +276,10 @@ /// metadata nodes. LogicalResult createAccessGroupMetadata(); + /// Process alias.scope LLVM Metadata operations and create LLVM + /// metadata nodes for them and their domains. + LogicalResult createAliasScopeMetadata(); + /// Translates dialect attributes attached to the given operation. LogicalResult convertDialectAttributes(Operation *op); @@ -300,7 +312,7 @@ /// values after all operations are converted. DenseMap branchMapping; - /// Mapping from an access group metadata optation to its LLVM metadata. + /// Mapping from an access group metadata operation to its LLVM metadata. /// This map is populated on module entry and is used to annotate loops (as /// identified via their branches) and contained memory accesses. DenseMap accessGroupMetadataMapping; @@ -310,6 +322,10 @@ /// attribute. DenseMap loopOptionsMetadataMapping; + /// Mapping from an access scope metadata operation to its LLVM metadata. + /// This map is populated on module entry. + DenseMap aliasScopeMetadataMapping; + /// Stack of user-specified state elements, useful when translating operations /// with regions. SmallVector> stack; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -335,32 +335,76 @@ // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// -static LogicalResult verifyAccessGroups(Operation *op) { - if (Attribute attribute = - op->getAttr(LLVMDialect::getAccessGroupsAttrName())) { +LogicalResult verifySymbolAttribute( + Operation *op, StringRef attributeName, + std::function verifySymbolType) { + if (Attribute attribute = op->getAttr(attributeName)) { // The attribute is already verified to be a symbol ref array attribute via // a constraint in the operation definition. - for (SymbolRefAttr accessGroupRef : + for (SymbolRefAttr symbolRef : attribute.cast().getAsRange()) { - StringRef metadataName = accessGroupRef.getRootReference(); + StringRef metadataName = symbolRef.getRootReference(); + StringRef symbolName = symbolRef.getLeafReference(); + // We want @metadata::@symbol, not just @symbol + if (metadataName == symbolName) { + return op->emitOpError() << "expected '" << symbolRef + << "' to specify a fully qualified reference"; + } auto metadataOp = SymbolTable::lookupNearestSymbolFrom( op->getParentOp(), metadataName); if (!metadataOp) - return op->emitOpError() << "expected '" << accessGroupRef - << "' to reference a metadata op"; - StringRef accessGroupName = accessGroupRef.getLeafReference(); - Operation *accessGroupOp = - SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); - if (!accessGroupOp) - return op->emitOpError() << "expected '" << accessGroupRef - << "' to reference an access_group op"; + return op->emitOpError() + << "expected '" << symbolRef << "' to reference a metadata op"; + Operation *symbolOp = + SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); + if (!symbolOp) + return op->emitOpError() + << "expected '" << symbolRef << "' to be a valid reference"; + if (failed(verifySymbolType(symbolOp, symbolRef))) { + return failure(); + } } } return success(); } +// Verifies that metadata ops are wired up properly. +template +static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { + auto verifySymbolType = [op](Operation *symbolOp, + SymbolRefAttr symbolRef) -> LogicalResult { + if (!isa(symbolOp)) { + return op->emitOpError() + << "expected '" << symbolRef << "' to resolve to a " + << OpTy::getOperationName(); + } + return success(); + }; + + return verifySymbolAttribute(op, attributeName, verifySymbolType); +} + +static LogicalResult verifyMemoryOpMetadata(Operation *op) { + // access_groups + if (failed(verifyOpMetadata( + op, LLVMDialect::getAccessGroupsAttrName()))) + return failure(); + + // alias_scopes + if (failed(verifyOpMetadata( + op, LLVMDialect::getAliasScopesAttrName()))) + return failure(); + + // noalias_scopes + if (failed(verifyOpMetadata( + op, LLVMDialect::getNoAliasScopesAttrName()))) + return failure(); + + return success(); +} + static LogicalResult verify(LoadOp op) { - return verifyAccessGroups(op.getOperation()); + return verifyMemoryOpMetadata(op.getOperation()); } void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, @@ -422,7 +466,7 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(StoreOp op) { - return verifyAccessGroups(op.getOperation()); + return verifyMemoryOpMetadata(op.getOperation()); } void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -774,6 +774,78 @@ } } +LogicalResult ModuleTranslation::createAliasScopeMetadata() { + mlirModule->walk([&](LLVM::MetadataOp metadatas) { + // Create the domains first, so they can be reference below in the scopes. + DenseMap aliasScopeDomainMetadataMapping; + metadatas.walk([&](LLVM::AliasScopeDomainMetadataOp op) { + llvm::LLVMContext &ctx = llvmModule->getContext(); + llvm::SmallVector operands; + operands.push_back({}); // Placeholder for self-reference + if (Optional description = op.description()) + operands.push_back(llvm::MDString::get(ctx, description.getValue())); + llvm::MDNode *domain = llvm::MDNode::get(ctx, operands); + domain->replaceOperandWith(0, domain); // Self-reference for uniqueness + aliasScopeDomainMetadataMapping.insert({op, domain}); + }); + + // Now create the scopes, referencing the domains created above. + metadatas.walk([&](LLVM::AliasScopeMetadataOp op) { + llvm::LLVMContext &ctx = llvmModule->getContext(); + assert(isa(op->getParentOp())); + auto metadataOp = dyn_cast(op->getParentOp()); + Operation *domainOp = + SymbolTable::lookupNearestSymbolFrom(metadataOp, op.domainAttr()); + llvm::MDNode *domain = aliasScopeDomainMetadataMapping.lookup(domainOp); + assert(domain && "Scope's domain should already be valid"); + llvm::SmallVector operands; + operands.push_back({}); // Placeholder for self-reference + operands.push_back(domain); + if (Optional description = op.description()) + operands.push_back(llvm::MDString::get(ctx, description.getValue())); + llvm::MDNode *scope = llvm::MDNode::get(ctx, operands); + scope->replaceOperandWith(0, scope); // Self-reference for uniqueness + aliasScopeMetadataMapping.insert({op, scope}); + }); + }); + return success(); +} + +llvm::MDNode * +ModuleTranslation::getAliasScope(Operation &opInst, + SymbolRefAttr aliasScopeRef) const { + StringRef metadataName = aliasScopeRef.getRootReference(); + StringRef scopeName = aliasScopeRef.getLeafReference(); + auto metadataOp = SymbolTable::lookupNearestSymbolFrom( + opInst.getParentOp(), metadataName); + Operation *aliasScopeOp = + SymbolTable::lookupNearestSymbolFrom(metadataOp, scopeName); + return aliasScopeMetadataMapping.lookup(aliasScopeOp); +} + +void ModuleTranslation::setAliasScopeMetadata(Operation *op, + llvm::Instruction *inst) { + auto populateScopeMetadata = [this, op, inst](StringRef attrName, + StringRef llvmMetadataName) { + auto scopes = op->getAttrOfType(attrName); + if (!scopes || scopes.empty()) + return; + llvm::Module *module = inst->getModule(); + SmallVector scopeMDs; + for (SymbolRefAttr scopeRef : scopes.getAsRange()) + scopeMDs.push_back(getAliasScope(*op, scopeRef)); + llvm::MDNode *unionMD = nullptr; + if (scopeMDs.size() == 1) + unionMD = llvm::cast(scopeMDs.front()); + else if (scopeMDs.size() >= 2) + unionMD = llvm::MDNode::get(module->getContext(), scopeMDs); + inst->setMetadata(module->getMDKindID(llvmMetadataName), unionMD); + }; + + populateScopeMetadata(LLVMDialect::getAliasScopesAttrName(), "alias.scope"); + populateScopeMetadata(LLVMDialect::getNoAliasScopesAttrName(), "noalias"); +} + llvm::Type *ModuleTranslation::convertType(Type type) { return typeTranslator.translateType(type); } @@ -842,6 +914,8 @@ return nullptr; if (failed(translator.createAccessGroupMetadata())) return nullptr; + if (failed(translator.createAliasScopeMetadata())) + return nullptr; if (failed(translator.convertFunctions())) return nullptr; if (llvm::verifyModule(*translator.llvmModule, &llvm::errs())) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -909,7 +909,7 @@ module { llvm.func @accessGroups(%arg0 : !llvm.ptr) { - // expected-error@below {{expected '@func1' to reference a metadata op}} + // expected-error@below {{expected '@func1' to specify a fully qualified reference}} %0 = llvm.load %arg0 { "access_groups" = [@func1] } : !llvm.ptr llvm.return } @@ -922,11 +922,87 @@ module { llvm.func @accessGroups(%arg0 : !llvm.ptr) { - // expected-error@below {{expected '@metadata' to reference an access_group op}} - %0 = llvm.load %arg0 { "access_groups" = [@metadata] } : !llvm.ptr + // expected-error@below {{expected '@accessGroups::@group1' to reference a metadata op}} + %0 = llvm.load %arg0 { "access_groups" = [@accessGroups::@group1] } : !llvm.ptr + llvm.return + } + llvm.metadata @metadata { + llvm.return + } +} + +// ----- + +module { + llvm.func @accessGroups(%arg0 : !llvm.ptr) { + // expected-error@below {{expected '@metadata::@group1' to be a valid reference}} + %0 = llvm.load %arg0 { "access_groups" = [@metadata::@group1] } : !llvm.ptr + llvm.return + } + llvm.metadata @metadata { + llvm.return + } +} + +// ----- + +module { + llvm.func @accessGroups(%arg0 : !llvm.ptr) { + // expected-error@below {{expected '@metadata::@scope' to resolve to a llvm.access_group}} + %0 = llvm.load %arg0 { "access_groups" = [@metadata::@scope] } : !llvm.ptr + llvm.return + } + llvm.metadata @metadata { + llvm.alias_scope_domain @domain + llvm.alias_scope @scope { domain = @domain } + llvm.return + } +} + +// ----- + +module { + llvm.func @accessGroups(%arg0 : !llvm.ptr) { + // expected-error@below {{attribute 'alias_scopes' failed to satisfy constraint: symbol ref array attribute}} + %0 = llvm.load %arg0 { "alias_scopes" = "test" } : !llvm.ptr + llvm.return + } +} + +// ----- + +module { + llvm.func @accessGroups(%arg0 : !llvm.ptr) { + // expected-error@below {{attribute 'noalias_scopes' failed to satisfy constraint: symbol ref array attribute}} + %0 = llvm.load %arg0 { "noalias_scopes" = "test" } : !llvm.ptr + llvm.return + } +} + +// ----- + +module { + llvm.func @aliasScope(%arg0 : !llvm.ptr) { + // expected-error@below {{expected '@metadata::@group' to resolve to a llvm.alias_scope}} + %0 = llvm.load %arg0 { "alias_scopes" = [@metadata::@group] } : !llvm.ptr + llvm.return + } + llvm.metadata @metadata { + llvm.access_group @group + llvm.return + } +} + +// ----- + +module { + llvm.func @aliasScope(%arg0 : !llvm.ptr) { + // expected-error@below {{expected '@metadata::@group' to resolve to a llvm.alias_scope}} + %0 = llvm.load %arg0 { "noalias_scopes" = [@metadata::@group] } : !llvm.ptr llvm.return } llvm.metadata @metadata { + llvm.access_group @group llvm.return } } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1594,3 +1594,39 @@ // CHECK: ![[PIPELINE_DISABLE_NODE]] = !{!"llvm.loop.pipeline.disable", i1 true} // CHECK: ![[II_NODE]] = !{!"llvm.loop.pipeline.initiationinterval", i32 2} // CHECK: ![[ACCESS_GROUPS_NODE]] = !{![[GROUP_NODE1]], ![[GROUP_NODE2]]} + +// ----- + +module { + llvm.func @aliasScope(%arg1 : !llvm.ptr, %arg2 : !llvm.ptr, %arg3 : !llvm.ptr) { + %0 = llvm.mlir.constant(0 : i32) : i32 + llvm.store %0, %arg1 { alias_scopes = [@metadata::@scope1], noalias_scopes = [@metadata::@scope2, @metadata::@scope3] } : !llvm.ptr + llvm.store %0, %arg2 { alias_scopes = [@metadata::@scope2], noalias_scopes = [@metadata::@scope1, @metadata::@scope3] } : !llvm.ptr + %1 = llvm.load %arg3 { alias_scopes = [@metadata::@scope3], noalias_scopes = [@metadata::@scope1, @metadata::@scope2] } : !llvm.ptr + llvm.return + } + + llvm.metadata @metadata { + llvm.alias_scope_domain @domain { description = "The domain"} + llvm.alias_scope @scope1 { domain = @domain, description = "The first scope" } + llvm.alias_scope @scope2 { domain = @domain } + llvm.alias_scope @scope3 { domain = @domain } + llvm.return + } +} + +// Function +// CHECK-LABEL: aliasScope +// CHECK: store {{.*}}, !alias.scope ![[SCOPE1:[0-9]+]], !noalias ![[SCOPES23:[0-9]+]] +// CHECK: store {{.*}}, !alias.scope ![[SCOPE2:[0-9]+]], !noalias ![[SCOPES13:[0-9]+]] +// CHECK: load {{.*}}, !alias.scope ![[SCOPE3:[0-9]+]], !noalias ![[SCOPES12:[0-9]+]] + +// Metadata +// CHECK-DAG: ![[DOMAIN:[0-9]+]] = distinct !{![[DOMAIN]], !"The domain"} +// CHECK-DAG: ![[SCOPE1]] = distinct !{![[SCOPE1]], ![[DOMAIN]], !"The first scope"} +// CHECK-DAG: ![[SCOPE2]] = distinct !{![[SCOPE2]], ![[DOMAIN]]} +// CHECK-DAG: ![[SCOPE3]] = distinct !{![[SCOPE3]], ![[DOMAIN]]} +// CHECK-DAG: ![[SCOPES12]] = !{![[SCOPE1]], ![[SCOPE2]]} +// CHECK-DAG: ![[SCOPES13]] = !{![[SCOPE1]], ![[SCOPE3]]} +// CHECK-DAG: ![[SCOPES23]] = !{![[SCOPE2]], ![[SCOPE3]]} +