diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -903,6 +903,40 @@ let hasFolder = 1; } +def GPU_MemsetOp : GPU_Op<"memset", [GPU_AsyncOpInterface]> { + + let summary = "GPU memset operation"; + + let description = [{ + The `gpu.memset` operation sets the content of memref to a scalar value. + + The op does not execute before all async dependencies have finished + executing. + + If the `async` keyword is present, the op is executed asynchronously (i.e. + it does not block until the execution has finished on the device). In + that case, it returns a !gpu.async.token. + + Example: + + ```mlir + %token = gpu.memset async [%dep] %dst, %value : memref, memref + ``` + }]; + + let arguments = (ins Variadic:$asyncDependencies, + Arg:$dst, + Arg:$value); + let results = (outs Optional:$asyncToken); + + let assemblyFormat = [{ + custom(type($asyncToken), $asyncDependencies) + $dst`,` $value `:` type($dst)`,` type($value) attr-dict + }]; + let verifier = [{ return ::verify(*this); }]; + let hasFolder = 1; +} + def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix", [MemoryEffects<[MemRead]>]>{ @@ -1040,7 +1074,7 @@ the same value. This op is meant to be used along with `gpu.subgroup_mma_compute`. - + Example: ```mlir diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -75,8 +75,8 @@ template class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern { public: - explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToLLVMPattern(typeConverter) {} + explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter) {} protected: MLIRContext *context = &this->getTypeConverter()->getContext(); @@ -165,6 +165,12 @@ {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, llvmIntPtrType /* intptr_t sizeBytes */, llvmPointerType /* void *stream */}}; + FunctionCallBuilder memsetCallBuilder = { + "mgpuMemset", + llvmVoidType, + {llvmPointerType /* void *dst */, llvmIntPtrType /* void *value */, + llvmIntPtrType /* intptr_t sizeBytes */, + llvmPointerType /* void *stream */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -172,8 +178,8 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToGpuRuntimeCallPattern(converter) {} private: LogicalResult @@ -186,8 +192,8 @@ class ConvertAllocOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToGpuRuntimeCallPattern(converter) {} private: LogicalResult @@ -200,8 +206,8 @@ class ConvertDeallocOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToGpuRuntimeCallPattern(converter) {} private: LogicalResult @@ -212,8 +218,8 @@ class ConvertAsyncYieldToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToGpuRuntimeCallPattern(converter) {} private: LogicalResult @@ -226,8 +232,8 @@ class ConvertWaitOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToGpuRuntimeCallPattern(converter) {} private: LogicalResult @@ -240,8 +246,8 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToGpuRuntimeCallPattern(converter) {} private: LogicalResult @@ -265,9 +271,9 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter, + ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter, StringRef gpuBinaryAnnotation) - : ConvertOpToGpuRuntimeCallPattern(typeConverter), + : ConvertOpToGpuRuntimeCallPattern(converter), gpuBinaryAnnotation(gpuBinaryAnnotation) {} private: @@ -300,14 +306,28 @@ class ConvertMemcpyOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToGpuRuntimeCallPattern(converter) {} private: LogicalResult matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; + +/// A rewrite pattern to convert gpu.memset operations into a GPU runtime +/// call. Currently it supports CUDA and ROCm (HIP). +class ConvertMemsetOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &converter) + : ConvertOpToGpuRuntimeCallPattern(converter) {} + +private: + LogicalResult + matchAndRewrite(gpu::MemsetOp memsetOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; } // namespace void GpuToLLVMConversionPass::runOnOperation() { @@ -332,6 +352,7 @@ ConvertDeallocOpToGpuRuntimeCallPattern, ConvertHostRegisterOpToGpuRuntimeCallPattern, ConvertMemcpyOpToGpuRuntimeCallPattern, + ConvertMemsetOpToGpuRuntimeCallPattern, ConvertWaitAsyncOpToGpuRuntimeCallPattern, ConvertWaitOpToGpuRuntimeCallPattern, ConvertAsyncYieldToGpuRuntimeCallPattern>(converter); @@ -800,6 +821,53 @@ return success(); } +LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::MemsetOp memsetOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto memRefType = memsetOp.dst().getType().cast(); + + if (failed(areAllLLVMTypes(memsetOp, operands, rewriter)) || + !isConvertibleAndHasIdentityMaps(memRefType) || + failed(isAsyncWithOneDependency(rewriter, memsetOp))) + return failure(); + + auto loc = memsetOp.getLoc(); + auto adaptor = gpu::MemsetOpAdaptor(operands, memsetOp->getAttrDictionary()); + + MemRefDescriptor dstDesc(adaptor.dst()); + + Value numElements = + memRefType.hasStaticShape() + ? createIndexConstant(rewriter, loc, memRefType.getNumElements()) + // For identity layouts (verified above), the number of elements is + // stride[0] * size[0]. + : rewriter.create(loc, dstDesc.stride(rewriter, loc, 0), + dstDesc.size(rewriter, loc, 0)); + + Type elementPtrType = getElementPtrType(memRefType); + Value nullPtr = rewriter.create(loc, elementPtrType); + Value sizeBytesGepPtr = rewriter.create( + loc, elementPtrType, ArrayRef{nullPtr, numElements}); + Value valueGepPtr = rewriter.create( + loc, elementPtrType, + ArrayRef{ + MemRefDescriptor(adaptor.value()).alignedPtr(rewriter, loc), + createIndexConstant(rewriter, loc, 0)}); + + auto sizeBytes = + rewriter.create(loc, getIndexType(), sizeBytesGepPtr); + auto value = + rewriter.create(loc, getIndexType(), valueGepPtr); + auto dst = rewriter.create( + loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc)); + + auto stream = adaptor.asyncDependencies().front(); + memsetCallBuilder.create(loc, rewriter, {dst, value, sizeBytes, stream}); + + rewriter.replaceOp(memsetOp, {stream}); + return success(); +} + std::unique_ptr> mlir::createGpuToLLVMConversionPass() { return std::make_unique(); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -955,6 +955,20 @@ return success(); } +static LogicalResult verify(MemsetOp op) { + auto dstType = op.dst().getType(); + auto valueType = op.value().getType(); + + if (getElementTypeOrSelf(valueType) != getElementTypeOrSelf(dstType)) + return op.emitOpError("arguments have incompatible element type"); + + ArrayRef value_shape = valueType.dyn_cast().getShape(); + if (value_shape.size() != 1 || value_shape.front() != 1) + return op.emitOpError("value argument must be a scalar"); + + return success(); +} + static ParseResult parseAsyncDependencies( OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl &asyncDependencies) { @@ -1089,6 +1103,11 @@ return foldMemRefCast(*this); } +LogicalResult MemsetOp::fold(ArrayRef operands, + SmallVectorImpl<::mlir::OpFoldResult> &results) { + return foldMemRefCast(*this); +} + #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/ArrayRef.h" #include "cuda.h" +#include "cuda_runtime_api.h" #ifdef _WIN32 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) @@ -37,6 +38,16 @@ fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ }(expr) +#define CUDA_RUNTIME_REPORT_IF_ERROR(expr) \ + [](cudaError_t result) { \ + if (result == cudaSuccess) \ + return; \ + const char *name = cudaGetErrorName(result); \ + if (!name) \ + name = ""; \ + fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ + }(expr) + // Make the primary context of device 0 current for the duration of the instance // and restore the previous context on destruction. class ScopedContext { @@ -150,6 +161,11 @@ sizeBytes, stream)); } +extern "C" void mgpuMemset(void *dst, int value, uint64_t sizeBytes, + CUstream stream) { + CUDA_RUNTIME_REPORT_IF_ERROR(cudaMemsetAsync(dst, value, sizeBytes, stream)); +} + /// Helper functions for writing mlir example code // Allows to register byte array with the CUDA runtime. Helpful until we have diff --git a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp @@ -139,6 +139,10 @@ hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream)); } +extern "C" void mgpuMemset(void *dst, int value, uint64_t sizeBytes, + uint64_t width, hipStream_t stream) { + HIP_REPORT_IF_ERROR(hipMemsetAsync(dst, value, sizeBytes, stream)); +} /// Helper functions for writing mlir example code // Allows to register byte array with the ROCM runtime. Helpful until we have diff --git a/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s + +module attributes {gpu.container_module} { + + // CHECK: func @foo + func @foo(%dst : memref<7xf32, 1>, %value : memref<1xf32>) { + // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate + %t0 = gpu.wait async + // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint + // CHECK: %[[value:.*]] = llvm.ptrtoint + // CHECK: %[[dst:.*]] = llvm.bitcast + // CHECK: llvm.call @mgpuMemset(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]]) + %t1 = gpu.memset async [%t0] %dst, %value : memref<7xf32, 1>, memref<1xf32> + // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]]) + // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]]) + gpu.wait [%t1] + return + } +} diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir --- a/mlir/test/Dialect/GPU/canonicalize.mlir +++ b/mlir/test/Dialect/GPU/canonicalize.mlir @@ -9,3 +9,13 @@ gpu.memcpy %0,%1 : memref, memref return } + +// CHECK-LABEL: @memset_after_cast +func @memset_after_cast(%arg0: memref<10xf32>, %arg1: memref<1xf32>) { + // CHECK-NOT: memref.cast + // CHECK: gpu.memset + %0 = memref.cast %arg0 : memref<10xf32> to memref + %1 = memref.cast %arg1 : memref<1xf32> to memref<1xf32> + gpu.memset %0,%1 : memref, memref<1xf32> + return +} diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -467,6 +467,20 @@ // ----- +func @memset_incompatible_type(%dst : memref, %value : memref<1xi32>) { + // expected-error @+1 {{'gpu.memset' op arguments have incompatible element type}} + gpu.memset %dst, %value : memref, memref<1xi32> +} + +// ----- + +func @memset_incompatible_shape(%dst : memref, %value : memref<9xf32>) { + // expected-error @+1 {{'gpu.memset' op value argument must be a scalar}} + gpu.memset %dst, %value : memref, memref<9xf32> +} + +// ----- + func @mmamatrix_invalid_shape(){ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> %i = constant 16 : index diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -195,6 +195,17 @@ return } + func @memset(%dst : memref<3x7xf32>, %value : memref<1xf32, 1>) { + // CHECK-LABEL: func @memset + // CHECK: gpu.memset {{.*}}, {{.*}} : memref<3x7xf32>, memref<1xf32, 1> + gpu.memset %dst, %value : memref<3x7xf32>, memref<1xf32, 1> + // CHECK: %[[t0:.*]] = gpu.wait async + %0 = gpu.wait async + // CHECK: {{.*}} = gpu.memset async [%[[t0]]] {{.*}}, {{.*}} : memref<3x7xf32>, memref<1xf32, 1> + %1 = gpu.memset async [%0] %dst, %value : memref<3x7xf32>, memref<1xf32, 1> + return + } + func @mmamatrix_valid_element_type(){ // CHECK-LABEL: func @mmamatrix_valid_element_type %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>