diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -529,175 +529,8 @@ ConversionPatternRewriter &rewriter) const override; }; -class ConvertCreateDnTensorOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertCreateDnTensorOpToGpuRuntimeCallPattern( - LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) { - } - -private: - LogicalResult - matchAndRewrite(gpu::CreateDnTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertDestroyDnTensorOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertDestroyDnTensorOpToGpuRuntimeCallPattern( - LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern( - typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::DestroyDnTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertCreateCooOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertCreateCooOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::CreateCooOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertCreateCooAoSOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertCreateCooAoSOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::CreateCooAoSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertCreateCsrOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertCreateCsrOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::CreateCsrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern( - LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern( - typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::Create2To4SpMatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertDestroySpMatOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertDestroySpMatOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::DestroySpMatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern( - LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) { - } - -private: - LogicalResult - matchAndRewrite(gpu::SpMVBufferSizeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertSpMVOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertSpMVOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::SpMVOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern( - LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) { - } - -private: - LogicalResult - matchAndRewrite(gpu::SpMMBufferSizeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern( - LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern( - typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertSpMMOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertSpMMOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::SpMMOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class ConvertSDDMMOpToGpuRuntimeCallPattern - : public ConvertOpToGpuRuntimeCallPattern { -public: - ConvertSDDMMOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) - : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} - -private: - LogicalResult - matchAndRewrite(gpu::SDDMMOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -// TODO: Apply this pattern to all GPU ops. +/// Generic rewriting rule for operation on sparse matrices. +/// Currently supports CUDA (by means of cuSparse and cuSparseLt). #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \ class Convert##op_name##ToGpuRuntimeCallPattern \ : public ConvertOpToGpuRuntimeCallPattern { \ @@ -712,12 +545,25 @@ ConversionPatternRewriter &rewriter) const override; \ }; +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp) DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp) DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp) DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp) DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMEstimateMemoryOp) -DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMGetSizeOp) DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp) +DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMGetSizeOp) } // namespace @@ -2054,18 +1900,18 @@ ConvertCreateCsrOpToGpuRuntimeCallPattern, ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern, ConvertDestroySpMatOpToGpuRuntimeCallPattern, + ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern, + ConvertSpMVOpToGpuRuntimeCallPattern, + ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern, + ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern, + ConvertSpMMOpToGpuRuntimeCallPattern, + ConvertSDDMMOpToGpuRuntimeCallPattern, ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern, ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern, ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern, ConvertSpGEMMEstimateMemoryOpToGpuRuntimeCallPattern, - ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern, ConvertSpGEMMCopyOpToGpuRuntimeCallPattern, - ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern, - ConvertSpMVOpToGpuRuntimeCallPattern, - ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern, - ConvertSpMMOpToGpuRuntimeCallPattern, - ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern, - ConvertSDDMMOpToGpuRuntimeCallPattern>(converter); + ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern>(converter); patterns.add( converter, gpuBinaryAnnotation, kernelBarePtrCallConv); patterns.add(&converter.getContext());