diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -106,6 +106,7 @@ let dependentDialects = [ "arith::ArithmeticDialect", "LLVM::LLVMDialect", + "linalg::LinalgDialect", "memref::MemRefDialect", "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -36,7 +36,8 @@ kFromFile = 1, kFromCOO = 2, kEmptyCOO = 3, - kToCOO = 4 + kToCOO = 4, + kToIter = 5 }; //===----------------------------------------------------------------------===// @@ -202,14 +203,23 @@ sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i])); } -/// Generates a temporary buffer of the given size and type. -static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, - unsigned sz, Type tp) { +/// Generates an uninitialized temporary buffer of the given size and +/// type, but return it at type `memref` (rather than at type +/// `memref<$sz x $tp>`). +inline static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, + unsigned sz, Type tp) { auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); Value a = constantIndex(rewriter, loc, sz); return rewriter.create(loc, memTp, ValueRange{a}); } +/// Generates an uninitialized temporary buffer with room for one value +/// of the given type, and return the `memref<$tp>`. +inline static Value genAllocaScalar(ConversionPatternRewriter &rewriter, + Location loc, Type tp) { + return rewriter.create(loc, MemRefType::get({}, tp)); +} + /// Generates a temporary buffer of the given type and given contents. static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, ArrayRef values) { @@ -345,6 +355,49 @@ rewriter.create(loc, pTp, fn, params); } +/// Generates a call to `SparseTensorCOO::Iterator::getNext()`. +/// To avoid needing to handle multiple outputs and avoid defining +/// a bunch of new MLIR types for `Element`, we instead have both +/// `indices` and `elemPtr` serve as out-parameters and return a bool +/// to indicate whether those out-parameters are filled or whether we +/// have no more elements to iterate. +/// +/// \param [in] iter A value of MLIR-type `!llvm.ptr` which we +/// static-cast to C++-type `SparseTensorCOO::Iterator*`. +/// \param [out] ind A value of `memref` type, where the dynamic +/// size matches the iterator/sparse-tensor. +/// \param [out] elemPtr A value of MLIR-type `memref`. +/// +/// \returns `i1` indicating whether `ind` and `elemPtr` were filled. +static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, + Value iter, Value ind, Value elemPtr) { + Location loc = op->getLoc(); + Type elemTp = elemPtr.getType().cast().getElementType(); + StringRef name; + if (elemTp.isF64()) + name = "getNextF64"; + else if (elemTp.isF32()) + name = "getNextF32"; + else if (elemTp.isInteger(64)) + name = "getNextI64"; + else if (elemTp.isInteger(32)) + name = "getNextI32"; + else if (elemTp.isInteger(16)) + name = "getNextI16"; + else if (elemTp.isInteger(8)) + name = "getNextI8"; + else + llvm_unreachable("Unknown element type"); + SmallVector params; + params.push_back(iter); + params.push_back(ind); + params.push_back(elemPtr); + Type i1 = rewriter.getI1Type(); + auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true); + auto call = rewriter.create(loc, i1, fn, params); + return call.getResult(0); +} + /// If the tensor is a sparse constant, generates and returns the pair of /// the constants for the indices and the values. static Optional> @@ -379,6 +432,43 @@ return rewriter.create(loc, values, ivs[0]); } +/// Generates code to allocate a tensor of the given type, and zero +/// initialize it. This function assumes the TensorType is fully +/// specified (i.e., has static rank and sizes). +// TODO(wrengr): support dynamic sizes. +static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, + RankedTensorType tensorTp) { + Type elemTp = tensorTp.getElementType(); + auto memTp = MemRefType::get(tensorTp.getShape(), elemTp); + Value mem = rewriter.create(loc, memTp); + Value zero = constantZero(rewriter, loc, elemTp); + rewriter.create(loc, zero, mem).result(); + return mem; +} + +/// Insert the element returned by genGetNextCall() into the tensor +/// created by allocDenseTensor(). +/// +/// \param elemPtr The `memref` filled by genGetNextCall(). +/// \param tensor The `memref<... x V>` returned by allocDenseTensor(). +/// \param rank The rank of the `tensor`, and length of `ind`. +/// \param ind The `memref` filled by genGetNextCall(). +static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, + Location loc, Value elemPtr, + Value tensor, unsigned rank, + Value ind) { + // Can't pass `Value ind` directly to memref::LoadOp::build(); + // instead must explicitly convert it into a ValueRange `ivs`. + SmallVector ivs; + ivs.reserve(rank); + for (unsigned i = 0; i < rank; i++) { + Value idx = constantIndex(rewriter, loc, i); + ivs.push_back(rewriter.create(loc, ind, idx)); + } + Value elemV = rewriter.create(loc, elemPtr); + rewriter.create(loc, elemV, tensor, ivs); +} + //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// @@ -509,8 +599,75 @@ rewriter.replaceOp(op, genNewCall(rewriter, op, params)); return success(); } - if (!encDst || encSrc) { - // TODO: sparse => dense + if (!encDst && encSrc) { + // This is sparse => dense conversion, which is handled as follows: + // dst = new Tensor(0); + // iter = src->toCOO()->getIterator(); + // while (elem = iter->getNext()) { + // dst[elem.indices] = elem.value; + // } + // While it would be more efficient to inline the iterator logic + // directly rather than allocating an object and calling methods, + // this is good enough for now. + Location loc = op->getLoc(); + RankedTensorType tensorTp = resType.dyn_cast(); + if (!tensorTp) { + op.emitError() << "Result type is not a RankedTensorType"; + return failure(); + } + unsigned rank = tensorTp.getRank(); + Type elemTp = tensorTp.getElementType(); + Value dst = allocDenseTensor(rewriter, loc, tensorTp); + Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); + Value iter; + { + // Clone encSrc but removing the dimOrdering. + // The srcDimOrdering will already be applied during the + // conversion from `SparseTensorStorage src` to SparseTensorCOO + // (before that COO is converted to an iterator); so we don't + // want newParams() to apply it a second time. + // + // The dimLevelType is only actually used by the actions which + // return SparseTensorStorage (namely: kEmpty, kFromFile, and + // kFromCOO); so since we are using kToIter, the only operational + // requirement is that it has the right length. Since the dst + // is a dense tensor, we choose to set dimLevelType to all-dense + // for semantic correctness. + encDst = SparseTensorEncodingAttr::get( + op->getContext(), + SmallVector( + rank, SparseTensorEncodingAttr::DimLevelType::Dense), + AffineMap(), encSrc.getPointerBitWidth(), + encSrc.getIndexBitWidth()); + SmallVector sizes; + SmallVector params; + // TODO(wrengr): support dynamic sizes. + sizesFromType(rewriter, sizes, loc, tensorTp); + newParams(rewriter, params, op, encDst, kToIter, sizes, src); + iter = genNewCall(rewriter, op, params); + } + SmallVector noArgs; + SmallVector noTypes; + auto whileOp = rewriter.create(loc, noTypes, noArgs); + { + Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes); + rewriter.setInsertionPointToEnd(before); + Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); + rewriter.create(loc, cond, before->getArguments()); + } + { + Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes); + rewriter.setInsertionPointToStart(after); + insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); + rewriter.create(loc); + } + rewriter.setInsertionPointAfter(whileOp); + rewriter.replaceOpWithNewOp(op, resType, dst); + return success(); + } + if (!encDst && !encSrc) { + // dense => dense return failure(); } // This is a dense => sparse conversion or a sparse constant in COO => diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -118,7 +118,8 @@ // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. target.addLegalOp(); + arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, + tensor::ExtractOp>(); target.addLegalDialect(); // Populate with rules and apply rewriting rules. diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -108,6 +108,12 @@ /// Getter for elements array. const std::vector> &getElements() const { return elements; } + // Forward declaration of the class required by getIterator. We make + // it a nested class so that it can access the private fields. + class Iterator; + /// Returns an iterator over the elements of a SparseTensorCOO. + Iterator *getIterator() const { return new Iterator(*this); } + /// Factory method. Permutes the original dimensions according to /// the given ordering and expects subsequent add() calls to honor /// that same ordering for the given indices. The result is a @@ -134,10 +140,40 @@ } return false; } - std::vector sizes; // per-dimension sizes + const std::vector sizes; // per-dimension sizes std::vector> elements; }; +/// An iterator over the elements of a sparse tensor in coordinate-scheme +/// format. This iterator is not designed for use by C++ code itself, +/// but rather by generated MLIR code (i.e., by calls to `IMPL_COO_GETNEXT`); +/// hence why it may look idisyncratic or unconventional compared to +/// conventional C++ iterators. +template +class SparseTensorCOO::Iterator { + // TODO(wrengr): really this class should be a thin wrapper/subclass + // of the std::vector, rather than needing to do a dereference every + // time a method is called; but we don't want to actually copy the whole + // contents of the underlying array(s) when this class is initialized. + // Maybe we should be a thin wrapper/subclass of SparseTensorCOO? + // Or have a variant of SparseTensorStorage::toCOO() to construct this + // iterator directly? + const std::vector> &elements; + unsigned pos; + +public: + // TODO(wrengr): to guarantee safety we'd either need to consume the + // SparseTensorCOO (e.g., requiring an rvalue-reference) or get notified + // somehow whenever the SparseTensorCOO adds new elements, sorts, etc. + Iterator(const SparseTensorCOO &coo) : elements(coo.elements), pos(0) {} + + const Element *getNext() { + if (pos < elements.size()) + return &(elements[pos++]); + return nullptr; + } +}; + /// Abstract base class of sparse tensor storage. Note that we use /// function overloading to implement "partial" method specialization. class SparseTensorStorageBase { @@ -539,25 +575,32 @@ kFromFile = 1, kFromCOO = 2, kEmptyCOO = 3, - kToCOO = 4 + kToCOO = 4, + kToIter = 5 }; #define CASE(p, i, v, P, I, V) \ if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ SparseTensorCOO *tensor = nullptr; \ - if (action == kFromFile) \ - tensor = \ - openSparseTensorCOO(static_cast(ptr), rank, sizes, perm); \ - else if (action == kFromCOO) \ - tensor = static_cast *>(ptr); \ - else if (action == kEmptyCOO) \ + if (action <= kFromCOO) { \ + if (action == kFromFile) { \ + char *filename = static_cast(ptr); \ + tensor = openSparseTensorCOO(filename, rank, sizes, perm); \ + } else if (action == kFromCOO) \ + tensor = static_cast *>(ptr); \ + return SparseTensorStorage::newSparseTensor(rank, sizes, perm, \ + sparsity, tensor); \ + } else if (action == kEmptyCOO) \ return SparseTensorCOO::newSparseTensorCOO(rank, sizes, perm); \ - else if (action == kToCOO) \ - return static_cast *>(ptr)->toCOO(perm); \ - else \ - assert(action == kEmpty); \ - return SparseTensorStorage::newSparseTensor(rank, sizes, perm, \ - sparsity, tensor); \ + else { \ + tensor = static_cast *>(ptr)->toCOO(perm); \ + if (action == kToCOO) \ + return tensor; \ + else { \ + assert(action == kToIter); \ + return tensor->getIterator(); \ + } \ + } \ } #define IMPL1(NAME, TYPE, LIB) \ @@ -604,6 +647,29 @@ return tensor; \ } +/// Calls SparseTensorCOO::Iterator::getNext() with the following semantics. +/// To avoid needing to handle multiple outputs and avoid defining +/// a bunch of new MLIR types for `Element`, we instead have both +/// `iref` and `value` serve as out-parameters and return a bool to +/// indicate whether those out-parameters are filled or whether we have +/// no more elements to iterate. +#define IMPL_COO_GETNEXT(NAME, V) \ + bool _mlir_ciface_##NAME(void *ptr, StridedMemRefType *iref, \ + StridedMemRefType *vref) { \ + assert(iref->strides[0] == 1); \ + uint64_t *indx = iref->data + iref->offset; \ + V *value = vref->data + vref->offset; \ + const uint64_t isize = iref->sizes[0]; \ + auto iter = static_cast::Iterator *>(ptr); \ + const Element *elem = iter->getNext(); \ + if (elem == nullptr) \ + return false; \ + for (uint64_t r = 0; r < isize; r++) \ + indx[r] = elem->indices[r]; \ + *value = elem->value; \ + return true; \ + } + /// Constructs a new sparse tensor. This is the "swiss army knife" /// method for materializing sparse tensors into the computation. /// @@ -613,6 +679,8 @@ /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO +/// kToIter = returns iterator from storage in ptr (call IMPL_COO_GETNEXT to +/// use) void * _mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT StridedMemRefType *sref, @@ -710,10 +778,30 @@ IMPL3(addEltI16, int16_t) IMPL3(addEltI8, int8_t) +/// Helper to enumerate elements of coordinate scheme, one per value type. +IMPL_COO_GETNEXT(getNextF64, double) +IMPL_COO_GETNEXT(getNextF32, float) +IMPL_COO_GETNEXT(getNextI64, int64_t) +IMPL_COO_GETNEXT(getNextI32, int32_t) +IMPL_COO_GETNEXT(getNextI16, int16_t) +IMPL_COO_GETNEXT(getNextI8, int8_t) + #undef CASE #undef IMPL1 #undef IMPL2 #undef IMPL3 +#undef IMPL_COO_GETNEXT + +// TODO(wrengr): Either make this function more robust/usable, or figure +// out how to avoid it. +bool printInqualityF64(int64_t c, int64_t i, int64_t j, int64_t k, + double expected, double found) { + if (expected == found) + return false; + fprintf(stdout, "%c[%ld,%ld,%ld] Expected: %lg; but found: %lg\n", (int)c, i, + j, k, expected, found); + return true; +} //===----------------------------------------------------------------------===// // diff --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir @@ -0,0 +1,162 @@ +// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed"] +}> + +#SparseTensor = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed", "compressed"], + dimOrdering = affine_map<(i,j,k) -> (k,i,j)> +}> + +// CHECK-LABEL: func @sparse_convert_1d( +// CHECK-SAME: %[[Arg:.*]]: !llvm.ptr) -> tensor<13xi32> +// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[I13:.*]] = arith.constant 13 : index +// +// CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<13xi32> +// CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref<13xi32> +// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex> +// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref +// CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref +// +// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<1xi8> +// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<1xi8> to memref +// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 +// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<1xi8> +// +// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<1xindex> +// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<1xindex> to memref +// CHECK-DAG: memref.store %[[I13]], %[[SizesS]][%[[I0]]] : memref<1xindex> +// +// CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex> +// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref +// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex> +// +// CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32 +// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32 +// CHECK: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK: scf.while : () -> () { +// CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 +// CHECK: scf.condition(%[[Cond]]) +// CHECK: } do { +// CHECK: %[[Iv0:.*]] = memref.load %[[IndS]][%[[I0]]] : memref<1xindex> +// CHECK: %[[ElemVal:.*]] = memref.load %[[ElemBuffer]][] : memref +// CHECK: memref.store %[[ElemVal]], %[[M]][%[[Iv0]]] : memref<13xi32> +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[T:.*]] = memref.tensor_load %[[M]] : memref<13xi32> +// CHECK: return %[[T]] : tensor<13xi32> +func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> { + %0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32> + return %0 : tensor<13xi32> +} + +// CHECK-LABEL: func @sparse_convert_2d( +// CHECK-SAME: %[[Arg:.*]]: !llvm.ptr) -> tensor<2x4xf64> +// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[I2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index +// +// CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<2x4xf64> +// CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<2x4xf64> +// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<2xindex> to memref +// CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref +// +// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8> +// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref +// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 +// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8> +// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8> +// +// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref +// CHECK-DAG: memref.store %[[I2]], %[[SizesS]][%[[I0]]] : memref<2xindex> +// CHECK-DAG: memref.store %[[I4]], %[[SizesS]][%[[I1]]] : memref<2xindex> +// +// CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<2xindex> to memref +// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<2xindex> +// CHECK-DAG: memref.store %[[I1]], %[[PermS]][%[[I1]]] : memref<2xindex> +// +// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32 +// CHECK: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK: scf.while : () -> () { +// CHECK: %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 +// CHECK: scf.condition(%[[Cond]]) +// CHECK: } do { +// CHECK: %[[Iv0:.*]] = memref.load %[[IndS]][%[[I0]]] : memref<2xindex> +// CHECK: %[[Iv1:.*]] = memref.load %[[IndS]][%[[I1]]] : memref<2xindex> +// CHECK: %[[ElemVal:.*]] = memref.load %[[ElemBuffer]][] : memref +// CHECK: memref.store %[[ElemVal]], %[[M]][%[[Iv0]], %[[Iv1]]] : memref<2x4xf64> +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[T:.*]] = memref.tensor_load %[[M]] : memref<2x4xf64> +// CHECK: return %[[T]] : tensor<2x4xf64> +func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> { + %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64> + return %0 : tensor<2x4xf64> +} + +// CHECK-LABEL: func @sparse_convert_3d( +// CHECK-SAME: %[[Arg:.*]]: !llvm.ptr) -> tensor<2x3x4xf64> +// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[I2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[I3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index +// +// CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<2x3x4xf64> +// CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<2x3x4xf64> +// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<3xindex> +// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<3xindex> to memref +// CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref +// +// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<3xi8> +// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<3xi8> to memref +// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 +// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<3xi8> +// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<3xi8> +// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I2]]] : memref<3xi8> +// +// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<3xindex> +// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<3xindex> to memref +// CHECK-DAG: memref.store %[[I2]], %[[SizesS]][%[[I0]]] : memref<3xindex> +// CHECK-DAG: memref.store %[[I3]], %[[SizesS]][%[[I1]]] : memref<3xindex> +// CHECK-DAG: memref.store %[[I4]], %[[SizesS]][%[[I2]]] : memref<3xindex> +// +// CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<3xindex> +// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<3xindex> to memref +// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<3xindex> +// CHECK-DAG: memref.store %[[I1]], %[[PermS]][%[[I1]]] : memref<3xindex> +// CHECK-DAG: memref.store %[[I2]], %[[PermS]][%[[I2]]] : memref<3xindex> +// +// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32 +// CHECK: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK: scf.while : () -> () { +// CHECK: %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 +// CHECK: scf.condition(%[[Cond]]) +// CHECK: } do { +// CHECK: %[[Iv0:.*]] = memref.load %[[IndS]][%[[I0]]] : memref<3xindex> +// CHECK: %[[Iv1:.*]] = memref.load %[[IndS]][%[[I1]]] : memref<3xindex> +// CHECK: %[[Iv2:.*]] = memref.load %[[IndS]][%[[I2]]] : memref<3xindex> +// CHECK: %[[ElemVal:.*]] = memref.load %[[ElemBuffer]][] : memref +// CHECK: memref.store %[[ElemVal]], %[[M]][%[[Iv0]], %[[Iv1]], %[[Iv2]]] : memref<2x3x4xf64> +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[T:.*]] = memref.tensor_load %[[M]] : memref<2x3x4xf64> +// CHECK: return %[[T]] : tensor<2x3x4xf64> +func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> { + %0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64> + return %0 : tensor<2x3x4xf64> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2dense.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2dense.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_sparse2dense.mlir @@ -0,0 +1,143 @@ +// RUN: mlir-opt %s \ +// RUN: -sparsification -sparse-tensor-conversion \ +// RUN: -linalg-bufferize -convert-linalg-to-loops \ +// RUN: -convert-vector-to-scf -convert-scf-to-std \ +// RUN: -func-bufferize -tensor-constant-bufferize -tensor-bufferize \ +// RUN: -std-bufferize -finalizing-bufferize \ +// RUN: -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext + +#Tensor1 = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed", "compressed" ], + dimOrdering = affine_map<(i,j,k) -> (i,j,k)> +}> + +#Tensor2 = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed", "compressed" ], + dimOrdering = affine_map<(i,j,k) -> (j,k,i)> +}> + +#Tensor3 = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed", "compressed" ], + dimOrdering = affine_map<(i,j,k) -> (k,i,j)> +}> + +#Tensor4 = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed", "compressed" ], + dimOrdering = affine_map<(i,j,k) -> (i,j,k)> +}> + +#Tensor5 = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed", "compressed" ], + dimOrdering = affine_map<(i,j,k) -> (j,k,i)> +}> + +#Tensor6 = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed", "compressed" ], + dimOrdering = affine_map<(i,j,k) -> (k,i,j)> +}> + +// +// Integration test that tests conversions between sparse tensors. +// +module { + // + // Verification utilities. + // + func private @exit(index) -> () + func private @printInqualityF64(index, index, index, index, f64, f64) -> i1 + func @checkTensor(%name: index, %arg0: tensor<2x3x4xf64>, %arg1: tensor<2x3x4xf64>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + scf.for %i = %c0 to %c2 step %c1 { + scf.for %j = %c0 to %c3 step %c1 { + scf.for %k = %c0 to %c4 step %c1 { + %a = tensor.extract %arg0[%i, %j, %k] : tensor<2x3x4xf64> + %b = tensor.extract %arg1[%i, %j, %k] : tensor<2x3x4xf64> + // TODO(wrengr): figure out a better way to print the + // diagnostics on failure. + %c = call @printInqualityF64(%name, %i, %j, %k, %a, %b) : (index, index, index, index, f64, f64) -> i1 + scf.if %c { + call @exit(%c1) : (index) -> () + } + } + } + } + return + } + + // + // Main driver. + // + func @entry() { + // + // Initialize a 3-dim dense tensor. + // + %t = arith.constant dense<[ + [ [ 1.0, 2.0, 3.0, 4.0 ], + [ 5.0, 6.0, 7.0, 8.0 ], + [ 9.0, 10.0, 11.0, 12.0 ] ], + [ [ 13.0, 14.0, 15.0, 16.0 ], + [ 17.0, 18.0, 19.0, 20.0 ], + [ 21.0, 22.0, 23.0, 24.0 ] ] + ]> : tensor<2x3x4xf64> + + // + // Convert dense tensor directly to various sparse tensors. + // tensor1: stored as 2x3x4 + // tensor2: stored as 3x4x2 + // tensor3: stored as 4x2x3 + // + %1 = sparse_tensor.convert %t : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor1> + %2 = sparse_tensor.convert %t : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor2> + %3 = sparse_tensor.convert %t : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor3> + %4 = sparse_tensor.convert %t : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor4> + %5 = sparse_tensor.convert %t : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor5> + %6 = sparse_tensor.convert %t : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor6> + + // + // Convert sparse tensor back to dense. + // + %a = sparse_tensor.convert %1 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64> + %b = sparse_tensor.convert %2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64> + %c = sparse_tensor.convert %3 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64> + %d = sparse_tensor.convert %4 : tensor<2x3x4xf64, #Tensor4> to tensor<2x3x4xf64> + %e = sparse_tensor.convert %5 : tensor<2x3x4xf64, #Tensor5> to tensor<2x3x4xf64> + %f = sparse_tensor.convert %6 : tensor<2x3x4xf64, #Tensor6> to tensor<2x3x4xf64> + + // + // Check round-trip equality. + // + %nameA = arith.constant 97 : index + %nameB = arith.constant 98 : index + %nameC = arith.constant 99 : index + %nameD = arith.constant 100 : index + %nameE = arith.constant 101 : index + %nameF = arith.constant 102 : index + call @checkTensor(%nameA, %t, %a) : (index, tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + call @checkTensor(%nameB, %t, %b) : (index, tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + call @checkTensor(%nameC, %t, %c) : (index, tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + call @checkTensor(%nameD, %t, %d) : (index, tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + call @checkTensor(%nameE, %t, %e) : (index, tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + call @checkTensor(%nameF, %t, %f) : (index, tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + + // Release the resources. + // TODO(wrengr): what's the proper way to release a dense tensor? + // We can't just memref.dealloc since it's not a memref anymore... + sparse_tensor.release %1 : tensor<2x3x4xf64, #Tensor1> + sparse_tensor.release %2 : tensor<2x3x4xf64, #Tensor2> + sparse_tensor.release %3 : tensor<2x3x4xf64, #Tensor3> + sparse_tensor.release %4 : tensor<2x3x4xf64, #Tensor4> + sparse_tensor.release %5 : tensor<2x3x4xf64, #Tensor5> + sparse_tensor.release %6 : tensor<2x3x4xf64, #Tensor6> + + return + } +}