diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::LLVM; @@ -71,15 +72,34 @@ // For GPU kernels, // 1. Insert AMDGPU_KERNEL calling convention. - // 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute. + // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user + // has overriden this value - 256 is the default in clang // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on OpenCL // and HIP kernels per Clang) llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName()); llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); - llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 256"); + if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) { + llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 256"); + } llvmFunc->addFnAttr("amdgpu-implicitarg-num-bytes", "56"); } + // Override flat-work-group-size + if ("rocdl.max_flat_work_group_size" == attribute.getName()) { + auto func = dyn_cast(op); + if (!func) + return failure(); + auto value = attribute.getValue().dyn_cast(); + if (!value) + return failure(); + + llvm::Function *llvmFunc = + moduleTranslation.lookupFunction(func.getName()); + llvm::SmallString<8> llvmAttrValue; + llvm::raw_svector_ostream attrValueStream(llvmAttrValue); + attrValueStream << "1, " << value.getInt(); + llvmFunc->addFnAttr("amdgpu-max-flat-work-group-size", llvmAttrValue); + } return success(); } };