diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -109,10 +109,19 @@ "::llvm::cast<::mlir::gpu::MMAMatrixType>($_self).getElementType()", "gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">; -// Generic type for all sparse handles (could be refined). -def GPU_SparseHandle : DialectType< - GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::SparseHandleType>()">, "sparse handle type">, - BuildableType<"mlir::gpu::SparseHandleType::get($_builder.getContext())">; +// Types for all sparse handles. +def GPU_SparseEnvHandle : DialectType< + GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">, "sparse environment handle type">, + BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">; +def GPU_SparseDnVecHandle : DialectType< + GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">, "sparse dense vector handle type">, + BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">; +def GPU_SparseDnMatHandle : DialectType< + GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">, "sparse dense matrix handle type">, + BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">; +def GPU_SparseSpMatHandle : DialectType< + GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">, "sparse matrix handle type in sparse">, + BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">; //===----------------------------------------------------------------------===// // GPU Interfaces. diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -163,14 +163,40 @@ // Adds a `gpu.async.token` to the front of the argument list. void addAsyncDependency(Operation *op, Value token); -// Represents any sparse handle. -class SparseHandleType - : public Type::TypeBase { +// Handle types for sparse. +enum class SparseHandleKind { Env, DnVec, DnMat, SpMat }; + +template +class SparseHandleConcreteType + : public Type::TypeBase, Type, TypeStorage> { public: - // Used for generic hooks in TypeBase. - using Base::Base; + using Type::TypeBase, Type, + TypeStorage>::Base::Base; + SparseHandleKind getSparseKind() const { return kind; } + + SparseHandleConcreteType(SparseHandleKind kind) : kind(kind) {} + SparseHandleConcreteType() = default; + template + SparseHandleConcreteType(SparseHandleConcreteType other) + : SparseHandleConcreteType(K) {} + static bool classof(Type S) { + SparseHandleConcreteType SHCT; + if (S.getTypeID() != SHCT.getTypeID()) + return false; + return true; + auto SS = (SparseHandleConcreteType *)&S; + return SS->getSparseHandleKind() == K; + } + +private: + const SparseHandleKind kind = K; }; +using SparseEnvHandleType = SparseHandleConcreteType; +using SparseDnVecHandleType = SparseHandleConcreteType; +using SparseDnMatHandleType = SparseHandleConcreteType; +using SparseSpMatHandleType = SparseHandleConcreteType; + } // namespace gpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1557,8 +1557,7 @@ }]; let arguments = (ins Variadic:$asyncDependencies); - let results = (outs Res:$env, Optional:$asyncToken); - + let results = (outs Res:$env, Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) attr-dict }]; @@ -1583,7 +1582,7 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Arg:$env); + Arg:$env); let results = (outs Optional:$asyncToken); let assemblyFormat = [{ @@ -1612,7 +1611,7 @@ let arguments = (ins Variadic:$asyncDependencies, AnyMemRef:$memref, Index:$size); - let results = (outs Res:$dvec, Optional:$asyncToken); + let results = (outs Res:$dvec, Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) @@ -1639,7 +1638,7 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Arg:$dvec); + Arg:$dvec); let results = (outs Optional:$asyncToken); let assemblyFormat = [{ @@ -1670,7 +1669,7 @@ Index:$rows, Index:$cols, AnyMemRef:$memref); - let results = (outs Res:$dmat, Optional:$asyncToken); + let results = (outs Res:$dmat, Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) @@ -1697,7 +1696,7 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Arg:$dmat); + Arg:$dmat); let results = (outs Optional:$asyncToken); let assemblyFormat = [{ @@ -1732,7 +1731,7 @@ AnyMemRef:$rowIdxs, AnyMemRef:$colIdxs, AnyMemRef:$values); - let results = (outs Res:$spmat, Optional:$asyncToken); + let results = (outs Res:$spmat, Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) @@ -1769,7 +1768,7 @@ AnyMemRef:$rowPos, AnyMemRef:$colIdxs, AnyMemRef:$values); - let results = (outs Res:$spmat, Optional:$asyncToken); + let results = (outs Res:$spmat, Optional:$asyncToken); let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) @@ -1797,7 +1796,7 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - Arg:$spmat); + Arg:$spmat); let results = (outs Optional:$asyncToken); let assemblyFormat = [{ @@ -1823,12 +1822,11 @@ %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA, %dnX, %dnY ``` }]; - let arguments = (ins Variadic:$asyncDependencies, - GPU_SparseHandle:$env, - GPU_SparseHandle:$spmatA, - GPU_SparseHandle:$dnX, - GPU_SparseHandle:$dnY); + GPU_SparseEnvHandle:$env, + GPU_SparseSpMatHandle:$spmatA, + GPU_SparseDnVecHandle:$dnX, + GPU_SparseDnVecHandle:$dnY); let results = (outs Res:$bufferSz, Optional:$asyncToken); let assemblyFormat = [{ @@ -1855,12 +1853,11 @@ %token = gpu.spmv async [%dep] %env, %spmatA, %dnX, %dnY : memref ``` }]; - let arguments = (ins Variadic:$asyncDependencies, - GPU_SparseHandle:$env, - GPU_SparseHandle:$spmatA, - GPU_SparseHandle:$dnX, - GPU_SparseHandle:$dnY, + GPU_SparseEnvHandle:$env, + GPU_SparseSpMatHandle:$spmatA, + GPU_SparseDnVecHandle:$dnX, + GPU_SparseDnVecHandle:$dnY, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); @@ -1890,10 +1887,10 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - GPU_SparseHandle:$env, - GPU_SparseHandle:$spmatA, - GPU_SparseHandle:$dnmatB, - GPU_SparseHandle:$dnmatC); + GPU_SparseEnvHandle:$env, + GPU_SparseSpMatHandle:$spmatA, + GPU_SparseDnMatHandle:$dnmatB, + GPU_SparseDnMatHandle:$dnmatC); let results = (outs Res:$bufferSz, Optional:$asyncToken); let assemblyFormat = [{ @@ -1922,10 +1919,10 @@ }]; let arguments = (ins Variadic:$asyncDependencies, - GPU_SparseHandle:$env, - GPU_SparseHandle:$spmatA, - GPU_SparseHandle:$dnmatB, - GPU_SparseHandle:$dnmatC, + GPU_SparseEnvHandle:$env, + GPU_SparseSpMatHandle:$spmatA, + GPU_SparseDnMatHandle:$dnmatB, + GPU_SparseDnMatHandle:$dnmatC, AnyMemRef:$buffer); let results = (outs Optional:$asyncToken); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -1478,7 +1478,22 @@ return converter.getPointerType( IntegerType::get(&converter.getContext(), 8)); }); - converter.addConversion([&converter](gpu::SparseHandleType type) -> Type { + converter.addConversion( + [&converter](gpu::SparseDnVecHandleType type) -> Type { + return converter.getPointerType( + IntegerType::get(&converter.getContext(), 8)); + }); + converter.addConversion( + [&converter](gpu::SparseDnMatHandleType type) -> Type { + return converter.getPointerType( + IntegerType::get(&converter.getContext(), 8)); + }); + converter.addConversion( + [&converter](gpu::SparseSpMatHandleType type) -> Type { + return converter.getPointerType( + IntegerType::get(&converter.getContext(), 8)); + }); + converter.addConversion([&converter](gpu::SparseEnvHandleType type) -> Type { return converter.getPointerType( IntegerType::get(&converter.getContext(), 8)); }); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -146,7 +146,10 @@ void GPUDialect::initialize() { addTypes(); addTypes(); - addTypes(); + addTypes(); + addTypes(); + addTypes(); + addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" @@ -201,17 +204,27 @@ shape, elementType, operand); } - if (keyword == "sparse.handle") - return SparseHandleType::get(context); + if (keyword == "sparse.env_handle") + return SparseEnvHandleType::get(context); + if (keyword == "sparse.spmat_handle") + return SparseSpMatHandleType::get(context); + if (keyword == "sparse.dnvec_handle") + return SparseDnVecHandleType::get(context); + if (keyword == "sparse.dnmat_handle") + return SparseDnMatHandleType::get(context); parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); return Type(); } - +// TODO: print refined type here. Notice that should be corresponding to the +// parser void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "async.token"; }) - .Case([&](Type) { os << "sparse.handle"; }) + .Case([&](Type) { os << "sparse.env_handle"; }) + .Case([&](Type) { os << "sparse.spmat_handle"; }) + .Case([&](Type) { os << "sparse.dnmat_handle"; }) + .Case([&](Type) { os << "sparse.dnvec_handle"; }) .Case([&](MMAMatrixType fragTy) { os << "mma_matrix<"; auto shape = fragTy.getShape(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -436,22 +436,25 @@ // Create sparse environment and sparse matrix/dense vector handles. Type indexTp = rewriter.getIndexType(); - Type handleTp = rewriter.getType(); + Type envHandleTp = rewriter.getType(); + Type dnVecHandleTp = rewriter.getType(); + Type spmatHandleTp = rewriter.getType(); Type tokenTp = rewriter.getType(); Value token = genFirstWait(rewriter, loc); auto env = - rewriter.create(loc, handleTp, tokenTp, token); + rewriter.create(loc, envHandleTp, tokenTp, token); Value handle = env.getResult(0); token = env.getAsyncToken(); - Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY, - szX, nseA, rowA, colA, valA, isCOO, enableRT); + Operation *spGenA = + genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nnzA, + rowA, colA, valA, isCOO, enableRT); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); - auto dvecX = rewriter.create(loc, handleTp, tokenTp, + auto dvecX = rewriter.create(loc, dnVecHandleTp, tokenTp, token, vecX, szX); Value dnX = dvecX.getResult(0); token = dvecX.getAsyncToken(); - auto dvecY = rewriter.create(loc, handleTp, tokenTp, + auto dvecY = rewriter.create(loc, dnVecHandleTp, tokenTp, token, vecY, szY); Value dnY = dvecY.getResult(0); token = dvecY.getAsyncToken(); @@ -540,22 +543,24 @@ // Create sparse environment and sparse matrix/dense matrix handles. Type indexTp = rewriter.getIndexType(); - Type handleTp = rewriter.getType(); + Type envHandleTp = rewriter.getType(); + Type dnMatHandleTp = rewriter.getType(); + Type spMatHandleTp = rewriter.getType(); Type tokenTp = rewriter.getType(); Value token = genFirstWait(rewriter, loc); auto env = - rewriter.create(loc, handleTp, tokenTp, token); + rewriter.create(loc, envHandleTp, tokenTp, token); Value handle = env.getResult(0); token = env.getAsyncToken(); - Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szm, + Operation *spGenA = genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA, rowA, colA, valA, isCOO, enableRT); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); - auto dmatB = rewriter.create(loc, handleTp, tokenTp, + auto dmatB = rewriter.create(loc, dnMatHandleTp, tokenTp, token, szk, szn, matB); Value dnB = dmatB.getResult(0); token = dmatB.getAsyncToken(); - auto dmatC = rewriter.create(loc, handleTp, tokenTp, + auto dmatC = rewriter.create(loc, dnMatHandleTp, tokenTp, token, szm, szn, matC); Value dnC = dmatC.getResult(0); token = dmatC.getAsyncToken();