diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -1410,8 +1410,138 @@ } #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) -// TODO(D125432): move `_mlir_ciface_newSparseTensor` closer to these -// macro definitions, but as a separate change so as not to muddy the diff. + +// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor +// can safely rewrite kIndex to kU64. We make this assertion to guarantee +// that this file cannot get out of sync with its header. +static_assert(std::is_same::value, + "Expected index_type == uint64_t"); + +/// Constructs a new sparse tensor. This is the "swiss army knife" +/// method for materializing sparse tensors into the computation. +/// +/// Action: +/// kEmpty = returns empty storage to fill later +/// kFromFile = returns storage, where ptr contains filename to read +/// kFromCOO = returns storage, where ptr contains coordinate scheme to assign +/// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO +/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO +/// kToIterator = returns iterator from storage in ptr (call getNext() to use) +void * +_mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT + StridedMemRefType *sref, + StridedMemRefType *pref, + OverheadType ptrTp, OverheadType indTp, + PrimaryType valTp, Action action, void *ptr) { + assert(aref && sref && pref); + assert(aref->strides[0] == 1 && sref->strides[0] == 1 && + pref->strides[0] == 1); + assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); + const DimLevelType *sparsity = aref->data + aref->offset; + const index_type *shape = sref->data + sref->offset; + const index_type *perm = pref->data + pref->offset; + uint64_t rank = aref->sizes[0]; + + // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. + // This is safe because of the static_assert above. + if (ptrTp == OverheadType::kIndex) + ptrTp = OverheadType::kU64; + if (indTp == OverheadType::kIndex) + indTp = OverheadType::kU64; + + // Double matrices with all combinations of overhead storage. + CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, + uint64_t, double); + CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, + uint32_t, double); + CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, + uint16_t, double); + CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, + uint8_t, double); + CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, + uint64_t, double); + CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, + uint32_t, double); + CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, + uint16_t, double); + CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, + uint8_t, double); + CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, + uint64_t, double); + CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, + uint32_t, double); + CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, + uint16_t, double); + CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, + uint8_t, double); + CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, + uint64_t, double); + CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, + uint32_t, double); + CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, + uint16_t, double); + CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, + uint8_t, double); + + // Float matrices with all combinations of overhead storage. + CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, + uint64_t, float); + CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, + uint32_t, float); + CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, + uint16_t, float); + CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, + uint8_t, float); + CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, + uint64_t, float); + CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, + uint32_t, float); + CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, + uint16_t, float); + CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, + uint8_t, float); + CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, + uint64_t, float); + CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, + uint32_t, float); + CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, + uint16_t, float); + CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, + uint8_t, float); + CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, + uint64_t, float); + CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, + uint32_t, float); + CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, + uint16_t, float); + CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, + uint8_t, float); + + // Integral matrices with both overheads of the same type. + CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); + CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); + CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); + CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); + CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); + CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); + CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); + CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); + CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); + CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); + CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); + CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); + CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); + + // Complex matrices with wide overhead. + CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64); + CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32); + + // Unsupported case (add above if needed). + fputs("unsupported combination of types\n", stderr); + exit(1); +} +#undef CASE +#undef CASE_SECSAME /// Methods that provide direct access to values. #define IMPL_SPARSEVALUES(VNAME, V) \ @@ -1549,138 +1679,6 @@ FOREVERY_V(IMPL_EXPINSERT) #undef IMPL_EXPINSERT -// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor -// can safely rewrite kIndex to kU64. We make this assertion to guarantee -// that this file cannot get out of sync with its header. -static_assert(std::is_same::value, - "Expected index_type == uint64_t"); - -/// Constructs a new sparse tensor. This is the "swiss army knife" -/// method for materializing sparse tensors into the computation. -/// -/// Action: -/// kEmpty = returns empty storage to fill later -/// kFromFile = returns storage, where ptr contains filename to read -/// kFromCOO = returns storage, where ptr contains coordinate scheme to assign -/// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO -/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO -/// kToIterator = returns iterator from storage in ptr (call getNext() to use) -void * -_mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT - StridedMemRefType *sref, - StridedMemRefType *pref, - OverheadType ptrTp, OverheadType indTp, - PrimaryType valTp, Action action, void *ptr) { - assert(aref && sref && pref); - assert(aref->strides[0] == 1 && sref->strides[0] == 1 && - pref->strides[0] == 1); - assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); - const DimLevelType *sparsity = aref->data + aref->offset; - const index_type *shape = sref->data + sref->offset; - const index_type *perm = pref->data + pref->offset; - uint64_t rank = aref->sizes[0]; - - // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. - // This is safe because of the static_assert above. - if (ptrTp == OverheadType::kIndex) - ptrTp = OverheadType::kU64; - if (indTp == OverheadType::kIndex) - indTp = OverheadType::kU64; - - // Double matrices with all combinations of overhead storage. - CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, - uint64_t, double); - CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, - uint32_t, double); - CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, - uint16_t, double); - CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, - uint8_t, double); - CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, - uint64_t, double); - CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, - uint32_t, double); - CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, - uint16_t, double); - CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, - uint8_t, double); - CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, - uint64_t, double); - CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, - uint32_t, double); - CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, - uint16_t, double); - CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, - uint8_t, double); - CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, - uint64_t, double); - CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, - uint32_t, double); - CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, - uint16_t, double); - CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, - uint8_t, double); - - // Float matrices with all combinations of overhead storage. - CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, - uint64_t, float); - CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, - uint32_t, float); - CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, - uint16_t, float); - CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, - uint8_t, float); - CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, - uint64_t, float); - CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, - uint32_t, float); - CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, - uint16_t, float); - CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, - uint8_t, float); - CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, - uint64_t, float); - CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, - uint32_t, float); - CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, - uint16_t, float); - CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, - uint8_t, float); - CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, - uint64_t, float); - CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, - uint32_t, float); - CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, - uint16_t, float); - CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, - uint8_t, float); - - // Integral matrices with both overheads of the same type. - CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); - CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); - CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); - CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); - CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); - CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); - CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); - CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); - CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); - CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); - CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); - CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); - CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); - - // Complex matrices with wide overhead. - CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64); - CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32); - - // Unsupported case (add above if needed). - fputs("unsupported combination of types\n", stderr); - exit(1); -} -#undef CASE -#undef CASE_SECSAME - /// Output a sparse tensor, one per value type. #define IMPL_OUTSPARSETENSOR(VNAME, V) \ void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \