diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -440,6 +440,10 @@ ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes) const; + /// Computes the size of type in bytes. + Value getSizeInBytes(Location loc, Type type, + ConversionPatternRewriter &rewriter) const; + /// Computes total size in bytes of to store the given shape. Value getCumulativeSizeInBytes(Location loc, Type elementType, ArrayRef shape, diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -741,4 +741,16 @@ let printer = [{ p << getOperationName(); }]; } +def GPU_HostRegisterOp : GPU_Op<"host_register">, + Arguments<(ins AnyUnrankedMemRef:$value)> { + let summary = "Registers a memref for access from device."; + let description = [{ + This op registers the host memory pointed to by a memref to be accessed from + a device. + }]; + + let assemblyFormat = "$value attr-dict `:` type($value)"; + let verifier = [{ return success(); }]; +} + #endif // GPU_OPS diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -117,6 +117,26 @@ "mgpuStreamSynchronize", llvmVoidType, {llvmPointerType /* void *stream */}}; + FunctionCallBuilder hostRegisterCallBuilder = { + "mgpuMemHostRegisterMemRef", + llvmVoidType, + {llvmIntPtrType /* intptr_t rank */, + llvmPointerType /* void *memrefDesc */, + llvmIntPtrType /* intptr_t elementSizeBytes */}}; +}; + +/// A rewrite patter to convert gpu.host_register operations into a GPU runtime +/// call. Currently it supports CUDA and ROCm (HIP). +class ConvertHostRegisterOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + +private: + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; }; /// A rewrite patter to convert gpu.launch_func operations into a sequence of @@ -192,6 +212,33 @@ builder.getSymbolRefAttr(function), arguments); } +// Returns whether value is of LLVM type. +static bool isLLVMType(Value value) { + return value.getType().isa(); +} + +LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (!llvm::all_of(operands, isLLVMType)) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + Location loc = op->getLoc(); + + auto memRefType = cast(op).value().getType(); + auto elementType = memRefType.cast().getElementType(); + auto elementSize = getSizeInBytes(loc, elementType, rewriter); + + auto arguments = + typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter); + arguments.push_back(elementSize); + hostRegisterCallBuilder.create(loc, rewriter, arguments); + + rewriter.eraseOp(op); + return success(); +} + // Creates a struct containing all kernel parameters on the stack and returns // an array of type-erased pointers to the fields of the struct. The array can // then be passed to the CUDA / ROCm (HIP) kernel launch calls. @@ -269,11 +316,6 @@ LLVM::Linkage::Internal); } -// Returns whether value is of LLVM type. -static bool isLLVMType(Value value) { - return value.getType().isa(); -} - // Emits LLVM IR to launch a kernel function. Expects the module that contains // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. @@ -351,6 +393,7 @@ void mlir::populateGpuToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, StringRef gpuBinaryAnnotation) { + patterns.insert(converter); patterns.insert( converter, gpuBinaryAnnotation); patterns.insert(&converter.getContext()); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -927,30 +927,32 @@ : createIndexConstant(rewriter, loc, s)); } -Value ConvertToLLVMPattern::getCumulativeSizeInBytes( - Location loc, Type elementType, ArrayRef sizes, - ConversionPatternRewriter &rewriter) const { - // Compute the total number of memref elements. - Value cumulativeSizeInBytes = - sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front(); - for (unsigned i = 1, e = sizes.size(); i < e; ++i) - cumulativeSizeInBytes = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, sizes[i]}); - +Value ConvertToLLVMPattern::getSizeInBytes( + Location loc, Type type, ConversionPatternRewriter &rewriter) const { // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: // %0 = getelementptr %elementType* null, %indexType 1 // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. - auto convertedPtrType = typeConverter.convertType(elementType) - .cast() - .getPointerTo(); + auto convertedPtrType = + typeConverter.convertType(type).cast().getPointerTo(); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( loc, convertedPtrType, ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); - auto elementSize = - rewriter.create(loc, getIndexType(), gep); + return rewriter.create(loc, getIndexType(), gep); +} + +Value ConvertToLLVMPattern::getCumulativeSizeInBytes( + Location loc, Type elementType, ArrayRef sizes, + ConversionPatternRewriter &rewriter) const { + // Compute the total number of memref elements. + Value cumulativeSizeInBytes = + sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front(); + for (unsigned i = 1, e = sizes.size(); i < e; ++i) + cumulativeSizeInBytes = rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, sizes[i]}); + auto elementSize = this->getSizeInBytes(loc, elementType, rewriter); return rewriter.create( loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, elementSize}); } diff --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + gpu.host_register %cast_data : memref<*xi32> %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + gpu.host_register %cast_sum : memref<*xi32> store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,5 @@ return } -func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + gpu.host_register %cast_data : memref<*xi32> %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + gpu.host_register %cast_sum : memref<*xi32> store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,5 @@ return } -func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + gpu.host_register %cast_data : memref<*xi32> %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + gpu.host_register %cast_sum : memref<*xi32> store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,5 @@ return } -func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir @@ -11,7 +11,7 @@ %sy = dim %dst, %c1 : memref %sz = dim %dst, %c0 : memref %cast_dst = memref_cast %dst : memref to memref<*xf32> - call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> () + gpu.host_register %cast_dst : memref<*xf32> gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) { %t0 = muli %tz, %block_y : index @@ -28,5 +28,4 @@ return } -func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + gpu.host_register %cast_data : memref<*xi32> %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + gpu.host_register %cast_sum : memref<*xi32> store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,5 @@ return } -func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir @@ -8,7 +8,7 @@ %c0 = constant 0 : index %sx = dim %dst, %c0 : memref %cast_dst = memref_cast %dst : memref to memref<*xf32> - call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> () + gpu.host_register %cast_dst : memref<*xf32> gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %val = index_cast %tx : index to i32 @@ -25,5 +25,4 @@ return } -func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + gpu.host_register %cast_data : memref<*xi32> %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + gpu.host_register %cast_sum : memref<*xi32> store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,5 @@ return } -func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir --- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir +++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir @@ -18,7 +18,7 @@ %21 = constant 5 : i32 %22 = memref_cast %arg0 : memref<5xf32> to memref %23 = memref_cast %22 : memref to memref<*xf32> - call @mgpuMemHostRegisterFloat(%23) : (memref<*xf32>) -> () + gpu.host_register %23 : memref<*xf32> call @print_memref_f32(%23) : (memref<*xf32>) -> () %24 = constant 1.0 : f32 call @other_func(%24, %22) : (f32, memref) -> () @@ -26,5 +26,4 @@ return } -func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir --- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir +++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir @@ -26,11 +26,11 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32> - call @mgpuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> () + gpu.host_register %cast_data : memref<*xf32> %cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32> - call @mgpuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> () + gpu.host_register %cast_sum : memref<*xf32> %cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32> - call @mgpuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> () + gpu.host_register %cast_mul : memref<*xf32> store %cst0, %data[%c0, %c0] : memref<2x6xf32> store %cst1, %data[%c0, %c1] : memref<2x6xf32> @@ -66,5 +66,4 @@ return } -func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir --- a/mlir/test/mlir-cuda-runner/shuffle.mlir +++ b/mlir/test/mlir-cuda-runner/shuffle.mlir @@ -7,8 +7,8 @@ %one = constant 1 : index %c0 = constant 0 : index %sx = dim %dst, %c0 : memref - %cast_dest = memref_cast %dst : memref to memref<*xf32> - call @mgpuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> () + %cast_dst = memref_cast %dst : memref to memref<*xf32> + gpu.host_register %cast_dst : memref<*xf32> gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %t0 = index_cast %tx : index to i32 @@ -24,9 +24,8 @@ store %value, %dst[%tx] : memref gpu.terminator } - call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> () + call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> () return } -func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/two-modules.mlir b/mlir/test/mlir-cuda-runner/two-modules.mlir --- a/mlir/test/mlir-cuda-runner/two-modules.mlir +++ b/mlir/test/mlir-cuda-runner/two-modules.mlir @@ -8,7 +8,7 @@ %c0 = constant 0 : index %sx = dim %dst, %c0 : memref %cast_dst = memref_cast %dst : memref to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> () + gpu.host_register %cast_dst : memref<*xi32> gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %t0 = index_cast %tx : index to i32 @@ -25,5 +25,4 @@ return } -func @mgpuMemHostRegisterInt32(%memref : memref<*xi32>) func @print_memref_i32(%memref : memref<*xi32>) diff --git a/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir b/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir --- a/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir +++ b/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir @@ -18,7 +18,7 @@ %21 = constant 5 : i32 %22 = memref_cast %arg0 : memref<5xf32> to memref %cast = memref_cast %22 : memref to memref<*xf32> - call @mgpuMemHostRegisterFloat(%cast) : (memref<*xf32>) -> () + gpu.host_register %cast : memref<*xf32> %23 = memref_cast %22 : memref to memref<*xf32> call @print_memref_f32(%23) : (memref<*xf32>) -> () %24 = constant 1.0 : f32 @@ -28,6 +28,5 @@ return } -func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref) -> (memref) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-rocm-runner/two-modules.mlir b/mlir/test/mlir-rocm-runner/two-modules.mlir --- a/mlir/test/mlir-rocm-runner/two-modules.mlir +++ b/mlir/test/mlir-rocm-runner/two-modules.mlir @@ -8,7 +8,7 @@ %c1 = constant 1 : index %sx = dim %dst, %c0 : memref %cast_dst = memref_cast %dst : memref to memref<*xi32> - call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> () + gpu.host_register %cast_dst : memref<*xi32> %dst_device = call @mgpuMemGetDeviceMemRef1dInt32(%dst) : (memref) -> (memref) gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %c1, %block_z = %c1) { @@ -26,6 +26,5 @@ return } -func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @mgpuMemGetDeviceMemRef1dInt32(%ptr : memref) -> (memref) func @print_memref_i32(%ptr : memref<*xi32>) diff --git a/mlir/test/mlir-rocm-runner/vecadd.mlir b/mlir/test/mlir-rocm-runner/vecadd.mlir --- a/mlir/test/mlir-rocm-runner/vecadd.mlir +++ b/mlir/test/mlir-rocm-runner/vecadd.mlir @@ -26,9 +26,9 @@ %6 = memref_cast %3 : memref to memref<*xf32> %7 = memref_cast %4 : memref to memref<*xf32> %8 = memref_cast %5 : memref to memref<*xf32> - call @mgpuMemHostRegisterFloat(%6) : (memref<*xf32>) -> () - call @mgpuMemHostRegisterFloat(%7) : (memref<*xf32>) -> () - call @mgpuMemHostRegisterFloat(%8) : (memref<*xf32>) -> () + gpu.host_register %6 : memref<*xf32> + gpu.host_register %7 : memref<*xf32> + gpu.host_register %8 : memref<*xf32> %9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref) -> (memref) %10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref) -> (memref) %11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref) -> (memref) @@ -38,6 +38,5 @@ return } -func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref) -> (memref) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-rocm-runner/vector-transferops.mlir b/mlir/test/mlir-rocm-runner/vector-transferops.mlir --- a/mlir/test/mlir-rocm-runner/vector-transferops.mlir +++ b/mlir/test/mlir-rocm-runner/vector-transferops.mlir @@ -55,8 +55,8 @@ %cast0 = memref_cast %22 : memref to memref<*xf32> %cast1 = memref_cast %23 : memref to memref<*xf32> - call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> () - call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> () + gpu.host_register %cast0 : memref<*xf32> + gpu.host_register %cast1 : memref<*xf32> %24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref) -> (memref) %26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref) -> (memref) @@ -71,6 +71,5 @@ return } -func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref) -> (memref) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -75,17 +75,19 @@ CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); } -// Allows to register a MemRef with the CUDA runtime. Initializes array with -// value. Helpful until we have transfer functions implemented. -template -void mgpuMemHostRegisterMemRef(const DynamicMemRefType &memRef, T value) { - llvm::SmallVector denseStrides(memRef.rank); - llvm::ArrayRef sizes(memRef.sizes, memRef.rank); - llvm::ArrayRef strides(memRef.strides, memRef.rank); +// Allows to register a MemRef with the CUDA runtime. Helpful until we have +// transfer functions implemented. +extern "C" void +mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType *descriptor, + int64_t elementSizeBytes) { + + llvm::SmallVector denseStrides(rank); + llvm::ArrayRef sizes(descriptor->sizes, rank); + llvm::ArrayRef strides(sizes.end(), rank); std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), std::multiplies()); - auto count = denseStrides.front(); + auto sizeBytes = denseStrides.front() * elementSizeBytes; // Only densely packed tensors are currently supported. std::rotate(denseStrides.begin(), denseStrides.begin() + 1, @@ -93,17 +95,6 @@ denseStrides.back() = 1; assert(strides == llvm::makeArrayRef(denseStrides)); - auto *pointer = memRef.data + memRef.offset; - std::fill_n(pointer, count, value); - mgpuMemHostRegister(pointer, count * sizeof(T)); -} - -extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) { - UnrankedMemRefType memRef = {rank, ptr}; - mgpuMemHostRegisterMemRef(DynamicMemRefType(memRef), 1.23f); -} - -extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) { - UnrankedMemRefType memRef = {rank, ptr}; - mgpuMemHostRegisterMemRef(DynamicMemRefType(memRef), 123); + auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; + mgpuMemHostRegister(ptr, sizeBytes); } diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -76,17 +76,19 @@ HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0)); } -// Allows to register a MemRef with the ROCM runtime. Initializes array with -// value. Helpful until we have transfer functions implemented. -template -void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef sizes, - llvm::ArrayRef strides, T value) { - assert(sizes.size() == strides.size()); - llvm::SmallVector denseStrides(strides.size()); +// Allows to register a MemRef with the ROCm runtime. Helpful until we have +// transfer functions implemented. +extern "C" void +mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType *descriptor, + int64_t elementSizeBytes) { + + llvm::SmallVector denseStrides(rank); + llvm::ArrayRef sizes(descriptor->sizes, rank); + llvm::ArrayRef strides(sizes.end(), rank); std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), std::multiplies()); - auto count = denseStrides.front(); + auto sizeBytes = denseStrides.front() * elementSizeBytes; // Only densely packed tensors are currently supported. std::rotate(denseStrides.begin(), denseStrides.begin() + 1, @@ -94,22 +96,8 @@ denseStrides.back() = 1; assert(strides == llvm::makeArrayRef(denseStrides)); - std::fill_n(pointer, count, value); - mgpuMemHostRegister(pointer, count * sizeof(T)); -} - -extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) { - auto *desc = static_cast *>(ptr); - auto sizes = llvm::ArrayRef(desc->sizes, rank); - auto strides = llvm::ArrayRef(desc->sizes + rank, rank); - mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f); -} - -extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) { - auto *desc = static_cast *>(ptr); - auto sizes = llvm::ArrayRef(desc->sizes, rank); - auto strides = llvm::ArrayRef(desc->sizes + rank, rank); - mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123); + auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; + mgpuMemHostRegister(ptr, sizeBytes); } template