diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -52,6 +52,22 @@ mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop, "any-storage-any-loop", "Enable sparse parallelization for any storage and loop."))}; + PassOptions::Option gpuDataTransfer{ + *this, "gpu-data-transfer-strategy", + ::llvm::cl::desc( + "Set the data transfer strategy between the host and the GPUs"), + ::llvm::cl::init(mlir::SparseDataTransferStrategy::kRegularDMA), + llvm::cl::values( + clEnumValN(mlir::SparseDataTransferStrategy::kRegularDMA, + "regular-dma", + "Default option: malloc on host without additional " + "options or care and then use DMA to copy the data"), + clEnumValN(mlir::SparseDataTransferStrategy::kPinnedDMA, "pinned-dma", + "Based on the default option, pin the host memory to " + "accelerate the data transfer"), + clEnumValN(mlir::SparseDataTransferStrategy::kZeroCopy, "zero-copy", + "Use zero-copy to perform the data transfer from the host " + "to the GPU"))}; PassOptions::Option enableIndexReduction{ *this, "enable-index-reduction", @@ -138,8 +154,9 @@ /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { - return SparsificationOptions(parallelization, enableIndexReduction, - enableGPULibgen, enableRuntimeLibrary); + return SparsificationOptions(parallelization, gpuDataTransfer, + enableIndexReduction, enableGPULibgen, + enableRuntimeLibrary); } /// Projects out the options for `createSparseTensorConversionPass`. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -44,19 +44,25 @@ // TODO: support reduction parallelization too? }; +enum class SparseDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA }; + #define GEN_PASS_DECL #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" /// Options for the Sparsification pass. struct SparsificationOptions { - SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc, + SparsificationOptions(SparseParallelizationStrategy p, + SparseDataTransferStrategy t, bool idxReduc, bool gpuLibgen, bool enableRT) - : parallelizationStrategy(p), enableIndexReduction(idxReduc), - enableGPULibgen(gpuLibgen), enableRuntimeLibrary(enableRT) {} + : parallelizationStrategy(p), dataTransferStrategy(t), + enableIndexReduction(idxReduc), enableGPULibgen(gpuLibgen), + enableRuntimeLibrary(enableRT) {} SparsificationOptions() - : SparsificationOptions(SparseParallelizationStrategy::kNone, false, + : SparsificationOptions(SparseParallelizationStrategy::kNone, + SparseDataTransferStrategy::kRegularDMA, false, false, true) {} SparseParallelizationStrategy parallelizationStrategy; + SparseDataTransferStrategy dataTransferStrategy; bool enableIndexReduction; bool enableGPULibgen; bool enableRuntimeLibrary; @@ -211,8 +217,9 @@ void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads); -void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, - bool enableRT); +void populateSparseGPULibgenPatterns( + RewritePatternSet &patterns, bool enableRT, + SparseDataTransferStrategy gpuDataTransfer); std::unique_ptr createSparseGPUCodegenPass(); std::unique_ptr createSparseGPUCodegenPass(unsigned numThreads); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -102,6 +102,19 @@ clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop, "any-storage-any-loop", "Enable sparse parallelization for any storage and loop."))}]>, + Option<"gpu-transfer", "gpu-data-transfer-strategy", "mlir::SparseDataTransferStrategy", + "mlir::SparseDataTransferStrategy::kRegularDMA", + "Set the data transfer strategy", [{llvm::cl::values( + clEnumValN(mlir::SparseDataTransferStrategy::kRegularDMA, + "regular-dma", + "Default option: malloc on host without additional " + "options or care and then use DMA to copy the data"), + clEnumValN(mlir::SparseDataTransferStrategy::kPinnedDMA, "pinned-dma", + "Based on the default option, pin the host memory to " + "accelerate the data transfer"), + clEnumValN(mlir::SparseDataTransferStrategy::kZeroCopy, "zero-copy", + "Use zero-copy to perform the data transfer from the host " + "to the GPU"))}]>, Option<"enableGPULibgen", "enable-gpu-libgen", "bool", "false", "Enable GPU acceleration by means of direct library calls (like cuSPARSE)">, @@ -110,6 +123,7 @@ ]; } + def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> { let summary = "Applies sparse tensor rewriting rules after sparsification"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -461,14 +461,18 @@ } /// Match and rewrite SpMV kernel. -static LogicalResult rewriteSpMV(PatternRewriter &rewriter, - linalg::GenericOp op, bool enableRT) { +static LogicalResult +rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, + SparseDataTransferStrategy dataTransferStrategy) { Location loc = op.getLoc(); Value a = op.getOperand(0); Value x = op.getOperand(1); Value y = op.getOperand(2); // we have y = Ax SmallVector tokens; + bool isZeroCopy = + dataTransferStrategy == SparseDataTransferStrategy::kZeroCopy; + // Only admissible sparse matrix format and dense vectors. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); @@ -487,12 +491,22 @@ Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT); Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT); Value memV = genToValues(rewriter, loc, a); + Value memX = genTensorToMemref(rewriter, loc, x); + Value memY = genTensorToMemref(rewriter, loc, y); + Value memR_cast, memC_cast, memV_cast, memX_cast, memY_cast; + if (dataTransferStrategy != SparseDataTransferStrategy::kRegularDMA) { + memR_cast = genHostRegisterMemref(rewriter, loc, memR); + if (memC) + memC_cast = genHostRegisterMemref(rewriter, loc, memC); + memV_cast = genHostRegisterMemref(rewriter, loc, memV); + memX_cast = genHostRegisterMemref(rewriter, loc, memX); + memY_cast = genHostRegisterMemref(rewriter, loc, memY); + } + Value rowA = genAllocCopy(rewriter, loc, memR, tokens); Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valA = genAllocCopy(rewriter, loc, memV, tokens); - Value memX = genTensorToMemref(rewriter, loc, x); - Value vecX = genAllocCopy(rewriter, loc, memX, tokens); - Value memY = genTensorToMemref(rewriter, loc, y); + Value vecX = isZeroCopy ? memX : genAllocCopy(rewriter, loc, memX, tokens); Value vecY = genAllocCopy(rewriter, loc, memY, tokens); genBlockingWait(rewriter, loc, tokens); tokens.clear(); @@ -546,11 +560,20 @@ token = genDeallocMemRef(rewriter, loc, colA, token); token = genDeallocMemRef(rewriter, loc, valA, token); token = genDeallocMemRef(rewriter, loc, buffer, token); - token = genDeallocMemRef(rewriter, loc, vecX, token); + if (!isZeroCopy) + token = genDeallocMemRef(rewriter, loc, vecX, token); token = genCopyMemRef(rewriter, loc, memY, vecY, token); token = genDeallocMemRef(rewriter, loc, vecY, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + if (dataTransferStrategy != SparseDataTransferStrategy::kRegularDMA) { + genHostUnregisterMemref(rewriter, loc, memR_cast); + if (memC) + genHostUnregisterMemref(rewriter, loc, memC_cast); + genHostUnregisterMemref(rewriter, loc, memV_cast); + genHostUnregisterMemref(rewriter, loc, memX_cast); + genHostUnregisterMemref(rewriter, loc, memY_cast); + } tokens.clear(); // Done. @@ -559,14 +582,18 @@ } /// Match and rewrite SpMM kernel. -static LogicalResult rewriteSpMM(PatternRewriter &rewriter, - linalg::GenericOp op, bool enableRT) { +static LogicalResult +rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, + SparseDataTransferStrategy dataTransferStrategy) { Location loc = op.getLoc(); Value a = op.getOperand(0); Value b = op.getOperand(1); Value c = op.getOperand(2); // we have C = AB SmallVector tokens; + bool isZeroCopy = + dataTransferStrategy == SparseDataTransferStrategy::kZeroCopy; + // Only admissible sparse matrix format and dense matrices. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); @@ -586,12 +613,22 @@ Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT); Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT); Value memV = genToValues(rewriter, loc, a); + Value bufB = genTensorToMemref(rewriter, loc, b); + Value bufC = genTensorToMemref(rewriter, loc, c); + Value memR_cast, memC_cast, memV_cast, bufB_cast, bufC_cast; + if (dataTransferStrategy != SparseDataTransferStrategy::kRegularDMA) { + memR_cast = genHostRegisterMemref(rewriter, loc, memR); + if (memC) + memC_cast = genHostRegisterMemref(rewriter, loc, memC); + memV_cast = genHostRegisterMemref(rewriter, loc, memV); + bufB_cast = genHostRegisterMemref(rewriter, loc, bufB); + bufC_cast = genHostRegisterMemref(rewriter, loc, bufC); + } + Value rowA = genAllocCopy(rewriter, loc, memR, tokens); Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valA = genAllocCopy(rewriter, loc, memV, tokens); - Value bufB = genTensorToMemref(rewriter, loc, b); - Value matB = genAllocCopy(rewriter, loc, bufB, tokens); - Value bufC = genTensorToMemref(rewriter, loc, c); + Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens); Value matC = genAllocCopy(rewriter, loc, bufC, tokens); genBlockingWait(rewriter, loc, tokens); tokens.clear(); @@ -649,11 +686,20 @@ token = genDeallocMemRef(rewriter, loc, colA, token); token = genDeallocMemRef(rewriter, loc, valA, token); token = genDeallocMemRef(rewriter, loc, buffer, token); - token = genDeallocMemRef(rewriter, loc, matB, token); + if (!isZeroCopy) + token = genDeallocMemRef(rewriter, loc, matB, token); token = genCopyMemRef(rewriter, loc, bufC, matC, token); token = genDeallocMemRef(rewriter, loc, matC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + if (dataTransferStrategy != SparseDataTransferStrategy::kRegularDMA) { + genHostUnregisterMemref(rewriter, loc, memR_cast); + if (memC) + genHostUnregisterMemref(rewriter, loc, memC_cast); + genHostUnregisterMemref(rewriter, loc, memV_cast); + genHostUnregisterMemref(rewriter, loc, bufB_cast); + genHostUnregisterMemref(rewriter, loc, bufC_cast); + } tokens.clear(); // Done. @@ -662,23 +708,34 @@ } // Match and rewrite 2:4 SpMM kernels. -static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, - linalg::GenericOp op) { +static LogicalResult +rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op, + SparseDataTransferStrategy dataTransferStrategy) { Location loc = op.getLoc(); Value A = op.getOperand(0); Value B = op.getOperand(1); Value C = op.getOperand(2); // we have C = AB SmallVector tokens; + bool isZeroCopy = + dataTransferStrategy == SparseDataTransferStrategy::kZeroCopy; + // All input should be dense tensors. if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) return failure(); Value bufA = genTensorToMemref(rewriter, loc, A); - Value matA = genAllocCopy(rewriter, loc, bufA, tokens); Value bufB = genTensorToMemref(rewriter, loc, B); - Value matB = genAllocCopy(rewriter, loc, bufB, tokens); Value bufC = genTensorToMemref(rewriter, loc, C); + Value bufA_cast, bufB_cast, bufC_cast; + if (dataTransferStrategy != SparseDataTransferStrategy::kRegularDMA) { + bufA_cast = genHostRegisterMemref(rewriter, loc, bufA); + bufB_cast = genHostRegisterMemref(rewriter, loc, bufB); + bufC_cast = genHostRegisterMemref(rewriter, loc, bufC); + } + + Value matA = isZeroCopy ? bufA : genAllocCopy(rewriter, loc, bufA, tokens); + Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens); Value matC = genAllocCopy(rewriter, loc, bufC, tokens); genBlockingWait(rewriter, loc, tokens); tokens.clear(); @@ -753,26 +810,38 @@ token = genDeallocMemRef(rewriter, loc, buffer, token); token = genDeallocMemRef(rewriter, loc, buffer2, token); token = genDeallocMemRef(rewriter, loc, buffer3, token); - token = genDeallocMemRef(rewriter, loc, matA, token); - token = genDeallocMemRef(rewriter, loc, matB, token); + + if (!isZeroCopy) + token = genDeallocMemRef(rewriter, loc, matA, token); + if (!isZeroCopy) + token = genDeallocMemRef(rewriter, loc, matB, token); token = genCopyMemRef(rewriter, loc, bufC, matC, token); token = genDeallocMemRef(rewriter, loc, matC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + if (dataTransferStrategy != SparseDataTransferStrategy::kRegularDMA) { + genHostUnregisterMemref(rewriter, loc, bufA_cast); + genHostUnregisterMemref(rewriter, loc, bufB_cast); + genHostUnregisterMemref(rewriter, loc, bufC_cast); + } tokens.clear(); rewriter.replaceOpWithNewOp(op, bufC); return success(); } /// Match and rewrite SDDMM kernel. -static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, - linalg::GenericOp op, bool enableRT) { +static LogicalResult +rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, + SparseDataTransferStrategy dataTransferStrategy) { Location loc = op.getLoc(); Value a = op.getOperand(0); Value b = op.getOperand(1); Value c = op.getOperand(2); SmallVector tokens; + bool isZeroCopy = + dataTransferStrategy == SparseDataTransferStrategy::kZeroCopy; + // Only admissible sparse matrix format and dense matrices, no COO. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); @@ -793,12 +862,23 @@ Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); Value bufA = genTensorToMemref(rewriter, loc, a); - Value matA = genAllocCopy(rewriter, loc, bufA, tokens); Value bufB = genTensorToMemref(rewriter, loc, b); - Value matB = genAllocCopy(rewriter, loc, bufB, tokens); Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT); Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT); Value memV = genToValues(rewriter, loc, c); + + Value bufB_cast, bufA_cast, memR_cast, memC_cast, memV_cast; + if (dataTransferStrategy != SparseDataTransferStrategy::kRegularDMA) { + bufB_cast = genHostRegisterMemref(rewriter, loc, bufB); + bufA_cast = genHostRegisterMemref(rewriter, loc, bufA); + memR_cast = genHostRegisterMemref(rewriter, loc, memR); + if (memC) + memC_cast = genHostRegisterMemref(rewriter, loc, memC); + memV_cast = genHostRegisterMemref(rewriter, loc, memV); + } + + Value matA = isZeroCopy ? bufA : genAllocCopy(rewriter, loc, bufA, tokens); + Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens); Value rowC = genAllocCopy(rewriter, loc, memR, tokens); Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valC = genAllocCopy(rewriter, loc, memV, tokens); @@ -849,8 +929,10 @@ token = rewriter.create(loc, tokenTp, token, spMatC) .getAsyncToken(); token = genDeallocMemRef(rewriter, loc, buffer, token); - token = genDeallocMemRef(rewriter, loc, matA, token); - token = genDeallocMemRef(rewriter, loc, matB, token); + if (!isZeroCopy) { + token = genDeallocMemRef(rewriter, loc, matA, token); + token = genDeallocMemRef(rewriter, loc, matB, token); + } token = genDeallocMemRef(rewriter, loc, rowC, token); if (colC) token = genDeallocMemRef(rewriter, loc, colC, token); @@ -858,6 +940,14 @@ token = genDeallocMemRef(rewriter, loc, valC, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); + if (dataTransferStrategy != SparseDataTransferStrategy::kRegularDMA) { + genHostUnregisterMemref(rewriter, loc, bufB_cast); + genHostUnregisterMemref(rewriter, loc, bufA_cast); + genHostUnregisterMemref(rewriter, loc, memR_cast); + if (memC) + genHostUnregisterMemref(rewriter, loc, memC_cast); + genHostUnregisterMemref(rewriter, loc, memV_cast); + } tokens.clear(); // Done. @@ -976,8 +1066,8 @@ struct LinalgOpRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LinalgOpRewriter(MLIRContext *context, bool rt) - : OpRewritePattern(context), enableRT(rt) {} + LinalgOpRewriter(MLIRContext *context, bool rt, SparseDataTransferStrategy t) + : OpRewritePattern(context), enableRT(rt), dataTransferStrategy(t) {} LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { @@ -1003,7 +1093,7 @@ linalg::isReductionIterator(iteratorTypes[1]) && // TODO: add transposed {i, j} maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { - return rewriteSpMV(rewriter, op, enableRT); + return rewriteSpMV(rewriter, op, enableRT, dataTransferStrategy); } // Recognize a SpMM kernel. @@ -1015,9 +1105,9 @@ // TODO: maybe add transposed {i, j} in future maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { if (op->getAttr("DENSE24")) - return rewrite2To4SpMM(rewriter, op); + return rewrite2To4SpMM(rewriter, op, dataTransferStrategy); - return rewriteSpMM(rewriter, op, enableRT); + return rewriteSpMM(rewriter, op, enableRT, dataTransferStrategy); } // Recognize a SDDMM kernel. @@ -1029,7 +1119,7 @@ // TODO: maybe add transposed {i, j} in future maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumReductionOfMulUnary(op)) { - return rewriteSDDMM(rewriter, op, enableRT); + return rewriteSDDMM(rewriter, op, enableRT, dataTransferStrategy); } return failure(); @@ -1037,6 +1127,7 @@ private: bool enableRT; + SparseDataTransferStrategy dataTransferStrategy; }; } // namespace @@ -1056,7 +1147,9 @@ patterns.add(patterns.getContext(), numThreads); } -void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, - bool enableRT) { - patterns.add(patterns.getContext(), enableRT); +void mlir::populateSparseGPULibgenPatterns( + RewritePatternSet &patterns, bool enableRT, + SparseDataTransferStrategy gpuDataTransfer) { + patterns.add(patterns.getContext(), enableRT, + gpuDataTransfer); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -65,6 +65,7 @@ SparsificationPass(const SparsificationPass &pass) = default; SparsificationPass(const SparsificationOptions &options) { parallelization = options.parallelizationStrategy; + transfer = options.dataTransferStrategy; enableIndexReduction = options.enableIndexReduction; enableGPULibgen = options.enableGPULibgen; enableRuntimeLibrary = options.enableRuntimeLibrary; @@ -73,12 +74,17 @@ void runOnOperation() override { auto *ctx = &getContext(); // Translate strategy flags to strategy options. - SparsificationOptions options(parallelization, enableIndexReduction, - enableGPULibgen, enableRuntimeLibrary); + SparsificationOptions options(parallelization, gpuDataTransfer, + enableIndexReduction, enableGPULibgen, + enableRuntimeLibrary); // Apply GPU libgen (if requested), sparsification, and cleanup rewriting. RewritePatternSet patterns(ctx); if (enableGPULibgen) { - populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); +#TODO : Zero copy is disabled due to correctness bugs.Tracker #64316 + assert(gpuDataTransfer != SparseDataTransferStrategy::kZeroCopy && + "zero-copy transfer not supported with GPU libgen"); + populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary, + gpuDataTransfer); } populateSparsificationPatterns(patterns, options); scf::ForOp::getCanonicalizationPatterns(patterns, ctx); diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir --- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir @@ -7,63 +7,63 @@ // Compute matrix matrix C = AB // // CHECK-LABEL: func.func @matmul( -// CHECK-SAME: %[[VAL_0:.*]]: tensor>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor, -// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor> -// CHECK-DAG: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> -// CHECK-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor> -// CHECK-DAG: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor -// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor> to memref -// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor> to memref -// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref -// CHECK: %[[VAL_12:.*]] = gpu.wait async -// CHECK: %[[VAL_13:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref -// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]] = gpu.alloc async {{\[}}%[[VAL_12]]] (%[[VAL_13]]) : memref -// CHECK: %[[VAL_16:.*]] = gpu.memcpy async {{\[}}%[[VAL_15]]] %[[VAL_14]], %[[VAL_9]] : memref, memref -// CHECK: %[[VAL_17:.*]] = gpu.wait async -// CHECK: %[[VAL_18:.*]] = memref.dim %[[VAL_10]], %[[VAL_3]] : memref -// CHECK: %[[VAL_19:.*]], %[[VAL_20:.*]] = gpu.alloc async {{\[}}%[[VAL_17]]] (%[[VAL_18]]) : memref -// CHECK: %[[VAL_21:.*]] = gpu.memcpy async {{\[}}%[[VAL_20]]] %[[VAL_19]], %[[VAL_10]] : memref, memref -// CHECK: %[[VAL_22:.*]] = gpu.wait async -// CHECK: %[[VAL_23:.*]] = memref.dim %[[VAL_11]], %[[VAL_3]] : memref -// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]] = gpu.alloc async {{\[}}%[[VAL_22]]] (%[[VAL_23]]) : memref -// CHECK: %[[VAL_26:.*]] = gpu.memcpy async {{\[}}%[[VAL_25]]] %[[VAL_24]], %[[VAL_11]] : memref, memref -// CHECK: %[[VAL_27:.*]] = bufferization.to_memref %[[VAL_1]] : memref -// CHECK: %[[VAL_28:.*]] = gpu.wait async -// CHECK: %[[VAL_29:.*]] = memref.dim %[[VAL_27]], %[[VAL_3]] : memref -// CHECK: %[[VAL_30:.*]] = memref.dim %[[VAL_27]], %[[VAL_4]] : memref -// CHECK: %[[VAL_31:.*]], %[[VAL_32:.*]] = gpu.alloc async {{\[}}%[[VAL_28]]] (%[[VAL_29]], %[[VAL_30]]) : memref -// CHECK: %[[VAL_33:.*]] = gpu.memcpy async {{\[}}%[[VAL_32]]] %[[VAL_31]], %[[VAL_27]] : memref, memref -// CHECK: %[[VAL_34:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK-SAME: %[[VAL_0:.*]]: tensor>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor> +// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor> +// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref +// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref +// CHECK: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK: %[[VAL_14:.*]] = gpu.wait async +// CHECK: %[[VAL_15:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref +// CHECK: %[[VAL_16:.*]], %[[VAL_17:.*]] = gpu.alloc async {{\[}}%[[VAL_14]]] (%[[VAL_15]]) : memref +// CHECK: %[[VAL_18:.*]] = gpu.memcpy async {{\[}}%[[VAL_17]]] %[[VAL_16]], %[[VAL_9]] : memref, memref +// CHECK: %[[VAL_19:.*]] = gpu.wait async +// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_10]], %[[VAL_3]] : memref +// CHECK: %[[VAL_21:.*]], %[[VAL_22:.*]] = gpu.alloc async {{\[}}%[[VAL_19]]] (%[[VAL_20]]) : memref +// CHECK: %[[VAL_23:.*]] = gpu.memcpy async {{\[}}%[[VAL_22]]] %[[VAL_21]], %[[VAL_10]] : memref, memref +// CHECK: %[[VAL_24:.*]] = gpu.wait async +// CHECK: %[[VAL_25:.*]] = memref.dim %[[VAL_11]], %[[VAL_3]] : memref +// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]] = gpu.alloc async {{\[}}%[[VAL_24]]] (%[[VAL_25]]) : memref +// CHECK: %[[VAL_28:.*]] = gpu.memcpy async {{\[}}%[[VAL_27]]] %[[VAL_26]], %[[VAL_11]] : memref, memref +// CHECK: %[[VAL_29:.*]] = gpu.wait async +// CHECK: %[[VAL_30:.*]] = memref.dim %[[VAL_12]], %[[VAL_3]] : memref +// CHECK: %[[VAL_31:.*]] = memref.dim %[[VAL_12]], %[[VAL_4]] : memref +// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = gpu.alloc async {{\[}}%[[VAL_29]]] (%[[VAL_30]], %[[VAL_31]]) : memref +// CHECK: %[[VAL_34:.*]] = gpu.memcpy async {{\[}}%[[VAL_33]]] %[[VAL_32]], %[[VAL_12]] : memref, memref // CHECK: %[[VAL_35:.*]] = gpu.wait async -// CHECK: %[[VAL_36:.*]] = memref.dim %[[VAL_34]], %[[VAL_3]] : memref -// CHECK: %[[VAL_37:.*]] = memref.dim %[[VAL_34]], %[[VAL_4]] : memref +// CHECK: %[[VAL_36:.*]] = memref.dim %[[VAL_13]], %[[VAL_3]] : memref +// CHECK: %[[VAL_37:.*]] = memref.dim %[[VAL_13]], %[[VAL_4]] : memref // CHECK: %[[VAL_38:.*]], %[[VAL_39:.*]] = gpu.alloc async {{\[}}%[[VAL_35]]] (%[[VAL_36]], %[[VAL_37]]) : memref -// CHECK: %[[VAL_40:.*]] = gpu.memcpy async {{\[}}%[[VAL_39]]] %[[VAL_38]], %[[VAL_34]] : memref, memref -// CHECK: gpu.wait {{\[}}%[[VAL_16]], %[[VAL_21]], %[[VAL_26]], %[[VAL_33]], %[[VAL_40]]] +// CHECK: %[[VAL_40:.*]] = gpu.memcpy async {{\[}}%[[VAL_39]]] %[[VAL_38]], %[[VAL_13]] : memref, memref +// CHECK: gpu.wait {{\[}}%[[VAL_18]], %[[VAL_23]], %[[VAL_28]], %[[VAL_34]], %[[VAL_40]]] // CHECK: %[[VAL_41:.*]] = gpu.wait async -// CHECK: %[[VAL_44:.*]], %[[VAL_45:.*]] = gpu.create_csr async {{\[}}%[[VAL_41]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_5]], %[[VAL_14]], %[[VAL_19]], %[[VAL_24]] : memref, memref, memref -// CHECK: %[[VAL_46:.*]], %[[VAL_47:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_45]]] %[[VAL_31]], %[[VAL_7]], %[[VAL_8]] : index, index into memref -// CHECK: %[[VAL_48:.*]], %[[VAL_49:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_47]]] %[[VAL_38]], %[[VAL_6]], %[[VAL_8]] : index, index into memref -// CHECK: %[[VAL_50:.*]], %[[VAL_51:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_49]]] %[[VAL_44]], %[[VAL_46]], %[[VAL_48]] : index -// CHECK: %[[VAL_52:.*]], %[[VAL_53:.*]] = gpu.alloc async {{\[}}%[[VAL_51]]] (%[[VAL_50]]) : memref -// CHECK: %[[VAL_54:.*]] = gpu.spmm async {{\[}}%[[VAL_53]]] %[[VAL_44]], %[[VAL_46]], %[[VAL_48]], %[[VAL_52]] : memref -// CHECK: %[[VAL_55:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_54]]] %[[VAL_44]] -// CHECK: %[[VAL_56:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_55]]] %[[VAL_46]] -// CHECK: %[[VAL_57:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_56]]] %[[VAL_48]] -// CHECK: %[[VAL_59:.*]] = gpu.dealloc async {{\[}}%[[VAL_57]]] %[[VAL_14]] : memref -// CHECK: %[[VAL_60:.*]] = gpu.dealloc async {{\[}}%[[VAL_59]]] %[[VAL_19]] : memref -// CHECK: %[[VAL_61:.*]] = gpu.dealloc async {{\[}}%[[VAL_60]]] %[[VAL_24]] : memref -// CHECK: %[[VAL_62:.*]] = gpu.dealloc async {{\[}}%[[VAL_61]]] %[[VAL_52]] : memref -// CHECK: %[[VAL_63:.*]] = gpu.dealloc async {{\[}}%[[VAL_62]]] %[[VAL_31]] : memref -// CHECK: %[[VAL_64:.*]] = gpu.memcpy async {{\[}}%[[VAL_63]]] %[[VAL_34]], %[[VAL_38]] : memref, memref -// CHECK: %[[VAL_65:.*]] = gpu.dealloc async {{\[}}%[[VAL_64]]] %[[VAL_38]] : memref -// CHECK: gpu.wait {{\[}}%[[VAL_65]]] -// CHECK: %[[VAL_66:.*]] = bufferization.to_tensor %[[VAL_34]] : memref -// CHECK: return %[[VAL_66]] : tensor +// CHECK: %[[VAL_42:.*]], %[[VAL_43:.*]] = gpu.create_csr async {{\[}}%[[VAL_41]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_5]], %[[VAL_16]], %[[VAL_21]], %[[VAL_26]] : memref, memref, memref +// CHECK: %[[VAL_44:.*]], %[[VAL_45:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_43]]] %[[VAL_32]], %[[VAL_7]], %[[VAL_8]] : index, index into memref +// CHECK: %[[VAL_46:.*]], %[[VAL_47:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_45]]] %[[VAL_38]], %[[VAL_6]], %[[VAL_8]] : index, index into memref +// CHECK: %[[VAL_48:.*]], %[[VAL_49:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_47]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]] : index into f64 +// CHECK: %[[VAL_50:.*]], %[[VAL_51:.*]] = gpu.alloc async {{\[}}%[[VAL_49]]] (%[[VAL_48]]) : memref +// CHECK: %[[VAL_52:.*]] = gpu.spmm async {{\[}}%[[VAL_51]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]], %[[VAL_50]] : memref into f64 +// CHECK: %[[VAL_53:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_52]]] %[[VAL_42]] +// CHECK: %[[VAL_54:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_53]]] %[[VAL_44]] +// CHECK: %[[VAL_55:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_54]]] %[[VAL_46]] +// CHECK: %[[VAL_56:.*]] = gpu.dealloc async {{\[}}%[[VAL_55]]] %[[VAL_16]] : memref +// CHECK: %[[VAL_57:.*]] = gpu.dealloc async {{\[}}%[[VAL_56]]] %[[VAL_21]] : memref +// CHECK: %[[VAL_58:.*]] = gpu.dealloc async {{\[}}%[[VAL_57]]] %[[VAL_26]] : memref +// CHECK: %[[VAL_59:.*]] = gpu.dealloc async {{\[}}%[[VAL_58]]] %[[VAL_50]] : memref +// CHECK: %[[VAL_60:.*]] = gpu.dealloc async {{\[}}%[[VAL_59]]] %[[VAL_32]] : memref +// CHECK: %[[VAL_61:.*]] = gpu.memcpy async {{\[}}%[[VAL_60]]] %[[VAL_13]], %[[VAL_38]] : memref, memref +// CHECK: %[[VAL_62:.*]] = gpu.dealloc async {{\[}}%[[VAL_61]]] %[[VAL_38]] : memref +// CHECK: gpu.wait {{\[}}%[[VAL_62]]] +// CHECK: %[[VAL_63:.*]] = bufferization.to_tensor %[[VAL_13]] : memref +// CHECK: return %[[VAL_63]] : tensor // CHECK: } func.func @matmul(%A: tensor, %B: tensor, %C_in: tensor) -> tensor { %C_out = linalg.matmul diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir --- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir @@ -8,30 +8,30 @@ // CHECK: %[[VAL_3:.*]] = arith.constant 0 : index // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_0]] : memref -// CHECK: %[[VAL_6:.*]] = gpu.wait async -// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_5]], %[[VAL_3]] : memref -// CHECK: %[[VAL_8:.*]] = memref.dim %[[VAL_5]], %[[VAL_4]] : memref -// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = gpu.alloc async {{\[}}%[[VAL_6]]] (%[[VAL_7]], %[[VAL_8]]) : memref -// CHECK: %[[VAL_11:.*]] = gpu.memcpy async {{\[}}%[[VAL_10]]] %[[VAL_9]], %[[VAL_5]] : memref, memref -// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref -// CHECK: %[[VAL_13:.*]] = gpu.wait async -// CHECK: %[[VAL_14:.*]] = memref.dim %[[VAL_12]], %[[VAL_3]] : memref -// CHECK: %[[VAL_15:.*]] = memref.dim %[[VAL_12]], %[[VAL_4]] : memref -// CHECK: %[[VAL_16:.*]], %[[VAL_17:.*]] = gpu.alloc async {{\[}}%[[VAL_13]]] (%[[VAL_14]], %[[VAL_15]]) : memref -// CHECK: %[[VAL_18:.*]] = gpu.memcpy async {{\[}}%[[VAL_17]]] %[[VAL_16]], %[[VAL_12]] : memref, memref -// CHECK: %[[VAL_19:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_1]] : memref +// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK: %[[VAL_8:.*]] = gpu.wait async +// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_5]], %[[VAL_3]] : memref +// CHECK: %[[VAL_10:.*]] = memref.dim %[[VAL_5]], %[[VAL_4]] : memref +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]] = gpu.alloc async {{\[}}%[[VAL_8]]] (%[[VAL_9]], %[[VAL_10]]) : memref +// CHECK: %[[VAL_13:.*]] = gpu.memcpy async {{\[}}%[[VAL_12]]] %[[VAL_11]], %[[VAL_5]] : memref, memref +// CHECK: %[[VAL_14:.*]] = gpu.wait async +// CHECK: %[[VAL_15:.*]] = memref.dim %[[VAL_6]], %[[VAL_3]] : memref +// CHECK: %[[VAL_16:.*]] = memref.dim %[[VAL_6]], %[[VAL_4]] : memref +// CHECK: %[[VAL_17:.*]], %[[VAL_18:.*]] = gpu.alloc async {{\[}}%[[VAL_14]]] (%[[VAL_15]], %[[VAL_16]]) : memref +// CHECK: %[[VAL_19:.*]] = gpu.memcpy async {{\[}}%[[VAL_18]]] %[[VAL_17]], %[[VAL_6]] : memref, memref // CHECK: %[[VAL_20:.*]] = gpu.wait async -// CHECK: %[[VAL_21:.*]] = memref.dim %[[VAL_19]], %[[VAL_3]] : memref -// CHECK: %[[VAL_22:.*]] = memref.dim %[[VAL_19]], %[[VAL_4]] : memref +// CHECK: %[[VAL_21:.*]] = memref.dim %[[VAL_7]], %[[VAL_3]] : memref +// CHECK: %[[VAL_22:.*]] = memref.dim %[[VAL_7]], %[[VAL_4]] : memref // CHECK: %[[VAL_23:.*]], %[[VAL_24:.*]] = gpu.alloc async {{\[}}%[[VAL_20]]] (%[[VAL_21]], %[[VAL_22]]) : memref -// CHECK: %[[VAL_25:.*]] = gpu.memcpy async {{\[}}%[[VAL_24]]] %[[VAL_23]], %[[VAL_19]] : memref, memref -// CHECK: gpu.wait {{\[}}%[[VAL_11]], %[[VAL_18]], %[[VAL_25]]] -// CHECK: %[[VAL_26:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref -// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_16]], %[[VAL_3]] : memref +// CHECK: %[[VAL_25:.*]] = gpu.memcpy async {{\[}}%[[VAL_24]]] %[[VAL_23]], %[[VAL_7]] : memref, memref +// CHECK: gpu.wait {{\[}}%[[VAL_13]], %[[VAL_19]], %[[VAL_25]]] +// CHECK: %[[VAL_26:.*]] = memref.dim %[[VAL_11]], %[[VAL_3]] : memref +// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_17]], %[[VAL_3]] : memref // CHECK: %[[VAL_28:.*]] = memref.dim %[[VAL_23]], %[[VAL_4]] : memref // CHECK: %[[VAL_29:.*]] = gpu.wait async -// CHECK: %[[VAL_30:.*]], %[[VAL_31:.*]] = gpu.create_2to4_spmat async {{\[}}%[[VAL_29]]] %[[VAL_26]], %[[VAL_27]], %[[VAL_9]] : memref -// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_31]]] %[[VAL_16]], %[[VAL_27]], %[[VAL_28]] : index, index into memref +// CHECK: %[[VAL_30:.*]], %[[VAL_31:.*]] = gpu.create_2to4_spmat async {{\[}}%[[VAL_29]]] %[[VAL_26]], %[[VAL_27]], %[[VAL_11]] : memref +// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_31]]] %[[VAL_17]], %[[VAL_27]], %[[VAL_28]] : index, index into memref // CHECK: %[[VAL_34:.*]], %[[VAL_35:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_33]]] %[[VAL_23]], %[[VAL_26]], %[[VAL_28]] : index, index into memref // CHECK: %[[VAL_36:.*]]:3, %[[VAL_37:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_35]]] %[[VAL_30]], %[[VAL_32]], %[[VAL_34]] : index, index, index into f16 // CHECK: %[[VAL_38:.*]], %[[VAL_39:.*]] = gpu.alloc async {{\[}}%[[VAL_37]]] (%[[VAL_36]]#0) : memref @@ -44,12 +44,12 @@ // CHECK: %[[VAL_48:.*]] = gpu.dealloc async {{\[}}%[[VAL_47]]] %[[VAL_38]] : memref // CHECK: %[[VAL_49:.*]] = gpu.dealloc async {{\[}}%[[VAL_48]]] %[[VAL_40]] : memref // CHECK: %[[VAL_50:.*]] = gpu.dealloc async {{\[}}%[[VAL_49]]] %[[VAL_42]] : memref -// CHECK: %[[VAL_51:.*]] = gpu.dealloc async {{\[}}%[[VAL_50]]] %[[VAL_9]] : memref -// CHECK: %[[VAL_52:.*]] = gpu.dealloc async {{\[}}%[[VAL_51]]] %[[VAL_16]] : memref -// CHECK: %[[VAL_53:.*]] = gpu.memcpy async {{\[}}%[[VAL_52]]] %[[VAL_19]], %[[VAL_23]] : memref, memref +// CHECK: %[[VAL_51:.*]] = gpu.dealloc async {{\[}}%[[VAL_50]]] %[[VAL_11]] : memref +// CHECK: %[[VAL_52:.*]] = gpu.dealloc async {{\[}}%[[VAL_51]]] %[[VAL_17]] : memref +// CHECK: %[[VAL_53:.*]] = gpu.memcpy async {{\[}}%[[VAL_52]]] %[[VAL_7]], %[[VAL_23]] : memref, memref // CHECK: %[[VAL_54:.*]] = gpu.dealloc async {{\[}}%[[VAL_53]]] %[[VAL_23]] : memref // CHECK: gpu.wait {{\[}}%[[VAL_54]]] -// CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref +// CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_7]] : memref // CHECK: return %[[VAL_55]] : tensor // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir --- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir @@ -11,57 +11,57 @@ // CHECK-SAME: %[[VAL_0:.*]]: tensor>, // CHECK-SAME: %[[VAL_1:.*]]: tensor, // CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor> -// CHECK-DAG: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> -// CHECK-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor> -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor> to memref> -// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor> to memref> -// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref -// CHECK: %[[VAL_11:.*]] = gpu.wait async -// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_8]], %[[VAL_3]] : memref> -// CHECK: %[[VAL_13:.*]], %[[VAL_14:.*]] = gpu.alloc async {{\[}}%[[VAL_11]]] (%[[VAL_12]]) : memref -// CHECK: %[[VAL_15:.*]] = gpu.memcpy async {{\[}}%[[VAL_14]]] %[[VAL_13]], %[[VAL_8]] : memref, memref> -// CHECK: %[[VAL_16:.*]] = gpu.wait async -// CHECK: %[[VAL_17:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref> -// CHECK: %[[VAL_18:.*]], %[[VAL_19:.*]] = gpu.alloc async {{\[}}%[[VAL_16]]] (%[[VAL_17]]) : memref -// CHECK: %[[VAL_20:.*]] = gpu.memcpy async {{\[}}%[[VAL_19]]] %[[VAL_18]], %[[VAL_9]] : memref, memref> -// CHECK: %[[VAL_21:.*]] = gpu.wait async -// CHECK: %[[VAL_22:.*]] = memref.dim %[[VAL_10]], %[[VAL_3]] : memref -// CHECK: %[[VAL_23:.*]], %[[VAL_24:.*]] = gpu.alloc async {{\[}}%[[VAL_21]]] (%[[VAL_22]]) : memref -// CHECK: %[[VAL_25:.*]] = gpu.memcpy async {{\[}}%[[VAL_24]]] %[[VAL_23]], %[[VAL_10]] : memref, memref -// CHECK: %[[VAL_26:.*]] = bufferization.to_memref %[[VAL_1]] : memref -// CHECK: %[[VAL_27:.*]] = gpu.wait async -// CHECK: %[[VAL_28:.*]] = memref.dim %[[VAL_26]], %[[VAL_3]] : memref -// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]] = gpu.alloc async {{\[}}%[[VAL_27]]] (%[[VAL_28]]) : memref -// CHECK: %[[VAL_31:.*]] = gpu.memcpy async {{\[}}%[[VAL_30]]] %[[VAL_29]], %[[VAL_26]] : memref, memref -// CHECK: %[[VAL_32:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor> +// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor> +// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor> to memref> +// CHECK: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor> to memref> +// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref +// CHECK: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref +// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK: %[[VAL_13:.*]] = gpu.wait async +// CHECK: %[[VAL_14:.*]] = memref.dim %[[VAL_8]], %[[VAL_3]] : memref> +// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = gpu.alloc async {{\[}}%[[VAL_13]]] (%[[VAL_14]]) : memref +// CHECK: %[[VAL_17:.*]] = gpu.memcpy async {{\[}}%[[VAL_16]]] %[[VAL_15]], %[[VAL_8]] : memref, memref> +// CHECK: %[[VAL_18:.*]] = gpu.wait async +// CHECK: %[[VAL_19:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref> +// CHECK: %[[VAL_20:.*]], %[[VAL_21:.*]] = gpu.alloc async {{\[}}%[[VAL_18]]] (%[[VAL_19]]) : memref +// CHECK: %[[VAL_22:.*]] = gpu.memcpy async {{\[}}%[[VAL_21]]] %[[VAL_20]], %[[VAL_9]] : memref, memref> +// CHECK: %[[VAL_23:.*]] = gpu.wait async +// CHECK: %[[VAL_24:.*]] = memref.dim %[[VAL_10]], %[[VAL_3]] : memref +// CHECK: %[[VAL_25:.*]], %[[VAL_26:.*]] = gpu.alloc async {{\[}}%[[VAL_23]]] (%[[VAL_24]]) : memref +// CHECK: %[[VAL_27:.*]] = gpu.memcpy async {{\[}}%[[VAL_26]]] %[[VAL_25]], %[[VAL_10]] : memref, memref +// CHECK: %[[VAL_28:.*]] = gpu.wait async +// CHECK: %[[VAL_29:.*]] = memref.dim %[[VAL_11]], %[[VAL_3]] : memref +// CHECK: %[[VAL_30:.*]], %[[VAL_31:.*]] = gpu.alloc async {{\[}}%[[VAL_28]]] (%[[VAL_29]]) : memref +// CHECK: %[[VAL_32:.*]] = gpu.memcpy async {{\[}}%[[VAL_31]]] %[[VAL_30]], %[[VAL_11]] : memref, memref // CHECK: %[[VAL_33:.*]] = gpu.wait async -// CHECK: %[[VAL_34:.*]] = memref.dim %[[VAL_32]], %[[VAL_3]] : memref +// CHECK: %[[VAL_34:.*]] = memref.dim %[[VAL_12]], %[[VAL_3]] : memref // CHECK: %[[VAL_35:.*]], %[[VAL_36:.*]] = gpu.alloc async {{\[}}%[[VAL_33]]] (%[[VAL_34]]) : memref -// CHECK: %[[VAL_37:.*]] = gpu.memcpy async {{\[}}%[[VAL_36]]] %[[VAL_35]], %[[VAL_32]] : memref, memref -// CHECK: gpu.wait {{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]], %[[VAL_31]], %[[VAL_37]]] +// CHECK: %[[VAL_37:.*]] = gpu.memcpy async {{\[}}%[[VAL_36]]] %[[VAL_35]], %[[VAL_12]] : memref, memref +// CHECK: gpu.wait {{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_27]], %[[VAL_32]], %[[VAL_37]]] // CHECK: %[[VAL_38:.*]] = gpu.wait async -// CHECK: %[[VAL_41:.*]], %[[VAL_42:.*]] = gpu.create_coo async {{\[}}%[[VAL_38]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_5]], %[[VAL_13]], %[[VAL_18]], %[[VAL_23]] : memref, memref, memref -// CHECK: %[[VAL_43:.*]], %[[VAL_44:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_42]]] %[[VAL_29]], %[[VAL_7]] : index into memref -// CHECK: %[[VAL_45:.*]], %[[VAL_46:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_44]]] %[[VAL_35]], %[[VAL_6]] : index into memref -// CHECK: %[[VAL_47:.*]], %[[VAL_48:.*]] = gpu.spmv_buffer_size async {{\[}}%[[VAL_46]]] %[[VAL_41]], %[[VAL_43]], %[[VAL_45]] -// CHECK: %[[VAL_49:.*]], %[[VAL_50:.*]] = gpu.alloc async {{\[}}%[[VAL_48]]] (%[[VAL_47]]) : memref -// CHECK: %[[VAL_51:.*]] = gpu.spmv async {{\[}}%[[VAL_50]]] %[[VAL_41]], %[[VAL_43]], %[[VAL_45]], %[[VAL_49]] : memref -// CHECK: %[[VAL_52:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_51]]] %[[VAL_41]] -// CHECK: %[[VAL_53:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_52]]] %[[VAL_43]] -// CHECK: %[[VAL_54:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_53]]] %[[VAL_45]] -// CHECK: %[[VAL_56:.*]] = gpu.dealloc async {{\[}}%[[VAL_54]]] %[[VAL_13]] : memref -// CHECK: %[[VAL_57:.*]] = gpu.dealloc async {{\[}}%[[VAL_56]]] %[[VAL_18]] : memref -// CHECK: %[[VAL_58:.*]] = gpu.dealloc async {{\[}}%[[VAL_57]]] %[[VAL_23]] : memref -// CHECK: %[[VAL_59:.*]] = gpu.dealloc async {{\[}}%[[VAL_58]]] %[[VAL_49]] : memref -// CHECK: %[[VAL_60:.*]] = gpu.dealloc async {{\[}}%[[VAL_59]]] %[[VAL_29]] : memref -// CHECK: %[[VAL_61:.*]] = gpu.memcpy async {{\[}}%[[VAL_60]]] %[[VAL_32]], %[[VAL_35]] : memref, memref -// CHECK: %[[VAL_62:.*]] = gpu.dealloc async {{\[}}%[[VAL_61]]] %[[VAL_35]] : memref -// CHECK: gpu.wait {{\[}}%[[VAL_62]]] -// CHECK: %[[VAL_63:.*]] = bufferization.to_tensor %[[VAL_32]] : memref -// CHECK: return %[[VAL_63]] : tensor +// CHECK: %[[VAL_39:.*]], %[[VAL_40:.*]] = gpu.create_coo async {{\[}}%[[VAL_38]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_5]], %[[VAL_15]], %[[VAL_20]], %[[VAL_25]] : memref, memref, memref +// CHECK: %[[VAL_41:.*]], %[[VAL_42:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_40]]] %[[VAL_30]], %[[VAL_7]] : index into memref +// CHECK: %[[VAL_43:.*]], %[[VAL_44:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_42]]] %[[VAL_35]], %[[VAL_6]] : index into memref +// CHECK: %[[VAL_45:.*]], %[[VAL_46:.*]] = gpu.spmv_buffer_size async {{\[}}%[[VAL_44]]] %[[VAL_39]], %[[VAL_41]], %[[VAL_43]] into f64 +// CHECK: %[[VAL_47:.*]], %[[VAL_48:.*]] = gpu.alloc async {{\[}}%[[VAL_46]]] (%[[VAL_45]]) : memref +// CHECK: %[[VAL_49:.*]] = gpu.spmv async {{\[}}%[[VAL_48]]] %[[VAL_39]], %[[VAL_41]], %[[VAL_43]], %[[VAL_47]] : memref into f64 +// CHECK: %[[VAL_50:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_49]]] %[[VAL_39]] +// CHECK: %[[VAL_51:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_50]]] %[[VAL_41]] +// CHECK: %[[VAL_52:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_51]]] %[[VAL_43]] +// CHECK: %[[VAL_53:.*]] = gpu.dealloc async {{\[}}%[[VAL_52]]] %[[VAL_15]] : memref +// CHECK: %[[VAL_54:.*]] = gpu.dealloc async {{\[}}%[[VAL_53]]] %[[VAL_20]] : memref +// CHECK: %[[VAL_55:.*]] = gpu.dealloc async {{\[}}%[[VAL_54]]] %[[VAL_25]] : memref +// CHECK: %[[VAL_56:.*]] = gpu.dealloc async {{\[}}%[[VAL_55]]] %[[VAL_47]] : memref +// CHECK: %[[VAL_57:.*]] = gpu.dealloc async {{\[}}%[[VAL_56]]] %[[VAL_30]] : memref +// CHECK: %[[VAL_58:.*]] = gpu.memcpy async {{\[}}%[[VAL_57]]] %[[VAL_12]], %[[VAL_35]] : memref, memref +// CHECK: %[[VAL_59:.*]] = gpu.dealloc async {{\[}}%[[VAL_58]]] %[[VAL_35]] : memref +// CHECK: gpu.wait {{\[}}%[[VAL_59]]] +// CHECK: %[[VAL_60:.*]] = bufferization.to_tensor %[[VAL_12]] : memref +// CHECK: return %[[VAL_60]] : tensor // CHECK: } func.func @matvec(%A: tensor, %x: tensor, diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir --- a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir @@ -29,49 +29,49 @@ // CHECK: %[[VAL_4:.*]] = arith.constant 0 : index // CHECK: %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> // CHECK: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64> -// CHECK: %[[VAL_7:.*]] = gpu.wait async -// CHECK: %[[VAL_8:.*]], %[[VAL_9:.*]] = gpu.alloc async {{\[}}%[[VAL_7]]] () : memref<8x8xf64> -// CHECK: %[[VAL_10:.*]] = gpu.memcpy async {{\[}}%[[VAL_9]]] %[[VAL_8]], %[[VAL_6]] : memref<8x8xf64>, memref<8x8xf64> -// CHECK: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64> -// CHECK: %[[VAL_12:.*]] = gpu.wait async -// CHECK: %[[VAL_13:.*]], %[[VAL_14:.*]] = gpu.alloc async {{\[}}%[[VAL_12]]] () : memref<8x8xf64> -// CHECK: %[[VAL_15:.*]] = gpu.memcpy async {{\[}}%[[VAL_14]]] %[[VAL_13]], %[[VAL_11]] : memref<8x8xf64>, memref<8x8xf64> -// CHECK: %[[VAL_16:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref -// CHECK: %[[VAL_17:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref -// CHECK: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64> +// CHECK: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_11:.*]] = gpu.wait async +// CHECK: %[[VAL_12:.*]], %[[VAL_13:.*]] = gpu.alloc async {{\[}}%[[VAL_11]]] () : memref<8x8xf64> +// CHECK: %[[VAL_14:.*]] = gpu.memcpy async {{\[}}%[[VAL_13]]] %[[VAL_12]], %[[VAL_6]] : memref<8x8xf64>, memref<8x8xf64> +// CHECK: %[[VAL_15:.*]] = gpu.wait async +// CHECK: %[[VAL_16:.*]], %[[VAL_17:.*]] = gpu.alloc async {{\[}}%[[VAL_15]]] () : memref<8x8xf64> +// CHECK: %[[VAL_18:.*]] = gpu.memcpy async {{\[}}%[[VAL_17]]] %[[VAL_16]], %[[VAL_7]] : memref<8x8xf64>, memref<8x8xf64> // CHECK: %[[VAL_19:.*]] = gpu.wait async -// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_16]], %[[VAL_4]] : memref +// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_8]], %[[VAL_4]] : memref // CHECK: %[[VAL_21:.*]], %[[VAL_22:.*]] = gpu.alloc async {{\[}}%[[VAL_19]]] (%[[VAL_20]]) : memref -// CHECK: %[[VAL_23:.*]] = gpu.memcpy async {{\[}}%[[VAL_22]]] %[[VAL_21]], %[[VAL_16]] : memref, memref +// CHECK: %[[VAL_23:.*]] = gpu.memcpy async {{\[}}%[[VAL_22]]] %[[VAL_21]], %[[VAL_8]] : memref, memref // CHECK: %[[VAL_24:.*]] = gpu.wait async -// CHECK: %[[VAL_25:.*]] = memref.dim %[[VAL_17]], %[[VAL_4]] : memref +// CHECK: %[[VAL_25:.*]] = memref.dim %[[VAL_9]], %[[VAL_4]] : memref // CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]] = gpu.alloc async {{\[}}%[[VAL_24]]] (%[[VAL_25]]) : memref -// CHECK: %[[VAL_28:.*]] = gpu.memcpy async {{\[}}%[[VAL_27]]] %[[VAL_26]], %[[VAL_17]] : memref, memref +// CHECK: %[[VAL_28:.*]] = gpu.memcpy async {{\[}}%[[VAL_27]]] %[[VAL_26]], %[[VAL_9]] : memref, memref // CHECK: %[[VAL_29:.*]] = gpu.wait async -// CHECK: %[[VAL_30:.*]] = memref.dim %[[VAL_18]], %[[VAL_4]] : memref +// CHECK: %[[VAL_30:.*]] = memref.dim %[[VAL_10]], %[[VAL_4]] : memref // CHECK: %[[VAL_31:.*]], %[[VAL_32:.*]] = gpu.alloc async {{\[}}%[[VAL_29]]] (%[[VAL_30]]) : memref -// CHECK: %[[VAL_33:.*]] = gpu.memcpy async {{\[}}%[[VAL_32]]] %[[VAL_31]], %[[VAL_18]] : memref, memref -// CHECK: gpu.wait {{\[}}%[[VAL_10]], %[[VAL_15]], %[[VAL_23]], %[[VAL_28]], %[[VAL_33]]] +// CHECK: %[[VAL_33:.*]] = gpu.memcpy async {{\[}}%[[VAL_32]]] %[[VAL_31]], %[[VAL_10]] : memref, memref +// CHECK: gpu.wait {{\[}}%[[VAL_14]], %[[VAL_18]], %[[VAL_23]], %[[VAL_28]], %[[VAL_33]]] // CHECK: %[[VAL_34:.*]] = gpu.wait async -// CHECK: %[[VAL_37:.*]], %[[VAL_38:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_34]]] %[[VAL_8]], %[[VAL_3]], %[[VAL_3]] : index, index into memref<8x8xf64> -// CHECK: %[[VAL_39:.*]], %[[VAL_40:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_38]]] %[[VAL_13]], %[[VAL_3]], %[[VAL_3]] : index, index into memref<8x8xf64> -// CHECK: %[[VAL_41:.*]], %[[VAL_42:.*]] = gpu.create_csr async {{\[}}%[[VAL_40]]] %[[VAL_3]], %[[VAL_3]], %[[VAL_5]], %[[VAL_21]], %[[VAL_26]], %[[VAL_31]] : memref, memref, memref -// CHECK: %[[VAL_43:.*]], %[[VAL_44:.*]] = gpu.sddmm_buffer_size async {{\[}}%[[VAL_42]]] %[[VAL_37]], %[[VAL_39]], %[[VAL_41]] into f64 -// CHECK: %[[VAL_45:.*]], %[[VAL_46:.*]] = gpu.alloc async {{\[}}%[[VAL_44]]] (%[[VAL_43]]) : memref -// CHECK: %[[VAL_47:.*]] = gpu.sddmm async {{\[}}%[[VAL_46]]] %[[VAL_37]], %[[VAL_39]], %[[VAL_41]], %[[VAL_45]] : memref into f64 -// CHECK: %[[VAL_48:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_47]]] %[[VAL_37]] -// CHECK: %[[VAL_49:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_48]]] %[[VAL_39]] -// CHECK: %[[VAL_50:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_49]]] %[[VAL_41]] -// CHECK: %[[VAL_52:.*]] = gpu.dealloc async {{\[}}%[[VAL_50]]] %[[VAL_45]] : memref -// CHECK: %[[VAL_53:.*]] = gpu.dealloc async {{\[}}%[[VAL_52]]] %[[VAL_8]] : memref<8x8xf64> -// CHECK: %[[VAL_54:.*]] = gpu.dealloc async {{\[}}%[[VAL_53]]] %[[VAL_13]] : memref<8x8xf64> -// CHECK: %[[VAL_55:.*]] = gpu.dealloc async {{\[}}%[[VAL_54]]] %[[VAL_21]] : memref -// CHECK: %[[VAL_56:.*]] = gpu.dealloc async {{\[}}%[[VAL_55]]] %[[VAL_26]] : memref -// CHECK: %[[VAL_57:.*]] = gpu.memcpy async {{\[}}%[[VAL_56]]] %[[VAL_18]], %[[VAL_31]] : memref, memref -// CHECK: %[[VAL_58:.*]] = gpu.dealloc async {{\[}}%[[VAL_57]]] %[[VAL_31]] : memref -// CHECK: gpu.wait {{\[}}%[[VAL_58]]] -// CHECK: %[[VAL_59:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> -// CHECK: return %[[VAL_59]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> +// CHECK: %[[VAL_35:.*]], %[[VAL_36:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_34]]] %[[VAL_12]], %[[VAL_3]], %[[VAL_3]] : index, index into memref<8x8xf64> +// CHECK: %[[VAL_37:.*]], %[[VAL_38:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_36]]] %[[VAL_16]], %[[VAL_3]], %[[VAL_3]] : index, index into memref<8x8xf64> +// CHECK: %[[VAL_39:.*]], %[[VAL_40:.*]] = gpu.create_csr async {{\[}}%[[VAL_38]]] %[[VAL_3]], %[[VAL_3]], %[[VAL_5]], %[[VAL_21]], %[[VAL_26]], %[[VAL_31]] : memref, memref, memref +// CHECK: %[[VAL_41:.*]], %[[VAL_42:.*]] = gpu.sddmm_buffer_size async {{\[}}%[[VAL_40]]] %[[VAL_35]], %[[VAL_37]], %[[VAL_39]] into f64 +// CHECK: %[[VAL_43:.*]], %[[VAL_44:.*]] = gpu.alloc async {{\[}}%[[VAL_42]]] (%[[VAL_41]]) : memref +// CHECK: %[[VAL_45:.*]] = gpu.sddmm async {{\[}}%[[VAL_44]]] %[[VAL_35]], %[[VAL_37]], %[[VAL_39]], %[[VAL_43]] : memref into f64 +// CHECK: %[[VAL_46:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_45]]] %[[VAL_35]] +// CHECK: %[[VAL_47:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_46]]] %[[VAL_37]] +// CHECK: %[[VAL_48:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_47]]] %[[VAL_39]] +// CHECK: %[[VAL_49:.*]] = gpu.dealloc async {{\[}}%[[VAL_48]]] %[[VAL_43]] : memref +// CHECK: %[[VAL_50:.*]] = gpu.dealloc async {{\[}}%[[VAL_49]]] %[[VAL_12]] : memref<8x8xf64> +// CHECK: %[[VAL_51:.*]] = gpu.dealloc async {{\[}}%[[VAL_50]]] %[[VAL_16]] : memref<8x8xf64> +// CHECK: %[[VAL_52:.*]] = gpu.dealloc async {{\[}}%[[VAL_51]]] %[[VAL_21]] : memref +// CHECK: %[[VAL_53:.*]] = gpu.dealloc async {{\[}}%[[VAL_52]]] %[[VAL_26]] : memref +// CHECK: %[[VAL_54:.*]] = gpu.memcpy async {{\[}}%[[VAL_53]]] %[[VAL_10]], %[[VAL_31]] : memref, memref +// CHECK: %[[VAL_55:.*]] = gpu.dealloc async {{\[}}%[[VAL_54]]] %[[VAL_31]] : memref +// CHECK: gpu.wait {{\[}}%[[VAL_55]]] +// CHECK: %[[VAL_56:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> +// CHECK: return %[[VAL_56]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> // CHECK: } // // A kernel that computes a direct sampled matrix matrix multiplication diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir @@ -1,13 +1,18 @@ // // NOTE: this test requires gpu-sm80 and cusparselt // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e main --entry-point-result=void \ +// DEFINE: | FileCheck %s + +// RUN: %{compile}" | %{run} +// RUN: %{compile} gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} gpu-data-transfer-strategy=zero-copy" | %{run} #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir @@ -1,14 +1,16 @@ // // NOTE: this test requires gpu-sm80 and cusparselt // -// RUN: mlir-opt --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \ -// RUN: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \ -// RUN: %s \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// DEFINE: %{compile} = mlir-opt --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \ +// DEFINE: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \ +// DEFINE: %s +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e main --entry-point-result=void \ +// DEFINE: | FileCheck %s + +// RUN: %{compile} | %{run} module { llvm.func @mgpuCreateSparseLtEnv() diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir @@ -1,25 +1,28 @@ // // NOTE: this test requires gpu-sm80 // +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e main --entry-point-result=void \ +// DEFINE: | FileCheck %s +// +// // with RT lib (SoA COO): // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=true" | %{run} +// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run} // // without RT lib (AoS COO): note, may fall back to CPU // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=false" | %{run} +// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run} #SortedCOO = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ] diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir @@ -1,25 +1,28 @@ // // NOTE: this test requires gpu-sm80 // +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e main --entry-point-result=void \ +// DEFINE: | FileCheck %s +// // with RT lib (SoA COO): // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=true" | %{run} +// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run} // // without RT lib (AoS COO): note, may fall back to CPU // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e main --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=false" | %{run} +// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run} +// #SortedCOO = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ] diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir @@ -1,28 +1,29 @@ // // NOTE: this test requires gpu-sm80 // +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 +// DEFINE: %{run} = TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: --shared-libs=%mlir_cuda_runtime \ +// DEFINE: --shared-libs=%mlir_c_runner_utils \ +// DEFINE: --e entry --entry-point-result=void \ +// DEFINE: | FileCheck %s +// // with RT lib: // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e entry --entry-point-result=void \ -// RUN: | FileCheck %s +// RUN: %{compile} enable-runtime-library=true" | %{run} +// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run} // // without RT lib: // -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \ -// RUN: | TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: --shared-libs=%mlir_cuda_runtime \ -// RUN: --shared-libs=%mlir_c_runner_utils \ -// RUN: --e entry --entry-point-result=void \ -// RUN: | FileCheck %s -// +// RUN: %{compile} enable-runtime-library=false" | %{run} +// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run} +// Tracker #64316 +// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run} +// !Filename = !llvm.ptr