diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -109,10 +109,30 @@ "::llvm::cast<::mlir::gpu::MMAMatrixType>($_self).getElementType()", "gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">; -// Generic type for all sparse handles (could be refined). -def GPU_SparseHandle : DialectType< - GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::SparseHandleType>()">, "sparse handle type">, - BuildableType<"mlir::gpu::SparseHandleType::get($_builder.getContext())">; +// Types for all sparse handles. +def GPU_SparseEnvHandle : + DialectType()">, + "sparse environment handle type">, + BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">; + +def GPU_SparseDnVecHandle : + DialectType()">, + "sparse dense vector handle type">, + BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">; + +def GPU_SparseDnMatHandle : + DialectType()">, + "sparse dense matrix handle type">, + BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">; + +def GPU_SparseSpMatHandle : + DialectType()">, + "sparse matrix handle type in sparse">, + BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">; //===----------------------------------------------------------------------===// // GPU Interfaces. diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -163,14 +163,23 @@ // Adds a `gpu.async.token` to the front of the argument list. void addAsyncDependency(Operation *op, Value token); -// Represents any sparse handle. +// Handle types for sparse. +enum class SparseHandleKind { Env, DnVec, DnMat, SpMat }; + +template class SparseHandleType - : public Type::TypeBase { + : public Type::TypeBase, Type, TypeStorage> { public: - // Used for generic hooks in TypeBase. + using Base = + typename Type::TypeBase, Type, TypeStorage>::Base; using Base::Base; }; +using SparseEnvHandleType = SparseHandleType; +using SparseDnVecHandleType = SparseHandleType; +using SparseDnMatHandleType = SparseHandleType; +using SparseSpMatHandleType = SparseHandleType; + } // namespace gpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1557,14 +1557,16 @@ }]; let arguments = (ins Variadic:$asyncDependencies); - let results = (outs Res:$env, Optional:$asyncToken); - + let results = (outs Res:$env, + Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) attr-dict }]; } -def GPU_DestroySparseEnvOp : GPU_Op<"destroy_sparse_env", [GPU_AsyncOpInterface]> { +def GPU_DestroySparseEnvOp : GPU_Op< + "destroy_sparse_env", + [GPU_AsyncOpInterface]> { let summary = "Destroy sparse environment operation"; let description = [{ The `gpu.destroy_sparse_env` operation releases all resources of a sparse @@ -1583,11 +1585,12 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Arg:$env); + Arg:$env); let results = (outs Optional:$asyncToken); let assemblyFormat = [{ - custom(type($asyncToken), $asyncDependencies) $env attr-dict + custom(type($asyncToken), $asyncDependencies) + $env attr-dict }]; } @@ -1612,7 +1615,8 @@ let arguments = (ins Variadic:$asyncDependencies, AnyMemRef:$memref, Index:$size); - let results = (outs Res:$dvec, Optional:$asyncToken); + let results = (outs Res:$dvec, + Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) @@ -1639,11 +1643,12 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Arg:$dvec); + Arg:$dvec); let results = (outs Optional:$asyncToken); let assemblyFormat = [{ - custom(type($asyncToken), $asyncDependencies) $dvec attr-dict + custom(type($asyncToken), $asyncDependencies) + $dvec attr-dict }]; } @@ -1667,10 +1672,10 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Index:$rows, - Index:$cols, - AnyMemRef:$memref); - let results = (outs Res:$dmat, Optional:$asyncToken); + Index:$rows, + Index:$cols, + AnyMemRef:$memref); + let results = (outs Res:$dmat, Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) @@ -1697,11 +1702,12 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Arg:$dmat); + Arg:$dmat); let results = (outs Optional:$asyncToken); let assemblyFormat = [{ - custom(type($asyncToken), $asyncDependencies) $dmat attr-dict + custom(type($asyncToken), $asyncDependencies) + $dmat attr-dict }]; } @@ -1732,7 +1738,7 @@ AnyMemRef:$rowIdxs, AnyMemRef:$colIdxs, AnyMemRef:$values); - let results = (outs Res:$spmat, Optional:$asyncToken); + let results = (outs Res:$spmat, Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) @@ -1769,7 +1775,7 @@ AnyMemRef:$rowPos, AnyMemRef:$colIdxs, AnyMemRef:$values); - let results = (outs Res:$spmat, Optional:$asyncToken); + let results = (outs Res:$spmat, Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) @@ -1797,7 +1803,7 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Arg:$spmat); + Arg:$spmat); let results = (outs Optional:$asyncToken); let assemblyFormat = [{ @@ -1823,12 +1829,11 @@ %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA, %dnX, %dnY ``` }]; - let arguments = (ins Variadic:$asyncDependencies, - GPU_SparseHandle:$env, - GPU_SparseHandle:$spmatA, - GPU_SparseHandle:$dnX, - GPU_SparseHandle:$dnY); + GPU_SparseEnvHandle:$env, + GPU_SparseSpMatHandle:$spmatA, + GPU_SparseDnVecHandle:$dnX, + GPU_SparseDnVecHandle:$dnY); let results = (outs Res:$bufferSz, Optional:$asyncToken); let assemblyFormat = [{ @@ -1855,12 +1860,11 @@ %token = gpu.spmv async [%dep] %env, %spmatA, %dnX, %dnY : memref ``` }]; - let arguments = (ins Variadic:$asyncDependencies, - GPU_SparseHandle:$env, - GPU_SparseHandle:$spmatA, - GPU_SparseHandle:$dnX, - GPU_SparseHandle:$dnY, + GPU_SparseEnvHandle:$env, + GPU_SparseSpMatHandle:$spmatA, + GPU_SparseDnVecHandle:$dnX, + GPU_SparseDnVecHandle:$dnY, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); @@ -1890,10 +1894,10 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - GPU_SparseHandle:$env, - GPU_SparseHandle:$spmatA, - GPU_SparseHandle:$dnmatB, - GPU_SparseHandle:$dnmatC); + GPU_SparseEnvHandle:$env, + GPU_SparseSpMatHandle:$spmatA, + GPU_SparseDnMatHandle:$dnmatB, + GPU_SparseDnMatHandle:$dnmatC); let results = (outs Res:$bufferSz, Optional:$asyncToken); let assemblyFormat = [{ @@ -1922,10 +1926,10 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - GPU_SparseHandle:$env, - GPU_SparseHandle:$spmatA, - GPU_SparseHandle:$dnmatB, - GPU_SparseHandle:$dnmatC, + GPU_SparseEnvHandle:$env, + GPU_SparseSpMatHandle:$spmatA, + GPU_SparseDnMatHandle:$dnmatB, + GPU_SparseDnMatHandle:$dnmatC, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -1478,7 +1478,22 @@ return converter.getPointerType( IntegerType::get(&converter.getContext(), 8)); }); - converter.addConversion([&converter](gpu::SparseHandleType type) -> Type { + converter.addConversion( + [&converter](gpu::SparseDnVecHandleType type) -> Type { + return converter.getPointerType( + IntegerType::get(&converter.getContext(), 8)); + }); + converter.addConversion( + [&converter](gpu::SparseDnMatHandleType type) -> Type { + return converter.getPointerType( + IntegerType::get(&converter.getContext(), 8)); + }); + converter.addConversion( + [&converter](gpu::SparseSpMatHandleType type) -> Type { + return converter.getPointerType( + IntegerType::get(&converter.getContext(), 8)); + }); + converter.addConversion([&converter](gpu::SparseEnvHandleType type) -> Type { return converter.getPointerType( IntegerType::get(&converter.getContext(), 8)); }); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -146,7 +146,10 @@ void GPUDialect::initialize() { addTypes(); addTypes(); - addTypes(); + addTypes(); + addTypes(); + addTypes(); + addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" @@ -201,17 +204,27 @@ shape, elementType, operand); } - if (keyword == "sparse.handle") - return SparseHandleType::get(context); + if (keyword == "sparse.env_handle") + return SparseEnvHandleType::get(context); + if (keyword == "sparse.spmat_handle") + return SparseSpMatHandleType::get(context); + if (keyword == "sparse.dnvec_handle") + return SparseDnVecHandleType::get(context); + if (keyword == "sparse.dnmat_handle") + return SparseDnMatHandleType::get(context); parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); return Type(); } - +// TODO: print refined type here. Notice that should be corresponding to the +// parser void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "async.token"; }) - .Case([&](Type) { os << "sparse.handle"; }) + .Case([&](Type) { os << "sparse.env_handle"; }) + .Case([&](Type) { os << "sparse.spmat_handle"; }) + .Case([&](Type) { os << "sparse.dnmat_handle"; }) + .Case([&](Type) { os << "sparse.dnvec_handle"; }) .Case([&](MMAMatrixType fragTy) { os << "mma_matrix<"; auto shape = fragTy.getShape(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -436,22 +436,25 @@ // Create sparse environment and sparse matrix/dense vector handles. Type indexTp = rewriter.getIndexType(); - Type handleTp = rewriter.getType(); + Type envHandleTp = rewriter.getType(); + Type dnVecHandleTp = rewriter.getType(); + Type spmatHandleTp = rewriter.getType(); Type tokenTp = rewriter.getType(); Value token = genFirstWait(rewriter, loc); auto env = - rewriter.create(loc, handleTp, tokenTp, token); + rewriter.create(loc, envHandleTp, tokenTp, token); Value handle = env.getResult(0); token = env.getAsyncToken(); - Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY, - szX, nseA, rowA, colA, valA, isCOO, enableRT); + Operation *spGenA = + genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nnzA, + rowA, colA, valA, isCOO, enableRT); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); - auto dvecX = rewriter.create(loc, handleTp, tokenTp, + auto dvecX = rewriter.create(loc, dnVecHandleTp, tokenTp, token, vecX, szX); Value dnX = dvecX.getResult(0); token = dvecX.getAsyncToken(); - auto dvecY = rewriter.create(loc, handleTp, tokenTp, + auto dvecY = rewriter.create(loc, dnVecHandleTp, tokenTp, token, vecY, szY); Value dnY = dvecY.getResult(0); token = dvecY.getAsyncToken(); @@ -540,22 +543,24 @@ // Create sparse environment and sparse matrix/dense matrix handles. Type indexTp = rewriter.getIndexType(); - Type handleTp = rewriter.getType(); + Type envHandleTp = rewriter.getType(); + Type dnMatHandleTp = rewriter.getType(); + Type spMatHandleTp = rewriter.getType(); Type tokenTp = rewriter.getType(); Value token = genFirstWait(rewriter, loc); auto env = - rewriter.create(loc, handleTp, tokenTp, token); + rewriter.create(loc, envHandleTp, tokenTp, token); Value handle = env.getResult(0); token = env.getAsyncToken(); - Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szm, + Operation *spGenA = genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA, rowA, colA, valA, isCOO, enableRT); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); - auto dmatB = rewriter.create(loc, handleTp, tokenTp, + auto dmatB = rewriter.create(loc, dnMatHandleTp, tokenTp, token, szk, szn, matB); Value dnB = dmatB.getResult(0); token = dmatB.getAsyncToken(); - auto dmatC = rewriter.create(loc, handleTp, tokenTp, + auto dmatC = rewriter.create(loc, dnMatHandleTp, tokenTp, token, szm, szn, matC); Value dnC = dmatC.getResult(0); token = dmatC.getAsyncToken();