diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -103,6 +103,13 @@ } }; +/// The type of callback functions which receive an element. We avoid +/// packaging the coordinates and value together as an `Element` object +/// because this helps keep code somewhat cleaner. +template +using ElementConsumer = + const std::function &, V)> &; + /// A memory-resident sparse tensor in coordinate scheme (collection of /// elements). This data structure is used to read a sparse tensor from /// any external format into memory and sort the elements lexicographically @@ -381,24 +388,19 @@ /// Returns the target tensor's dimension sizes. inline const std::vector &permutedSizes() const { return permsz; } - /// The type of callback functions which receive an element (in target - /// order). We avoid packaging the coordinates and value together - /// as an `Element` object because this helps keep code somewhat cleaner. - typedef const std::function &, V)> - &ElementConsumer; - /// Enumerates all elements of the source tensor, permutes their /// indices, and passes the permuted element to the callback. /// The callback must not store the cursor reference directly, since /// this function reuses the storage. Instead, the callback must copy /// it if they want to keep it. - inline void forallElements(ElementConsumer yield) { + inline void forallElements(ElementConsumer yield) { forallElements(yield, 0, 0); } private: /// The recursive component of the public `forallElements`. - void forallElements(ElementConsumer yield, uint64_t parentPos, uint64_t d) { + void forallElements(ElementConsumer yield, uint64_t parentPos, + uint64_t d) { if (d == getRank()) { // TODO: yield(cursor, src.getValue(parentPos));