diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -34,6 +34,29 @@ /// Get the name of the attribute used to annotate external kernel /// functions. static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; } + /// Get the name of the attribute used to annotate max threads required + /// per CTA for kernel functions. + static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; } + /// Get the name of the metadata names for each dimension + static StringRef getMaxntidXName() { return "maxntidx"; } + static StringRef getMaxntidYName() { return "maxntidy"; } + static StringRef getMaxntidZName() { return "maxntidz"; } + + /// Get the name of the attribute used to annotate exact threads required + /// per CTA for kernel functions. + static StringRef getReqntidAttrName() { return "nvvm.reqntid"; } + /// Get the name of the metadata names for each dimension + static StringRef getReqntidXName() { return "reqntidx"; } + static StringRef getReqntidYName() { return "reqntidy"; } + static StringRef getReqntidZName() { return "reqntidz"; } + + /// Get the name of the attribute used to annotate min CTA required + /// per SM for kernel functions. + static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; } + + /// Get the name of the attribute used to annotate max number of + /// registers that can be allocated per thread. + static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; } }]; let useDefaultAttributePrinterParser = 1; diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -16,7 +16,9 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" @@ -27,6 +29,7 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; @@ -672,13 +675,37 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { + StringAttr attrName = attr.getName(); // Kernel function attribute should be attached to functions. - if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) { + if (attrName == NVVMDialect::getKernelFuncAttrName()) { if (!isa(op)) { return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() << "' attribute attached to unexpected op"; } } + // If maxntid and reqntid exist, it must be an array with max 3 dim + if (attrName == NVVMDialect::getMaxntidAttrName() || + attrName == NVVMDialect::getReqntidAttrName()) { + auto values = attr.getValue().dyn_cast(); + if (!values || values.empty() || values.size() > 3) + return op->emitError() + << "'" << attrName + << "' attribute must be integer array with maximum 3 index"; + for (auto val : attr.getValue().cast()) { + if (!val.dyn_cast()) + return op->emitError() + << "'" << attrName + << "' attribute must be integer array with maximum 3 index"; + } + } + // If minctasm and maxnreg exist, it must be an array with max 3 dim + if (attrName == NVVMDialect::getMinctasmAttrName() || + attrName == NVVMDialect::getMaxnregAttrName()) { + if (!attr.getValue().dyn_cast()) + return op->emitError() + << "'" << attrName << "' attribute must be integer constant"; + } + return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -13,7 +13,9 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/IR/IRBuilder.h" @@ -116,21 +118,59 @@ LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { - if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { - auto func = dyn_cast(op); - if (!func) - return failure(); + auto func = dyn_cast(op); + if (!func) + return failure(); + llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); + llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName()); - llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); - llvm::Function *llvmFunc = - moduleTranslation.lookupFunction(func.getName()); + auto generateMetadata = [&](int dim, StringRef name) { llvm::Metadata *llvmMetadata[] = { + llvm::ValueAsMetadata::get(llvmFunc), + llvm::MDString::get(llvmContext, name), + llvm::ValueAsMetadata::get(llvm::ConstantInt::get( + llvm::Type::getInt32Ty(llvmContext), dim))}; + llvm::MDNode *llvmMetadataNode = + llvm::MDNode::get(llvmContext, llvmMetadata); + moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations") + ->addOperand(llvmMetadataNode); + }; + if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { + if (!attribute.getValue().dyn_cast()) + return failure(); + SmallVector values = + extractFromI64ArrayAttr(attribute.getValue()); + generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName()); + if (values.size() > 1) + generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName()); + if (values.size() > 2) + generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName()); + } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { + if (!attribute.getValue().dyn_cast()) + return failure(); + SmallVector values = + extractFromI64ArrayAttr(attribute.getValue()); + generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName()); + if (values.size() > 1) + generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName()); + if (values.size() > 2) + generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName()); + } else if (attribute.getName() == + NVVM::NVVMDialect::getMinctasmAttrName()) { + auto value = attribute.getValue().dyn_cast(); + generateMetadata(value.getInt(), "minctasm"); + } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) { + auto value = attribute.getValue().dyn_cast(); + generateMetadata(value.getInt(), "maxnreg"); + } else if (attribute.getName() == + NVVM::NVVMDialect::getKernelFuncAttrName()) { + llvm::Metadata *llvmMetadataKernel[] = { llvm::ValueAsMetadata::get(llvmFunc), llvm::MDString::get(llvmContext, "kernel"), llvm::ValueAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))}; llvm::MDNode *llvmMetadataNode = - llvm::MDNode::get(llvmContext, llvmMetadata); + llvm::MDNode::get(llvmContext, llvmMetadataKernel); moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations") ->addOperand(llvmMetadataNode); } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s // CHECK-LABEL: @nvvm_special_regs llvm.func @nvvm_special_regs() -> i32 { @@ -349,3 +349,90 @@ // CHECK: !nvvm.annotations = // CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"kernel", i32 1} + +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"maxntidx", i32 1} +// CHECK: {ptr @kernel_func, !"maxntidy", i32 23} +// CHECK: {ptr @kernel_func, !"maxntidz", i32 32} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [1,23,32]} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"reqntidx", i32 1} +// CHECK: {ptr @kernel_func, !"reqntidy", i32 23} +// CHECK: {ptr @kernel_func, !"reqntidz", i32 32} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"minctasm", i32 16} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"maxnreg", i32 16} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32], + nvvm.minctasm = 16, nvvm.maxnreg = 32} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"maxnreg", i32 32} +// CHECK: {ptr @kernel_func, !"maxntidx", i32 1} +// CHECK: {ptr @kernel_func, !"maxntidy", i32 23} +// CHECK: {ptr @kernel_func, !"maxntidz", i32 32} +// CHECK: {ptr @kernel_func, !"minctasm", i32 16} + +// ----- +// expected-error @below {{'"nvvm.minctasm"' attribute must be integer constant}} +llvm.func @kernel_func() attributes {nvvm.kernel, +nvvm.minctasm = "foo"} { + llvm.return +} + + +// ----- +// expected-error @below {{'"nvvm.maxnreg"' attribute must be integer constant}} +llvm.func @kernel_func() attributes {nvvm.kernel, +nvvm.maxnreg = "boo"} { + llvm.return +} +// ----- +// expected-error @below {{'"nvvm.reqntid"' attribute must be integer array with maximum 3 index}} +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [3,4,5,6]} { + llvm.return +} + +// ----- +// expected-error @below {{'"nvvm.maxntid"' attribute must be integer array with maximum 3 index}} +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [3,4,5,6]} { + llvm.return +} +