diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1902,6 +1902,8 @@ }]; } +// TODO: add GPU_SDDMMOp + def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> { let summary = "SpMM operation"; let description = [{ diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -596,6 +596,8 @@ ConversionPatternRewriter &rewriter) const override; }; +// TODO: impl SDDMM Op lowering pass here + class ConvertSpMMOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -500,6 +500,8 @@ return success(); } +// TODO: implement SDDMM rewriter here + /// Match and rewrite SpMM kernel. static LogicalResult rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT) {