Changeset View
Changeset View
Standalone View
Standalone View
mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
Show First 20 Lines • Show All 614 Lines • ▼ Show 20 Lines | void forallElements(ElementConsumer<V> yield, uint64_t parentPos, | ||||
uint64_t d) { | uint64_t d) { | ||||
// Recover the `<P,I,V>` type parameters of `src`. | // Recover the `<P,I,V>` type parameters of `src`. | ||||
const auto &src = static_cast<const StorageImpl &>(this->src); | const auto &src = static_cast<const StorageImpl &>(this->src); | ||||
if (d == Base::getRank()) { | if (d == Base::getRank()) { | ||||
assert(parentPos < src.values.size() && | assert(parentPos < src.values.size() && | ||||
"Value position is out of bounds"); | "Value position is out of bounds"); | ||||
// TODO: <https://github.com/llvm/llvm-project/issues/54179> | // TODO: <https://github.com/llvm/llvm-project/issues/54179> | ||||
yield(this->cursor, src.values[parentPos]); | yield(this->cursor, src.values[parentPos]); | ||||
} else if (src.isCompressedDim(d)) { | return; | ||||
} | |||||
const auto dlt = src.getDimType(d); | |||||
if (isCompressedDLT(dlt)) { | |||||
// Look up the bounds of the `d`-level segment determined by the | // Look up the bounds of the `d`-level segment determined by the | ||||
// `d-1`-level position `parentPos`. | // `d-1`-level position `parentPos`. | ||||
const std::vector<P> &pointersD = src.pointers[d]; | const std::vector<P> &pointersD = src.pointers[d]; | ||||
assert(parentPos + 1 < pointersD.size() && | assert(parentPos + 1 < pointersD.size() && | ||||
"Parent pointer position is out of bounds"); | "Parent pointer position is out of bounds"); | ||||
const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]); | const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]); | ||||
const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]); | const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]); | ||||
// Loop-invariant code for looking up the `d`-level coordinates/indices. | // Loop-invariant code for looking up the `d`-level coordinates/indices. | ||||
const std::vector<I> &indicesD = src.indices[d]; | const std::vector<I> &indicesD = src.indices[d]; | ||||
assert(pstop <= indicesD.size() && "Index position is out of bounds"); | assert(pstop <= indicesD.size() && "Index position is out of bounds"); | ||||
uint64_t &cursorReordD = this->cursor[this->reord[d]]; | uint64_t &cursorReordD = this->cursor[this->reord[d]]; | ||||
for (uint64_t pos = pstart; pos < pstop; ++pos) { | for (uint64_t pos = pstart; pos < pstop; ++pos) { | ||||
cursorReordD = static_cast<uint64_t>(indicesD[pos]); | cursorReordD = static_cast<uint64_t>(indicesD[pos]); | ||||
forallElements(yield, pos, d + 1); | forallElements(yield, pos, d + 1); | ||||
} | } | ||||
} else if (src.isSingletonDim(d)) { | } else if (isSingletonDLT(dlt)) { | ||||
MLIR_SPARSETENSOR_FATAL("unsupported dimension level type"); | MLIR_SPARSETENSOR_FATAL("unsupported dimension level type: %d\n", | ||||
} else { // Dense dimension. | static_cast<uint8_t>(dlt)); | ||||
assert(src.isDenseDim(d)); // TODO: reuse the ASSERT_DENSE_DIM message | } else { | ||||
assert(isDenseDLT(dlt)); // TODO: reuse the ASSERT_DENSE_DIM message | |||||
const uint64_t sz = src.getDimSizes()[d]; | const uint64_t sz = src.getDimSizes()[d]; | ||||
const uint64_t pstart = parentPos * sz; | const uint64_t pstart = parentPos * sz; | ||||
uint64_t &cursorReordD = this->cursor[this->reord[d]]; | uint64_t &cursorReordD = this->cursor[this->reord[d]]; | ||||
for (uint64_t i = 0; i < sz; ++i) { | for (uint64_t i = 0; i < sz; ++i) { | ||||
cursorReordD = i; | cursorReordD = i; | ||||
forallElements(yield, pstart + i, d + 1); | forallElements(yield, pstart + i, d + 1); | ||||
} | } | ||||
} | } | ||||
▲ Show 20 Lines • Show All 261 Lines • Show Last 20 Lines |