diff --git a/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir b/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir @@ -0,0 +1,142 @@ +// RUN: mlir-opt %s \ +// RUN: --test-sparsification="lower ptr-type=2 ind-type=2 fast-output" \ +// RUN: --convert-linalg-to-loops \ +// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ +// RUN: --std-bufferize --finalizing-bufferize \ +// RUN: --convert-scf-to-std --convert-vector-to-llvm --convert-std-to-llvm | \ +// RUN: TENSOR0="%mlir_integration_test_dir/data/test.mtx" \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// +// Use descriptive names for opaque pointers. +// +!Filename = type !llvm.ptr +!SparseTensor = type !llvm.ptr + +#trait_sampled_dense_dense = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j)>, // S + affine_map<(i,j,k) -> (i,k)>, // A + affine_map<(i,j,k) -> (k,j)>, // B + affine_map<(i,j,k) -> (i,j)> // X (out) + ], + sparse = [ + [ "S", "S" ], // S + [ "D", "D" ], // A + [ "D", "D" ], // B + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)" +} + +// +// Integration test that lowers a kernel annotated as sparse to +// actual sparse code, initializes a matching sparse storage scheme +// from file, and runs the resulting code with the JIT compiler. +// +module { + // + // The kernel expressed as an annotated Linalg op. The kernel + // computes a sampled matrix matrix multiplication. + // + func @sampled_dense_dense(%argS: !SparseTensor, + %arga: tensor, + %argb: tensor, + %argx: tensor) -> tensor { + %args = linalg.sparse_tensor %argS : !SparseTensor to tensor + %0 = linalg.generic #trait_sampled_dense_dense + ins(%args, %arga, %argb: tensor, tensor, tensor) + outs(%argx: tensor) { + ^bb(%s: f32, %a: f32, %b: f32, %x: f32): + %0 = mulf %a, %b : f32 + %1 = mulf %s, %0 : f32 + %2 = addf %x, %1 : f32 + linalg.yield %2 : f32 + } -> tensor + return %0 : tensor + } + + // + // Runtime support library that is called directly from here. + // + func private @getTensorFilename(index) -> (!Filename) + func private @newSparseTensor(!Filename, memref, index, index, index) -> (!SparseTensor) + func private @delSparseTensor(!SparseTensor) -> () + func private @print_memref_f32(%ptr : tensor<*xf32>) + + // + // Main driver that reads matrix from file and calls the sparse kernel. + // + func @entry() { + %d0 = constant 0.0 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c5 = constant 5 : index + %c10 = constant 10 : index + + // Mark both dimensions of the matrix as sparse and encode the + // storage scheme types (this must match the metadata in the + // trait and compiler switches). + %annotations = alloc(%c2) : memref + %sparse = constant true + store %sparse, %annotations[%c0] : memref + store %sparse, %annotations[%c1] : memref + %i32 = constant 3 : index + %f32 = constant 1 : index + + // Setup memory for the dense matrices and initialize. + %adata = alloc(%c5, %c10) : memref + %bdata = alloc(%c10, %c5) : memref + %xdata = alloc(%c5, %c5) : memref + scf.for %i = %c0 to %c5 step %c1 { + scf.for %j = %c0 to %c5 step %c1 { + store %d0, %xdata[%i, %j] : memref + } + %p = addi %i, %c1 : index + %q = index_cast %p : index to i32 + %d = sitofp %q : i32 to f32 + scf.for %j = %c0 to %c10 step %c1 { + store %d, %adata[%i, %j] : memref + store %d, %bdata[%j, %i] : memref + } + } + %a = tensor_load %adata : memref + %b = tensor_load %bdata : memref + %x = tensor_load %xdata : memref + + // Read the sparse matrix from file, construct sparse storage + // according to in memory, and call the kernel. + %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename) + %s = call @newSparseTensor(%fileName, %annotations, %i32, %i32, %f32) + : (!Filename, memref, index, index, index) -> (!SparseTensor) + %0 = call @sampled_dense_dense(%s, %a, %b, %x) + : (!SparseTensor, tensor, tensor, tensor) -> tensor + + // Print the result for verification. + // + // CHECK: ( 10, 0, 0, 56, 0 ) + // CHECK: ( 0, 80, 0, 0, 250 ) + // CHECK: ( 0, 0, 270, 0, 0 ) + // CHECK: ( 164, 0, 0, 640, 0 ) + // CHECK: ( 0, 520, 0, 0, 1250 ) + // + %r = tensor_to_memref %0 : memref + scf.for %i = %c0 to %c5 step %c1 { + %v = vector.transfer_read %r[%i, %c0], %d0: memref, vector<5xf32> + vector.print %v : vector<5xf32> + } + + // Release the resources. + call @delSparseTensor(%s) : (!SparseTensor) -> () + dealloc %adata : memref + dealloc %bdata : memref + dealloc %xdata : memref + + return + } +} diff --git a/mlir/integration_test/Sparse/CPU/sparse_sum.mlir b/mlir/integration_test/Sparse/CPU/sparse_sum.mlir --- a/mlir/integration_test/Sparse/CPU/sparse_sum.mlir +++ b/mlir/integration_test/Sparse/CPU/sparse_sum.mlir @@ -55,7 +55,7 @@ // Runtime support library that is called directly from here. // func private @getTensorFilename(index) -> (!Filename) - func private @newSparseTensor(!Filename, memref) -> (!SparseTensor) + func private @newSparseTensor(!Filename, memref, index, index, index) -> (!SparseTensor) func private @delSparseTensor(!SparseTensor) -> () func private @print_memref_f64(%ptr : tensor<*xf64>) @@ -68,12 +68,15 @@ %c1 = constant 1 : index %c2 = constant 2 : index - // Mark both dimensions of the matrix as sparse - // (this must match the annotation in the trait). + // Mark both dimensions of the matrix as sparse and encode the + // storage scheme types (this must match the metadata in the + // trait and compiler switches). %annotations = alloc(%c2) : memref %sparse = constant true store %sparse, %annotations[%c0] : memref store %sparse, %annotations[%c1] : memref + %i64 = constant 2 : index + %f64 = constant 0 : index // Setup memory for a single reduction scalar, // initialized to zero. @@ -84,8 +87,8 @@ // Read the sparse matrix from file, construct sparse storage // according to in memory, and call the kernel. %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename) - %a = call @newSparseTensor(%fileName, %annotations) - : (!Filename, memref) -> (!SparseTensor) + %a = call @newSparseTensor(%fileName, %annotations, %i64, %i64, %f64) + : (!Filename, memref, index, index, index) -> (!SparseTensor) %0 = call @kernel_sum_reduce(%a, %x) : (!SparseTensor, tensor) -> tensor diff --git a/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp @@ -73,7 +73,9 @@ Type eltType = resType.cast().getElementType(); StringRef name; if (eltType.isIndex() || eltType.isInteger(64)) - name = "sparsePtrsI64"; + name = "sparsePointers64"; + else if (eltType.isInteger(32)) + name = "sparsePointers32"; else return failure(); rewriter.replaceOpWithNewOp( @@ -95,7 +97,9 @@ Type eltType = resType.cast().getElementType(); StringRef name; if (eltType.isIndex() || eltType.isInteger(64)) - name = "sparseIndxsI64"; + name = "sparseIndices64"; + else if (eltType.isInteger(32)) + name = "sparseIndices32"; else return failure(); rewriter.replaceOpWithNewOp( @@ -117,7 +121,9 @@ Type eltType = resType.cast().getElementType(); StringRef name; if (eltType.isF64()) - name = "sparseValsF64"; + name = "sparseValuesF64"; + else if (eltType.isF32()) + name = "sparseValuesF32"; else return failure(); rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -29,8 +29,17 @@ //===----------------------------------------------------------------------===// // -// Internal support for reading sparse tensors in one of the following -// external file formats: +// Internal support for storing and reading sparse tensors. +// +// The following memory-resident sparse storage schemes are supported: +// +// (a) A coordinate scheme for temporarily storing and lexicographically +// sorting a sparse tensor by index. +// +// (b) A "one-size-fits-all" sparse storage scheme defined by per-rank +// sparse/dense annnotations to be used by generated MLIR code. +// +// The following external formats are supported: // // (1) Matrix Market Exchange (MME): *.mtx // https://math.nist.gov/MatrixMarket/formats.html @@ -65,20 +74,24 @@ : sizes(szs), pos(0) { elements.reserve(capacity); } - // Add element as indices and value. + /// Adds element as indices and value. void add(const std::vector &ind, double val) { assert(sizes.size() == ind.size()); for (int64_t r = 0, rank = sizes.size(); r < rank; r++) assert(ind[r] < sizes[r]); // within bounds elements.emplace_back(Element(ind, val)); } - // Sort elements lexicographically by index. + /// Sorts elements lexicographically by index. void sort() { std::sort(elements.begin(), elements.end(), lexOrder); } - // Primitive one-time iteration. + /// Primitive one-time iteration. const Element &next() { return elements[pos++]; } + /// Getter for sizes array. + const std::vector &getSizes() const { return sizes; } + /// Getter for elements array. + const std::vector &getElements() const { return elements; } private: - // Returns true if indices of e1 < indices of e2. + /// Returns true if indices of e1 < indices of e2. static bool lexOrder(const Element &e1, const Element &e2) { assert(e1.indices.size() == e2.indices.size()); for (int64_t r = 0, rank = e1.indices.size(); r < rank; r++) { @@ -88,13 +101,31 @@ } return false; } - -public: std::vector sizes; // per-rank dimension sizes std::vector elements; uint64_t pos; }; +/// Abstract base class of sparse tensor storage. Note that we use +/// function overloading to implement "partial" method specialization. +class SparseTensorStorageBase { +public: + virtual uint64_t getDimSize(uint64_t) = 0; + virtual void getPointers(std::vector **, uint64_t) { fatal("p64"); } + virtual void getPointers(std::vector **, uint64_t) { fatal("p32"); } + virtual void getIndices(std::vector **, uint64_t) { fatal("i64"); } + virtual void getIndices(std::vector **, uint64_t) { fatal("i32"); } + virtual void getValues(std::vector **) { fatal("valf64"); } + virtual void getValues(std::vector **) { fatal("valf32"); } + virtual ~SparseTensorStorageBase() {} + +private: + void fatal(const char *tp) { + fprintf(stderr, "unsupported %s\n", tp); + exit(1); + } +}; + /// A memory-resident sparse tensor using a storage scheme based on per-rank /// annotations on dense/sparse. This data structure provides a bufferized /// form of an imaginary SparseTensorType, until such a type becomes a @@ -103,26 +134,38 @@ /// "one-size-fits-all" solution that simply takes an input tensor and /// annotations to implement all required setup in a general manner. template -class SparseTensorStorage { +class SparseTensorStorage : public SparseTensorStorageBase { public: /// Constructs sparse tensor storage scheme following the given /// per-rank dimension dense/sparse annotations. SparseTensorStorage(SparseTensor *tensor, bool *sparsity) - : sizes(tensor->sizes), positions(sizes.size()), indices(sizes.size()) { + : sizes(tensor->getSizes()), pointers(sizes.size()), + indices(sizes.size()) { // Provide hints on capacity. // TODO: needs fine-tuning based on sparsity - values.reserve(tensor->elements.size()); + values.reserve(tensor->getElements().size()); for (uint64_t d = 0, s = 1, rank = sizes.size(); d < rank; d++) { - s *= tensor->sizes[d]; + s *= tensor->getSizes()[d]; if (sparsity[d]) { - positions[d].reserve(s + 1); + pointers[d].reserve(s + 1); indices[d].reserve(s); s = 1; } } // Then setup the tensor. - traverse(tensor, sparsity, 0, tensor->elements.size(), 0); + traverse(tensor, sparsity, 0, tensor->getElements().size(), 0); + } + + virtual ~SparseTensorStorage() {} + + uint64_t getDimSize(uint64_t d) override { return sizes[d]; } + void getPointers(std::vector

