diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -223,8 +223,6 @@ if (shape1[d] != shape2[d]) return op.emitError() << "unexpected conversion mismatch in dimension " << d; - if (shape1[d] == MemRefType::kDynamicSize) - return op.emitError("unexpected dynamic size"); } return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -14,8 +14,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -110,10 +112,11 @@ } /// Generates a call into the "swiss army knife" method of the sparse runtime -/// support library for materializing sparse tensors into the computation. -static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op, - SparseTensorEncodingAttr &enc, uint32_t action, - Value ptr) { +/// support library for materializing sparse tensors into the computation. The +/// method returns the call value and assigns the permutation to 'perm'. +static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, + SparseTensorEncodingAttr &enc, uint32_t action, + Value &perm, Value ptr = Value()) { Location loc = op->getLoc(); ShapedType resType = op->getResult(0).getType().cast(); SmallVector params; @@ -136,17 +139,16 @@ // Dimension order permutation array. This is the "identity" permutation by // default, or otherwise the "reverse" permutation of a given ordering, so // that indices can be mapped quickly to the right position. - SmallVector perm(sz); - AffineMap p = enc.getDimOrdering(); - if (p) { - assert(p.isPermutation() && p.getNumResults() == sz); + SmallVector rev(sz); + if (AffineMap p = enc.getDimOrdering()) { for (unsigned i = 0; i < sz; i++) - perm[p.getDimPosition(i)] = APInt(64, i); + rev[p.getDimPosition(i)] = APInt(64, i); } else { for (unsigned i = 0; i < sz; i++) - perm[i] = APInt(64, i); + rev[i] = APInt(64, i); } - params.push_back(getTensor(rewriter, 64, loc, perm)); + perm = getTensor(rewriter, 64, loc, rev); + params.push_back(perm); // Secondary and primary types encoding. unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); @@ -159,53 +161,54 @@ params.push_back( rewriter.create(loc, rewriter.getI64IntegerAttr(primary))); // User action and pointer. + Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); + if (!ptr) + ptr = rewriter.create(loc, pTp); params.push_back( rewriter.create(loc, rewriter.getI32IntegerAttr(action))); params.push_back(ptr); // Generate the call to create new tensor. - Type ptrType = - LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); StringRef name = "newSparseTensor"; - rewriter.replaceOpWithNewOp( - op, ptrType, getFunc(op, name, ptrType, params), params); + auto call = + rewriter.create(loc, pTp, getFunc(op, name, pTp, params), params); + return call.getResult(0); } -/// Generates a call that exposes the data pointer as a void pointer. -// TODO: probing the data pointer directly is a bit raw; we should replace -// this with proper memref util calls once they become available. -static bool genPtrCall(ConversionPatternRewriter &rewriter, Operation *op, - Value val, Value &ptr) { +/// Generates a call that adds one element to a coordinate scheme. +static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, + Value ptr, Value tensor, Value ind, Value perm, + ValueRange ivs) { Location loc = op->getLoc(); - ShapedType sType = op->getResult(0).getType().cast(); - Type eltType = sType.getElementType(); - // Specialize name for the data type. Even though the final buffferized - // version only operates on pointers, different names are required to - // ensure type correctness for all intermediate states. StringRef name; + Type eltType = tensor.getType().cast().getElementType(); if (eltType.isF64()) - name = "getPtrF64"; + name = "addEltF64"; else if (eltType.isF32()) - name = "getPtrF32"; + name = "addEltF32"; else if (eltType.isInteger(64)) - name = "getPtrI64"; + name = "addEltI64"; else if (eltType.isInteger(32)) - name = "getPtrI32"; + name = "addEltI32"; else if (eltType.isInteger(16)) - name = "getPtrI16"; + name = "addEltI16"; else if (eltType.isInteger(8)) - name = "getPtrI8"; + name = "addEltI8"; else - return false; - auto memRefTp = MemRefType::get(sType.getShape(), eltType); - auto unrankedTp = UnrankedMemRefType::get(eltType, 0); - Value c = rewriter.create(loc, memRefTp, val); - Value d = rewriter.create(loc, unrankedTp, c); - Type ptrType = - LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); - auto call = - rewriter.create(loc, ptrType, getFunc(op, name, ptrType, d), d); - ptr = call.getResult(0); - return true; + llvm_unreachable("Unknown element type"); + Value val = rewriter.create(loc, tensor, ivs); + // TODO: add if here? + unsigned i = 0; + for (auto iv : ivs) { + Value idx = rewriter.create(loc, rewriter.getIndexAttr(i++)); + rewriter.create(loc, iv, ind, idx); + } + SmallVector params; + params.push_back(ptr); + params.push_back(val); + params.push_back(ind); + params.push_back(perm); + Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); + rewriter.create(loc, pTp, getFunc(op, name, pTp, params), params); } //===----------------------------------------------------------------------===// @@ -273,7 +276,8 @@ auto enc = getSparseTensorEncoding(resType); if (!enc) return failure(); - genNewCall(rewriter, op, enc, 0, operands[0]); + Value perm; + rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0])); return success(); } }; @@ -291,11 +295,46 @@ // and sparse => dense if (!encDst || encSrc) return failure(); - // This is a dense => sparse conversion. - Value ptr; - if (!genPtrCall(rewriter, op, operands[0], ptr)) - return failure(); - genNewCall(rewriter, op, encDst, 1, ptr); + // This is a dense => sparse conversion, that is handled as follows: + // t = newSparseCOO() + // for i1 in dim1 + // .. + // for ik in dimk + // val = a[i1,..,ik] + // if val != 0 + // t->add(val, [i1,..,ik], [p1,..,pk]) + // s = newSparseTensor(t) + // Note that the dense tensor traversal code is actually implemented + // using MLIR IR to avoid having to expose too much low-level + // memref traversal details to the runtime support library. + Location loc = op->getLoc(); + ShapedType shape = resType.cast(); + auto memTp = + MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); + Value perm; + Value ptr = genNewCall(rewriter, op, encDst, 2, perm); + Value tensor = operands[0]; + Value arg = rewriter.create( + loc, rewriter.getIndexAttr(shape.getRank())); + Value ind = rewriter.create(loc, memTp, ValueRange{arg}); + SmallVector lo; + SmallVector hi; + SmallVector st; + Value zero = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value one = rewriter.create(loc, rewriter.getIndexAttr(1)); + for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { + lo.push_back(zero); + hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i)); + st.push_back(one); + } + scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {}, + [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange args) -> scf::ValueVector { + genAddEltCall(rewriter, op, ptr, tensor, ind, perm, + ivs); + return {}; + }); + rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr)); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -99,6 +99,9 @@ ConversionTarget target(*ctx); target.addIllegalOp(); + // All dynamic rules below accept new function, call, return, and dimop + // operations as legal output of the rewriting provided that all sparse + // tensor types have been fully rewritten. target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); target.addDynamicallyLegalOp([&](CallOp op) { @@ -106,8 +109,15 @@ }); target.addDynamicallyLegalOp( [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); - target.addLegalOp(); + target.addDynamicallyLegalOp([&](tensor::DimOp op) { + return converter.isLegal(op.getOperandTypes()); + }); + // The following operations and dialects may be introduced by the + // rewriting rules, and are therefore marked as legal. + target.addLegalOp(); + target.addLegalDialect(); + // Populate with rules and apply rewriting rules. populateFuncOpTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); populateSparseTensorConversionPatterns(converter, patterns); diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -18,6 +18,8 @@ #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#undef AART + #include #include #include @@ -36,7 +38,7 @@ // (a) A coordinate scheme for temporarily storing and lexicographically // sorting a sparse tensor by index. // -// (b) A "one-size-fits-all" sparse storage scheme defined by per-rank +// (b) A "one-size-fits-all" sparse tensor storage scheme defined by per-rank // sparse/dense annnotations to be used by generated MLIR code. // // The following external formats are supported: @@ -71,7 +73,7 @@ template struct SparseTensor { public: - SparseTensor(const std::vector &szs, uint64_t capacity = 0) + SparseTensor(const std::vector &szs, uint64_t capacity) : sizes(szs), pos(0) { if (capacity) elements.reserve(capacity); @@ -94,6 +96,16 @@ /// Getter for elements array. const std::vector> &getElements() const { return elements; } + /// Factory method. + static SparseTensor *newSparseTensor(uint64_t size, uint64_t *sizes, + uint64_t *perm, + uint64_t capacity = 0) { + std::vector indices(size); + for (uint64_t r = 0; r < size; r++) + indices[perm[r]] = sizes[r]; + return new SparseTensor(indices, capacity); + } + private: /// Returns true if indices of e1 < indices of e2. static bool lexOrder(const Element &e1, const Element &e2) { @@ -155,8 +167,9 @@ template class SparseTensorStorage : public SparseTensorStorageBase { public: - /// Constructs sparse tensor storage scheme following the given - /// per-rank dimension dense/sparse annotations. + /// Constructs a sparse tensor storage scheme from the given sparse + /// tensor in coordinate scheme following the given per-rank dimension + /// dense/sparse annotations. SparseTensorStorage(SparseTensor *tensor, uint8_t *sparsity) : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) { // Provide hints on capacity. @@ -175,10 +188,86 @@ } // Then setup the tensor. traverse(tensor, sparsity, 0, nnz, 0); +#ifdef AART + dump(); +#endif } virtual ~SparseTensorStorage() {} +#ifdef AART + void dump() const { + fprintf(stderr, "++++++++++ rank=%lu +++++++++++\n", sizes.size()); + if constexpr (std::is_same_v) + fprintf(stderr, "PTR64 "); + else if constexpr (std::is_same_v) + fprintf(stderr, "PTR32 "); + else if constexpr (std::is_same_v) + fprintf(stderr, "PTR16 "); + else if constexpr (std::is_same_v) + fprintf(stderr, "PTR8 "); + if constexpr (std::is_same_v) + fprintf(stderr, "INDX64 "); + else if constexpr (std::is_same_v) + fprintf(stderr, "INDX32 "); + else if constexpr (std::is_same_v) + fprintf(stderr, "INDX16 "); + else if constexpr (std::is_same_v) + fprintf(stderr, "INDX8 "); + if constexpr (std::is_same_v) + fprintf(stderr, "VALF64\n"); + else if constexpr (std::is_same_v) + fprintf(stderr, "VALF32\n"); + else if constexpr (std::is_same_v) + fprintf(stderr, "VALI64\n"); + else if constexpr (std::is_same_v) + fprintf(stderr, "VALI32\n"); + else if constexpr (std::is_same_v) + fprintf(stderr, "VALI16\n"); + else if constexpr (std::is_same_v) + fprintf(stderr, "VALI8\n"); + for (uint64_t r = 0; r < sizes.size(); r++) { + fprintf(stderr, "dim %lu #%lu\n", r, sizes[r]); + fprintf(stderr, " positions[%lu] #%lu :", r, pointers[r].size()); + for (uint64_t i = 0; i < pointers[r].size(); i++) + if constexpr (std::is_same_v) + fprintf(stderr, " %lu", pointers[r][i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %u", pointers[r][i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %u", pointers[r][i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %u", pointers[r][i]); + fprintf(stderr, "\n indices[%lu] #%lu :", r, indices[r].size()); + for (uint64_t i = 0; i < indices[r].size(); i++) + if constexpr (std::is_same_v) + fprintf(stderr, " %lu", indices[r][i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %u", indices[r][i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %u", indices[r][i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %u", indices[r][i]); + fprintf(stderr, "\n"); + } + fprintf(stderr, "values #%lu :", values.size()); + for (uint64_t i = 0; i < values.size(); i++) + if constexpr (std::is_same_v) + fprintf(stderr, " %lf", values[i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %f", values[i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %ld", values[i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %d", values[i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %d", values[i]); + else if constexpr (std::is_same_v) + fprintf(stderr, " %d", values[i]); + fprintf(stderr, "\n+++++++++++++++++++++++++++++\n"); + } +#endif + uint64_t getRank() const { return sizes.size(); } uint64_t getDimSize(uint64_t d) override { return sizes[d]; } @@ -192,7 +281,7 @@ } void getValues(std::vector **out) override { *out = &values; } - // Factory method. + /// Factory method. static SparseTensorStorage *newSparseTensor(SparseTensor *t, uint8_t *s) { t->sort(); // sort lexicographically @@ -202,10 +291,9 @@ } private: - /// Initializes sparse tensor storage scheme from a memory-resident - /// representation of an external sparse tensor. This method prepares - /// the pointers and indices arrays under the given per-rank dimension - /// dense/sparse annotations. + /// Initializes sparse tensor storage scheme from a memory-resident sparse + /// tensor in coordinate scheme. This method prepares the pointers and indices + /// arrays under the given per-rank dimension dense/sparse annotations. void traverse(SparseTensor *tensor, uint8_t *sparsity, uint64_t lo, uint64_t hi, uint64_t d) { const std::vector> &elements = tensor->getElements(); @@ -335,6 +423,9 @@ template static SparseTensor *openTensor(char *filename, uint64_t size, uint64_t *sizes, uint64_t *perm) { +#ifdef AART + fprintf(stderr, "SPARSE SUPPORT LIB: OPEN FILE %s\n", filename); +#endif // Open the file. FILE *file = fopen(filename, "r"); if (!file) { @@ -355,14 +446,13 @@ // and the number of nonzeros as initial capacity. assert(size == idata[0] && "rank mismatch"); uint64_t nnz = idata[1]; + for (uint64_t r = 0; r < size; r++) + assert((sizes[r] == 0 || sizes[r] == idata[2 + r]) && + "dimension size mismatch"); + SparseTensor *tensor = + SparseTensor::newSparseTensor(size, idata + 2, perm, nnz); + // Read all nonzero elements. std::vector indices(size); - for (uint64_t r = 0; r < size; r++) { - uint64_t sz = idata[2 + r]; - assert((sizes[r] == 0 || sizes[r] == sz) && "dimension size mismatch"); - indices[perm[r]] = sz; - } - SparseTensor *tensor = new SparseTensor(indices, nnz); - // Read all nonzero elements. for (uint64_t k = 0; k < nnz; k++) { uint64_t idx = -1; for (uint64_t r = 0; r < size; r++) { @@ -384,39 +474,17 @@ } // Close the file and return tensor. fclose(file); - return tensor; -} - -/// Helper to copy a linearized dense tensor. -template -static V *copyTensorTraverse(SparseTensor *tensor, - std::vector &indices, uint64_t r, - uint64_t rank, uint64_t *sizes, uint64_t *perm, - V *data) { - for (uint64_t i = 0, sz = sizes[r]; i < sz; i++) { - indices[perm[r]] = i; - if (r + 1 == rank) { - V d = *data++; - if (d) - tensor->add(indices, d); - } else { - data = - copyTensorTraverse(tensor, indices, r + 1, rank, sizes, perm, data); - } +#ifdef AART + tensor->sort(); // sort lexicographically + const std::vector> &elements = tensor->getElements(); + for (uint64_t k = 1; k < nnz; k++) { + uint64_t same = 0; + for (uint64_t r = 0; r < size; r++) + if (elements[k].indices[r] == elements[k - 1].indices[r]) + same++; + assert(same < size && "duplicate element"); } - return data; -} - -/// Copies the nonzeros of a linearized dense tensor into a memory-resident -/// sparse tensor in coordinate scheme. -template -static SparseTensor *copyTensor(uint64_t size, uint64_t *sizes, - uint64_t *perm, V *data) { - std::vector indices(size); - for (uint64_t r = 0; r < size; r++) - indices[perm[r]] = sizes[r]; - SparseTensor *tensor = new SparseTensor(indices); - copyTensorTraverse(tensor, indices, 0, size, sizes, perm, data); +#endif return tensor; } @@ -445,11 +513,6 @@ // //===----------------------------------------------------------------------===// -struct UnrankedMemRef { - uint64_t rank; - void *descriptor; -}; - #define TEMPLATE(NAME, TYPE) \ struct NAME { \ const TYPE *base; \ @@ -464,8 +527,10 @@ SparseTensor *tensor; \ if (action == 0) \ tensor = openTensor(static_cast(ptr), asize, sizes, perm); \ + else if (action == 1) \ + tensor = static_cast *>(ptr); \ else \ - tensor = copyTensor(asize, sizes, perm, static_cast(ptr)); \ + return SparseTensor::newSparseTensor(asize, sizes, perm); \ return SparseTensorStorage::newSparseTensor(tensor, sparsity); \ } @@ -483,8 +548,22 @@ return {v->data(), v->data(), 0, {v->size()}, {1}}; \ } -#define PTR(NAME) \ - const void *NAME(int64_t sz, UnrankedMemRef *m) { return m->descriptor; } +#define IMPL3(NAME, TYPE) \ + void *NAME(void *tensor, TYPE value, uint64_t *ibase, uint64_t *idata, \ + uint64_t ioff, uint64_t isize, uint64_t istride, uint64_t *pbase, \ + uint64_t *pdata, uint64_t poff, uint64_t psize, \ + uint64_t pstride) { \ + assert(istride == 1 && pstride == 1 && isize == psize); \ + uint64_t *indx = idata + ioff; \ + if (!value) \ + return tensor; \ + uint64_t *perm = pdata + poff; \ + std::vector indices(isize); \ + for (uint64_t r = 0; r < isize; r++) \ + indices[perm[r]] = indx[r]; \ + static_cast *>(tensor)->add(indices, value); \ + return tensor; \ + } TEMPLATE(MemRef1DU64, uint64_t); TEMPLATE(MemRef1DU32, uint32_t); @@ -510,6 +589,10 @@ /// Constructs a new sparse tensor. This is the "swiss army knife" /// method for materializing sparse tensors into the computation. +/// action +/// 0 : ptr contains filename to read into storage +/// 1 : ptr contains coordinate scheme to assign to storage +/// 2 : returns coordinate scheme to fill (call back later with 1) void *newSparseTensor(uint8_t *abase, uint8_t *adata, uint64_t aoff, uint64_t asize, uint64_t astride, uint64_t *sbase, uint64_t *sdata, uint64_t soff, uint64_t ssize, @@ -517,7 +600,11 @@ uint64_t poff, uint64_t psize, uint64_t pstride, uint64_t ptrTp, uint64_t indTp, uint64_t valTp, uint32_t action, void *ptr) { +#ifdef AART + fprintf(stderr, "SPARSE SUPPORT LIB: swiss army knife %u\n", action); +#endif assert(astride == 1 && sstride == 1 && pstride == 1); + assert(asize == ssize && ssize == psize); uint8_t *sparsity = adata + aoff; uint64_t *sizes = sdata + soff; uint64_t *perm = pdata + poff; @@ -606,18 +693,19 @@ delete static_cast(tensor); } -/// Helper to get pointer, one per value type. -PTR(getPtrF64) -PTR(getPtrF32) -PTR(getPtrI64) -PTR(getPtrI32) -PTR(getPtrI16) -PTR(getPtrI8) +/// Helper to add value to coordinate scheme, one per value type. +IMPL3(addEltF64, double) +IMPL3(addEltF32, float) +IMPL3(addEltI64, int64_t) +IMPL3(addEltI32, int32_t) +IMPL3(addEltI16, int16_t) +IMPL3(addEltI8, int8_t) #undef TEMPLATE #undef CASE #undef IMPL1 #undef IMPL2 +#undef IMPL3 } // extern "C" diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -112,24 +112,93 @@ return %0 : tensor } -// CHECK-LABEL: func @sparse_convert( +// CHECK-LABEL: func @sparse_convert_1d( +// CHECK-SAME: %[[A:.*]]: tensor) -> !llvm.ptr +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[D0:.*]] = constant dense<0> : tensor<1xi64> +// CHECK-DAG: %[[D1:.*]] = constant dense<1> : tensor<1xi8> +// CHECK-DAG: %[[X:.*]] = tensor.cast %[[D1]] : tensor<1xi8> to tensor +// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[D0]] : tensor<1xi64> to tensor +// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Y]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}}) +// CHECK: %[[M:.*]] = memref.alloca() : memref<1xindex> +// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref +// CHECK: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] { +// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor +// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex> +// CHECK: call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Y]]) +// CHECK: } +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Y]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]]) +// CHECK: return %[[T]] : !llvm.ptr +func @sparse_convert_1d(%arg0: tensor) -> tensor { + %0 = sparse_tensor.convert %arg0 : tensor to tensor + return %0 : tensor +} + +// CHECK-LABEL: func @sparse_convert_2d( // CHECK-SAME: %[[A:.*]]: tensor<2x4xf64>) -> !llvm.ptr +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[U:.*]] = constant dense<[0, 1]> : tensor<2xi8> // CHECK-DAG: %[[V:.*]] = constant dense<[2, 4]> : tensor<2xi64> // CHECK-DAG: %[[W:.*]] = constant dense<[0, 1]> : tensor<2xi64> -// CHECK: %[[C:.*]] = memref.buffer_cast %arg0 : memref<2x4xf64> -// CHECK: %[[M:.*]] = memref.cast %[[C]] : memref<2x4xf64> to memref<*xf64> -// CHECK: %[[C:.*]] = call @getPtrF64(%[[M]]) : (memref<*xf64>) -> !llvm.ptr // CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor // CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor // CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor +// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}}) +// CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<2xindex> to memref +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %{{.*}} step %[[C1]] { +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %{{.*}} step %[[C1]] { +// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]], %[[J]]] : tensor<2x4xf64> +// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<2xindex> +// CHECK: memref.store %[[J]], %[[M]][%[[C1]]] : memref<2xindex> +// CHECK: call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]]) +// CHECK: } +// CHECK: } // CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]]) // CHECK: return %[[T]] : !llvm.ptr -func @sparse_convert(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> { +func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> { %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix> return %0 : tensor<2x4xf64, #SparseMatrix> } +// CHECK-LABEL: func @sparse_convert_3d( +// CHECK-SAME: %[[A:.*]]: tensor) -> !llvm.ptr +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[U:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8> +// CHECK-DAG: %[[V:.*]] = constant dense<0> : tensor<3xi64> +// CHECK-DAG: %[[W:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64> +// CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<3xi8> to tensor +// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<3xi64> to tensor +// CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<3xi64> to tensor +// CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}}) +// CHECK: %[[M:.*]] = memref.alloca() : memref<3xindex> +// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<3xindex> to memref +// CHECK: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor +// CHECK: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor +// CHECK: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U1]] step %[[C1]] { +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[U2]] step %[[C1]] { +// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[U3]] step %[[C1]] { +// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]], %[[J]], %[[K]]] : tensor +// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<3xindex> +// CHECK: memref.store %[[J]], %[[M]][%[[C1]]] : memref<3xindex> +// CHECK: memref.store %[[K]], %[[M]][%[[C2]]] : memref<3xindex> +// CHECK: call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]]) +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]]) +// CHECK: return %[[T]] : !llvm.ptr +func @sparse_convert_3d(%arg0: tensor) -> tensor { + %0 = sparse_tensor.convert %arg0 : tensor to tensor + return %0 : tensor +} + // CHECK-LABEL: func @sparse_pointers( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 0 : index