diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUCompilationAttr.td b/mlir/include/mlir/Dialect/GPU/IR/GPUCompilationAttr.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUCompilationAttr.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUCompilationAttr.td @@ -179,4 +179,31 @@ }]; } +//===----------------------------------------------------------------------===// +// GPU object manager attributes. +//===----------------------------------------------------------------------===// + +def GPU_SelectObjectAttr : GPU_Attr<"SelectObject", "select_object", [ + DeclareAttrInterfaceMethods + ]> { + let description = [{ + This GPU object manager selects a single GPU object for embedding. The + object is selected based on the `target` parameter, this parameter can be + either a number -i.e. selects the ith-target, or the target itself -i.e. + searches for the specified target in the object array. + + If no target is given, it selects the first object in the array. + }]; + let parameters = (ins + OptionalParameter<"Attribute", "Object target to embed.">: + $target + ); + let assemblyFormat = [{ + (`<` $target^ `>`)? + }]; + let genVerifyDecl = 1; +} + #endif // GPU_COMPILATIONATTR diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1109,9 +1109,18 @@ - An attribute implementing the `ObjectManagerAttrInterface` interface. - An array of GPU object attributes. + If no `objectManager` is present in the assembly format, then an implicit + `#gpu.select_object` attribute is added. + + Examples: + 1. GPU binary with implicit `#gpu.select_object` object manager attribute. ``` gpu.binary @myobject [#gpu.object<...>, #gpu.object<...>] ``` + 2. GPU binary with an explicit `#gpu.select_object<1>` object manager attribute. + ``` + gpu.binary @myobject #gpu.select_object<1> [#gpu.object<...>, #gpu.object<...>] + ``` }]; let builders = [ OpBuilder<(ins "StringRef":$name, "Attribute":$objectManager, diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -50,6 +50,7 @@ add_mlir_dialect_library(MLIRGPUTargets Targets/AMDGPUTarget.cpp Targets/NVPTXTarget.cpp + Targets/ObjectHandler.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1601,19 +1601,28 @@ result.attributes.push_back(builder.getNamedAttr( SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); properties.objects = objects; - properties.objectManager = manager; + if (manager) + properties.objectManager = manager; + else + properties.objectManager = builder.getAttr(nullptr); } static ParseResult parseObjectManager(OpAsmParser &parser, Attribute &objectManager) { - if (parser.parseAttribute(objectManager)) - return failure(); + if (succeeded(parser.parseOptionalLess())) { + if (parser.parseAttribute(objectManager)) + return failure(); + if (parser.parseGreater()) + return failure(); + } + if (!objectManager) + objectManager = parser.getBuilder().getAttr(nullptr); return success(); } static void printObjectManager(OpAsmPrinter &printer, Operation *op, Attribute objectManager) { - if (objectManager) + if (objectManager != SelectObjectAttr::get(op->getContext(), nullptr)) printer << '<' << objectManager << '>'; } diff --git a/mlir/lib/Dialect/GPU/Targets/ObjectHandler.cpp b/mlir/lib/Dialect/GPU/Targets/ObjectHandler.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Targets/ObjectHandler.cpp @@ -0,0 +1,350 @@ +//===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements base ObjectManager attributes, like the default +// SelectObject attribute. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" + +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +namespace { +std::string getBinaryIdentifier(StringRef moduleName) { + return moduleName.str() + "_bin_cst"; +} +} // namespace + +LogicalResult +gpu::SelectObjectAttr::verify(function_ref emitError, + Attribute target) { + // Check `target`, it can be null, an integer attr or an attr implementing + // `TargetAttrInterface`. + if (target) { + if (auto intAttr = mlir::dyn_cast(target)) { + if (intAttr.getInt() < 0) { + return emitError() << "The object index must be positive."; + } + } else if (!target.hasTrait()) { + return emitError() + << "The target attribute must implement the `TargetAttrInterface` " + "interface or be an `IntegerAttr`."; + } + } + return success(); +} + +LogicalResult gpu::SelectObjectAttr::embedBinary( + Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + + auto binaryOp = mlir::dyn_cast(op); + assert(binaryOp && "Op is not a BinaryOp."); + + ArrayRef objects = binaryOp.getObjectsAttr().getValue(); + + // Obtain the index of the object to select. + int64_t index = -1; + if (Attribute target = getTarget()) { + // If the target attribute is a number it is the index. Otherwise compare + // the attribute to every target inside the object array to find the index. + if (auto indexAttr = mlir::dyn_cast(target)) { + index = indexAttr.getInt(); + } else { + for (auto [i, attr] : llvm::enumerate(objects)) { + auto obj = mlir::dyn_cast(attr); + if (obj.getTarget() == target) { + index = i; + } + } + } + } else { + // If the target attribute is null then it's selecting the first object in + // the object array. + index = 0; + } + + if (index < 0 || index >= static_cast(objects.size())) { + op->emitError("The requested target object couldn't be found."); + return failure(); + } + ObjectAttr object = mlir::dyn_cast(objects[index]); + + llvm::Module *module = moduleTranslation.getLLVMModule(); + + // Embed the object as a global string. + llvm::Constant *binary = llvm::ConstantDataArray::getString( + builder.getContext(), object.getObject().getValue(), false); + llvm::GlobalVariable *serializedObj = + new llvm::GlobalVariable(*module, binary->getType(), true, + llvm::GlobalValue::LinkageTypes::InternalLinkage, + binary, getBinaryIdentifier(binaryOp.getName())); + serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage); + serializedObj->setAlignment(llvm::MaybeAlign(8)); + serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None); + return success(); +} + +namespace llvm { +namespace { +class LaunchKernel { +public: + LaunchKernel(Module &module, IRBuilderBase &builder, + mlir::LLVM::ModuleTranslation &moduleTranslation); + // Get the kernel launch callee. + FunctionCallee getKernelLaunchFn(); + + // Get the module function callee. + FunctionCallee getModuleFunctionFn(); + + // Get the module load callee. + FunctionCallee getModuleLoadFn(); + + // Get the module unload callee. + FunctionCallee getModuleUnloadFn(); + + // Get the stream create callee. + FunctionCallee getStreamCreateFn(); + + // Get the stream destroy callee. + FunctionCallee getStreamDestroyFn(); + + // Get the stream sync callee. + FunctionCallee getStreamSyncFn(); + + // Ger or create the function name global string. + Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName); + + // Create the void* kernel array for passing the arguments. + Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op); + + // Create the full kernel launch. + mlir::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op); + +private: + Module &module; + IRBuilderBase &builder; + mlir::LLVM::ModuleTranslation &moduleTranslation; + Type *i32Ty{}; + Type *voidTy{}; + Type *intPtrTy{}; + PointerType *ptrTy{}; +}; +} // namespace +} // namespace llvm + +llvm::LaunchKernel::LaunchKernel( + Module &module, IRBuilderBase &builder, + mlir::LLVM::ModuleTranslation &moduleTranslation) + : module(module), builder(builder), moduleTranslation(moduleTranslation) { + i32Ty = builder.getInt32Ty(); + ptrTy = builder.getPtrTy(0); + voidTy = builder.getVoidTy(); + intPtrTy = builder.getIntPtrTy(module.getDataLayout()); +} + +llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() { + return module.getOrInsertFunction( + "mgpuLaunchKernel", + FunctionType::get( + voidTy, + ArrayRef({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy, + intPtrTy, intPtrTy, i32Ty, ptrTy, ptrTy, ptrTy}), + false)); +} + +llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() { + return module.getOrInsertFunction( + "mgpuModuleGetFunction", + FunctionType::get(ptrTy, ArrayRef({ptrTy, ptrTy}), false)); +} + +llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() { + return module.getOrInsertFunction( + "mgpuModuleLoad", + FunctionType::get(ptrTy, ArrayRef({ptrTy}), false)); +} + +llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() { + return module.getOrInsertFunction( + "mgpuModuleUnload", + FunctionType::get(voidTy, ArrayRef({ptrTy}), false)); +} + +llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() { + return module.getOrInsertFunction("mgpuStreamCreate", + FunctionType::get(ptrTy, false)); +} + +llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() { + return module.getOrInsertFunction( + "mgpuStreamDestroy", + FunctionType::get(voidTy, ArrayRef({ptrTy}), false)); +} + +llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() { + return module.getOrInsertFunction( + "mgpuStreamSynchronize", + FunctionType::get(voidTy, ArrayRef({ptrTy}), false)); +} + +// Generates an LLVM IR dialect global that contains the name of the given +// kernel function as a C string, and returns a pointer to its beginning. +llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName, + StringRef kernelName) { + std::string globalName = + std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName)); + + if (GlobalVariable *gv = module.getGlobalVariable(globalName)) + return gv; + + return builder.CreateGlobalString(kernelName, globalName); +} + +// Creates a struct containing all kernel parameters on the stack and returns +// an array of type-erased pointers to the fields of the struct. The array can +// then be passed to the CUDA / ROCm (HIP) kernel launch calls. +// The generated code is essentially as follows: +// +// %struct = alloca(sizeof(struct { Parameters... })) +// %array = alloca(NumParameters * sizeof(void *)) +// for (i : [0, NumParameters)) +// %fieldPtr = llvm.getelementptr %struct[0, i] +// llvm.store parameters[i], %fieldPtr +// %elementPtr = llvm.getelementptr %array[i] +// llvm.store %fieldPtr, %elementPtr +// return %array +llvm::Value * +llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) { + SmallVector args = + moduleTranslation.lookupValues(op.getKernelOperands()); + SmallVector structTypes(args.size(), nullptr); + + for (auto [i, arg] : llvm::enumerate(args)) + structTypes[i] = arg->getType(); + + Type *structTy = StructType::create(module.getContext(), structTypes); + Value *argStruct = builder.CreateAlloca(structTy, 0u); + Value *argArray = builder.CreateAlloca( + ptrTy, ConstantInt::get(intPtrTy, structTypes.size())); + + for (auto [i, arg] : enumerate(args)) { + Value *structMember = builder.CreateStructGEP(structTy, argStruct, i); + builder.CreateStore(arg, structMember); + Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i); + builder.CreateStore(structMember, arrayMember); + } + return argArray; +} + +// Emits LLVM IR to launch a kernel function. Expects the module that contains +// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a +// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. +// +// %0 = call %binarygetter +// %1 = call %moduleLoad(%0) +// %2 = +// %3 = call %moduleGetFunction(%1, %2) +// %4 = call %streamCreate() +// %5 = +// call %launchKernel(%3, , 0, %4, %5, nullptr) +// call %streamSynchronize(%4) +// call %streamDestroy(%4) +// call %moduleUnload(%1) +mlir::LogicalResult +llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op) { + auto llvmValue = [&](mlir::Value value) -> Value * { + Value *v = moduleTranslation.lookupValue(value); + assert(v && "Value has not been translated."); + return v; + }; + + // Get grid dimensions. + mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues(); + Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y), + *gz = llvmValue(grid.z); + + // Get block dimensions. + mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues(); + Value *bx = llvmValue(block.x), *by = llvmValue(block.y), + *bz = llvmValue(block.z); + + // Get dynamic shared memory size. + Value *dynamicMemorySize = nullptr; + if (mlir::Value dynSz = op.getDynamicSharedMemorySize()) + dynamicMemorySize = llvmValue(dynSz); + else + dynamicMemorySize = ConstantInt::get(i32Ty, 0); + + // Create the argument array. + Value *argArray = createKernelArgArray(op); + + // Load the kernel module. + StringRef moduleName = op.getKernelModuleName().getValue(); + std::string binaryIdentifier = getBinaryIdentifier(moduleName); + Value *binary = module.getGlobalVariable(binaryIdentifier, true); + if (!binary) + return op.emitError() << "Couldn't find the binary: " << binaryIdentifier; + Value *moduleObject = builder.CreateCall(getModuleLoadFn(), {binary}); + + // Load the kernel function. + Value *moduleFunction = builder.CreateCall( + getModuleFunctionFn(), + {moduleObject, + getOrCreateFunctionName(moduleName, op.getKernelName().getValue())}); + + // Get the stream to use for execution. If there's no async object then create + // a stream to make a synchronous kernel launch. + Value *stream = nullptr; + bool handleStream = false; + if (mlir::Value asyncObject = op.getAsyncObject()) { + stream = llvmValue(asyncObject); + } else { + handleStream = true; + stream = builder.CreateCall(getStreamCreateFn(), {}); + } + + // Create the launch call. + Value *nullPtr = ConstantPointerNull::get(ptrTy); + builder.CreateCall( + getKernelLaunchFn(), + ArrayRef({moduleFunction, gx, gy, gz, bx, by, bz, + dynamicMemorySize, stream, argArray, nullPtr})); + + // Sync & destroy the stream, for synchronous launches. + if (handleStream) { + builder.CreateCall(getStreamSyncFn(), {stream}); + builder.CreateCall(getStreamDestroyFn(), {stream}); + } + + // Unload the kernel module. + builder.CreateCall(getModuleUnloadFn(), {moduleObject}); + + return success(); +} + +LogicalResult gpu::SelectObjectAttr::launchKernel( + Operation *launchFuncOperation, Operation *binaryOperation, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + llvm::Module *module = moduleTranslation.getLLVMModule(); + auto launchFuncOp = mlir::dyn_cast(launchFuncOperation); + assert(launchFuncOp && "Op is not a LaunchFuncOp."); + return llvm::LaunchKernel(*module, builder, moduleTranslation) + .createKernelLaunch(launchFuncOp); +} diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -626,3 +626,17 @@ gpu.module @gpu_funcs [1] { } } + +// ----- + +module { + // expected-error @+1 {{'gpu.binary' op attribute 'objects' failed to satisfy constraint: An array of GPU object attributes with at least 1 elements}} + gpu.binary @binary [] +} + +// ----- + +module { + // expected-error @+1 {{gpu.binary' op attribute 'objectManager' failed to satisfy constraint: any attribute Attribute implementing the `ObjectManagerAttrInterface` interface.}} + gpu.binary @binary <1> [#gpu.object<#gpu.nvptx, "">] +}