diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -130,9 +130,16 @@ PassOptions::Option gpuFeatures{*this, "gpu-features", desc("GPU target features")}; + /// This option is used to enable GPU library generation. + PassOptions::Option enableGPULibgen{ + *this, "enable-gpu-libgen", + desc("Enables GPU acceleration by means of direct library calls (like " + "cuSPARSE)")}; + /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { - return SparsificationOptions(parallelization, enableIndexReduction); + return SparsificationOptions(parallelization, enableIndexReduction, + enableGPULibgen, enableRuntimeLibrary); } /// Projects out the options for `createSparseTensorConversionPass`. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -49,12 +49,17 @@ /// Options for the Sparsification pass. struct SparsificationOptions { - SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc) - : parallelizationStrategy(p), enableIndexReduction(idxReduc) {} + SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc, + bool gpuLibgen, bool enableRT) + : parallelizationStrategy(p), enableIndexReduction(idxReduc), + enableGPULibgen(gpuLibgen), enableRuntimeLibrary(enableRT) {} SparsificationOptions() - : SparsificationOptions(SparseParallelizationStrategy::kNone, false) {} + : SparsificationOptions(SparseParallelizationStrategy::kNone, false, + false, true) {} SparseParallelizationStrategy parallelizationStrategy; bool enableIndexReduction; + bool enableGPULibgen; + bool enableRuntimeLibrary; }; /// Sets up sparsification rewriting rules with the given options. @@ -206,6 +211,9 @@ void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads); +void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, + bool enableRT); + std::unique_ptr createSparseGPUCodegenPass(); std::unique_ptr createSparseGPUCodegenPass(unsigned numThreads); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -73,6 +73,7 @@ "affine::AffineDialect", "arith::ArithDialect", "bufferization::BufferizationDialect", + "gpu::GPUDialect", "LLVM::LLVMDialect", "linalg::LinalgDialect", "memref::MemRefDialect", @@ -100,7 +101,12 @@ "Enable dense parallelization for any loop."), clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop, "any-storage-any-loop", - "Enable sparse parallelization for any storage and loop."))}]> + "Enable sparse parallelization for any storage and loop."))}]>, + Option<"enableGPULibgen", "enable-gpu-libgen", "bool", + "false", + "Enable GPU acceleration by means of direct library calls (like cuSPARSE)">, + Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", + "true", "Enable runtime library for manipulating sparse tensors">, ]; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -1,4 +1,4 @@ -//===- SparseGPUCodegen.cpp - Generates GPU code (using CUDA) -------------===// +//===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -18,9 +18,12 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" @@ -140,8 +143,7 @@ SmallVector dynamicSizes; for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { if (shape[r] == ShapedType::kDynamic) { - Value dim = constantIndex(builder, loc, r); - Value dimOp = builder.create(loc, mem, dim); + Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r); dynamicSizes.push_back(dimOp); } } @@ -149,6 +151,15 @@ token, dynamicSizes, ValueRange()); } +// Allocates a void buffer on the device with given size. +static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, + Value token) { + const auto memTp = + MemRefType::get({ShapedType::kDynamic}, builder.getI8Type()); + return builder.create(loc, TypeRange({memTp, token.getType()}), + token, size, ValueRange()); +} + /// Deallocates memory from the device. static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, Value token) { @@ -163,6 +174,26 @@ .getAsyncToken(); } +/// Generates an alloc/copy pair. +static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, + SmallVectorImpl &tokens) { + Value firstToken = genFirstWait(builder, loc); + auto alloc = genAllocMemRef(builder, loc, b, firstToken); + Value devMem = alloc.getResult(0); + Value depToken = alloc.getAsyncToken(); // copy-after-alloc + tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); + return devMem; +} + +/// Generates a memref from tensor operation. +static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, + Value tensor) { + auto tensorType = tensor.getType().cast(); + auto memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + return rewriter.create(loc, memrefType, tensor); +} + /// Prepares the outlined arguments, passing scalars and buffers in. Here we /// assume that the first buffer is the one allocated for output. We create /// a set of properly chained asynchronous allocation/copy pairs to increase @@ -186,12 +217,7 @@ useHostRegistrationForOut = false; continue; } - Value firstToken = genFirstWait(builder, loc); - auto alloc = genAllocMemRef(builder, loc, b, firstToken); - Value devMem = alloc.getResult(0); - Value depToken = alloc.getAsyncToken(); // copy-after-alloc - args.push_back(devMem); - tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); + args.push_back(genAllocCopy(builder, loc, b, tokens)); } return out; } @@ -272,10 +298,216 @@ } //===----------------------------------------------------------------------===// -// Rewriting rules. +// Library helper methods. +//===----------------------------------------------------------------------===// + +/// Helper to detect a * b. +static bool matchMulOfArgs(linalg::GenericOp op, Value val) { + if (auto *def = val.getDefiningOp()) { + if (isa(def) || isa(def)) { + Value a = op.getBlock()->getArguments()[0]; + Value b = op.getBlock()->getArguments()[1]; + return (def->getOperand(0) == a && def->getOperand(1) == b) || + (def->getOperand(0) == b && def->getOperand(1) == a); + } + } + return false; +} + +/// Helper to detect x = x + a * b +static bool matchSumOfMultOfArgs(linalg::GenericOp op) { + auto yieldOp = cast(op.getRegion().front().getTerminator()); + if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { + if (isa(def) || isa(def)) { + Value x = op.getBlock()->getArguments()[2]; + return (def->getOperand(0) == x && + matchMulOfArgs(op, def->getOperand(1))) || + (def->getOperand(1) == x && + matchMulOfArgs(op, def->getOperand(0))); + } + } + return false; +} + +/// Test for sorted COO with suitable data and coordinates types. +static bool isAdmissibleCOO(SparseTensorType &aTp) { + return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) && + aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && + (aTp.getElementType().isF64() || aTp.getElementType().isF32()) && + (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 || + aTp.getCrdWidth() == 64); +} + +/// Test for CSR with suitable data and coordinates types. +static bool isAdmissibleCSR(SparseTensorType &aTp) { + return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && + aTp.isUniqueLvl(1) && + (aTp.getElementType().isF64() || aTp.getElementType().isF32()) && + (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 || + aTp.getCrdWidth() == 64); +} + +/// Generates the first positions/coordinates of a sparse matrix. +static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, + bool isCOO, bool enableRT) { + if (isCOO) { + // Library uses SoA COO, direct IR uses AoS COO. + if (enableRT) + return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0); + return genToCoordinatesBuffer(builder, loc, a); + } + // CSR uses positions. + return genToPositions(builder, loc, a, 1); +} + +/// Generates the second coordinates of a sparse matrix. +static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, + bool isCOO, bool enableRT) { + if (isCOO && !enableRT) + return Value(); // nothing needed + return genToCoordinates(builder, loc, a, 1, /*cooStart=*/0); +} + +/// Generates the sparse matrix multiplication. +static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp, + Type tokenTp, Value token, Value szY, Value szX, + Value nnzA, Value rowA, Value colA, Value valA, + bool isCOO, bool enableRT) { + if (isCOO) { + // Library uses SoA COO, direct IR uses AoS COO. + if (enableRT) + return builder.create(loc, handleTp, tokenTp, token, + szY, szX, nnzA, rowA, colA, valA); + llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); + } + return builder.create(loc, handleTp, tokenTp, token, szY, + szX, nnzA, rowA, colA, valA); +} + +/// Match and rewrite SpMV kernel. +static LogicalResult rewriteSpMV(PatternRewriter &rewriter, + linalg::GenericOp op, bool enableRT) { + Location loc = op.getLoc(); + Value a = op.getOperand(0); + Value x = op.getOperand(1); + Value y = op.getOperand(2); // we have y = Ax + SmallVector tokens; + + // Only admissible sparse matrix format and dense vectors for now. + bool isCOO = false; + SparseTensorType aTp = getSparseTensorType(a); + SparseTensorType xTp = getSparseTensorType(x); + SparseTensorType yTp = getSparseTensorType(y); + if (xTp.hasEncoding() || yTp.hasEncoding()) + return failure(); + if (isAdmissibleCOO(aTp)) { + isCOO = true; + // TODO: CreateCooAoSOp was deprecated, find another way + if (!enableRT) + return failure(); + } else if (isAdmissibleCSR(aTp)) { + isCOO = false; + } else { + return failure(); + } + + // Start sparse kernel and copy data from host to device. + // a : memR/memC/memV -> rowA,colA,valA + // x : memX -> vecX + // y : memY -> vecY + Value nnzA = rewriter.create(loc, a); + Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); + Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); + Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT); + Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT); + Value memV = genToValues(rewriter, loc, a); + Value rowA = genAllocCopy(rewriter, loc, memR, tokens); + Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); + Value valA = genAllocCopy(rewriter, loc, memV, tokens); + Value memX = genTensorToMemref(rewriter, loc, x); + Value vecX = genAllocCopy(rewriter, loc, memX, tokens); + Value memY = genTensorToMemref(rewriter, loc, y); + Value vecY = genAllocCopy(rewriter, loc, memY, tokens); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + + // Create sparse environment and sparse matrix/dense vector handles. + Type indexTp = rewriter.getIndexType(); + Type handleTp = rewriter.getType(); + Type tokenTp = rewriter.getType(); + Value token = genFirstWait(rewriter, loc); + auto env = + rewriter.create(loc, handleTp, tokenTp, token); + Value handle = env.getResult(0); + token = env.getAsyncToken(); + Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY, + szX, nnzA, rowA, colA, valA, isCOO, enableRT); + Value spMatA = spGenA->getResult(0); + token = spGenA->getResult(1); + auto dvecX = rewriter.create(loc, handleTp, tokenTp, + token, vecX, szX); + Value dnX = dvecX.getResult(0); + token = dvecX.getAsyncToken(); + auto dvecY = rewriter.create(loc, handleTp, tokenTp, + token, vecY, szY); + Value dnY = dvecY.getResult(0); + token = dvecY.getAsyncToken(); + + // Precompute buffersize for SpMV. + auto bufferComp = rewriter.create( + loc, indexTp, tokenTp, token, handle, spMatA, dnX, dnY); + Value bufferSz = bufferComp.getResult(0); + token = bufferComp.getAsyncToken(); + auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); + Value buffer = buf.getResult(0); + token = buf.getAsyncToken(); + + // Perform the SpMV. + auto spmvComp = rewriter.create(loc, tokenTp, token, handle, + spMatA, dnX, dnY, buffer); + token = spmvComp.getAsyncToken(); + + // Copy data back to host and free all the resoures. + token = rewriter.create(loc, tokenTp, token, spMatA) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnX) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, dnY) + .getAsyncToken(); + token = rewriter.create(loc, tokenTp, token, handle) + .getAsyncToken(); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + token = genFirstWait(rewriter, loc); + token = genCopyMemRef(rewriter, loc, memY, vecY, token); + token = genDeallocMemRef(rewriter, loc, rowA, token); + if (colA) + token = genDeallocMemRef(rewriter, loc, colA, token); + token = genDeallocMemRef(rewriter, loc, valA, token); + token = genDeallocMemRef(rewriter, loc, buffer, token); + token = genDeallocMemRef(rewriter, loc, vecX, token); + token = genDeallocMemRef(rewriter, loc, vecY, token); + tokens.push_back(token); + genBlockingWait(rewriter, loc, tokens); + tokens.clear(); + + // Done. + rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); + return success(); +} + +/// Match and rewrite SpMM kernel. +static LogicalResult rewriteSpMM(PatternRewriter &rewriter, + linalg::GenericOp op, bool enableRT) { + return failure(); // TODO: implement +} + +//===----------------------------------------------------------------------===// +// Rewriting rules for direct code generation. //===----------------------------------------------------------------------===// -/// Proof-of-concept rewriter. This rule generates a CUDA implementation +/// Proof-of-concept rewriter. This rule generates a GPU implementation /// for each outermost forall loop generated by the sparse compiler. /// TODO: right works with parallelization-strategy=dense-outer-loop /// but give this its own flags in the future @@ -373,13 +605,77 @@ unsigned numThreads; }; +//===----------------------------------------------------------------------===// +// Rewriting rules for library recognition and code generation. +//===----------------------------------------------------------------------===// + +/// Proof-of-concept rewriter. This rule recognizes certain math kernels +/// and replaces these with corresponding calls into the sparse library. +struct LinalgOpRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LinalgOpRewriter(MLIRContext *context, bool rt) + : OpRewritePattern(context), enableRT(rt) {} + + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + if (op.getNumDpsInits() != 1) + return failure(); // reject multi-output + + const unsigned numLoops = op.getNumLoops(); + const unsigned numTensors = op->getNumOperands(); + const auto iteratorTypes = op.getIteratorTypesArray(); + SmallVector maps = op.getIndexingMapsArray(); + + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr i, j, k; + bindDims(getContext(), i, j, k); + + // TODO: more robust patterns, tranposed versions, more kernels... + + // Recognize a SpMV kernel. + if (numLoops == 2 && numTensors == 3 && + linalg::isParallelIterator(iteratorTypes[0]) && + linalg::isReductionIterator(iteratorTypes[1]) && + maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { + return rewriteSpMV(rewriter, op, enableRT); + } + + // Recognize a SpMM kernel. + if (numLoops == 3 && numTensors == 3 && + linalg::isParallelIterator(iteratorTypes[0]) && + linalg::isParallelIterator(iteratorTypes[1]) && + linalg::isReductionIterator(iteratorTypes[2]) && + maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { + return rewriteSpMM(rewriter, op, enableRT); + } + + return failure(); + } + +private: + bool enableRT; +}; + } // namespace //===----------------------------------------------------------------------===// // Public method for populating GPU rewriting rules. +// +// Currently two set of rewriting rules are made available. The first set +// implements direct code generation, currently by means of convering the +// outermost paralell loop into GPU threads. The second set implements +// libary recognition of a set of sparse operations. Eventually, the right +// combination of these two approaches has to be found. //===----------------------------------------------------------------------===// void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads) { patterns.add(patterns.getContext(), numThreads); } + +void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, + bool enableRT) { + patterns.add(patterns.getContext(), enableRT); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -66,14 +66,20 @@ SparsificationPass(const SparsificationOptions &options) { parallelization = options.parallelizationStrategy; enableIndexReduction = options.enableIndexReduction; + enableGPULibgen = options.enableGPULibgen; + enableRuntimeLibrary = options.enableRuntimeLibrary; } void runOnOperation() override { auto *ctx = &getContext(); // Translate strategy flags to strategy options. - SparsificationOptions options(parallelization, enableIndexReduction); - // Apply sparsification and cleanup rewriting. + SparsificationOptions options(parallelization, enableIndexReduction, + enableGPULibgen, enableRuntimeLibrary); + // Apply GPU libgen (if requested), sparsification, and cleanup rewriting. RewritePatternSet patterns(ctx); + if (enableGPULibgen) { + populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); + } populateSparsificationPatterns(patterns, options); scf::ForOp::getCanonicalizationPatterns(patterns, ctx); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" @@ -98,6 +99,7 @@ void getDependentDialects(::mlir::DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); } diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir @@ -0,0 +1,78 @@ +// RUN: mlir-opt %s --linalg-generalize-named-ops \ +// RUN: --sparsification="enable-gpu-libgen" | FileCheck %s + +#SortedCOO = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ] +}> + +module { + +// CHECK-LABEL: func.func @matvec( +// CHECK-SAME: %[[VAL_0:.*]]: tensor>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor> +// CHECK-DAG: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> +// CHECK-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor> +// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor> +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor> +// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> +// CHECK: %[[VAL_11:.*]] = gpu.wait async +// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_8]], %[[VAL_3]] : memref> +// CHECK: %[[VAL_13:.*]], %[[VAL_14:.*]] = gpu.alloc async {{\[}}%[[VAL_11]]] (%[[VAL_12]]) : memref +// CHECK: %[[VAL_15:.*]] = gpu.memcpy async {{\[}}%[[VAL_14]]] %[[VAL_13]], %[[VAL_8]] : memref, memref> +// CHECK: %[[VAL_16:.*]] = gpu.wait async +// CHECK: %[[VAL_17:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref> +// CHECK: %[[VAL_18:.*]], %[[VAL_19:.*]] = gpu.alloc async {{\[}}%[[VAL_16]]] (%[[VAL_17]]) : memref +// CHECK: %[[VAL_20:.*]] = gpu.memcpy async {{\[}}%[[VAL_19]]] %[[VAL_18]], %[[VAL_9]] : memref, memref> +// CHECK: %[[VAL_21:.*]] = gpu.wait async +// CHECK: %[[VAL_22:.*]] = memref.dim %[[VAL_10]], %[[VAL_3]] : memref +// CHECK: %[[VAL_23:.*]], %[[VAL_24:.*]] = gpu.alloc async {{\[}}%[[VAL_21]]] (%[[VAL_22]]) : memref +// CHECK: %[[VAL_25:.*]] = gpu.memcpy async {{\[}}%[[VAL_24]]] %[[VAL_23]], %[[VAL_10]] : memref, memref +// CHECK: %[[VAL_26:.*]] = bufferization.to_memref %[[VAL_1]] : memref +// CHECK: %[[VAL_27:.*]] = gpu.wait async +// CHECK: %[[VAL_28:.*]] = memref.dim %[[VAL_26]], %[[VAL_3]] : memref +// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]] = gpu.alloc async {{\[}}%[[VAL_27]]] (%[[VAL_28]]) : memref +// CHECK: %[[VAL_31:.*]] = gpu.memcpy async {{\[}}%[[VAL_30]]] %[[VAL_29]], %[[VAL_26]] : memref, memref +// CHECK: %[[VAL_32:.*]] = bufferization.to_memref %[[VAL_2]] : memref +// CHECK: %[[VAL_33:.*]] = gpu.wait async +// CHECK: %[[VAL_34:.*]] = memref.dim %[[VAL_32]], %[[VAL_3]] : memref +// CHECK: %[[VAL_35:.*]], %[[VAL_36:.*]] = gpu.alloc async {{\[}}%[[VAL_33]]] (%[[VAL_34]]) : memref +// CHECK: %[[VAL_37:.*]] = gpu.memcpy async {{\[}}%[[VAL_36]]] %[[VAL_35]], %[[VAL_32]] : memref, memref +// CHECK: gpu.wait {{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]], %[[VAL_31]], %[[VAL_37]]] +// CHECK: %[[VAL_38:.*]] = gpu.wait async +// CHECK: %[[VAL_39:.*]], %[[VAL_40:.*]] = gpu.create_sparse_env async {{\[}}%[[VAL_38]]] +// CHECK: %[[VAL_41:.*]], %[[VAL_42:.*]] = gpu.create_coo async {{\[}}%[[VAL_40]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_5]], %[[VAL_13]], %[[VAL_18]], %[[VAL_23]] : memref, memref, memref +// CHECK: %[[VAL_43:.*]], %[[VAL_44:.*]] = gpu.create_dn_vec async {{\[}}%[[VAL_42]]] %[[VAL_29]], %[[VAL_7]] : memref +// CHECK: %[[VAL_45:.*]], %[[VAL_46:.*]] = gpu.create_dn_vec async {{\[}}%[[VAL_44]]] %[[VAL_35]], %[[VAL_6]] : memref +// CHECK: %[[VAL_47:.*]], %[[VAL_48:.*]] = gpu.spmv_buffer_size async {{\[}}%[[VAL_46]]] %[[VAL_39]], %[[VAL_41]], %[[VAL_43]], %[[VAL_45]] +// CHECK: %[[VAL_49:.*]], %[[VAL_50:.*]] = gpu.alloc async {{\[}}%[[VAL_48]]] (%[[VAL_47]]) : memref +// CHECK: %[[VAL_51:.*]] = gpu.spmv async {{\[}}%[[VAL_50]]] %[[VAL_39]], %[[VAL_41]], %[[VAL_43]], %[[VAL_45]], %[[VAL_49]] : memref +// CHECK: %[[VAL_52:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_51]]] %[[VAL_41]] +// CHECK: %[[VAL_53:.*]] = gpu.destroy_dn_vec async {{\[}}%[[VAL_52]]] %[[VAL_43]] +// CHECK: %[[VAL_54:.*]] = gpu.destroy_dn_vec async {{\[}}%[[VAL_53]]] %[[VAL_45]] +// CHECK: %[[VAL_55:.*]] = gpu.destroy_sparse_env async {{\[}}%[[VAL_54]]] %[[VAL_39]] +// CHECK: gpu.wait {{\[}}%[[VAL_55]]] +// CHECK: %[[VAL_56:.*]] = gpu.wait async +// CHECK: %[[VAL_57:.*]] = gpu.memcpy async {{\[}}%[[VAL_56]]] %[[VAL_32]], %[[VAL_35]] : memref, memref +// CHECK: %[[VAL_58:.*]] = gpu.dealloc async {{\[}}%[[VAL_57]]] %[[VAL_13]] : memref +// CHECK: %[[VAL_59:.*]] = gpu.dealloc async {{\[}}%[[VAL_58]]] %[[VAL_18]] : memref +// CHECK: %[[VAL_60:.*]] = gpu.dealloc async {{\[}}%[[VAL_59]]] %[[VAL_23]] : memref +// CHECK: %[[VAL_61:.*]] = gpu.dealloc async {{\[}}%[[VAL_60]]] %[[VAL_49]] : memref +// CHECK: %[[VAL_62:.*]] = gpu.dealloc async {{\[}}%[[VAL_61]]] %[[VAL_29]] : memref +// CHECK: %[[VAL_63:.*]] = gpu.dealloc async {{\[}}%[[VAL_62]]] %[[VAL_35]] : memref +// CHECK: gpu.wait {{\[}}%[[VAL_63]]] +// CHECK: return %[[VAL_2]] : tensor +// CHECK: } +func.func @matvec(%A: tensor, + %x: tensor, + %y_in: tensor) -> tensor { + %y_out = linalg.matvec + ins(%A, %x: tensor, tensor) + outs(%y_in: tensor) -> tensor + return %y_out : tensor +} + +}