diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -273,7 +273,7 @@ /// Returns the type of this function. /// FIXME: We should drive this via the ODS `type` param. - FunctionType getType() { + FunctionType getType() { return getTypeAttr().getValue().cast(); } @@ -1006,6 +1006,18 @@ let hasFolder = 1; } +def GPU_SetDefaultDeviceOp : GPU_Op<"set_default_device", + [MemoryEffects<[MemWrite]>]>, + Arguments<(ins I32:$devIndex)> { + let summary = "Set default GPU for operations after this by index"; + let description = [{ + Operation that sets the current default GPU, using a zero-based index + into the set of GPUs on the system. The default GPU setting may be + thread-local. + }]; + let assemblyFormat = "attr-dict $devIndex"; +} + def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix", [MemoryEffects<[MemRead]>]>{ diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -185,6 +185,10 @@ {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */, llvmIntPtrType /* intptr_t sizeBytes */, llvmPointerType /* void *stream */}}; + FunctionCallBuilder setDefaultDeviceCallBuilder = { + "mgpuSetDefaultDevice", + llvmVoidType, + {llvmInt32Type /* uint32_t devIndex */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -342,6 +346,21 @@ matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; + +/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call. +/// Currently supports CUDA and ROCm (HIP) +class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern( + LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern( + typeConverter) {} + + LogicalResult + matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; } // namespace void GpuToLLVMConversionPass::runOnOperation() { @@ -844,6 +863,15 @@ return success(); } +LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.devIndex()}); + rewriter.replaceOp(op, {}); + return success(); +} + std::unique_ptr> mlir::createGpuToLLVMConversionPass() { return std::make_unique(); @@ -861,6 +889,7 @@ ConvertHostRegisterOpToGpuRuntimeCallPattern, ConvertMemcpyOpToGpuRuntimeCallPattern, ConvertMemsetOpToGpuRuntimeCallPattern, + ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern, ConvertWaitAsyncOpToGpuRuntimeCallPattern, ConvertWaitOpToGpuRuntimeCallPattern, ConvertAsyncYieldToGpuRuntimeCallPattern>(converter); diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -35,16 +35,20 @@ fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ }(expr) -// Make the primary context of device 0 current for the duration of the instance -// and restore the previous context on destruction. +thread_local static int32_t defaultDevice = 0; + +// Make the primary context of the current default device current for the +// duration +// of the instance and restore the previous context on destruction. class ScopedContext { public: ScopedContext() { - // Static reference to CUDA primary context for device ordinal 0. + // Static reference to CUDA primary context for device ordinal + // defaultDevice. static CUcontext context = [] { CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); CUdevice device; - CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); CUcontext ctx; // Note: this does not affect the current context. CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device)); @@ -187,3 +191,8 @@ auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; mgpuMemHostRegister(ptr, sizeBytes); } + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) { + defaultDevice = device; + CUDA_REPORT_IF_ERROR(cudaSetDevice(device)); +} diff --git a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp @@ -30,16 +30,18 @@ fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ }(expr) +thread_local static int32_t defaultDevice = 0; + // Sets the `Context` for the duration of the instance and restores the previous // context on destruction. class ScopedContext { public: ScopedContext() { - // Static reference to HIP primary context for device ordinal 0. + // Static reference to HIP primary context for device ordinal defaultDevice. static hipCtx_t context = [] { HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0)); hipDevice_t device; - HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0)); + HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/defaultDevice)); hipCtx_t ctx; HIP_REPORT_IF_ERROR(hipDevicePrimaryCtxRetain(&ctx, device)); return ctx; @@ -199,3 +201,8 @@ mgpuMemGetDevicePointer(aligned, &devicePtr); return {devicePtr, devicePtr, offset, {size}, {stride}}; } + +extern "C" void mgpuSetDefaultDevice(int32_t device) { + defaultDevice = device; + HIP_REPORT_IF_ERROR(hipSetDevice(device)); +} diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -252,4 +252,11 @@ gpu.device_async_wait %token {numGroups = 1 : i32} return } + + // CHECK-LABEL: func @set_default_device + func @set_default_device(%arg0: i32) { + // CHECK: gpu.set_default_device + gpu.set_default_device %arg0 + return + } }