diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -52,6 +52,21 @@ mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop, "any-storage-any-loop", "Enable sparse parallelization for any storage and loop."))}; + PassOptions::Option gpuDataTransfer{ + *this, "gpu-data-transfer-strategy", + ::llvm::cl::desc( + "Set the data transfer strategy between the host and the GPUs"), + ::llvm::cl::init(mlir::GPUDataTransferStrategy::kRegularDMA), + llvm::cl::values( + clEnumValN(mlir::GPUDataTransferStrategy::kRegularDMA, "regular-dma", + "Default option: malloc on host without additional " + "options or care and then use DMA to copy the data"), + clEnumValN(mlir::GPUDataTransferStrategy::kPinnedDMA, "pinned-dma", + "Based on the default option, pin the host memory to " + "accelerate the data transfer"), + clEnumValN(mlir::GPUDataTransferStrategy::kZeroCopy, "zero-copy", + "Use zero-copy to perform the data transfer from the host " + "to the GPU"))}; PassOptions::Option enableIndexReduction{ *this, "enable-index-reduction", @@ -138,8 +153,9 @@ /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { - return SparsificationOptions(parallelization, enableIndexReduction, - enableGPULibgen, enableRuntimeLibrary); + return SparsificationOptions(parallelization, gpuDataTransfer, + enableIndexReduction, enableGPULibgen, + enableRuntimeLibrary); } /// Projects out the options for `createSparseTensorConversionPass`. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -44,19 +44,26 @@ // TODO: support reduction parallelization too? }; +// TODO : Zero copy is disabled due to correctness bugs.Tracker #64316 +enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA }; + #define GEN_PASS_DECL #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" /// Options for the Sparsification pass. struct SparsificationOptions { - SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc, + SparsificationOptions(SparseParallelizationStrategy p, + GPUDataTransferStrategy t, bool idxReduc, bool gpuLibgen, bool enableRT) - : parallelizationStrategy(p), enableIndexReduction(idxReduc), - enableGPULibgen(gpuLibgen), enableRuntimeLibrary(enableRT) {} + : parallelizationStrategy(p), gpuDataTransferStrategy(t), + enableIndexReduction(idxReduc), enableGPULibgen(gpuLibgen), + enableRuntimeLibrary(enableRT) {} SparsificationOptions() - : SparsificationOptions(SparseParallelizationStrategy::kNone, false, + : SparsificationOptions(SparseParallelizationStrategy::kNone, + GPUDataTransferStrategy::kRegularDMA, false, false, true) {} SparseParallelizationStrategy parallelizationStrategy; + GPUDataTransferStrategy gpuDataTransferStrategy; bool enableIndexReduction; bool enableGPULibgen; bool enableRuntimeLibrary; @@ -211,8 +218,8 @@ void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads); -void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, - bool enableRT); +void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT, + GPUDataTransferStrategy gpuDataTransfer); std::unique_ptr createSparseGPUCodegenPass(); std::unique_ptr createSparseGPUCodegenPass(unsigned numThreads); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -102,6 +102,19 @@ clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop, "any-storage-any-loop", "Enable sparse parallelization for any storage and loop."))}]>, + Option<"gpuDataTransfer", "gpu-data-transfer-strategy", "mlir::GPUDataTransferStrategy", + "mlir::GPUDataTransferStrategy::kRegularDMA", + "Set the data transfer strategy", [{llvm::cl::values( + clEnumValN(mlir::GPUDataTransferStrategy::kRegularDMA, + "regular-dma", + "Default option: malloc on host without additional " + "options or care and then use DMA to copy the data"), + clEnumValN(mlir::GPUDataTransferStrategy::kPinnedDMA, "pinned-dma", + "Based on the default option, pin the host memory to " + "accelerate the data transfer"), + clEnumValN(mlir::GPUDataTransferStrategy::kZeroCopy, "zero-copy", + "Use zero-copy to perform the data transfer from the host " + "to the GPU"))}]>, Option<"enableGPULibgen", "enable-gpu-libgen", "bool", "false", "Enable GPU acceleration by means of direct library calls (like cuSPARSE)">, @@ -110,6 +123,7 @@ ]; } + def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> { let summary = "Applies sparse tensor rewriting rules after sparsification"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -461,14 +461,18 @@ } /// Match and rewrite SpMV kernel. -static LogicalResult rewriteSpMV(PatternRewriter &rewriter, - linalg::GenericOp op, bool enableRT) { +static LogicalResult +rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, + GPUDataTransferStrategy gpuDataTransferStrategy) { Location loc = op.getLoc(); Value a = op.getOperand(0); Value x = op.getOperand(1); Value y = op.getOperand(2); // we have y = Ax SmallVector tokens; + bool isZeroCopy = + gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy; + // Only admissible sparse matrix format and dense vectors. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); @@ -487,12 +491,27 @@ Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT); Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT); Value memV = genToValues(rewriter, loc, a); + Value memX, memY; + Value castR, castC, castV, castX, castY; + if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { + memX = genTensorToMemref(rewriter, loc, x); + memY = genTensorToMemref(rewriter, loc, y); + castR = genHostRegisterMemref(rewriter, loc, memR); + if (memC) + castC = genHostRegisterMemref(rewriter, loc, memC); + castV = genHostRegisterMemref(rewriter, loc, memV); + castX = genHostRegisterMemref(rewriter, loc, memX); + castY = genHostRegisterMemref(rewriter, loc, memY); + } + Value rowA = genAllocCopy(rewriter, loc, memR, tokens); Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valA = genAllocCopy(rewriter, loc, memV, tokens); - Value memX = genTensorToMemref(rewriter, loc, x); - Value vecX = genAllocCopy(rewriter, loc, memX, tokens); - Value memY = genTensorToMemref(rewriter, loc, y); + if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA) + memX = genTensorToMemref(rewriter, loc, x); + Value vecX = isZeroCopy ? memX : genAllocCopy(rewriter, loc, memX, tokens); + if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA) + memY = genTensorToMemref(rewriter, loc, y); Value vecY = genAllocCopy(rewriter, loc, memY, tokens); genBlockingWait(rewriter, loc, tokens); tokens.clear(); @@ -546,11 +565,20 @@ token = genDeallocMemRef(rewriter, loc, colA, token); token = genDeallocMemRef(rewriter, loc, valA, token); token = genDeallocMemRef(rewriter, loc, buffer, token); - token = genDeallocMemRef(rewriter, loc, vecX, token); + if (!isZeroCopy) + token = genDeallocMemRef(rewriter, loc, vecX, token); token = genCopyMemRef(rewriter, loc, memY, vecY, token); token = genDeallocMemRef(rewriter, loc, vecY, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { + genHostUnregisterMemref(rewriter, loc, castR); + if (memC) + genHostUnregisterMemref(rewriter, loc, castC); + genHostUnregisterMemref(rewriter, loc, castV); + genHostUnregisterMemref(rewriter, loc, castX); + genHostUnregisterMemref(rewriter, loc, castY); + } tokens.clear(); // Done. @@ -559,14 +587,18 @@ } /// Match and rewrite SpMM kernel. -static LogicalResult rewriteSpMM(PatternRewriter &rewriter, - linalg::GenericOp op, bool enableRT) { +static LogicalResult +rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, + GPUDataTransferStrategy gpuDataTransferStrategy) { Location loc = op.getLoc(); Value a = op.getOperand(0); Value b = op.getOperand(1); Value c = op.getOperand(2); // we have C = AB SmallVector tokens; + bool isZeroCopy = + gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy; + // Only admissible sparse matrix format and dense matrices. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); @@ -586,12 +618,27 @@ Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT); Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT); Value memV = genToValues(rewriter, loc, a); + Value bufB, bufC; + Value castR, castC, castV, castB, castBufC; + if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { + bufB = genTensorToMemref(rewriter, loc, b); + bufC = genTensorToMemref(rewriter, loc, c); + castR = genHostRegisterMemref(rewriter, loc, memR); + if (memC) + castC = genHostRegisterMemref(rewriter, loc, memC); + castV = genHostRegisterMemref(rewriter, loc, memV); + castB = genHostRegisterMemref(rewriter, loc, bufB); + castBufC = genHostRegisterMemref(rewriter, loc, bufC); + } + Value rowA = genAllocCopy(rewriter, loc, memR, tokens); Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valA = genAllocCopy(rewriter, loc, memV, tokens); - Value bufB = genTensorToMemref(rewriter, loc, b); - Value matB = genAllocCopy(rewriter, loc, bufB, tokens); - Value bufC = genTensorToMemref(rewriter, loc, c); + if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA) + bufB = genTensorToMemref(rewriter, loc, b); + Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens); + if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA) + bufC = genTensorToMemref(rewriter, loc, c); Value matC = genAllocCopy(rewriter, loc, bufC, tokens); genBlockingWait(rewriter, loc, tokens); tokens.clear(); @@ -649,11 +696,20 @@ token = genDeallocMemRef(rewriter, loc, colA, token); token = genDeallocMemRef(rewriter, loc, valA, token); token = genDeallocMemRef(rewriter, loc, buffer, token); - token = genDeallocMemRef(rewriter, loc, matB, token); + if (!isZeroCopy) + token = genDeallocMemRef(rewriter, loc, matB, token); token = genCopyMemRef(rewriter, loc, bufC, matC, token); token = genDeallocMemRef(rewriter, loc, matC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { + genHostUnregisterMemref(rewriter, loc, castR); + if (memC) + genHostUnregisterMemref(rewriter, loc, castC); + genHostUnregisterMemref(rewriter, loc, castV); + genHostUnregisterMemref(rewriter, loc, castB); + genHostUnregisterMemref(rewriter, loc, castC); + } tokens.clear(); // Done. @@ -662,23 +718,41 @@ } // Match and rewrite 2:4 SpMM kernels. -static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, - linalg::GenericOp op) { +static LogicalResult +rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op, + GPUDataTransferStrategy gpuDataTransferStrategy) { Location loc = op.getLoc(); Value A = op.getOperand(0); Value B = op.getOperand(1); Value C = op.getOperand(2); // we have C = AB SmallVector tokens; + bool isZeroCopy = + gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy; + // All input should be dense tensors. if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) return failure(); + Value matA, matB; Value bufA = genTensorToMemref(rewriter, loc, A); - Value matA = genAllocCopy(rewriter, loc, bufA, tokens); + if (!isZeroCopy) + matA = genAllocCopy(rewriter, loc, bufA, tokens); Value bufB = genTensorToMemref(rewriter, loc, B); - Value matB = genAllocCopy(rewriter, loc, bufB, tokens); + if (!isZeroCopy) + matB = genAllocCopy(rewriter, loc, bufB, tokens); Value bufC = genTensorToMemref(rewriter, loc, C); + Value castA, castB, castC; + if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { + castA = genHostRegisterMemref(rewriter, loc, bufA); + castB = genHostRegisterMemref(rewriter, loc, bufB); + castC = genHostRegisterMemref(rewriter, loc, bufC); + } + + if (isZeroCopy) { + matA = bufA; + matB = bufB; + } Value matC = genAllocCopy(rewriter, loc, bufC, tokens); genBlockingWait(rewriter, loc, tokens); tokens.clear(); @@ -754,26 +828,38 @@ token = genDeallocMemRef(rewriter, loc, buffer, token); token = genDeallocMemRef(rewriter, loc, buffer2, token); token = genDeallocMemRef(rewriter, loc, buffer3, token); - token = genDeallocMemRef(rewriter, loc, matA, token); - token = genDeallocMemRef(rewriter, loc, matB, token); + + if (!isZeroCopy) + token = genDeallocMemRef(rewriter, loc, matA, token); + if (!isZeroCopy) + token = genDeallocMemRef(rewriter, loc, matB, token); token = genCopyMemRef(rewriter, loc, bufC, matC, token); token = genDeallocMemRef(rewriter, loc, matC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { + genHostUnregisterMemref(rewriter, loc, castA); + genHostUnregisterMemref(rewriter, loc, castB); + genHostUnregisterMemref(rewriter, loc, castC); + } tokens.clear(); rewriter.replaceOpWithNewOp(op, bufC); return success(); } /// Match and rewrite SDDMM kernel. -static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, - linalg::GenericOp op, bool enableRT) { +static LogicalResult +rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, + GPUDataTransferStrategy gpuDataTransferStrategy) { Location loc = op.getLoc(); Value a = op.getOperand(0); Value b = op.getOperand(1); Value c = op.getOperand(2); SmallVector tokens; + bool isZeroCopy = + gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy; + // Only admissible sparse matrix format and dense matrices, no COO. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); @@ -793,13 +879,31 @@ Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); + Value matA, matB; Value bufA = genTensorToMemref(rewriter, loc, a); - Value matA = genAllocCopy(rewriter, loc, bufA, tokens); + if (!isZeroCopy) + matA = genAllocCopy(rewriter, loc, bufA, tokens); Value bufB = genTensorToMemref(rewriter, loc, b); - Value matB = genAllocCopy(rewriter, loc, bufB, tokens); + if (!isZeroCopy) + matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens); Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT); Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT); Value memV = genToValues(rewriter, loc, c); + + Value castB, castA, castR, castC, castV; + if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { + castB = genHostRegisterMemref(rewriter, loc, bufB); + castA = genHostRegisterMemref(rewriter, loc, bufA); + castR = genHostRegisterMemref(rewriter, loc, memR); + if (memC) + castC = genHostRegisterMemref(rewriter, loc, memC); + castV = genHostRegisterMemref(rewriter, loc, memV); + } + + if (isZeroCopy) { + matA = bufA; + matB = bufB; + } Value rowC = genAllocCopy(rewriter, loc, memR, tokens); Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valC = genAllocCopy(rewriter, loc, memV, tokens); @@ -850,8 +954,10 @@ token = rewriter.create(loc, tokenTp, token, spMatC) .getAsyncToken(); token = genDeallocMemRef(rewriter, loc, buffer, token); - token = genDeallocMemRef(rewriter, loc, matA, token); - token = genDeallocMemRef(rewriter, loc, matB, token); + if (!isZeroCopy) { + token = genDeallocMemRef(rewriter, loc, matA, token); + token = genDeallocMemRef(rewriter, loc, matB, token); + } token = genDeallocMemRef(rewriter, loc, rowC, token); if (colC) token = genDeallocMemRef(rewriter, loc, colC, token); @@ -859,6 +965,14 @@ token = genDeallocMemRef(rewriter, loc, valC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { + genHostUnregisterMemref(rewriter, loc, castB); + genHostUnregisterMemref(rewriter, loc, castA); + genHostUnregisterMemref(rewriter, loc, castR); + if (memC) + genHostUnregisterMemref(rewriter, loc, castC); + genHostUnregisterMemref(rewriter, loc, castV); + } tokens.clear(); // Done. @@ -977,8 +1091,8 @@ struct LinalgOpRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LinalgOpRewriter(MLIRContext *context, bool rt) - : OpRewritePattern(context), enableRT(rt) {} + LinalgOpRewriter(MLIRContext *context, bool rt, GPUDataTransferStrategy t) + : OpRewritePattern(context), enableRT(rt), gpuDataTransferStrategy(t) {} LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { @@ -1004,7 +1118,7 @@ linalg::isReductionIterator(iteratorTypes[1]) && // TODO: add transposed {i, j} maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { - return rewriteSpMV(rewriter, op, enableRT); + return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy); } // Recognize a SpMM kernel. @@ -1016,9 +1130,9 @@ // TODO: maybe add transposed {i, j} in future maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { if (op->getAttr("DENSE24")) - return rewrite2To4SpMM(rewriter, op); + return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy); - return rewriteSpMM(rewriter, op, enableRT); + return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy); } // Recognize a SDDMM kernel. @@ -1030,7 +1144,7 @@ // TODO: maybe add transposed {i, j} in future maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumReductionOfMulUnary(op)) { - return rewriteSDDMM(rewriter, op, enableRT); + return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy); } return failure(); @@ -1038,6 +1152,7 @@ private: bool enableRT; + GPUDataTransferStrategy gpuDataTransferStrategy; }; } // namespace @@ -1057,7 +1172,9 @@ patterns.add(patterns.getContext(), numThreads); } -void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, - bool enableRT) { - patterns.add(patterns.getContext(), enableRT); +void mlir::populateSparseGPULibgenPatterns( + RewritePatternSet &patterns, bool enableRT, + GPUDataTransferStrategy gpuDataTransfer) { + patterns.add(patterns.getContext(), enableRT, + gpuDataTransfer); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -65,6 +65,7 @@ SparsificationPass(const SparsificationPass &pass) = default; SparsificationPass(const SparsificationOptions &options) { parallelization = options.parallelizationStrategy; + gpuDataTransfer = options.gpuDataTransferStrategy; enableIndexReduction = options.enableIndexReduction; enableGPULibgen = options.enableGPULibgen; enableRuntimeLibrary = options.enableRuntimeLibrary; @@ -73,12 +74,17 @@ void runOnOperation() override { auto *ctx = &getContext(); // Translate strategy flags to strategy options. - SparsificationOptions options(parallelization, enableIndexReduction, - enableGPULibgen, enableRuntimeLibrary); + SparsificationOptions options(parallelization, gpuDataTransfer, + enableIndexReduction, enableGPULibgen, + enableRuntimeLibrary); // Apply GPU libgen (if requested), sparsification, and cleanup rewriting. RewritePatternSet patterns(ctx); if (enableGPULibgen) { - populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); + // TODO : Zero copy is disabled due to correctness bugs.Tracker #64316 + assert(gpuDataTransfer != GPUDataTransferStrategy::kZeroCopy && + "zero-copy transfer not supported with GPU libgen"); + populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary, + gpuDataTransfer); } populateSparsificationPatterns(patterns, options); scf::ForOp::getCanonicalizationPatterns(patterns, ctx); diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir @@ -1,13 +1,18 @@ // // NOTE: this test requires gpu-sm80 and cusparselt // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e main --entry-point-result=void \ +// DEFINE: | FileCheck %s + +// RUN: %{compile}" | %{run} +// RUN: %{compile} gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} gpu-data-transfer-strategy=zero-copy" | %{run} #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir @@ -1,14 +1,19 @@ // // NOTE: this test requires gpu-sm80 and cusparselt // -// RUN: mlir-opt --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \ -// RUN: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \ -// RUN: %s \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// DEFINE: %{compile} = mlir-opt --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \ +// DEFINE: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \ +// DEFINE: %s +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e main --entry-point-result=void \ +// DEFINE: | FileCheck %s + +// RUN: %{compile} | %{run} +// RUN: %{compile} --sparse-compiler="gpu-data-transfer-strategy=pinned-dma" | %{run} +// RUNNOT: %{compile} --sparse-compiler="gpu-data-transfer-strategy=zero-copy" | %{run} + module { llvm.func @mgpuCreateSparseLtEnv() diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir @@ -1,25 +1,28 @@ // // NOTE: this test requires gpu-sm80 // +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e main --entry-point-result=void \ +// DEFINE: | FileCheck %s +// +// // with RT lib (SoA COO): // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=true" | %{run} +// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run} // // without RT lib (AoS COO): note, may fall back to CPU // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=false" | %{run} +// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run} #SortedCOO = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ] diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir @@ -1,25 +1,28 @@ // // NOTE: this test requires gpu-sm80 // +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e main --entry-point-result=void \ +// DEFINE: | FileCheck %s +// // with RT lib (SoA COO): // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=true" | %{run} +// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run} // // without RT lib (AoS COO): note, may fall back to CPU // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=false" | %{run} +// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run} +// #SortedCOO = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ] diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir @@ -1,28 +1,29 @@ // // NOTE: this test requires gpu-sm80 // +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 +// DEFINE: %{run} = TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e entry --entry-point-result=void \ +// DEFINE: | FileCheck %s +// // with RT lib: // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e entry --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=true" | %{run} +// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run} // // without RT lib: // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e entry --entry-point-result=void \ -// RUN: | FileCheck %s -// +// RUN: %{compile} enable-runtime-library=false" | %{run} +// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run} +// !Filename = !llvm.ptr