diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -97,6 +97,7 @@ Type llvmPointerPointerType = this->getTypeConverter()->getPointerType(llvmPointerType); Type llvmInt8Type = IntegerType::get(context, 8); + Type llvmInt16Type = IntegerType::get(context, 16); Type llvmInt32Type = IntegerType::get(context, 32); Type llvmInt64Type = IntegerType::get(context, 64); Type llvmIntPtrType = IntegerType::get( @@ -182,7 +183,14 @@ {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, llvmIntPtrType /* intptr_t sizeBytes */, llvmPointerType /* void *stream */}}; - FunctionCallBuilder memsetCallBuilder = { + FunctionCallBuilder memset16CallBuilder = { + "mgpuMemset16", + llvmVoidType, + {llvmPointerType /* void *dst */, + llvmInt16Type /* unsigned short value */, + llvmIntPtrType /* intptr_t sizeBytes */, + llvmPointerType /* void *stream */}}; + FunctionCallBuilder memset32CallBuilder = { "mgpuMemset32", llvmVoidType, {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */, @@ -1197,22 +1205,32 @@ auto loc = memsetOp.getLoc(); Type valueType = adaptor.getValue().getType(); - if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) { - return rewriter.notifyMatchFailure(memsetOp, - "value must be a 32 bit scalar"); + // Ints and floats of 16 or 32 bit width are allowed. + if (!valueType.isIntOrFloat() || (valueType.getIntOrFloatBitWidth() != 16 && + valueType.getIntOrFloatBitWidth() != 32)) { + return rewriter.notifyMatchFailure( + memsetOp, "value must be a 16 or 32 bit int or float"); } + unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth(); + Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type; + MemRefDescriptor dstDesc(adaptor.getDst()); Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); auto value = - rewriter.create(loc, llvmInt32Type, adaptor.getValue()); + rewriter.create(loc, bitCastType, adaptor.getValue()); auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, dstDesc.alignedPtr(rewriter, loc), *getTypeConverter()); auto stream = adaptor.getAsyncDependencies().front(); - memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream}); + if (valueTypeWidth == 32) + memset32CallBuilder.create(loc, rewriter, + {dst, value, numElements, stream}); + else + memset16CallBuilder.create(loc, rewriter, + {dst, value, numElements, stream}); rewriter.replaceOp(memsetOp, {stream}); return success(); diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -168,6 +168,12 @@ value, count, stream)); } +extern "C" void mgpuMemset16(void *dst, unsigned short value, size_t count, + CUstream stream) { + CUDA_REPORT_IF_ERROR(cuMemsetD16Async(reinterpret_cast(dst), + value, count, stream)); +} + /// /// Helper functions for writing mlir example code /// diff --git a/mlir/test/Conversion/GPUCommon/typed-pointers.mlir b/mlir/test/Conversion/GPUCommon/typed-pointers.mlir --- a/mlir/test/Conversion/GPUCommon/typed-pointers.mlir +++ b/mlir/test/Conversion/GPUCommon/typed-pointers.mlir @@ -42,8 +42,8 @@ module attributes {gpu.container_module} { - // CHECK: func @foo - func.func @foo(%dst : memref<7xf32, 1>, %value : f32) { + // CHECK: func @memset_f32 + func.func @memset_f32(%dst : memref<7xf32, 1>, %value : f32) { // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate %t0 = gpu.wait async // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant @@ -59,3 +59,23 @@ } } +// ----- + +module attributes {gpu.container_module} { + + // CHECK: func @memset_f16 + func.func @memset_f16(%dst : memref<7xf16, 1>, %value : f16) { + // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate + %t0 = gpu.wait async + // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant + // CHECK: %[[value:.*]] = llvm.bitcast + // CHECK: %[[addr_cast:.*]] = llvm.addrspacecast + // CHECK: %[[dst:.*]] = llvm.bitcast %[[addr_cast]] + // CHECK: llvm.call @mgpuMemset16(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]]) + %t1 = gpu.memset async [%t0] %dst, %value : memref<7xf16, 1>, f16 + // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]]) + // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]]) + gpu.wait [%t1] + return + } +}