diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -104,6 +104,7 @@ let constructor = "mlir::createSparseTensorConversionPass()"; let dependentDialects = [ "LLVM::LLVMDialect", + "linalg::LinalgDialect", "memref::MemRefDialect", "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -260,6 +260,41 @@ params); } +/// Generates a call to SparseTensorCOO::Iterator::getNext() +/// If there is a next `Element`: the `indices` will be filled from +/// that element, and the returned `Value` will be the `V` of that element. +/// If there is no next `Element`: the `indices` will be left in an +/// indeterminate state (in practice it'll be left unmodified), and +/// the returned `Value` is zero-- which (by definition) is never a +/// valid `V` for `SparseTensorCOO` to contain, so there's no chance +/// of confusion, nor any loss of expressivity. +static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, + Type elemTp, Value iter, Value indices) { + Location loc = op->getLoc(); + StringRef name; + if (elemTp.isF64()) + name = "getNextF64"; + else if (elemTp.isF32()) + name = "getNextF32"; + else if (elemTp.isInteger(64)) + name = "getNextI64"; + else if (elemTp.isInteger(32)) + name = "getNextI32"; + else if (elemTp.isInteger(16)) + name = "getNextI16"; + else if (elemTp.isInteger(8)) + name = "getNextI8"; + else + llvm_unreachable("Unknown element type"); + SmallVector params; + params.push_back(iter); + params.push_back(indices); + auto call = rewriter.create( + loc, elemTp, getFunc(op, name, elemTp, params, /*emitCInterface=*/true), + params); + return call.getResult(0); +} + /// If the tensor is a sparse constant, generates and returns the pair of /// the constants for the indices and the values. static Optional> @@ -307,6 +342,22 @@ return rewriter.create(loc, memTp, ValueRange{arg}); } +/// Generates code to allocate a memref with the shape and element-type +/// given by `shapeTp`, and zero initialize it. +static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, + ShapedType shapeTp) { + Type elemTp = shapeTp.getElementType(); + Value tensor = rewriter.create( + loc, MemRefType::get(shapeTp.getShape(), elemTp)); + /* + return rewriter + .create(loc, getZero(rewriter, loc, elemTp), tensor) + .result(); + */ + rewriter.create(loc, getZero(rewriter, loc, elemTp), tensor); + return tensor; +} + //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// @@ -393,8 +444,44 @@ rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo)); return success(); } - if (!encDst || encSrc) { - // TODO: sparse => dense + if (!encDst && encSrc) { + // This is sparse => dense conversion, which is handled as follows: + // dst = new MemRef(0); + // iter = src->toCOO()->getIterator(); + // while (elem = iter->getNext()) { + // dst[elem.indices] = elem.value; + // } + // While it would be more efficient to inline the iterator logic + // directly rather than allocating an object and calling methods, + // this is good enough for now. + Location loc = op->getLoc(); + ShapedType shapeTp = resType.cast(); + Value dst = allocDenseTensor(rewriter, loc, shapeTp); + Value indices = allocaIndices(rewriter, loc, shapeTp.getRank()); + Value perm; + Value iter = genNewCall(rewriter, op, encSrc, 4, perm, src); + // Generate the while-loop Op itself. + TypeRange argTypes{}; + ValueRange args{}; + scf::WhileOp whileOp = rewriter.create(loc, argTypes, args); + Block *before = rewriter.createBlock(&whileOp.before(), {}, argTypes); + Block *after = rewriter.createBlock(&whileOp.after(), {}, argTypes); + // Build the while-loop's "before" region. + rewriter.setInsertionPointToEnd(before); + Type elemTp = shapeTp.getElementType(); + Value elemVal = genGetNextCall(rewriter, op, elemTp, iter, indices); + Value cond = genIsNonzero(rewriter, loc, elemVal); + rewriter.create(loc, cond, before->getArguments()); + // Build the while-loop's "after" region. + rewriter.setInsertionPointToStart(after); + rewriter.create(loc, elemVal, dst, indices); + // Finish up. + rewriter.setInsertionPointAfter(whileOp); + rewriter.replaceOp(op, dst); + return success(); + } + if (!encDst && !encSrc) { + // dense => dense return failure(); } // This is a dense => sparse conversion or a sparse constant in COO => diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -115,9 +115,9 @@ // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. target.addLegalOp(); + tensor::ExtractOp, CmpFOp, CmpIOp, linalg::FillOp>(); target.addLegalDialect(); + memref::MemRefDialect, linalg::LinalgDialect>(); // Populate with rules and apply rewriting rules. populateFuncOpTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -107,6 +107,12 @@ /// Getter for elements array. const std::vector> &getElements() const { return elements; } + // Forward declaration of the class required by getIterator. We make + // it a nested class so that it can access the private fields. + class Iterator; + /// Returns an iterator over the elements of a SparseTensorCOO. + Iterator *getIterator() const { return new Iterator(*this); } + /// Factory method. Permutes the original dimensions according to /// the given ordering and expects subsequent add() calls to honor /// that same ordering for the given indices. The result is a @@ -132,10 +138,40 @@ } return false; } - std::vector sizes; // per-rank dimension sizes + const std::vector sizes; // per-rank dimension sizes std::vector> elements; }; +/// An iterator over the elements of a sparse tensor in coordinate-scheme +/// format. This iterator is not designed for use by C++ code itself, +/// but rather by generated MLIR code (i.e., by calls to `IMPL_COO_GETNEXT`); +/// hence why it may look idisyncratic or unconventional compared to +/// conventional C++ iterators. +template +class SparseTensorCOO::Iterator { + // TODO(wrengr): really this class should be a thin wrapper/subclass + // of the std::vector, rather than needing to do a dereference every + // time a method is called; but we don't want to actually copy the whole + // contents of the underlying array(s) when this class is initialized. + // Maybe we should be a thin wrapper/subclass of SparseTensorCOO? + // Or have a variant of SparseTensorStorage::toCOO() to construct this + // iterator directly? + const std::vector> &elements; + unsigned pos; + +public: + // TODO(wrengr): to guarantee safety we'd either need to consume the + // SparseTensorCOO (e.g., requiring an rvalue-reference) or get notified + // somehow whenever the SparseTensorCOO adds new elements, sorts, etc. + Iterator(const SparseTensorCOO &coo) : elements(coo.elements), pos(0) {} + + const Element *getNext() { + if (pos < elements.size()) + return &(elements[pos++]); + return nullptr; + } +}; + /// Abstract base class of sparse tensor storage. Note that we use /// function overloading to implement "partial" method specialization. class SparseTensorStorageBase { @@ -517,8 +553,12 @@ tensor = static_cast *>(ptr); \ else if (action == 2) \ return SparseTensorCOO::newSparseTensorCOO(size, sizes, perm); \ - else \ - return static_cast *>(ptr)->toCOO(perm); \ + else { \ + tensor = static_cast *>(ptr)->toCOO(perm); \ + if (action == 3) \ + return tensor; \ + return tensor->getIterator(); \ + } \ return SparseTensorStorage::newSparseTensor(tensor, sparsity, \ perm); \ } @@ -560,6 +600,28 @@ return tensor; \ } +/// Calls SparseTensorCOO::Iterator::getNext() with the following semantics. +/// If there is a next `Element`: the `iref` will be filled from that +/// element, and the element's value is returned. +/// If there is no next `Element`: the `iref` will be left in an +/// indeterminate state (in practice it'll be left unmodified), and +/// the return value is zero-- which (by definition) is never a valid `V` +/// for `SparseTensorCOO` to contain, so there's no chance of confusion, +/// nor any loss of expressivity. +#define IMPL_COO_GETNEXT(NAME, V) \ + V _mlir_ciface_##NAME(void *ptr, StridedMemRefType *iref) { \ + assert(iref->strides[0] == 1); \ + uint64_t *indx = iref->data + iref->offset; \ + const uint64_t isize = iref->sizes[0]; \ + auto iter = static_cast::Iterator *>(ptr); \ + const Element *elem = iter->getNext(); \ + if (elem == nullptr) \ + return 0; \ + for (uint64_t r = 0; r < isize; r++) \ + indx[r] = elem->indices[r]; \ + return elem->value; \ + } + enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; enum PrimaryTypeEnum : uint64_t { @@ -578,6 +640,7 @@ /// 1 : ptr contains coordinate scheme to assign to new storage /// 2 : returns empty coordinate scheme to fill (call back 1 to setup) /// 3 : returns coordinate scheme from storage in ptr (call back 1 to convert) +/// 4 : returns iterator from storage in ptr (call IMPL_COO_GETNEXT to use) void * _mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT StridedMemRefType *sref, @@ -674,10 +737,19 @@ IMPL3(addEltI16, int16_t) IMPL3(addEltI8, int8_t) +/// Helper to enumerate elements of coordinate scheme, one per value type. +IMPL_COO_GETNEXT(getNextF64, double) +IMPL_COO_GETNEXT(getNextF32, float) +IMPL_COO_GETNEXT(getNextI64, int64_t) +IMPL_COO_GETNEXT(getNextI32, int32_t) +IMPL_COO_GETNEXT(getNextI16, int16_t) +IMPL_COO_GETNEXT(getNextI8, int8_t) + #undef CASE #undef IMPL1 #undef IMPL2 #undef IMPL3 +#undef IMPL_COO_GETNEXT //===----------------------------------------------------------------------===// // diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion.mlir @@ -1,11 +1,14 @@ // RUN: mlir-opt %s \ +// RUN: --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \ // RUN: --sparsification --sparse-tensor-conversion \ +// RUN: --linalg-bufferize --convert-linalg-to-loops \ // RUN: --convert-vector-to-scf --convert-scf-to-std \ // RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ -// RUN: --std-bufferize --finalizing-bufferize \ -// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-std-to-llvm --reconcile-unrealized-casts | \ +// RUN: --std-bufferize --finalizing-bufferize \ +// RUN: --convert-vector-to-llvm --convert-memref-to-llvm \ +// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ +// RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -75,6 +78,25 @@ } return } + func @checkTensor(%arg0: tensor<2x3x4xf64>, %arg1: tensor<2x3x4xf64>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + // Same content? + scf.for %i = %c0 to %c2 step %c1 { + scf.for %j = %c0 to %c3 step %c1 { + scf.for %k = %c0 to %c4 step %c1 { + %a = tensor.extract %arg0[%i, %j, %k] : tensor<2x3x4xf64> + %b = tensor.extract %arg1[%i, %j, %k] : tensor<2x3x4xf64> + %c = cmpf une, %a, %b : f64 + scf.if %c { + call @exit(%c1) : (index) -> () + } + }}} + return + } // // Output utility. @@ -132,6 +154,13 @@ %h = sparse_tensor.convert %2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64, #Tensor3> %i = sparse_tensor.convert %3 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64, #Tensor3> + // + // Convert sparse tensor back to dense. + // + %x = sparse_tensor.convert %1 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64> + %y = sparse_tensor.convert %2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64> + %z = sparse_tensor.convert %3 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64> + // // Check values equality. // @@ -160,6 +189,10 @@ call @checkf64(%v3, %hv) : (memref, memref) -> () call @checkf64(%v3, %iv) : (memref, memref) -> () + call @checkTensor(%t, %x) : (tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + call @checkTensor(%t, %y) : (tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + call @checkTensor(%t, %z) : (tensor<2x3x4xf64>, tensor<2x3x4xf64>) -> () + // // Check index equality. // @@ -184,25 +217,25 @@ %c11 = sparse_tensor.indices %c, %c1 : tensor<2x3x4xf64, #Tensor1> to memref %c12 = sparse_tensor.indices %c, %c2 : tensor<2x3x4xf64, #Tensor1> to memref - %d10 = sparse_tensor.indices %d, %c0 : tensor<2x3x4xf64, #Tensor2> to memref - %d11 = sparse_tensor.indices %d, %c1 : tensor<2x3x4xf64, #Tensor2> to memref - %d12 = sparse_tensor.indices %d, %c2 : tensor<2x3x4xf64, #Tensor2> to memref - %e10 = sparse_tensor.indices %e, %c0 : tensor<2x3x4xf64, #Tensor2> to memref - %e11 = sparse_tensor.indices %e, %c1 : tensor<2x3x4xf64, #Tensor2> to memref - %e12 = sparse_tensor.indices %e, %c2 : tensor<2x3x4xf64, #Tensor2> to memref - %f10 = sparse_tensor.indices %f, %c0 : tensor<2x3x4xf64, #Tensor2> to memref - %f11 = sparse_tensor.indices %f, %c1 : tensor<2x3x4xf64, #Tensor2> to memref - %f12 = sparse_tensor.indices %f, %c2 : tensor<2x3x4xf64, #Tensor2> to memref + %d20 = sparse_tensor.indices %d, %c0 : tensor<2x3x4xf64, #Tensor2> to memref + %d21 = sparse_tensor.indices %d, %c1 : tensor<2x3x4xf64, #Tensor2> to memref + %d22 = sparse_tensor.indices %d, %c2 : tensor<2x3x4xf64, #Tensor2> to memref + %e20 = sparse_tensor.indices %e, %c0 : tensor<2x3x4xf64, #Tensor2> to memref + %e21 = sparse_tensor.indices %e, %c1 : tensor<2x3x4xf64, #Tensor2> to memref + %e22 = sparse_tensor.indices %e, %c2 : tensor<2x3x4xf64, #Tensor2> to memref + %f20 = sparse_tensor.indices %f, %c0 : tensor<2x3x4xf64, #Tensor2> to memref + %f21 = sparse_tensor.indices %f, %c1 : tensor<2x3x4xf64, #Tensor2> to memref + %f22 = sparse_tensor.indices %f, %c2 : tensor<2x3x4xf64, #Tensor2> to memref - %g10 = sparse_tensor.indices %g, %c0 : tensor<2x3x4xf64, #Tensor3> to memref - %g11 = sparse_tensor.indices %g, %c1 : tensor<2x3x4xf64, #Tensor3> to memref - %g12 = sparse_tensor.indices %g, %c2 : tensor<2x3x4xf64, #Tensor3> to memref - %h10 = sparse_tensor.indices %h, %c0 : tensor<2x3x4xf64, #Tensor3> to memref - %h11 = sparse_tensor.indices %h, %c1 : tensor<2x3x4xf64, #Tensor3> to memref - %h12 = sparse_tensor.indices %h, %c2 : tensor<2x3x4xf64, #Tensor3> to memref - %i10 = sparse_tensor.indices %i, %c0 : tensor<2x3x4xf64, #Tensor3> to memref - %i11 = sparse_tensor.indices %i, %c1 : tensor<2x3x4xf64, #Tensor3> to memref - %i12 = sparse_tensor.indices %i, %c2 : tensor<2x3x4xf64, #Tensor3> to memref + %g30 = sparse_tensor.indices %g, %c0 : tensor<2x3x4xf64, #Tensor3> to memref + %g31 = sparse_tensor.indices %g, %c1 : tensor<2x3x4xf64, #Tensor3> to memref + %g32 = sparse_tensor.indices %g, %c2 : tensor<2x3x4xf64, #Tensor3> to memref + %h30 = sparse_tensor.indices %h, %c0 : tensor<2x3x4xf64, #Tensor3> to memref + %h31 = sparse_tensor.indices %h, %c1 : tensor<2x3x4xf64, #Tensor3> to memref + %h32 = sparse_tensor.indices %h, %c2 : tensor<2x3x4xf64, #Tensor3> to memref + %i30 = sparse_tensor.indices %i, %c0 : tensor<2x3x4xf64, #Tensor3> to memref + %i31 = sparse_tensor.indices %i, %c1 : tensor<2x3x4xf64, #Tensor3> to memref + %i32 = sparse_tensor.indices %i, %c2 : tensor<2x3x4xf64, #Tensor3> to memref call @check(%v10, %a10) : (memref, memref) -> () call @check(%v11, %a11) : (memref, memref) -> () @@ -214,25 +247,25 @@ call @check(%v11, %c11) : (memref, memref) -> () call @check(%v12, %c12) : (memref, memref) -> () - call @check(%v20, %d10) : (memref, memref) -> () - call @check(%v21, %d11) : (memref, memref) -> () - call @check(%v22, %d12) : (memref, memref) -> () - call @check(%v20, %e10) : (memref, memref) -> () - call @check(%v21, %e11) : (memref, memref) -> () - call @check(%v22, %e12) : (memref, memref) -> () - call @check(%v20, %f10) : (memref, memref) -> () - call @check(%v21, %f11) : (memref, memref) -> () - call @check(%v22, %f12) : (memref, memref) -> () + call @check(%v20, %d20) : (memref, memref) -> () + call @check(%v21, %d21) : (memref, memref) -> () + call @check(%v22, %d22) : (memref, memref) -> () + call @check(%v20, %e20) : (memref, memref) -> () + call @check(%v21, %e21) : (memref, memref) -> () + call @check(%v22, %e22) : (memref, memref) -> () + call @check(%v20, %f20) : (memref, memref) -> () + call @check(%v21, %f21) : (memref, memref) -> () + call @check(%v22, %f22) : (memref, memref) -> () - call @check(%v30, %g10) : (memref, memref) -> () - call @check(%v31, %g11) : (memref, memref) -> () - call @check(%v32, %g12) : (memref, memref) -> () - call @check(%v30, %h10) : (memref, memref) -> () - call @check(%v31, %h11) : (memref, memref) -> () - call @check(%v32, %h12) : (memref, memref) -> () - call @check(%v30, %i10) : (memref, memref) -> () - call @check(%v31, %i11) : (memref, memref) -> () - call @check(%v32, %i12) : (memref, memref) -> () + call @check(%v30, %g30) : (memref, memref) -> () + call @check(%v31, %g31) : (memref, memref) -> () + call @check(%v32, %g32) : (memref, memref) -> () + call @check(%v30, %h30) : (memref, memref) -> () + call @check(%v31, %h31) : (memref, memref) -> () + call @check(%v32, %h32) : (memref, memref) -> () + call @check(%v30, %i30) : (memref, memref) -> () + call @check(%v31, %i31) : (memref, memref) -> () + call @check(%v32, %i32) : (memref, memref) -> () // // Sanity check direct results.