**out, uint64_t d) override { + *out = &pointers[d]; + } + void getIndices(std::vector **out, uint64_t d) override { + *out = &indices[d]; } + void getValues(std::vector **out) override { *out = &values; } private: /// Initializes sparse tensor storage scheme from a memory-resident @@ -131,15 +174,15 @@ /// dense/sparse annotations. void traverse(SparseTensor *tensor, bool *sparsity, uint64_t lo, uint64_t hi, uint64_t d) { - const std::vector &elements = tensor->elements; + const std::vector &elements = tensor->getElements(); // Once dimensions are exhausted, insert the numerical values. if (d == sizes.size()) { values.push_back(lo < hi ? elements[lo].value : 0.0); return; } // Prepare a sparse pointer structure at this dimension. - if (sparsity[d] && positions[d].empty()) - positions[d].push_back(0); + if (sparsity[d] && pointers[d].empty()) + pointers[d].push_back(0); // Visit all elements in this interval. uint64_t full = 0; while (lo < hi) { @@ -162,22 +205,30 @@ } // Finalize the sparse pointer structure at this dimension. if (sparsity[d]) { - positions[d].push_back(indices[d].size()); + pointers[d].push_back(indices[d].size()); } else { - for (uint64_t sz = tensor->sizes[d]; full < sz; full++) + for (uint64_t sz = tensor->getSizes()[d]; full < sz; full++) traverse(tensor, sparsity, 0, 0, d + 1); // pass empty } } -public: +private: std::vector sizes; // per-rank dimension sizes - std::vector> positions; + std::vector> pointers; std::vector> indices; std::vector values; }; -typedef SparseTensorStorage - SparseTensorStorageU64U64F64; +/// Templated reader. +template +void *newSparseTensor(char *filename, bool *sparsity) { + uint64_t idata[64]; + SparseTensor *t = static_cast(openTensorC(filename, idata)); + SparseTensorStorageBase *tensor = + new SparseTensorStorage(t, sparsity); + delete t; + return tensor; +} /// Helper to convert string to lower case. static char *toLower(char *token) { @@ -292,24 +343,6 @@ extern "C" { -/// Cannot use templates with C linkage. - -struct MemRef1DU64 { - const uint64_t *base; - const uint64_t *data; - uint64_t off; - uint64_t sizes[1]; - uint64_t strides[1]; -}; - -struct MemRef1DF64 { - const double *base; - const double *data; - uint64_t off; - uint64_t sizes[1]; - uint64_t strides[1]; -}; - /// Reads in a sparse tensor with the given filename. The call yields a /// pointer to an opaque memory-resident sparse tensor object that is only /// understood by other methods in the sparse runtime support library. An @@ -398,51 +431,117 @@ return env; } -/// -/// Sparse primitives that support an opaque implementation of a bufferized -/// SparseTensor in MLIR. This could be replaced by actual codegen in MLIR. -/// +//===----------------------------------------------------------------------===// +// +// Public API of the sparse runtime support library that support an opaque +// implementation of a bufferized SparseTensor in MLIR. This could be replaced +// by actual codegen in MLIR. +// +//===----------------------------------------------------------------------===// -void *newSparseTensorC(char *filename, bool *annotations) { - uint64_t idata[64]; - SparseTensor *t = static_cast(openTensorC(filename, idata)); - SparseTensorStorageU64U64F64 *tensor = - new SparseTensorStorageU64U64F64(t, annotations); - delete t; - return tensor; -} +// Cannot use templates with C linkage. + +struct MemRef1DU64 { + const uint64_t *base; + const uint64_t *data; + uint64_t off; + uint64_t sizes[1]; + uint64_t strides[1]; +}; + +struct MemRef1DU32 { + const uint32_t *base; + const uint32_t *data; + uint64_t off; + uint64_t sizes[1]; + uint64_t strides[1]; +}; + +struct MemRef1DF64 { + const double *base; + const double *data; + uint64_t off; + uint64_t sizes[1]; + uint64_t strides[1]; +}; + +struct MemRef1DF32 { + const float *base; + const float *data; + uint64_t off; + uint64_t sizes[1]; + uint64_t strides[1]; +}; + +enum TypeEnum : uint64_t { kF64 = 0, kF32 = 1, kU64 = 2, kU32 = 3 }; -/// "MLIRized" version. void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff, - uint64_t asize, uint64_t astride) { + uint64_t asize, uint64_t astride, uint64_t ptrTp, + uint64_t indTp, uint64_t valTp) { assert(astride == 1); - return newSparseTensorC(filename, abase + aoff); + bool *sparsity = abase + aoff; + if (ptrTp == kU64 && indTp == kU64 && valTp == kF64) + return newSparseTensor(filename, sparsity); + if (ptrTp == kU64 && indTp == kU64 && valTp == kF32) + return newSparseTensor(filename, sparsity); + if (ptrTp == kU64 && indTp == kU32 && valTp == kF64) + return newSparseTensor(filename, sparsity); + if (ptrTp == kU64 && indTp == kU32 && valTp == kF32) + return newSparseTensor(filename, sparsity); + if (ptrTp == kU32 && indTp == kU64 && valTp == kF64) + return newSparseTensor(filename, sparsity); + if (ptrTp == kU32 && indTp == kU64 && valTp == kF32) + return newSparseTensor(filename, sparsity); + if (ptrTp == kU32 && indTp == kU32 && valTp == kF64) + return newSparseTensor(filename, sparsity); + if (ptrTp == kU32 && indTp == kU32 && valTp == kF32) + return newSparseTensor(filename, sparsity); + fputs("unsupported combination of types\n", stderr); + exit(1); } uint64_t sparseDimSize(void *tensor, uint64_t d) { - return static_cast(tensor)->sizes[d]; + return static_cast(tensor)->getDimSize(d); +} + +MemRef1DU64 sparsePointers64(void *tensor, uint64_t d) { + std::vector *v; + static_cast(tensor)->getPointers(&v, d); + return {v->data(), v->data(), 0, {v->size()}, {1}}; +} + +MemRef1DU32 sparsePointers32(void *tensor, uint64_t d) { + std::vector *v; + static_cast(tensor)->getPointers(&v, d); + return {v->data(), v->data(), 0, {v->size()}, {1}}; +} + +MemRef1DU64 sparseIndices64(void *tensor, uint64_t d) { + std::vector *v; + static_cast(tensor)->getIndices(&v, d); + return {v->data(), v->data(), 0, {v->size()}, {1}}; } -MemRef1DU64 sparsePtrsI64(void *tensor, uint64_t d) { - const std::vector &v = - static_cast(tensor)->positions[d]; - return {v.data(), v.data(), 0, {v.size()}, {1}}; +MemRef1DU32 sparseIndices32(void *tensor, uint64_t d) { + std::vector *v; + static_cast(tensor)->getIndices(&v, d); + return {v->data(), v->data(), 0, {v->size()}, {1}}; } -MemRef1DU64 sparseIndxsI64(void *tensor, uint64_t d) { - const std::vector &v = - static_cast(tensor)->indices[d]; - return {v.data(), v.data(), 0, {v.size()}, {1}}; +MemRef1DF64 sparseValuesF64(void *tensor) { + std::vector *v; + static_cast(tensor)->getValues(&v); + return {v->data(), v->data(), 0, {v->size()}, {1}}; } -MemRef1DF64 sparseValsF64(void *tensor) { - const std::vector &v = - static_cast(tensor)->values; - return {v.data(), v.data(), 0, {v.size()}, {1}}; +MemRef1DF32 sparseValuesF32(void *tensor) { + std::vector *v; + static_cast(tensor)->getValues(&v); + return {v->data(), v->data(), 0, {v->size()}, {1}}; } void delSparseTensor(void *tensor) { - delete static_cast(tensor); + delete static_cast(tensor); } } // extern "C" diff --git a/mlir/test/Dialect/Linalg/sparse_lower.mlir b/mlir/test/Dialect/Linalg/sparse_lower.mlir --- a/mlir/test/Dialect/Linalg/sparse_lower.mlir +++ b/mlir/test/Dialect/Linalg/sparse_lower.mlir @@ -75,9 +75,9 @@ // CHECK-MIR: %[[VAL_3:.*]] = constant 64 : index // CHECK-MIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-MIR: %[[VAL_5:.*]] = constant 1 : index -// CHECK-MIR: %[[VAL_6:.*]] = call @sparsePtrsI64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-MIR: %[[VAL_7:.*]] = call @sparseIndxsI64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-MIR: %[[VAL_8:.*]] = call @sparseValsF64(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK-MIR: %[[VAL_6:.*]] = call @sparsePointers64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-MIR: %[[VAL_7:.*]] = call @sparseIndices64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-MIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref // CHECK-MIR: %[[VAL_9:.*]] = tensor_to_memref %[[VAL_1]] : memref<64xf64> // CHECK-MIR: %[[VAL_10:.*]] = tensor_to_memref %[[VAL_2]] : memref<64xf64> // CHECK-MIR: %[[VAL_11:.*]] = alloc() : memref<64xf64> @@ -111,9 +111,9 @@ // CHECK-LIR: %[[VAL_3:.*]] = constant 64 : index // CHECK-LIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-LIR: %[[VAL_5:.*]] = constant 1 : index -// CHECK-LIR: %[[VAL_6:.*]] = call @sparsePtrsI64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-LIR: %[[VAL_7:.*]] = call @sparseIndxsI64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-LIR: %[[VAL_8:.*]] = call @sparseValsF64(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK-LIR: %[[VAL_6:.*]] = call @sparsePointers64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-LIR: %[[VAL_7:.*]] = call @sparseIndices64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-LIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref // CHECK-LIR: %[[VAL_9:.*]] = alloc() : memref<64xf64> // CHECK-LIR: scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { // CHECK-LIR: %[[VAL_11:.*]] = load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<64xf64> @@ -144,9 +144,9 @@ // CHECK-FAST: %[[VAL_3:.*]] = constant 64 : index // CHECK-FAST: %[[VAL_4:.*]] = constant 0 : index // CHECK-FAST: %[[VAL_5:.*]] = constant 1 : index -// CHECK-FAST: %[[VAL_6:.*]] = call @sparsePtrsI64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-FAST: %[[VAL_7:.*]] = call @sparseIndxsI64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-FAST: %[[VAL_8:.*]] = call @sparseValsF64(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK-FAST: %[[VAL_6:.*]] = call @sparsePointers64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-FAST: %[[VAL_7:.*]] = call @sparseIndices64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK-FAST: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref // CHECK-FAST: scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { // CHECK-FAST: %[[VAL_10:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref // CHECK-FAST: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_5]] : index diff --git a/mlir/test/Dialect/Linalg/sparse_lower_calls.mlir b/mlir/test/Dialect/Linalg/sparse_lower_calls.mlir --- a/mlir/test/Dialect/Linalg/sparse_lower_calls.mlir +++ b/mlir/test/Dialect/Linalg/sparse_lower_calls.mlir @@ -5,7 +5,7 @@ // CHECK-LABEL: func @sparse_pointers( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 1 : index -// CHECK: %[[T:.*]] = call @sparsePtrsI64(%[[A]], %[[C]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[T:.*]] = call @sparsePointers64(%[[A]], %[[C]]) : (!llvm.ptr, index) -> memref // CHECK: return %[[T]] : memref func @sparse_pointers(%arg0: !SparseTensor) -> memref { %a = linalg.sparse_tensor %arg0 : !SparseTensor to tensor<128xf64> @@ -14,10 +14,22 @@ return %0 : memref } +// CHECK-LABEL: func @sparse_pointers32( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) +// CHECK: %[[C:.*]] = constant 1 : index +// CHECK: %[[T:.*]] = call @sparsePointers32(%[[A]], %[[C]]) : (!llvm.ptr, index) -> memref +// CHECK: return %[[T]] : memref +func @sparse_pointers32(%arg0: !SparseTensor) -> memref { + %a = linalg.sparse_tensor %arg0 : !SparseTensor to tensor<128xf64> + %c = constant 1 : index + %0 = linalg.sparse_pointers %a, %c : tensor<128xf64> to memref + return %0 : memref +} + // CHECK-LABEL: func @sparse_indices( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 1 : index -// CHECK: %[[T:.*]] = call @sparseIndxsI64(%[[A]], %[[C]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[T:.*]] = call @sparseIndices64(%[[A]], %[[C]]) : (!llvm.ptr, index) -> memref // CHECK: return %[[T]] : memref func @sparse_indices(%arg0: !SparseTensor) -> memref { %a = linalg.sparse_tensor %arg0 : !SparseTensor to tensor<128xf64> @@ -26,12 +38,34 @@ return %0 : memref } -// CHECK-LABEL: func @sparse_values( +// CHECK-LABEL: func @sparse_indices32( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -// CHECK: %[[T:.*]] = call @sparseValsF64(%[[A]]) : (!llvm.ptr) -> memref +// CHECK: %[[C:.*]] = constant 1 : index +// CHECK: %[[T:.*]] = call @sparseIndices32(%[[A]], %[[C]]) : (!llvm.ptr, index) -> memref +// CHECK: return %[[T]] : memref +func @sparse_indices32(%arg0: !SparseTensor) -> memref { + %a = linalg.sparse_tensor %arg0 : !SparseTensor to tensor<128xf64> + %c = constant 1 : index + %0 = linalg.sparse_indices %a, %c : tensor<128xf64> to memref + return %0 : memref +} + +// CHECK-LABEL: func @sparse_valuesf64( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) +// CHECK: %[[T:.*]] = call @sparseValuesF64(%[[A]]) : (!llvm.ptr) -> memref // CHECK: return %[[T]] : memref -func @sparse_values(%arg0: !SparseTensor) -> memref { +func @sparse_valuesf64(%arg0: !SparseTensor) -> memref { %a = linalg.sparse_tensor %arg0 : !SparseTensor to tensor<128xf64> %0 = linalg.sparse_values %a : tensor<128xf64> to memref return %0 : memref } + +// CHECK-LABEL: func @sparse_valuesf32( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) +// CHECK: %[[T:.*]] = call @sparseValuesF32(%[[A]]) : (!llvm.ptr) -> memref +// CHECK: return %[[T]] : memref +func @sparse_valuesf32(%arg0: !SparseTensor) -> memref { + %a = linalg.sparse_tensor %arg0 : !SparseTensor to tensor<128xf32> + %0 = linalg.sparse_values %a : tensor<128xf32> to memref + return %0 : memref +}