diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1670,11 +1670,12 @@ Example: ```mlir - %dmat, %token = gpu.create_dn_mat async [%dep] %mem, %size : memref + %dmat, %token = gpu.create_dn_mat async [%dep] %handle, %rows, %cols, %mem : memref ``` }]; let arguments = (ins Variadic:$asyncDependencies, + GPU_SparseEnvHandle:$env, Index:$rows, Index:$cols, AnyMemRef:$memref); @@ -1682,7 +1683,7 @@ let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) - $rows `,` $cols `,` $memref attr-dict `:` type($memref) + $env `,` $rows `,` $cols `,` $memref attr-dict `:` type($memref) }]; } @@ -1789,6 +1790,41 @@ }]; } + +def GPU_Create2To4SpMatOp : GPU_Op<"create_2to4_spmat", [GPU_AsyncOpInterface]> { + let summary = "Create sparse matrix with 2:4 sparsity operation"; + let description = [{ + The `gpu.create_2to4_spmat` operation initializes a sparse matrix in dense + format with 2:4 sparsity. + The buffers must already be copied from the host to the device prior to + using this operation. The operation returns a handle to the sparse + matrix descriptor. + + If the `async` keyword is present, the op is executed asynchronously (i.e. + it does not block until the execution has finished on the device). In + that case, it returns a !gpu.async.token in addition to the environment. + + Example: + + ```mlir + %spmat, %token = gpu.create_2to4_spmat async [%dep] %env, %rows, %cols, %mem : memref + ``` + }]; + + let arguments = (ins Variadic:$asyncDependencies, + GPU_SparseEnvHandle:$env, + Index:$rows, + Index:$cols, + AnyMemRef:$memref); + let results = (outs Res:$spMat, + Optional:$asyncToken); + + let assemblyFormat = [{ + custom(type($asyncToken), $asyncDependencies) + $env `,` $rows `,` $cols `,` $memref attr-dict `:` type($memref) + }]; +} + def GPU_DestroySpMatOp : GPU_Op<"destroy_sp_mat", [GPU_AsyncOpInterface]> { let summary = "Destroy sparse matrix operation"; let description = [{ @@ -1944,7 +1980,7 @@ }]; } -def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> { +def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [AttrSizedResultSegments]> { let summary = "Precompute buffersize for SpMM operation"; let description = [{ The `gpu.spmm_buffer_size` operation returns the buffer size required @@ -1963,7 +1999,7 @@ Example: ```mlir - %buffersz, %token = gpu.spmm_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC into f32 + %bufferszs, %token = gpu.spmm_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC : i64 into f32 ``` }]; @@ -1975,11 +2011,11 @@ GPU_SparseDnMatHandle:$dnmatB, GPU_SparseDnMatHandle:$dnmatC, OptionalAttr:$computeType); - let results = (outs Res:$bufferSz, + let results = (outs Variadic:$bufferSzs, Optional:$asyncToken); let builders = [OpBuilder<(ins - "Type":$bufferSz, + "ValueRange":$bufferSzs, "Type":$asyncToken, "ValueRange":$asyncDependencies, "Value":$env, @@ -1994,11 +2030,11 @@ let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) - $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict ( `into` $computeType^)? + $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict `:` type($bufferSzs) ( `into` $computeType^)? }]; } -def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> { +def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface, AttrSizedOperandSegments]> { let summary = "SpMM operation"; let description = [{ The `gpu.spmm` operation performs the SpMM operation on the given sparse and @@ -2029,7 +2065,7 @@ GPU_SparseDnMatHandle:$dnmatB, GPU_SparseDnMatHandle:$dnmatC, OptionalAttr:$computeType, - AnyMemRef:$buffer); + Variadic:$buffers); let results = (outs Optional:$asyncToken); let builders = [OpBuilder<(ins @@ -2039,16 +2075,16 @@ "Value":$spmatA, "Value":$dnmatB, "Value":$dnmatC, - "Value":$buffer), [{ + "ValueRange":$buffers), [{ auto modeA = gpu::TransposeMode::NON_TRANSPOSE; auto modeB = gpu::TransposeMode::NON_TRANSPOSE; return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, - modeB, spmatA, dnmatB, dnmatC, {}, buffer);}]> + modeB, spmatA, dnmatB, dnmatC, {}, buffers);}]> ]; let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) - $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)? + $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffers attr-dict `:` type($buffers) ( `into` $computeType^)? }]; } diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -230,6 +230,42 @@ {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, llvmInt32Type, llvmPointerType /* void *stream */}}; + FunctionCallBuilder createSparseLtEnvCallBuilder = { + "mgpuCreateSparseLtEnv", + llvmPointerType, + {llvmPointerType /* void *stream */}}; + FunctionCallBuilder destroySparseLtEnvCallBuilder = { + "mgpuDestroySparseLtEnv", + llvmVoidType, + {llvmPointerType, llvmPointerType /* void *stream */}}; + FunctionCallBuilder createLtDnMatCallBuilder = { + "mgpuCreateLtDnMat", + llvmPointerType, + {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, + llvmInt32Type, llvmPointerType /* void *stream */}}; + FunctionCallBuilder destroyCuSparseLtSpMatBuilder = { + "mgpuDestroyCuSparseLtSpMat", + llvmVoidType, + {llvmPointerType, llvmPointerType /* void *stream */}}; + FunctionCallBuilder destroyCuSparseLtDnMatBuilder = { + "mgpuDestroyCuSparseLtDnMat", + llvmVoidType, + {llvmPointerType, llvmPointerType /* void *stream */}}; + FunctionCallBuilder create2To4SpMatCallBuilder = { + "mgpuCusparseLtCreate2To4SpMat", + llvmPointerType, + {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, + llvmInt32Type, llvmPointerType /* void *stream */}}; + FunctionCallBuilder cuSparseLtSpmmBufferSizeBuilder = { + "mgpuCuSparseLtSpmmBufferSize", + llvmPointerType, + {llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}}; + FunctionCallBuilder cuSparseLtSpmmBuilder = { + "mgpuCuSparseLtSpmm", + llvmVoidType, + {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, + llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, + llvmPointerType /*void *stream*/}}; FunctionCallBuilder destroySpMatCallBuilder = { "mgpuDestroySpMat", llvmVoidType, @@ -559,6 +595,20 @@ ConversionPatternRewriter &rewriter) const override; }; +class ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern( + LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern( + typeConverter) {} + +private: + LogicalResult + matchAndRewrite(gpu::Create2To4SpMatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + class ConvertDestroySpMatOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: @@ -733,955 +783,1098 @@ return 10; // CUDA_R_32I llvm_unreachable("unsupported element type"); -} - -// Returns whether all operands are of LLVM type. -static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, - ConversionPatternRewriter &rewriter) { - if (!llvm::all_of(operands, [](Value value) { - return LLVM::isCompatibleType(value.getType()); - })) - return rewriter.notifyMatchFailure( - op, "Cannot convert if operands aren't of LLVM type."); - return success(); -} + // TODO: We may want a run-time (of the mlir compiler) disablement/warning: + // cusparseLt currently won't work for cuda architecture <8.0 and will trigger + // a runtime (of the CUDA program) error , but it might be great if we could + // at least output a warning when we found the target architecture is <8.0 and + // the user still wants to use cusparseLt. to make sure when lowering gpu + // sparse dialect to llvm calls, the cusparselt calls are disabled for cuda + // architecture <8.0 + static bool is2To4Sparsity(Value spMat) { + if (auto op = spMat.getDefiningOp()) + return true; + if (auto op = spMat.getDefiningOp()) + return false; + if (auto op = spMat.getDefiningOp()) + return false; + // print the spMat defining op + spMat.getDefiningOp()->print(llvm::errs()); + llvm_unreachable("cannot find spmat def"); + } -static LogicalResult -isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, - gpu::AsyncOpInterface op) { - if (op.getAsyncDependencies().size() != 1) - return rewriter.notifyMatchFailure( - op, "Can only convert with exactly one async dependency."); + static std::string inferSpMMType(Value op) { + for (Operation *user : op.getUsers()) { + auto spmmOp = dyn_cast(user); + // if the other operator is 50% sparsity then we should use cusparseLt + if (!spmmOp) + continue; + if (is2To4Sparsity(spmmOp.getSpmatA())) + return "cusparseLt"; + } + return "cusparse"; + } - if (!op.getAsyncToken()) - return rewriter.notifyMatchFailure(op, "Can convert only async version."); + // Returns whether all operands are of LLVM type. + static LogicalResult areAllLLVMTypes(Operation * op, ValueRange operands, + ConversionPatternRewriter & rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + return success(); + } - return success(); -} + static LogicalResult isAsyncWithOneDependency( + ConversionPatternRewriter & rewriter, gpu::AsyncOpInterface op) { + if (op.getAsyncDependencies().size() != 1) + return rewriter.notifyMatchFailure( + op, "Can only convert with exactly one async dependency."); -LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto *op = hostRegisterOp.getOperation(); - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) - return failure(); + if (!op.getAsyncToken()) + return rewriter.notifyMatchFailure(op, "Can convert only async version."); - Location loc = op->getLoc(); + return success(); + } - auto memRefType = hostRegisterOp.getValue().getType(); - auto elementType = cast(memRefType).getElementType(); - auto elementSize = getSizeInBytes(loc, elementType, rewriter); + LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + auto *op = hostRegisterOp.getOperation(); + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) + return failure(); - auto arguments = getTypeConverter()->promoteOperands( - loc, op->getOperands(), adaptor.getOperands(), rewriter); - arguments.push_back(elementSize); - hostRegisterCallBuilder.create(loc, rewriter, arguments); + Location loc = op->getLoc(); - rewriter.eraseOp(op); - return success(); -} + auto memRefType = hostRegisterOp.getValue().getType(); + auto elementType = cast(memRefType).getElementType(); + auto elementSize = getSizeInBytes(loc, elementType, rewriter); -LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Operation *op = hostUnregisterOp.getOperation(); - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) - return failure(); + auto arguments = getTypeConverter()->promoteOperands( + loc, op->getOperands(), adaptor.getOperands(), rewriter); + arguments.push_back(elementSize); + hostRegisterCallBuilder.create(loc, rewriter, arguments); - Location loc = op->getLoc(); + rewriter.eraseOp(op); + return success(); + } - auto memRefType = hostUnregisterOp.getValue().getType(); - auto elementType = cast(memRefType).getElementType(); - auto elementSize = getSizeInBytes(loc, elementType, rewriter); + LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + Operation *op = hostUnregisterOp.getOperation(); + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) + return failure(); - auto arguments = getTypeConverter()->promoteOperands( - loc, op->getOperands(), adaptor.getOperands(), rewriter); - arguments.push_back(elementSize); - hostUnregisterCallBuilder.create(loc, rewriter, arguments); + Location loc = op->getLoc(); - rewriter.eraseOp(op); - return success(); -} + auto memRefType = hostUnregisterOp.getValue().getType(); + auto elementType = cast(memRefType).getElementType(); + auto elementSize = getSizeInBytes(loc, elementType, rewriter); -LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::AllocOp allocOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (adaptor.getHostShared()) - return rewriter.notifyMatchFailure( - allocOp, "host_shared allocation is not supported"); - - MemRefType memRefType = allocOp.getType(); - - if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) || - !isConvertibleAndHasIdentityMaps(memRefType) || - failed(isAsyncWithOneDependency(rewriter, allocOp))) - return failure(); - - auto loc = allocOp.getLoc(); - - // Get shape of the memref as values: static sizes are constant - // values and dynamic sizes are passed to 'alloc' as operands. - SmallVector shape; - SmallVector strides; - Value sizeBytes; - getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter, - shape, strides, sizeBytes); - - // Allocate the underlying buffer and store a pointer to it in the MemRef - // descriptor. - Type elementPtrType = this->getElementPtrType(memRefType); - auto stream = adaptor.getAsyncDependencies().front(); - Value allocatedPtr = - allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(); - if (!getTypeConverter()->useOpaquePointers()) - allocatedPtr = - rewriter.create(loc, elementPtrType, allocatedPtr); - - // No alignment. - Value alignedPtr = allocatedPtr; - - // Create the MemRef descriptor. - auto memRefDescriptor = this->createMemRefDescriptor( - loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); - - rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); - - return success(); -} + auto arguments = getTypeConverter()->promoteOperands( + loc, op->getOperands(), adaptor.getOperands(), rewriter); + arguments.push_back(elementSize); + hostUnregisterCallBuilder.create(loc, rewriter, arguments); -LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::DeallocOp deallocOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, deallocOp))) - return failure(); + rewriter.eraseOp(op); + return success(); + } - Location loc = deallocOp.getLoc(); + LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::AllocOp allocOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (adaptor.getHostShared()) + return rewriter.notifyMatchFailure( + allocOp, "host_shared allocation is not supported"); + + MemRefType memRefType = allocOp.getType(); + + if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) || + !isConvertibleAndHasIdentityMaps(memRefType) || + failed(isAsyncWithOneDependency(rewriter, allocOp))) + return failure(); + + auto loc = allocOp.getLoc(); + + // Get shape of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. + SmallVector shape; + SmallVector strides; + Value sizeBytes; + getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), + rewriter, shape, strides, sizeBytes); + + // Allocate the underlying buffer and store a pointer to it in the MemRef + // descriptor. + Type elementPtrType = this->getElementPtrType(memRefType); + auto stream = adaptor.getAsyncDependencies().front(); + Value allocatedPtr = + allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(); + if (!getTypeConverter()->useOpaquePointers()) + allocatedPtr = + rewriter.create(loc, elementPtrType, allocatedPtr); - Value pointer = - MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); - if (!getTypeConverter()->useOpaquePointers()) - pointer = rewriter.create(loc, llvmPointerType, pointer); - Value stream = adaptor.getAsyncDependencies().front(); - deallocCallBuilder.create(loc, rewriter, {pointer, stream}); + // No alignment. + Value alignedPtr = allocatedPtr; - rewriter.replaceOp(deallocOp, {stream}); - return success(); -} + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); -static bool isGpuAsyncTokenType(Value value) { - return isa(value.getType()); -} + rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); -// Converts !gpu.async.token operands of `async.yield` to runtime calls. The -// !gpu.async.token are lowered to stream within the async.execute region, but -// are passed as events between them. For each !gpu.async.token operand, we -// create an event and record it on the stream. -LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( - async::YieldOp yieldOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType)) - return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); - - Location loc = yieldOp.getLoc(); - SmallVector newOperands(adaptor.getOperands()); - llvm::SmallDenseSet streams; - for (auto &operand : yieldOp->getOpOperands()) { - if (!isGpuAsyncTokenType(operand.get())) - continue; - auto idx = operand.getOperandNumber(); - auto stream = adaptor.getOperands()[idx]; - auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); - eventRecordCallBuilder.create(loc, rewriter, {event, stream}); - newOperands[idx] = event; - streams.insert(stream); + return success(); } - for (auto stream : streams) - streamDestroyCallBuilder.create(loc, rewriter, {stream}); - rewriter.updateRootInPlace(yieldOp, - [&] { yieldOp->setOperands(newOperands); }); - return success(); -} + LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::DeallocOp deallocOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, deallocOp))) + return failure(); -// Returns whether `value` is the result of an LLVM::CallOp to `functionName`. -static bool isDefinedByCallTo(Value value, StringRef functionName) { - assert(isa(value.getType())); - if (auto defOp = value.getDefiningOp()) - return defOp.getCallee()->equals(functionName); - return false; -} + Location loc = deallocOp.getLoc(); -// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host -// with the stream/event operands. The operands are destroyed. That is, it -// assumes that it is not used afterwards or elsewhere. Otherwise we will get a -// runtime error. Eventually, we should guarantee this property. -LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::WaitOp waitOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (waitOp.getAsyncToken()) - return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); - - Location loc = waitOp.getLoc(); - - for (auto operand : adaptor.getOperands()) { - if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { - // The converted operand's definition created a stream. - streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); - streamDestroyCallBuilder.create(loc, rewriter, {operand}); - } else { - // Otherwise the converted operand is an event. This assumes that we use - // events in control flow code as well. - eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); - eventDestroyCallBuilder.create(loc, rewriter, {operand}); - } + Value pointer = + MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pointer = rewriter.create(loc, llvmPointerType, pointer); + Value stream = adaptor.getAsyncDependencies().front(); + deallocCallBuilder.create(loc, rewriter, {pointer, stream}); + + rewriter.replaceOp(deallocOp, {stream}); + return success(); } - rewriter.eraseOp(waitOp); - return success(); -} + static bool isGpuAsyncTokenType(Value value) { + return isa(value.getType()); + } -// Converts `gpu.wait async` to runtime calls. The converted op creates a new -// stream that is synchronized with stream/event operands. The operands are -// destroyed. That is, it assumes that it is not used afterwards or elsewhere. -// Otherwise we will get a runtime error. Eventually, we should guarantee this -// property. -LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::WaitOp waitOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (!waitOp.getAsyncToken()) - return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); - - Location loc = waitOp.getLoc(); - - auto insertionPoint = rewriter.saveInsertionPoint(); - SmallVector events; - for (auto pair : - llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) { - auto operand = std::get<1>(pair); - if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { - // The converted operand's definition created a stream. Insert an event - // into the stream just after the last use of the original token operand. - auto *defOp = std::get<0>(pair).getDefiningOp(); - rewriter.setInsertionPointAfter(defOp); + // Converts !gpu.async.token operands of `async.yield` to runtime calls. The + // !gpu.async.token are lowered to stream within the async.execute region, but + // are passed as events between them. For each !gpu.async.token operand, we + // create an event and record it on the stream. + LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( + async::YieldOp yieldOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType)) + return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); + + Location loc = yieldOp.getLoc(); + SmallVector newOperands(adaptor.getOperands()); + llvm::SmallDenseSet streams; + for (auto &operand : yieldOp->getOpOperands()) { + if (!isGpuAsyncTokenType(operand.get())) + continue; + auto idx = operand.getOperandNumber(); + auto stream = adaptor.getOperands()[idx]; auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); - eventRecordCallBuilder.create(loc, rewriter, {event, operand}); - events.push_back(event); - } else { - // Otherwise the converted operand is an event. This assumes that we use - // events in control flow code as well. - events.push_back(operand); + eventRecordCallBuilder.create(loc, rewriter, {event, stream}); + newOperands[idx] = event; + streams.insert(stream); } - } - rewriter.restoreInsertionPoint(insertionPoint); - auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); - for (auto event : events) - streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); - for (auto event : events) - eventDestroyCallBuilder.create(loc, rewriter, {event}); - rewriter.replaceOp(waitOp, {stream}); - - return success(); -} + for (auto stream : streams) + streamDestroyCallBuilder.create(loc, rewriter, {stream}); -// Creates a struct containing all kernel parameters on the stack and returns -// an array of type-erased pointers to the fields of the struct. The array can -// then be passed to the CUDA / ROCm (HIP) kernel launch calls. -// The generated code is essentially as follows: -// -// %struct = alloca(sizeof(struct { Parameters... })) -// %array = alloca(NumParameters * sizeof(void *)) -// for (i : [0, NumParameters)) -// %fieldPtr = llvm.getelementptr %struct[0, i] -// llvm.store parameters[i], %fieldPtr -// %elementPtr = llvm.getelementptr %array[i] -// llvm.store %fieldPtr, %elementPtr -// return %array -Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( - gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const { - auto loc = launchOp.getLoc(); - auto numKernelOperands = launchOp.getNumKernelOperands(); - SmallVector arguments; - if (kernelBarePtrCallConv) { - // Hack the bare pointer value on just for the argument promotion - LLVMTypeConverter *converter = getTypeConverter(); - LowerToLLVMOptions options = converter->getOptions(); - LowerToLLVMOptions overrideToMatchKernelOpts = options; - overrideToMatchKernelOpts.useBarePtrCallConv = true; - converter->dangerousSetOptions(overrideToMatchKernelOpts); - arguments = converter->promoteOperands( - loc, launchOp.getOperands().take_back(numKernelOperands), - adaptor.getOperands().take_back(numKernelOperands), builder); - converter->dangerousSetOptions(options); - } else { - arguments = getTypeConverter()->promoteOperands( - loc, launchOp.getOperands().take_back(numKernelOperands), - adaptor.getOperands().take_back(numKernelOperands), builder); + rewriter.updateRootInPlace(yieldOp, + [&] { yieldOp->setOperands(newOperands); }); + return success(); } - auto numArguments = arguments.size(); - SmallVector argumentTypes; - argumentTypes.reserve(numArguments); - for (auto argument : arguments) - argumentTypes.push_back(argument.getType()); - auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), - argumentTypes); - auto one = builder.create(loc, llvmInt32Type, 1); - auto structPtr = builder.create( - loc, getTypeConverter()->getPointerType(structType), structType, one, - /*alignment=*/0); - auto arraySize = - builder.create(loc, llvmInt32Type, numArguments); - auto arrayPtr = builder.create( - loc, llvmPointerPointerType, llvmPointerType, arraySize, /*alignment=*/0); - for (const auto &en : llvm::enumerate(arguments)) { - Value fieldPtr = builder.create( - loc, getTypeConverter()->getPointerType(argumentTypes[en.index()]), - structType, structPtr, ArrayRef{0, en.index()}); - builder.create(loc, en.value(), fieldPtr); - auto elementPtr = builder.create( - loc, llvmPointerPointerType, llvmPointerType, arrayPtr, - ArrayRef{en.index()}); - if (!getTypeConverter()->useOpaquePointers()) - fieldPtr = - builder.create(loc, llvmPointerType, fieldPtr); - builder.create(loc, fieldPtr, elementPtr); + // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. + static bool isDefinedByCallTo(Value value, StringRef functionName) { + assert(isa(value.getType())); + if (auto defOp = value.getDefiningOp()) + return defOp.getCallee()->equals(functionName); + return false; } - return arrayPtr; -} -// Generates an LLVM IR dialect global that contains the name of the given -// kernel function as a C string, and returns a pointer to its beginning. -// The code is essentially: -// -// llvm.global constant @kernel_name("function_name\00") -// func(...) { -// %0 = llvm.addressof @kernel_name -// %1 = llvm.constant (0 : index) -// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> -// } -Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( - StringRef moduleName, StringRef name, Location loc, - OpBuilder &builder) const { - // Make sure the trailing zero is included in the constant. - std::vector kernelName(name.begin(), name.end()); - kernelName.push_back('\0'); - - std::string globalName = - std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name)); - return LLVM::createGlobalString( - loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), - LLVM::Linkage::Internal, getTypeConverter()->useOpaquePointers()); -} + // Converts `gpu.wait` to runtime calls. The converted op synchronizes the + // host with the stream/event operands. The operands are destroyed. That is, + // it assumes that it is not used afterwards or elsewhere. Otherwise we will + // get a runtime error. Eventually, we should guarantee this property. + LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::WaitOp waitOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (waitOp.getAsyncToken()) + return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); + + Location loc = waitOp.getLoc(); + + for (auto operand : adaptor.getOperands()) { + if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { + // The converted operand's definition created a stream. + streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); + streamDestroyCallBuilder.create(loc, rewriter, {operand}); + } else { + // Otherwise the converted operand is an event. This assumes that we use + // events in control flow code as well. + eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); + eventDestroyCallBuilder.create(loc, rewriter, {operand}); + } + } -// Emits LLVM IR to launch a kernel function. Expects the module that contains -// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a -// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. -// -// %0 = call %binarygetter -// %1 = call %moduleLoad(%0) -// %2 = -// %3 = call %moduleGetFunction(%1, %2) -// %4 = call %streamCreate() -// %5 = -// call %launchKernel(%3, , 0, %4, %5, nullptr) -// call %streamSynchronize(%4) -// call %streamDestroy(%4) -// call %moduleUnload(%1) -// -// If the op is async, the stream corresponds to the (single) async dependency -// as well as the async token the op produces. -LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter))) - return failure(); - - if (launchOp.getAsyncDependencies().size() > 1) - return rewriter.notifyMatchFailure( - launchOp, "Cannot convert with more than one async dependency."); - - // Fail when the synchronous version of the op has async dependencies. The - // lowering destroys the stream, and we do not want to check that there is no - // use of the stream after this op. - if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty()) - return rewriter.notifyMatchFailure( - launchOp, "Cannot convert non-async op with async dependencies."); - - Location loc = launchOp.getLoc(); - - // Create an LLVM global with CUBIN extracted from the kernel annotation and - // obtain a pointer to the first byte in it. - auto kernelModule = SymbolTable::lookupNearestSymbolFrom( - launchOp, launchOp.getKernelModuleName()); - assert(kernelModule && "expected a kernel module"); - - auto binaryAttr = - kernelModule->getAttrOfType(gpuBinaryAnnotation); - if (!binaryAttr) { - kernelModule.emitOpError() - << "missing " << gpuBinaryAnnotation << " attribute"; - return failure(); + rewriter.eraseOp(waitOp); + return success(); } - SmallString<128> nameBuffer(kernelModule.getName()); - nameBuffer.append(kGpuBinaryStorageSuffix); - Value data = LLVM::createGlobalString( - loc, rewriter, nameBuffer.str(), binaryAttr.getValue(), - LLVM::Linkage::Internal, getTypeConverter()->useOpaquePointers()); - - auto module = moduleLoadCallBuilder.create(loc, rewriter, data); - // Get the function from the module. The name corresponds to the name of - // the kernel function. - auto kernelName = generateKernelNameConstant( - launchOp.getKernelModuleName().getValue(), - launchOp.getKernelName().getValue(), loc, rewriter); - auto function = moduleGetFunctionCallBuilder.create( - loc, rewriter, {module.getResult(), kernelName}); - Value zero = rewriter.create(loc, llvmInt32Type, 0); - Value stream = - adaptor.getAsyncDependencies().empty() - ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult() - : adaptor.getAsyncDependencies().front(); - // Create array of pointers to kernel arguments. - auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter); - auto nullpointer = rewriter.create(loc, llvmPointerPointerType); - Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize() - ? launchOp.getDynamicSharedMemorySize() - : zero; - launchKernelCallBuilder.create( - loc, rewriter, - {function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(), - adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), - adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams, - /*extra=*/nullpointer}); - - if (launchOp.getAsyncToken()) { - // Async launch: make dependent ops use the same stream. - rewriter.replaceOp(launchOp, {stream}); - } else { - // Synchronize with host and destroy stream. This must be the stream created - // above (with no other uses) because we check that the synchronous version - // does not have any async dependencies. - streamSynchronizeCallBuilder.create(loc, rewriter, stream); - streamDestroyCallBuilder.create(loc, rewriter, stream); - rewriter.eraseOp(launchOp); + // Converts `gpu.wait async` to runtime calls. The converted op creates a new + // stream that is synchronized with stream/event operands. The operands are + // destroyed. That is, it assumes that it is not used afterwards or elsewhere. + // Otherwise we will get a runtime error. Eventually, we should guarantee this + // property. + LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::WaitOp waitOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (!waitOp.getAsyncToken()) + return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); + + Location loc = waitOp.getLoc(); + + auto insertionPoint = rewriter.saveInsertionPoint(); + SmallVector events; + for (auto pair : + llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) { + auto operand = std::get<1>(pair); + if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { + // The converted operand's definition created a stream. Insert an event + // into the stream just after the last use of the original token + // operand. + auto *defOp = std::get<0>(pair).getDefiningOp(); + rewriter.setInsertionPointAfter(defOp); + auto event = + eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); + eventRecordCallBuilder.create(loc, rewriter, {event, operand}); + events.push_back(event); + } else { + // Otherwise the converted operand is an event. This assumes that we use + // events in control flow code as well. + events.push_back(operand); + } + } + rewriter.restoreInsertionPoint(insertionPoint); + auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); + for (auto event : events) + streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); + for (auto event : events) + eventDestroyCallBuilder.create(loc, rewriter, {event}); + rewriter.replaceOp(waitOp, {stream}); + + return success(); } - moduleUnloadCallBuilder.create(loc, rewriter, module.getResult()); - return success(); -} + // Creates a struct containing all kernel parameters on the stack and returns + // an array of type-erased pointers to the fields of the struct. The array can + // then be passed to the CUDA / ROCm (HIP) kernel launch calls. + // The generated code is essentially as follows: + // + // %struct = alloca(sizeof(struct { Parameters... })) + // %array = alloca(NumParameters * sizeof(void *)) + // for (i : [0, NumParameters)) + // %fieldPtr = llvm.getelementptr %struct[0, i] + // llvm.store parameters[i], %fieldPtr + // %elementPtr = llvm.getelementptr %array[i] + // llvm.store %fieldPtr, %elementPtr + // return %array + Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( + gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder & builder) + const { + auto loc = launchOp.getLoc(); + auto numKernelOperands = launchOp.getNumKernelOperands(); + SmallVector arguments; + if (kernelBarePtrCallConv) { + // Hack the bare pointer value on just for the argument promotion + LLVMTypeConverter *converter = getTypeConverter(); + LowerToLLVMOptions options = converter->getOptions(); + LowerToLLVMOptions overrideToMatchKernelOpts = options; + overrideToMatchKernelOpts.useBarePtrCallConv = true; + converter->dangerousSetOptions(overrideToMatchKernelOpts); + arguments = converter->promoteOperands( + loc, launchOp.getOperands().take_back(numKernelOperands), + adaptor.getOperands().take_back(numKernelOperands), builder); + converter->dangerousSetOptions(options); + } else { + arguments = getTypeConverter()->promoteOperands( + loc, launchOp.getOperands().take_back(numKernelOperands), + adaptor.getOperands().take_back(numKernelOperands), builder); + } -static Value bitAndAddrspaceCast(Location loc, - ConversionPatternRewriter &rewriter, - LLVM::LLVMPointerType destinationType, - Value sourcePtr, - LLVMTypeConverter &typeConverter) { - auto sourceTy = cast(sourcePtr.getType()); - if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) - sourcePtr = rewriter.create( - loc, - typeConverter.getPointerType(sourceTy.getElementType(), - destinationType.getAddressSpace()), - sourcePtr); - - if (typeConverter.useOpaquePointers()) - return sourcePtr; - - return rewriter.create(loc, destinationType, sourcePtr); -} + auto numArguments = arguments.size(); + SmallVector argumentTypes; + argumentTypes.reserve(numArguments); + for (auto argument : arguments) + argumentTypes.push_back(argument.getType()); + auto structType = LLVM::LLVMStructType::getNewIdentified( + context, StringRef(), argumentTypes); + auto one = builder.create(loc, llvmInt32Type, 1); + auto structPtr = builder.create( + loc, getTypeConverter()->getPointerType(structType), structType, one, + /*alignment=*/0); + auto arraySize = + builder.create(loc, llvmInt32Type, numArguments); + auto arrayPtr = builder.create(loc, llvmPointerPointerType, + llvmPointerType, arraySize, + /*alignment=*/0); + for (const auto &en : llvm::enumerate(arguments)) { + Value fieldPtr = builder.create( + loc, getTypeConverter()->getPointerType(argumentTypes[en.index()]), + structType, structPtr, ArrayRef{0, en.index()}); + builder.create(loc, en.value(), fieldPtr); + auto elementPtr = builder.create( + loc, llvmPointerPointerType, llvmPointerType, arrayPtr, + ArrayRef{en.index()}); + if (!getTypeConverter()->useOpaquePointers()) + fieldPtr = + builder.create(loc, llvmPointerType, fieldPtr); + builder.create(loc, fieldPtr, elementPtr); + } + return arrayPtr; + } -LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto memRefType = cast(memcpyOp.getSrc().getType()); + // Generates an LLVM IR dialect global that contains the name of the given + // kernel function as a C string, and returns a pointer to its beginning. + // The code is essentially: + // + // llvm.global constant @kernel_name("function_name\00") + // func(...) { + // %0 = llvm.addressof @kernel_name + // %1 = llvm.constant (0 : index) + // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> + // } + Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( + StringRef moduleName, StringRef name, Location loc, OpBuilder & builder) + const { + // Make sure the trailing zero is included in the constant. + std::vector kernelName(name.begin(), name.end()); + kernelName.push_back('\0'); + + std::string globalName = + std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name)); + return LLVM::createGlobalString( + loc, builder, globalName, + StringRef(kernelName.data(), kernelName.size()), + LLVM::Linkage::Internal, getTypeConverter()->useOpaquePointers()); + } - if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || - !isConvertibleAndHasIdentityMaps(memRefType) || - failed(isAsyncWithOneDependency(rewriter, memcpyOp))) - return failure(); + // Emits LLVM IR to launch a kernel function. Expects the module that contains + // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a + // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. + // + // %0 = call %binarygetter + // %1 = call %moduleLoad(%0) + // %2 = + // %3 = call %moduleGetFunction(%1, %2) + // %4 = call %streamCreate() + // %5 = + // call %launchKernel(%3, , 0, %4, %5, nullptr) + // call %streamSynchronize(%4) + // call %streamDestroy(%4) + // call %moduleUnload(%1) + // + // If the op is async, the stream corresponds to the (single) async dependency + // as well as the async token the op produces. + LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter))) + return failure(); + + if (launchOp.getAsyncDependencies().size() > 1) + return rewriter.notifyMatchFailure( + launchOp, "Cannot convert with more than one async dependency."); + + // Fail when the synchronous version of the op has async dependencies. The + // lowering destroys the stream, and we do not want to check that there is + // no use of the stream after this op. + if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty()) + return rewriter.notifyMatchFailure( + launchOp, "Cannot convert non-async op with async dependencies."); + + Location loc = launchOp.getLoc(); + + // Create an LLVM global with CUBIN extracted from the kernel annotation and + // obtain a pointer to the first byte in it. + auto kernelModule = SymbolTable::lookupNearestSymbolFrom( + launchOp, launchOp.getKernelModuleName()); + assert(kernelModule && "expected a kernel module"); + + auto binaryAttr = + kernelModule->getAttrOfType(gpuBinaryAnnotation); + if (!binaryAttr) { + kernelModule.emitOpError() + << "missing " << gpuBinaryAnnotation << " attribute"; + return failure(); + } - auto loc = memcpyOp.getLoc(); + SmallString<128> nameBuffer(kernelModule.getName()); + nameBuffer.append(kGpuBinaryStorageSuffix); + Value data = LLVM::createGlobalString( + loc, rewriter, nameBuffer.str(), binaryAttr.getValue(), + LLVM::Linkage::Internal, getTypeConverter()->useOpaquePointers()); + + auto module = moduleLoadCallBuilder.create(loc, rewriter, data); + // Get the function from the module. The name corresponds to the name of + // the kernel function. + auto kernelName = generateKernelNameConstant( + launchOp.getKernelModuleName().getValue(), + launchOp.getKernelName().getValue(), loc, rewriter); + auto function = moduleGetFunctionCallBuilder.create( + loc, rewriter, {module.getResult(), kernelName}); + Value zero = rewriter.create(loc, llvmInt32Type, 0); + Value stream = + adaptor.getAsyncDependencies().empty() + ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult() + : adaptor.getAsyncDependencies().front(); + // Create array of pointers to kernel arguments. + auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter); + auto nullpointer = + rewriter.create(loc, llvmPointerPointerType); + Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize() + ? launchOp.getDynamicSharedMemorySize() + : zero; + launchKernelCallBuilder.create( + loc, rewriter, + {function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(), + adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), + adaptor.getBlockSizeY(), adaptor.getBlockSizeZ(), + dynamicSharedMemorySize, stream, kernelParams, + /*extra=*/nullpointer}); + + if (launchOp.getAsyncToken()) { + // Async launch: make dependent ops use the same stream. + rewriter.replaceOp(launchOp, {stream}); + } else { + // Synchronize with host and destroy stream. This must be the stream + // created above (with no other uses) because we check that the + // synchronous version does not have any async dependencies. + streamSynchronizeCallBuilder.create(loc, rewriter, stream); + streamDestroyCallBuilder.create(loc, rewriter, stream); + rewriter.eraseOp(launchOp); + } + moduleUnloadCallBuilder.create(loc, rewriter, module.getResult()); + + return success(); + } - MemRefDescriptor srcDesc(adaptor.getSrc()); - Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); + static Value bitAndAddrspaceCast( + Location loc, ConversionPatternRewriter & rewriter, + LLVM::LLVMPointerType destinationType, Value sourcePtr, + LLVMTypeConverter & typeConverter) { + auto sourceTy = cast(sourcePtr.getType()); + if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) + sourcePtr = rewriter.create( + loc, + typeConverter.getPointerType(sourceTy.getElementType(), + destinationType.getAddressSpace()), + sourcePtr); + + if (typeConverter.useOpaquePointers()) + return sourcePtr; + + return rewriter.create(loc, destinationType, sourcePtr); + } - Type elementPtrType = getElementPtrType(memRefType); - Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( - loc, elementPtrType, - typeConverter->convertType(memRefType.getElementType()), nullPtr, - numElements); - auto sizeBytes = - rewriter.create(loc, getIndexType(), gepPtr); + LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + auto memRefType = cast(memcpyOp.getSrc().getType()); - auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, - srcDesc.alignedPtr(rewriter, loc), - *getTypeConverter()); - auto dst = bitAndAddrspaceCast( - loc, rewriter, llvmPointerType, - MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc), - *getTypeConverter()); + if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || + !isConvertibleAndHasIdentityMaps(memRefType) || + failed(isAsyncWithOneDependency(rewriter, memcpyOp))) + return failure(); - auto stream = adaptor.getAsyncDependencies().front(); - memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); + auto loc = memcpyOp.getLoc(); - rewriter.replaceOp(memcpyOp, {stream}); + MemRefDescriptor srcDesc(adaptor.getSrc()); + Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); - return success(); -} + Type elementPtrType = getElementPtrType(memRefType); + Value nullPtr = rewriter.create(loc, elementPtrType); + Value gepPtr = rewriter.create( + loc, elementPtrType, + typeConverter->convertType(memRefType.getElementType()), nullPtr, + numElements); + auto sizeBytes = + rewriter.create(loc, getIndexType(), gepPtr); -LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::MemsetOp memsetOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto memRefType = cast(memsetOp.getDst().getType()); + auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, + srcDesc.alignedPtr(rewriter, loc), + *getTypeConverter()); + auto dst = bitAndAddrspaceCast( + loc, rewriter, llvmPointerType, + MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc), + *getTypeConverter()); - if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || - !isConvertibleAndHasIdentityMaps(memRefType) || - failed(isAsyncWithOneDependency(rewriter, memsetOp))) - return failure(); + auto stream = adaptor.getAsyncDependencies().front(); + memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); - auto loc = memsetOp.getLoc(); + rewriter.replaceOp(memcpyOp, {stream}); - Type valueType = adaptor.getValue().getType(); - if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) { - return rewriter.notifyMatchFailure(memsetOp, - "value must be a 32 bit scalar"); + return success(); } - MemRefDescriptor dstDesc(adaptor.getDst()); - Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); + LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::MemsetOp memsetOp, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + auto memRefType = cast(memsetOp.getDst().getType()); - auto value = - rewriter.create(loc, llvmInt32Type, adaptor.getValue()); - auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, - dstDesc.alignedPtr(rewriter, loc), - *getTypeConverter()); + if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || + !isConvertibleAndHasIdentityMaps(memRefType) || + failed(isAsyncWithOneDependency(rewriter, memsetOp))) + return failure(); - auto stream = adaptor.getAsyncDependencies().front(); - memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream}); + auto loc = memsetOp.getLoc(); - rewriter.replaceOp(memsetOp, {stream}); - return success(); -} + Type valueType = adaptor.getValue().getType(); + if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) { + return rewriter.notifyMatchFailure(memsetOp, + "value must be a 32 bit scalar"); + } -LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Location loc = op.getLoc(); - setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.getDevIndex()}); - rewriter.replaceOp(op, {}); - return success(); -} + MemRefDescriptor dstDesc(adaptor.getDst()); + Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); -// Returns the element type of the defining spmat op. -// TODO: safer and more flexible to store data type in actual op instead? -static Type getSpMatElemType(Value spMat) { - if (auto op = spMat.getDefiningOp()) - return llvm::cast(op.getValues().getType()).getElementType(); - if (auto op = spMat.getDefiningOp()) - return llvm::cast(op.getValues().getType()).getElementType(); - llvm_unreachable("cannot find spmat def"); -} + auto value = rewriter.create(loc, llvmInt32Type, + adaptor.getValue()); + auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, + dstDesc.alignedPtr(rewriter, loc), + *getTypeConverter()); -// Returns the element type of the defining dnmat or dnvec op. -static Type getDnElemType(Value dn) { - if (auto op = dn.getDefiningOp()) - return op.getMemref().getType().getElementType(); - if (auto op = dn.getDefiningOp()) - return op.getMemref().getType().getElementType(); - llvm_unreachable("cannot find dn def"); -} + auto stream = adaptor.getAsyncDependencies().front(); + memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream}); -template -static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) { - Type llvmInt32Type = builder.getIntegerType(32); - return builder.create(loc, llvmInt32Type, - static_cast(TValue)); -} + rewriter.replaceOp(memsetOp, {stream}); + return success(); + } -static Value -genConstInt32FromOptionalComputeMode(OpBuilder &builder, Location loc, - std::optional computeTypeOptional, - Type defaultType) { - auto computeTypeInt = - getCuSparseDataTypeFrom(computeTypeOptional.value_or(defaultType)); - auto computeType = genConstInt32From(builder, loc, computeTypeInt); - return computeType; -} + LogicalResult + ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + Location loc = op.getLoc(); + setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.getDevIndex()}); + rewriter.replaceOp(op, {}); + return success(); + } -LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::CreateSparseEnvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - auto handle = - createSparseEnvCallBuilder.create(loc, rewriter, {stream}).getResult(); - rewriter.replaceOp(op, {handle, stream}); - return success(); -} + // Returns the element type of the defining spmat op. + // TODO: safer and more flexible to store data type in actual op instead? + static Type getSpMatElemType(Value spMat) { + if (auto op = spMat.getDefiningOp()) + return llvm::cast(op.getValues().getType()).getElementType(); + if (auto op = spMat.getDefiningOp()) + return llvm::cast(op.getValues().getType()).getElementType(); + if (auto op = spMat.getDefiningOp()) + return op.getMemref().getType().getElementType(); + llvm_unreachable("cannot find spmat def"); + } -LogicalResult ConvertDestroySparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::DestroySparseEnvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - destroySparseEnvCallBuilder.create(loc, rewriter, {adaptor.getEnv(), stream}); - rewriter.replaceOp(op, {stream}); - return success(); -} + // Returns the element type of the defining dnmat or dnvec op. + static Type getDnElemType(Value dn) { + if (auto op = dn.getDefiningOp()) + return op.getMemref().getType().getElementType(); + if (auto op = dn.getDefiningOp()) + return op.getMemref().getType().getElementType(); + llvm_unreachable("cannot find dn def"); + } -LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::CreateDnVecOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - Value pVec = - MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); - if (!getTypeConverter()->useOpaquePointers()) - pVec = rewriter.create(loc, llvmPointerType, pVec); - Type dType = op.getMemref().getType().getElementType(); - auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); - auto handle = - createDnVecCallBuilder - .create(loc, rewriter, {adaptor.getSize(), pVec, dtp, stream}) - .getResult(); - rewriter.replaceOp(op, {handle, stream}); - return success(); -} + template + static Value genConstInt32From(OpBuilder & builder, Location loc, T TValue) { + Type llvmInt32Type = builder.getIntegerType(32); + return builder.create(loc, llvmInt32Type, + static_cast(TValue)); + } -LogicalResult ConvertDestroyDnVecOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::DestroyDnVecOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - destroyDnVecCallBuilder.create(loc, rewriter, {adaptor.getDvec(), stream}); - rewriter.replaceOp(op, {stream}); - return success(); -} + static Value genConstInt32FromOptionalComputeMode( + OpBuilder & builder, Location loc, + std::optional computeTypeOptional, Type defaultType) { + auto computeTypeInt = + getCuSparseDataTypeFrom(computeTypeOptional.value_or(defaultType)); + auto computeType = genConstInt32From(builder, loc, computeTypeInt); + return computeType; + } -LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::CreateDnMatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - Value pMat = - MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); - if (!getTypeConverter()->useOpaquePointers()) - pMat = rewriter.create(loc, llvmPointerType, pMat); - Type dType = op.getMemref().getType().getElementType(); - auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); - auto handle = - createDnMatCallBuilder - .create(loc, rewriter, - {adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream}) - .getResult(); - rewriter.replaceOp(op, {handle, stream}); - return success(); -} + LogicalResult + ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::CreateSparseEnvOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + // use the cusparseLt create call if the dnmat is used with spmat with + // 2:4 sparsity + Value handle; + if (inferSpMMType(op.getEnv()) == "cusparseLt") { + handle = createSparseLtEnvCallBuilder.create(loc, rewriter, {stream}) + .getResult(); + } else { + handle = createSparseEnvCallBuilder.create(loc, rewriter, {stream}) + .getResult(); + } + rewriter.replaceOp(op, {handle, stream}); + return success(); + } -LogicalResult ConvertDestroyDnMatOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::DestroyDnMatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - destroyDnMatCallBuilder.create(loc, rewriter, {adaptor.getDmat(), stream}); - rewriter.replaceOp(op, {stream}); - return success(); -} + LogicalResult + ConvertDestroySparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::DestroySparseEnvOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + // use the cusparseLt destroy call if the dnmat is used with spmat with + // 2:4 sparsity + if (inferSpMMType(op.getEnv()) == "cusparseLt") { + destroySparseLtEnvCallBuilder.create(loc, rewriter, + {adaptor.getEnv(), stream}); + } else { + destroySparseEnvCallBuilder.create(loc, rewriter, + {adaptor.getEnv(), stream}); + } + rewriter.replaceOp(op, {stream}); + return success(); + } -LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::CreateCooOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - Value pRowIdxs = - MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc); - Value pColIdxs = - MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc); - Value pValues = - MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); - if (!getTypeConverter()->useOpaquePointers()) { - pRowIdxs = rewriter.create(loc, llvmPointerType, pRowIdxs); - pColIdxs = rewriter.create(loc, llvmPointerType, pColIdxs); - pValues = rewriter.create(loc, llvmPointerType, pValues); + LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::CreateDnVecOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + Value pVec = + MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pVec = rewriter.create(loc, llvmPointerType, pVec); + Type dType = op.getMemref().getType().getElementType(); + auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); + auto handle = + createDnVecCallBuilder + .create(loc, rewriter, {adaptor.getSize(), pVec, dtp, stream}) + .getResult(); + rewriter.replaceOp(op, {handle, stream}); + return success(); } - Type iType = - llvm::cast(op.getColIdxs().getType()).getElementType(); - Type dType = - llvm::cast(op.getValues().getType()).getElementType(); - auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); - auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); - auto handle = - createCooCallBuilder - .create(loc, rewriter, - {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), - pRowIdxs, pColIdxs, pValues, itp, dtp, stream}) - .getResult(); - rewriter.replaceOp(op, {handle, stream}); - return success(); -} -LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::CreateCsrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - Value pRowPos = - MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc); - Value pColIdxs = - MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc); - Value pValues = - MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); - if (!getTypeConverter()->useOpaquePointers()) { - pRowPos = rewriter.create(loc, llvmPointerType, pRowPos); - pColIdxs = rewriter.create(loc, llvmPointerType, pColIdxs); - pValues = rewriter.create(loc, llvmPointerType, pValues); + LogicalResult ConvertDestroyDnVecOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::DestroyDnVecOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + destroyDnVecCallBuilder.create(loc, rewriter, {adaptor.getDvec(), stream}); + rewriter.replaceOp(op, {stream}); + return success(); } - Type pType = - llvm::cast(op.getRowPos().getType()).getElementType(); - Type iType = - llvm::cast(op.getColIdxs().getType()).getElementType(); - Type dType = - llvm::cast(op.getValues().getType()).getElementType(); - auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType)); - auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); - auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); - auto handle = - createCsrCallBuilder - .create(loc, rewriter, - {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), - pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream}) - .getResult(); - rewriter.replaceOp(op, {handle, stream}); - return success(); -} -LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::DestroySpMatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto stream = adaptor.getAsyncDependencies().front(); - destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream}); - rewriter.replaceOp(op, {stream}); - return success(); -} + LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::CreateDnMatOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + Value pMat = + MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pMat = rewriter.create(loc, llvmPointerType, pMat); + // TODO: For now, we track the use of the handle and lower it to cusparse / + // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are + // used, we require two separate Creation ops to be the correct logic. In + // future, we may add support to using one handle in sparse tensor / GPU + // dialect in both cusparse and cusparseLt. use the cusparseLt create call + // if the dnmat is used with spmat with 2:4 sparsity + Value handle; + Type dType = op.getMemref().getType().getElementType(); + auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); + if (inferSpMMType(op.getDmat()) == "cusparseLt") { + auto envHandle = adaptor.getEnv(); + handle = createLtDnMatCallBuilder + .create(loc, rewriter, + {envHandle, adaptor.getRows(), adaptor.getCols(), + pMat, dtp, stream}) + .getResult(); + } else { + handle = + createDnMatCallBuilder + .create(loc, rewriter, + {adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream}) + .getResult(); + } + rewriter.replaceOp(op, {handle, stream}); + return success(); + } -LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::SpMVBufferSizeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto modeA = genConstInt32From(rewriter, loc, op.getModeA()); - // retrieve the compute type, notice that it may be optional - auto computeType = genConstInt32FromOptionalComputeMode( - rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY())); - auto stream = adaptor.getAsyncDependencies().front(); - auto bufferSize = - spMVBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), modeA, adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), computeType, stream}) - .getResult(); - rewriter.replaceOp(op, {bufferSize, stream}); - return success(); -} + LogicalResult ConvertDestroyDnMatOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::DestroyDnMatOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + // use the cusparseLt destroy call if the dnmat is used with spmat with + // 2:4 sparsity + if (inferSpMMType(op.getDmat()) == "cusparseLt") { + destroyCuSparseLtDnMatBuilder.create(loc, rewriter, + {adaptor.getDmat(), stream}); + } else { + destroyDnMatCallBuilder.create(loc, rewriter, + {adaptor.getDmat(), stream}); + } + rewriter.replaceOp(op, {stream}); + return success(); + } -LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::SpMVOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); - // retrieve the compute type, notice that it may be optional - auto computeType = genConstInt32FromOptionalComputeMode( - rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY())); - auto stream = adaptor.getAsyncDependencies().front(); - Value pBuf = - MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); - if (!getTypeConverter()->useOpaquePointers()) - pBuf = rewriter.create(loc, llvmPointerType, pBuf); - spMVCallBuilder.create(loc, rewriter, - {adaptor.getEnv(), modeA, adaptor.getSpmatA(), - adaptor.getDnX(), adaptor.getDnY(), computeType, pBuf, - stream}); - rewriter.replaceOp(op, {stream}); - return success(); -} + LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::CreateCooOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + Value pRowIdxs = + MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc); + Value pColIdxs = + MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc); + Value pValues = + MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) { + pRowIdxs = + rewriter.create(loc, llvmPointerType, pRowIdxs); + pColIdxs = + rewriter.create(loc, llvmPointerType, pColIdxs); + pValues = rewriter.create(loc, llvmPointerType, pValues); + } + Type iType = + llvm::cast(op.getColIdxs().getType()).getElementType(); + Type dType = + llvm::cast(op.getValues().getType()).getElementType(); + auto itp = + genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); + auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); + auto handle = + createCooCallBuilder + .create(loc, rewriter, + {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), + pRowIdxs, pColIdxs, pValues, itp, dtp, stream}) + .getResult(); + rewriter.replaceOp(op, {handle, stream}); + return success(); + } -LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::SpMMBufferSizeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); - auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); - auto stream = adaptor.getAsyncDependencies().front(); - // retrieve the compute type, notice that it may be optional - auto computeType = genConstInt32FromOptionalComputeMode( - rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC())); - - auto bufferSize = spMMBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, - adaptor.getSpmatA(), adaptor.getDnmatB(), - adaptor.getDnmatC(), computeType, stream}) - .getResult(); - rewriter.replaceOp(op, {bufferSize, stream}); - return success(); -} + LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::CreateCsrOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + Value pRowPos = + MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc); + Value pColIdxs = + MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc); + Value pValues = + MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) { + pRowPos = rewriter.create(loc, llvmPointerType, pRowPos); + pColIdxs = + rewriter.create(loc, llvmPointerType, pColIdxs); + pValues = rewriter.create(loc, llvmPointerType, pValues); + } + Type pType = + llvm::cast(op.getRowPos().getType()).getElementType(); + Type iType = + llvm::cast(op.getColIdxs().getType()).getElementType(); + Type dType = + llvm::cast(op.getValues().getType()).getElementType(); + auto ptp = + genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType)); + auto itp = + genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); + auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); + auto handle = + createCsrCallBuilder + .create(loc, rewriter, + {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), + pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream}) + .getResult(); + rewriter.replaceOp(op, {handle, stream}); + return success(); + } -LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); - auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); - auto computeType = genConstInt32FromOptionalComputeMode( - rewriter, loc, adaptor.getComputeType(), - getSpMatElemType(op.getSpmatC())); - auto stream = adaptor.getAsyncDependencies().front(); - auto bufferSize = SDDMMBufferSizeCallBuilder - .create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, - adaptor.getDnmatA(), adaptor.getDnmatB(), - adaptor.getSpmatC(), computeType, stream}) - .getResult(); - rewriter.replaceOp(op, {bufferSize, stream}); - return success(); -} + LogicalResult + ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::Create2To4SpMatOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + Value pMat = + MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pMat = rewriter.create(loc, llvmPointerType, pMat); + Type dType = + llvm::cast(op.getMemref().getType()).getElementType(); + auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); + auto envHandle = adaptor.getEnv(); + auto handle = create2To4SpMatCallBuilder + .create(loc, rewriter, + {envHandle, adaptor.getRows(), adaptor.getCols(), + pMat, dw, stream}) + .getResult(); + rewriter.replaceOp(op, {handle, stream}); + return success(); + } -LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::SpMMOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); - auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); - // retrieve the compute type, notice that it may be optional - auto computeType = genConstInt32FromOptionalComputeMode( - rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC())); - - auto stream = adaptor.getAsyncDependencies().front(); - Value pBuf = - MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); - if (!getTypeConverter()->useOpaquePointers()) - pBuf = rewriter.create(loc, llvmPointerType, pBuf); - spMMCallBuilder.create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(), - adaptor.getDnmatB(), adaptor.getDnmatC(), computeType, - pBuf, stream}); - rewriter.replaceOp(op, {stream}); - return success(); -} + LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::DestroySpMatOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto stream = adaptor.getAsyncDependencies().front(); + // use the cusparseLt destroy call if the spmat is 2:4 sparsity + if (is2To4Sparsity(op.getSpmat())) { + destroyCuSparseLtSpMatBuilder.create(loc, rewriter, + {adaptor.getSpmat(), stream}); -template -static void addOpaquePointerConversion(LLVMTypeConverter &converter) { - converter.addConversion([&converter](T) -> Type { - return converter.getPointerType( - IntegerType::get(&converter.getContext(), 8)); - }); -} + } else { + destroySpMatCallBuilder.create(loc, rewriter, + {adaptor.getSpmat(), stream}); + } + rewriter.replaceOp(op, {stream}); + return success(); + } -LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite( - gpu::SDDMMOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || - failed(isAsyncWithOneDependency(rewriter, op))) - return failure(); - Location loc = op.getLoc(); - auto computeType = genConstInt32FromOptionalComputeMode( - rewriter, loc, adaptor.getComputeType(), - getSpMatElemType(op.getSpmatC())); - auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); - auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); - auto stream = adaptor.getAsyncDependencies().front(); - Value pBuf = - MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); - if (!getTypeConverter()->useOpaquePointers()) - pBuf = rewriter.create(loc, llvmPointerType, pBuf); - SDDMMCallBuilder.create(loc, rewriter, - {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(), - adaptor.getDnmatB(), adaptor.getSpmatC(), - computeType, pBuf, stream}); - rewriter.replaceOp(op, {stream}); - return success(); -} + LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SpMVBufferSizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto modeA = genConstInt32From(rewriter, loc, op.getModeA()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstInt32FromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY())); + auto stream = adaptor.getAsyncDependencies().front(); + auto bufferSize = + spMVBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), modeA, adaptor.getSpmatA(), + adaptor.getDnX(), adaptor.getDnY(), computeType, stream}) + .getResult(); + rewriter.replaceOp(op, {bufferSize, stream}); + return success(); + } -void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns, - StringRef gpuBinaryAnnotation, - bool kernelBarePtrCallConv) { - addOpaquePointerConversion(converter); - addOpaquePointerConversion(converter); - addOpaquePointerConversion(converter); - addOpaquePointerConversion(converter); - addOpaquePointerConversion(converter); - - patterns.add(converter); - patterns.add( - converter, gpuBinaryAnnotation, kernelBarePtrCallConv); - patterns.add(&converter.getContext()); -} + LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SpMVOp op, OpAdaptor adaptor, ConversionPatternRewriter & rewriter) + const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstInt32FromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY())); + auto stream = adaptor.getAsyncDependencies().front(); + Value pBuf = + MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pBuf = rewriter.create(loc, llvmPointerType, pBuf); + spMVCallBuilder.create(loc, rewriter, + {adaptor.getEnv(), modeA, adaptor.getSpmatA(), + adaptor.getDnX(), adaptor.getDnY(), computeType, + pBuf, stream}); + rewriter.replaceOp(op, {stream}); + return success(); + } + + LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SpMMBufferSizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); + auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); + auto stream = adaptor.getAsyncDependencies().front(); + auto computeType = genConstInt32FromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC())); + + Value bufferSize; + if (is2To4Sparsity(op.getSpmatA())) { + bufferSize = cuSparseLtSpmmBufferSizeBuilder + .create(loc, rewriter, + {adaptor.getEnv(), adaptor.getSpmatA(), stream}) + .getResult(); + } else { + bufferSize = spMMBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getSpmatA(), adaptor.getDnmatB(), + adaptor.getDnmatC(), computeType, stream}) + .getResult(); + } + rewriter.replaceOp(op, {bufferSize, stream}); + return success(); + } + + LogicalResult + ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter & rewriter) const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); + auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); + auto computeType = genConstInt32FromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), + getSpMatElemType(op.getSpmatC())); + auto stream = adaptor.getAsyncDependencies().front(); + auto bufferSize = SDDMMBufferSizeCallBuilder + .create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getDnmatA(), adaptor.getDnmatB(), + adaptor.getSpmatC(), computeType, stream}) + .getResult(); + rewriter.replaceOp(op, {bufferSize, stream}); + return success(); + } + + LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SpMMOp op, OpAdaptor adaptor, ConversionPatternRewriter & rewriter) + const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); + auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); + // retrieve the compute type, notice that it may be optional + auto computeType = genConstInt32FromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC())); + + auto stream = adaptor.getAsyncDependencies().front(); + + // lower to cusparseLt if applicable + if (is2To4Sparsity(op.getSpmatA())) { + SmallVector pBufs; + for (Value buffer : adaptor.getBuffers()) { + Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pBuf = rewriter.create(loc, llvmPointerType, pBuf); + pBufs.push_back(pBuf); + } + cuSparseLtSpmmBuilder.create(loc, rewriter, + {adaptor.getEnv(), adaptor.getSpmatA(), + adaptor.getDnmatB(), adaptor.getDnmatC(), + computeType, pBufs[0], pBufs[1], pBufs[2], + stream}); + } else { + Value pBuf = MemRefDescriptor(adaptor.getBuffers().front()) + .allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pBuf = rewriter.create(loc, llvmPointerType, pBuf); + spMMCallBuilder.create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getSpmatA(), adaptor.getDnmatB(), + adaptor.getDnmatC(), computeType, pBuf, stream}); + } + rewriter.replaceOp(op, {stream}); + return success(); + } + + template + static void addOpaquePointerConversion(LLVMTypeConverter & converter) { + converter.addConversion([&converter](T) -> Type { + return converter.getPointerType( + IntegerType::get(&converter.getContext(), 8)); + }); + } + + LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::SDDMMOp op, OpAdaptor adaptor, ConversionPatternRewriter & rewriter) + const { + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || + failed(isAsyncWithOneDependency(rewriter, op))) + return failure(); + Location loc = op.getLoc(); + auto computeType = genConstInt32FromOptionalComputeMode( + rewriter, loc, adaptor.getComputeType(), + getSpMatElemType(op.getSpmatC())); + auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); + auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); + auto stream = adaptor.getAsyncDependencies().front(); + Value pBuf = + MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); + if (!getTypeConverter()->useOpaquePointers()) + pBuf = rewriter.create(loc, llvmPointerType, pBuf); + SDDMMCallBuilder.create(loc, rewriter, + {adaptor.getEnv(), modeA, modeB, + adaptor.getDnmatA(), adaptor.getDnmatB(), + adaptor.getSpmatC(), computeType, pBuf, stream}); + rewriter.replaceOp(op, {stream}); + return success(); + } + + void mlir::populateGpuToLLVMConversionPatterns( + LLVMTypeConverter & converter, RewritePatternSet & patterns, + StringRef gpuBinaryAnnotation, bool kernelBarePtrCallConv) { + addOpaquePointerConversion(converter); + addOpaquePointerConversion(converter); + addOpaquePointerConversion(converter); + addOpaquePointerConversion(converter); + addOpaquePointerConversion(converter); + + patterns.add(converter); + patterns.add( + converter, gpuBinaryAnnotation, kernelBarePtrCallConv); + patterns.add(&converter.getContext()); + } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -556,12 +556,12 @@ rowA, colA, valA, isCOO, enableRT); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); - auto dmatB = rewriter.create(loc, dnMatHandleTp, tokenTp, - token, szk, szn, matB); + auto dmatB = rewriter.create( + loc, dnMatHandleTp, tokenTp, token, handle, szk, szn, matB); Value dnB = dmatB.getResult(0); token = dmatB.getAsyncToken(); - auto dmatC = rewriter.create(loc, dnMatHandleTp, tokenTp, - token, szm, szn, matC); + auto dmatC = rewriter.create( + loc, dnMatHandleTp, tokenTp, token, handle, szm, szn, matC); Value dnC = dmatC.getResult(0); token = dmatC.getAsyncToken(); diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -200,15 +200,36 @@ EXCLUDE_FROM_LIBMLIR ) set_property(TARGET mlir_cuda_runtime PROPERTY CXX_STANDARD 14) - target_include_directories(mlir_cuda_runtime - PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} - ) - target_link_libraries(mlir_cuda_runtime - PRIVATE - ${CUDA_RUNTIME_LIBRARY} - ${CUDA_CUSPARSE_LIBRARY} - ) + + + # We need the cusparseLT to provide 2:4 sparsity support. + # As of the pre-1.0 version, we suppose the cusparselt is downloaded as an + # archive and extracted in an exclusive directory CUDA_CUSPARSELT_DIR, rather + # than installed by the package manager. This is the same as Nvidia examples. + if (DEFINED CUDA_CUSPARSELT_DIR) + target_include_directories(mlir_cuda_runtime + PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ${CUDA_CUSPARSELT_DIR}/include + ) + target_link_libraries(mlir_cuda_runtime + PRIVATE + ${CUDA_RUNTIME_LIBRARY} + ${CUDA_CUSPARSE_LIBRARY} + ${CUDA_CUSPARSELT_DIR}/lib64 + ) + else() + target_include_directories(mlir_cuda_runtime + PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ) + target_link_libraries(mlir_cuda_runtime + PRIVATE + ${CUDA_RUNTIME_LIBRARY} + ${CUDA_CUSPARSE_LIBRARY} + ) + endif() + add_definitions(-DMLIR_CUDA_CUSPARSELT_ENABLED=(defined(CUDA_CUSPARSELT_DIR))) endif() if(MLIR_ENABLE_ROCM_RUNNER) diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -21,6 +21,10 @@ #include "cuda_fp16.h" #include "cusparse.h" +#if MLIR_CUDA_CUSPARSELT_ENABLED +#include "cusparseLt.h" +#endif // MLIR_CUDA_CUSPARSELT_ENABLED + #ifdef _WIN32 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) #else @@ -432,3 +436,148 @@ matB, betap, matC, cTp, CUSPARSE_SDDMM_ALG_DEFAULT, buf)) } + +/// +/// Wrapper methods for the cuSparseLt library. +/// +#if MLIR_CUDA_CUSPARSELT_ENABLED +struct cusparseLtSpMatHandleAndData { + cusparseLtMatDescriptor_t mat; + void *values{nullptr}; + // TODO: the following is associated with the SpMM operator rather than the + // sparse matrix. Create workspace buffers and pass them to the SpMM + // execution. + cusparseLtMatmulAlgSelection_t alg_sel; + cusparseLtMatmulPlan_t plan; + cusparseLtMatmulDescriptor_t matmul; +}; +struct cusparseLtDnMatHandleAndData { + cusparseLtMatDescriptor_t mat; + void *values{nullptr}; +}; +struct cusparseLtWorkspaceSizes { + size_t workspace_size; + size_t compressed_size; + size_t compressed_buffer_size; +}; + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * +mgpuCreateSparseLtEnv(CUstream /*stream*/) { + cusparseLtHandle_t handle = nullptr; + // note that cuSparseLt still uses cusparseStatus_t + CUSPARSE_REPORT_IF_ERROR(cusparseLtInit(&handle)) + return reinterpret_cast(handle); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +mgpuDestroySparseLtEnv(void *h, CUstream /*stream*/) { + cusparseLtHandle_t handle = reinterpret_cast(h); + CUSPARSE_REPORT_IF_ERROR(cusparseLtDestroy(handle)) +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * +mgpuCreateCuSparseLtDnMat(void *h, intptr_t rows, intptr_t cols, void *values, + int32_t dw, CUstream /*stream*/) { + cusparseLtMatDescriptor_t mat; + auto handle = reinterpret_cast(h); + cudaDataType_t dtp = dataTp(dw); + // assuming row-major when deciding lda + CUSPARSE_REPORT_IF_ERROR( + cusparseLtDenseDescriptorInit(handle, &mat, rows, cols, /*lda=*/cols, + /*alignment=*/16, dtp, CUSPARSE_ORDER_ROW)) + cusparseLtDnMatHandleAndData matWithData{ + .mat = mat, + .values = values, + }; + return reinterpret_cast(matWithData); +} + +// This can be used to destroy both dense matrices and sparse matrices in +// cusparseLt +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +mgpuDestroyCuSparseLtSpMat(void *m, CUstream /*stream*/) { + auto matAndData = reinterpret_cast(m); + CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(mat->mat))) + // destroy the plan associated with the sparse matrix + CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(mat->plan))) +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +mgpuDestroyCuSparseLtDnMat(void *m, CUstream /*stream*/) { + auto matAndData = reinterpret_cast(m); + CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(mat->mat))) +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * +mgpuCusparseLtCreate2To4SpMat(void *h, intptr_t rows, intptr_t cols, + void *values, int32_t dw, CUstream /*stream*/) { + cusparseLtSpMatHandleAndData matWithData; + matWithData.values = values; + auto handle = reinterpret_cast(h); + cudaDataType_t dtp = dataTp_cusparseLt(dw); + // assuming row-major when deciding lda + CUSPARSE_REPORT_IF_ERROR(cusparseLtStructuredDescriptorInit( + handle, &(matWithData.mat), rows, cols, /*ld=*/cols, /*alignment=*/16, + dtp, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT)) + + return reinterpret_cast(matWithData); +} + +// Several things are being done in this stage, algorithm selection, planning, +// and returning workspace and compressed matrices data buffer sizes. +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * +mgpuCuSparseLtSpMMBufferSize(void *h, void *a, CUstream /*stream*/) { + // TODO: support more advanced settings, e.g., the input right operand is a + // sparse matrix assuming matA is the sparse matrix + auto handle = reinterpret_cast(h); + auto matA = reinterpret_cast(a); + + CHECK_CUSPARSE(cusparseLtMatmulAlgSelectionInit( + handle, &(matWithData.alg_sel), &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) + int alg = 0; + CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( + handle, &(matWithData.alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, + sizeof(alg))) + // TODO: add transpose support + CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( + handle, &(matA.matmul), c, CUSPARSE_OPERATION_NON_TRANSPOSE, &(matA->mat), + &matB, &matC, &matC, compute_type)) + CHECK_CUSPARSE(cusparseLtMatmulPlanInit(handle, &(matWithData.plan), &matmul, + &(matWithData.alg_sel))) + + CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(handle, &(matA.plan), + &(sizes.workspace_size))) + CHECK_CUSPARSE(cusparseLtSpMMACompressedSize(handle, &(matA.plan), + &(sizes.compressed_size), + &(sizes.compressed_buffer_size))) + // avoid zero-alloc + sizes.workspace_size = (sizes.workspace_size == 0 ? 1 : sizes.workspace_size); + sizes.compressed_size = + (sizes.compressed_size == 0 ? 1 : sizes.compressed_size); + sizes.compressed_buffer_size = + (sizes.compressed_buffer_size == 0 ? 1 : sizes.compressed_buffer_size); + return reinterpret_cast(sizes); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +mgpuCuSparseLtSpMM(void *h, void *a, void *b, void *c, int32_t dw, void *buf, + void *dA_compressed, void *dA_compressedBuffer, + CUstream stream) { + auto handle = reinterpret_cast(h); + auto matA = reinterpret_cast(a); + auto matB = reinterpret_cast(b); + auto matC = reinterpret_cast(c); + ALPHABETA(dw, alpha, beta) + + CHECK_CUSPARSE(cusparseLtSpMMACompress(handle, &(matA->plan), &(matA->values), + dA_compressed, dA_compressedBuffer, + stream)) + + // TODO: add support to multi-stream execution + // Perform the matrix multiplication. D = A*B+C using C==D for now + CHECK_CUSPARSE(cusparseLtMatmul(handle, &(matA->plan), &alpha, dA_compressed, + dB, &beta, matC->values, /*dD*/ matC->values, + d_workspace, &stream, 1)) +} + +#endif // MLIR_CUDA_CUSPARSELT_ENABLED \ No newline at end of file diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -28,6 +28,7 @@ option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.") option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.") option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.") + option(MLIR_RUN_CUDA_SM80_SPARSE_TESTS "run CUDA sparse 2to4 tests") option(MLIR_RUN_ARM_SVE_TESTS "Run Arm SVE tests.") option(MLIR_RUN_ARM_SME_TESTS "Run Arm SME tests.") @@ -55,6 +56,7 @@ MLIR_RUN_ARM_SVE_TESTS MLIR_RUN_ARM_SME_TESTS MLIR_RUN_CUDA_SM80_TESTS + MLIR_RUN_CUDA_SM80_SPARSE_TESTS ) configure_lit_site_cfg( diff --git a/mlir/test/Conversion/GPUCommon/2To4Sparsity/lit.local.cfg b/mlir/test/Conversion/GPUCommon/2To4Sparsity/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/2To4Sparsity/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests or not mlir_run_cuda_sm80_sparse_tests: + config.unsupported = True diff --git a/mlir/test/Conversion/GPUCommon/2To4Sparsity/lower-2to4-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/2To4Sparsity/lower-2to4-sparse-to-gpu-runtime-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/2To4Sparsity/lower-2to4-sparse-to-gpu-runtime-calls.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s --gpu-to-llvm='use-opaque-pointers=1' | FileCheck %s + +module attributes {gpu.container_module} { + + // CHECK-LABEL: func @matmul + // CHECK: llvm.call @mgpuStreamCreate + // CHECK: llvm.call @mgpuMemAlloc + // CHECK: llvm.call @mgpuMemAlloc + // CHECK: llvm.call @mgpuCreateSparseLtEnv + // CHECK: llvm.call @mgpuDestroyCuSparseLtSpMat + // CHECK: llvm.call @mgpuCreateCuSparseLtDnMat + // CHECK: llvm.call @mgpuCuSparseLtSpMMBufferSize + // CHECK: llvm.call @mgpuCuSparseLtSpMM + // CHECK: llvm.call @mgpuDestroyCuSparseLtSpMat + // CHECK: llvm.call @mgpuDestroyCuSparseLtDnMat + // CHECK: llvm.call @mgpuDestroySparseLtEnv + // CHECK: llvm.call @mgpuStreamSynchronize + // CHECK: llvm.call @mgpuStreamDestroy + func.func @matmul(%arg0: index) { + %token0 = gpu.wait async + %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref + %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref + %env, %token3 = gpu.create_sparse_env async [%token2] + %spmat, %token4 = gpu.create_2to4_spmat async [%token3] %env, %arg0, %arg0, %mem1: memref + %dnmat, %token5 = gpu.create_dn_mat async [%token4] %env, %arg0, %arg0, %mem2 : memref + %bufferSzs, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat : index,index,index + %token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2, %mem2, %mem2 : memref,memref,memref + %token8 = gpu.destroy_sp_mat async [%token7] %spmat + %token9 = gpu.destroy_dn_mat async [%token8] %dnmat + %token10 = gpu.destroy_sparse_env async [%token9] %env + gpu.wait [%token10] + return + } + +} diff --git a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir @@ -52,8 +52,8 @@ %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref %env, %token3 = gpu.create_sparse_env async [%token2] %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref, memref, memref - %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref - %bufferSz, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat + %dnmat, %token5 = gpu.create_dn_mat async [%token4] %env, %arg0, %arg0, %mem2 : memref + %bufferSz, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat : index %token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2 : memref %token8 = gpu.destroy_sp_mat async [%token7] %spmat %token9 = gpu.destroy_dn_mat async [%token8] %dnmat @@ -82,7 +82,7 @@ %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref %env, %token3 = gpu.create_sparse_env async [%token2] %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref, memref, memref - %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref + %dnmat, %token5 = gpu.create_dn_mat async [%token4] %env, %arg0, %arg0, %mem2 : memref %bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat %token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref %token8 = gpu.destroy_sp_mat async [%token7] %spmat diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -339,9 +339,9 @@ // CHECK: gpu.spmv async %token8 = gpu.spmv async [%token7] %env, %spmat, %dnvec, %dnvec, %mem2 : memref // CHECK: gpu.create_dn_mat async - %dnmat, %token9 = gpu.create_dn_mat async [%token8] %arg0, %arg0, %mem2 : memref + %dnmat, %token9 = gpu.create_dn_mat async [%token8] %env, %arg0, %arg0, %mem2 : memref // CHECK: gpu.spmm_buffer_size async - %bufferSz2, %token10 = gpu.spmm_buffer_size async [%token9] %env, %spmat, %dnmat, %dnmat + %bufferSz2, %token10 = gpu.spmm_buffer_size async [%token9] %env, %spmat, %dnmat, %dnmat : index // CHECK: gpu.spmm async %token11 = gpu.spmm async [%token10] %env, %spmat, %dnmat, %dnmat, %mem2 : memref // CHECK: gpu.sddmm_buffer_size async diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir --- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir @@ -47,9 +47,9 @@ // CHECK: %[[VAL_41:.*]] = gpu.wait async // CHECK: %[[VAL_42:.*]], %[[VAL_43:.*]] = gpu.create_sparse_env async {{\[}}%[[VAL_41]]] // CHECK: %[[VAL_44:.*]], %[[VAL_45:.*]] = gpu.create_csr async {{\[}}%[[VAL_43]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_5]], %[[VAL_14]], %[[VAL_19]], %[[VAL_24]] : memref, memref, memref -// CHECK: %[[VAL_46:.*]], %[[VAL_47:.*]] = gpu.create_dn_mat async {{\[}}%[[VAL_45]]] %[[VAL_7]], %[[VAL_8]], %[[VAL_31]] : memref -// CHECK: %[[VAL_48:.*]], %[[VAL_49:.*]] = gpu.create_dn_mat async {{\[}}%[[VAL_47]]] %[[VAL_6]], %[[VAL_8]], %[[VAL_38]] : memref -// CHECK: %[[VAL_50:.*]], %[[VAL_51:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_49]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]], %[[VAL_48]] +// CHECK: %[[VAL_46:.*]], %[[VAL_47:.*]] = gpu.create_dn_mat async {{\[}}%[[VAL_45]]] %[[VAL_42]], %[[VAL_7]], %[[VAL_8]], %[[VAL_31]] : memref +// CHECK: %[[VAL_48:.*]], %[[VAL_49:.*]] = gpu.create_dn_mat async {{\[}}%[[VAL_47]]] %[[VAL_42]], %[[VAL_6]], %[[VAL_8]], %[[VAL_38]] : memref +// CHECK: %[[VAL_50:.*]], %[[VAL_51:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_49]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]], %[[VAL_48]] : index // CHECK: %[[VAL_52:.*]], %[[VAL_53:.*]] = gpu.alloc async {{\[}}%[[VAL_51]]] (%[[VAL_50]]) : memref // CHECK: %[[VAL_54:.*]] = gpu.spmm async {{\[}}%[[VAL_53]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]], %[[VAL_48]], %[[VAL_52]] : memref // CHECK: %[[VAL_55:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_54]]] %[[VAL_44]] diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -41,6 +41,7 @@ config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@" config.mlir_run_cuda_tensor_core_tests = @MLIR_RUN_CUDA_TENSOR_CORE_TESTS@ config.mlir_run_cuda_sm80_tests = @MLIR_RUN_CUDA_SM80_TESTS@ +config.mlir_run_cuda_sm80_sparse_tests = @MLIR_RUN_CUDA_SM80_SPARSE_TESTS@ config.mlir_include_integration_tests = @MLIR_INCLUDE_INTEGRATION_TESTS@ config.arm_emulator_executable = "@ARM_EMULATOR_EXECUTABLE@" config.arm_emulator_options = "@ARM_EMULATOR_OPTIONS@" diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -49,6 +49,7 @@ "@MLIR_RUN_X86VECTOR_TESTS@": "0", "@MLIR_RUN_CUDA_TENSOR_CORE_TESTS@": "0", "@MLIR_RUN_CUDA_SM80_TESTS@": "0", + "@MLIR_RUN_CUDA_SM80_SPARSE_TESTS@": "0", "@MLIR_INCLUDE_INTEGRATION_TESTS@": "0", "@SHLIBDIR@": package_path("//llvm:BUILD"), },