diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -34,6 +34,29 @@ /// Get the name of the attribute used to annotate external kernel /// functions. static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; } + /// Get the name of the attribute used to annotate max threads required + /// per CTA for kernel functions. + static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; } + /// Get the name of the metadata names for each dimension + static StringRef getMaxntidXName() { return "maxntidx"; } + static StringRef getMaxntidYName() { return "maxntidy"; } + static StringRef getMaxntidZName() { return "maxntidz"; } + + /// Get the name of the attribute used to annotate exact threads required + /// per CTA for kernel functions. + static StringRef getReqntidAttrName() { return "nvvm.reqntid"; } + /// Get the name of the metadata names for each dimension + static StringRef getReqntidXName() { return "reqntidx"; } + static StringRef getReqntidYName() { return "reqntidy"; } + static StringRef getReqntidZName() { return "reqntidz"; } + + /// Get the name of the attribute used to annotate min CTA required + /// per SM for kernel functions. + static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; } + + /// Get the name of the attribute used to annotate max number of + /// registers that can be allocated per thread. + static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; } }]; let useDefaultAttributePrinterParser = 1; diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Operation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" @@ -116,21 +117,65 @@ LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { - if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { - auto func = dyn_cast(op); - if (!func) - return failure(); + auto func = dyn_cast(op); + if (!func) + return failure(); + llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); + llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName()); - llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); - llvm::Function *llvmFunc = - moduleTranslation.lookupFunction(func.getName()); + auto generateMetadata = [&](int dim, StringRef name) { llvm::Metadata *llvmMetadata[] = { + llvm::ValueAsMetadata::get(llvmFunc), + llvm::MDString::get(llvmContext, name), + llvm::ValueAsMetadata::get(llvm::ConstantInt::get( + llvm::Type::getInt32Ty(llvmContext), dim))}; + llvm::MDNode *llvmMetadataNode = + llvm::MDNode::get(llvmContext, llvmMetadata); + moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations") + ->addOperand(llvmMetadataNode); + }; + if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { + if (!attribute.getValue().dyn_cast()) + return failure(); + SmallVector values = + extractFromI64ArrayAttr(attribute.getValue()); + if (!values.empty()) + generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName()); + if (values.size() > 1) + generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName()); + if (values.size() > 2) + generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName()); + } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { + if (!attribute.getValue().dyn_cast()) + return failure(); + SmallVector values = + extractFromI64ArrayAttr(attribute.getValue()); + if (!values.empty()) + generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName()); + if (values.size() > 1) + generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName()); + if (values.size() > 2) + generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName()); + } else if (attribute.getName() == + NVVM::NVVMDialect::getMinctasmAttrName()) { + auto value = attribute.getValue().dyn_cast(); + if (!value) + return success(); + generateMetadata(value.getInt(), "minctasm"); + } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) { + auto value = attribute.getValue().dyn_cast(); + if (!value) + return success(); + generateMetadata(value.getInt(), "maxnreg"); + } else if (attribute.getName() == + NVVM::NVVMDialect::getKernelFuncAttrName()) { + llvm::Metadata *llvmMetadataKernel[] = { llvm::ValueAsMetadata::get(llvmFunc), llvm::MDString::get(llvmContext, "kernel"), llvm::ValueAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))}; llvm::MDNode *llvmMetadataNode = - llvm::MDNode::get(llvmContext, llvmMetadata); + llvm::MDNode::get(llvmContext, llvmMetadataKernel); moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations") ->addOperand(llvmMetadataNode); } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file | FileCheck %s // CHECK-LABEL: @nvvm_special_regs llvm.func @nvvm_special_regs() -> i32 { @@ -349,3 +349,64 @@ // CHECK: !nvvm.annotations = // CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"kernel", i32 1} + +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"maxntidx", i32 1} +// CHECK: {ptr @kernel_func, !"maxntidy", i32 23} +// CHECK: {ptr @kernel_func, !"maxntidz", i32 32} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [1,23,32]} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"reqntidx", i32 1} +// CHECK: {ptr @kernel_func, !"reqntidy", i32 23} +// CHECK: {ptr @kernel_func, !"reqntidz", i32 32} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"minctasm", i32 16} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"maxnreg", i32 16} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32], + nvvm.minctasm = 16, nvvm.maxnreg = 32} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"maxnreg", i32 32} +// CHECK: {ptr @kernel_func, !"maxntidx", i32 1} +// CHECK: {ptr @kernel_func, !"maxntidy", i32 23} +// CHECK: {ptr @kernel_func, !"maxntidz", i32 32} +// CHECK: {ptr @kernel_func, !"minctasm", i32 16} +