diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -35,6 +35,7 @@ static StringRef getLoopAttrName() { return "llvm.loop"; } static StringRef getParallelAccessAttrName() { return "parallel_access"; } static StringRef getLoopOptionsAttrName() { return "options"; } + static StringRef getAccessGroupsAttrName() { return "access_groups"; } /// Verifies if the given string is a well-formed data layout descriptor. /// Uses `reportError` to report errors. @@ -247,7 +248,8 @@ // `llvm::Intrinsic` enum; one usually wants these to be related. class LLVM_IntrOpBase overloadedResults, list overloadedOperands, - list traits, int numResults> + list traits, int numResults, + bit requiresAccessGroup = 0> : LLVM_OpBase, Results { string resultPattern = !if(!gt(numResults, 1), @@ -264,19 +266,21 @@ overloadedOperands>.lst), ", ") # [{ }); auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - }] # !if(!gt(numResults, 0), "$res = ", "") - # [{builder.CreateCall(fn, operands); - }]; + }] # [{auto *inst = builder.CreateCall(fn, operands); + }] # !if(!gt(requiresAccessGroup, 0), + "moduleTranslation.setAccessGroupsMetadata(op, inst);", + "(void) inst;") + # !if(!gt(numResults, 0), "$res = inst;", ""); } // Base class for LLVM intrinsic operations, should not be used directly. Places // the intrinsic into the LLVM dialect and prefixes its name with "intr.". class LLVM_IntrOp overloadedResults, list overloadedOperands, list traits, - int numResults> + int numResults, bit requiresAccessGroup = 0> : LLVM_IntrOpBase; + numResults, requiresAccessGroup>; // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -287,6 +287,10 @@ inst->setMetadata(module->getMDKindID("nontemporal"), metadata); } }]; + + code setAccessGroupsMetadataCode = [{ + moduleTranslation.setAccessGroupsMetadata(op, inst); + }]; } // Memory-related operations. @@ -326,12 +330,13 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { let arguments = (ins LLVM_PointerTo:$addr, + OptionalAttr:$access_groups, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); let results = (outs LLVM_Type:$res); string llvmBuilder = [{ auto *inst = builder.CreateLoad($addr, $volatile_); - }] # setAlignmentCode # setNonTemporalMetadataCode # [{ + }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode # [{ $res = inst; }]; let builders = [ @@ -346,16 +351,18 @@ CArg<"bool", "false">:$isNonTemporal)>]; let parser = [{ return parseLoadOp(parser, result); }]; let printer = [{ printLoadOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; } def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes { let arguments = (ins LLVM_LoadableType:$value, LLVM_PointerTo:$addr, + OptionalAttr:$access_groups, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); string llvmBuilder = [{ auto *inst = builder.CreateStore($value, $addr, $volatile_); - }] # setAlignmentCode # setNonTemporalMetadataCode; + }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode; let builders = [ OpBuilder<(ins "Value":$value, "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, @@ -363,6 +370,7 @@ ]; let parser = [{ return parseStoreOp(parser, result); }]; let printer = [{ printStoreOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; } // Casts. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -128,6 +128,9 @@ "attempting to map loop options that was already mapped"); } + // Sets LLVM metadata for memory operations that are in a parallel loop. + void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst); + /// Converts the type from MLIR LLVM dialect to LLVM. llvm::Type *convertType(Type type); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -404,6 +404,34 @@ // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// +static LogicalResult verifyAccessGroups(Operation *op) { + if (Attribute attribute = + op->getAttr(LLVMDialect::getAccessGroupsAttrName())) { + // The attribute is already verified to be a symbol ref array attribute via + // a constraint in the operation definition. + for (SymbolRefAttr accessGroupRef : + attribute.cast().getAsRange()) { + StringRef metadataName = accessGroupRef.getRootReference(); + auto metadataOp = SymbolTable::lookupNearestSymbolFrom( + op->getParentOp(), metadataName); + if (!metadataOp) + return op->emitOpError() << "expected '" << accessGroupRef + << "' to reference a metadata op"; + StringRef accessGroupName = accessGroupRef.getLeafReference(); + Operation *accessGroupOp = + SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); + if (!accessGroupOp) + return op->emitOpError() << "expected '" << accessGroupRef + << "' to reference an access_group op"; + } + } + return success(); +} + +static LogicalResult verify(LoadOp op) { + return verifyAccessGroups(op.getOperation()); +} + void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, Value addr, unsigned alignment, bool isVolatile, bool isNonTemporal) { @@ -462,6 +490,10 @@ // Builder, printer and parser for LLVM::StoreOp. //===----------------------------------------------------------------------===// +static LogicalResult verify(StoreOp op) { + return verifyAccessGroups(op.getOperation()); +} + void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, Value addr, unsigned alignment, bool isVolatile, bool isNonTemporal) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -656,6 +656,27 @@ return success(); } +void ModuleTranslation::setAccessGroupsMetadata(Operation *op, + llvm::Instruction *inst) { + auto accessGroups = + op->getAttrOfType(LLVMDialect::getAccessGroupsAttrName()); + if (accessGroups && !accessGroups.empty()) { + llvm::Module *module = inst->getModule(); + SmallVector metadatas; + for (SymbolRefAttr accessGroupRef : + accessGroups.getAsRange()) + metadatas.push_back(getAccessGroup(*op, accessGroupRef)); + + llvm::MDNode *unionMD = nullptr; + if (metadatas.size() == 1) + unionMD = llvm::cast(metadatas.front()); + else if (metadatas.size() >= 2) + unionMD = llvm::MDNode::get(module->getContext(), metadatas); + + inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD); + } +} + llvm::Type *ModuleTranslation::convertType(Type type) { return typeTranslator.translateType(type); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -796,3 +796,39 @@ llvm.return } } + +// ----- + +module { + llvm.func @accessGroups(%arg0 : !llvm.ptr) { + // expected-error@below {{attribute 'access_groups' failed to satisfy constraint: symbol ref array attribute}} + %0 = llvm.load %arg0 { "access_groups" = "test" } : !llvm.ptr + llvm.return + } +} + +// ----- + +module { + llvm.func @accessGroups(%arg0 : !llvm.ptr) { + // expected-error@below {{expected '@func1' to reference a metadata op}} + %0 = llvm.load %arg0 { "access_groups" = [@func1] } : !llvm.ptr + llvm.return + } + llvm.func @func1() { + llvm.return + } +} + +// ----- + +module { + llvm.func @accessGroups(%arg0 : !llvm.ptr) { + // expected-error@below {{expected '@metadata' to reference an access_group op}} + %0 = llvm.load %arg0 { "access_groups" = [@metadata] } : !llvm.ptr + llvm.return + } + llvm.metadata @metadata { + llvm.return + } +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1483,6 +1483,7 @@ llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} ^bb4: %3 = llvm.add %1, %arg2 : i32 + // CHECK: = load i32, i32* %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]] %5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]] llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt, #llvm.loopopt, #llvm.loopopt]}} @@ -1504,3 +1505,4 @@ // CHECK: ![[UNROLL_DISABLE_NODE]] = !{!"llvm.loop.unroll.disable", i1 true} // CHECK: ![[LICM_DISABLE_NODE]] = !{!"llvm.licm.disable", i1 true} // CHECK: ![[INTERLEAVE_NODE]] = !{!"llvm.loop.interleave.count", i32 1} +// CHECK: ![[ACCESS_GROUPS_NODE]] = !{![[GROUP_NODE1]], ![[GROUP_NODE2]]} diff --git a/mlir/test/mlir-tblgen/llvm-intrinsics.td b/mlir/test/mlir-tblgen/llvm-intrinsics.td --- a/mlir/test/mlir-tblgen/llvm-intrinsics.td +++ b/mlir/test/mlir-tblgen/llvm-intrinsics.td @@ -23,11 +23,33 @@ // It has no side effects. // CHECK: [NoSideEffect] // It has a result. -// CHECK: 1> +// CHECK: 1, +// It does not require an access group. +// CHECK: 0> // CHECK: Arguments<(ins LLVM_Type, LLVM_Type //---------------------------------------------------------------------------// +// This checks that we can define an op that takes in an access group metadata. +// +// RUN: cat %S/../../../llvm/include/llvm/IR/Intrinsics.td \ +// RUN: | grep -v "llvm/IR/Intrinsics" \ +// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=ptrmask --llvmir-intrinsics-access-group-regexp=ptrmask \ +// RUN: | FileCheck --check-prefix=GROUPS %s + +// GROUPS-LABEL: def LLVM_ptrmask +// GROUPS: LLVM_IntrOp<"ptrmask +// It has no side effects. +// GROUPS: [NoSideEffect] +// It has a result. +// GROUPS: 1, +// It requires generation of an access group LLVM metadata. +// GROUPS: 1> +// It has an access group attribute. +// GROUPS: OptionalAttr:$access_groups + +//---------------------------------------------------------------------------// + // This checks that the ODS we produce can be consumed by MLIR tablegen. We only // make sure the entire process does not fail and produces some C++. The shape // of this C++ code is tested by ODS tests. diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/MachineValueType.h" #include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/Regex.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Main.h" @@ -37,6 +38,12 @@ "are planning to emit"), llvm::cl::init("LLVM_IntrOp"), llvm::cl::cat(IntrinsicGenCat)); +static llvm::cl::opt accessGroupRegexp( + "llvmir-intrinsics-access-group-regexp", + llvm::cl::desc("Mark intrinsics that match the specified " + "regexp as taking an access group metadata"), + llvm::cl::cat(IntrinsicGenCat)); + // Used to represent the indices of overloadable operands/results. using IndicesTy = llvm::SmallBitVector; @@ -185,6 +192,10 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) { LLVMIntrinsic intr(record); + llvm::Regex accessGroupMatcher(accessGroupRegexp); + bool requiresAccessGroup = + !accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName()); + // Prepare strings for traits, if any. llvm::SmallVector traits; if (intr.isCommutative()) @@ -195,6 +206,8 @@ // Prepare strings for operands. llvm::SmallVector operands(intr.getNumOperands(), "LLVM_Type"); + if (requiresAccessGroup) + operands.push_back("OptionalAttr:$access_groups"); // Emit the definition. os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass @@ -204,7 +217,8 @@ printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os); os << ", "; printBracketedRange(traits, os); - os << ", " << intr.getNumResults() << ">, Arguments<(ins" + os << ", " << intr.getNumResults() << ", " + << (requiresAccessGroup ? "1" : "0") << ">, Arguments<(ins" << (operands.empty() ? "" : " "); llvm::interleaveComma(operands, os); os << ")>;\n\n";