diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -230,6 +230,15 @@ unsigned iteratorPos = 0; }; +// See +#define FOREVERY_V(DO) \ + DO(F64, double) \ + DO(F32, float) \ + DO(I64, int64_t) \ + DO(I32, int32_t) \ + DO(I16, int16_t) \ + DO(I8, int8_t) + // Forward. template class SparseTensorEnumeratorBase; @@ -294,30 +303,13 @@ } /// Allocate a new enumerator. - virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, - const uint64_t *) const { - fatal("enumf64"); - } - virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, - const uint64_t *) const { - fatal("enumf32"); - } - virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, - const uint64_t *) const { - fatal("enumi64"); - } - virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, - const uint64_t *) const { - fatal("enumi32"); - } - virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, - const uint64_t *) const { - fatal("enumi16"); - } - virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, - const uint64_t *) const { - fatal("enumi8"); +#define DECL_NEWENUMERATOR(VNAME, V) \ + virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, \ + const uint64_t *) const { \ + fatal("newEnumerator" #VNAME); \ } +FOREVERY_V(DECL_NEWENUMERATOR) +#undef DECL_NEWENUMERATOR /// Overhead storage. virtual void getPointers(std::vector **, uint64_t) { fatal("p64"); } @@ -330,40 +322,24 @@ virtual void getIndices(std::vector **, uint64_t) { fatal("i8"); } /// Primary storage. - virtual void getValues(std::vector **) { fatal("valf64"); } - virtual void getValues(std::vector **) { fatal("valf32"); } - virtual void getValues(std::vector **) { fatal("vali64"); } - virtual void getValues(std::vector **) { fatal("vali32"); } - virtual void getValues(std::vector **) { fatal("vali16"); } - virtual void getValues(std::vector **) { fatal("vali8"); } +#define DECL_GETVALUES(VNAME, V) \ + virtual void getValues(std::vector **) { fatal("getValues" #VNAME); } +FOREVERY_V(DECL_GETVALUES) +#undef DECL_GETVALUES /// Element-wise insertion in lexicographic index order. - virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); } - virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); } - virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); } - virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); } - virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); } - virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); } +#define DECL_LEXINSERT(VNAME, V) \ + virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); } +FOREVERY_V(DECL_LEXINSERT) +#undef DECL_LEXINSERT /// Expanded insertion. - virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) { - fatal("expf64"); - } - virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) { - fatal("expf32"); - } - virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) { - fatal("expi64"); - } - virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) { - fatal("expi32"); - } - virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) { - fatal("expi16"); - } - virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) { - fatal("expi8"); +#define DECL_EXPINSERT(VNAME, V) \ + virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \ + fatal("expInsert" #VNAME); \ } +FOREVERY_V(DECL_EXPINSERT) +#undef DECL_EXPINSERT /// Finishes insertion. virtual void endInsert() = 0; @@ -1415,17 +1391,23 @@ } #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) +// TODO(wrengr): move `_mlir_ciface_newSparseTensor` closer to these +// macro definitions, but as a separate change so as not to muddy the diff. -#define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \ - void _mlir_ciface_##NAME(StridedMemRefType *ref, void *tensor) { \ +/// Methods that provide direct access to values. +#define IMPL_SPARSEVALUES(VNAME, V) \ + void _mlir_ciface_sparseValues##VNAME(StridedMemRefType *ref, \ + void *tensor) { \ assert(ref &&tensor); \ - std::vector *v; \ - static_cast(tensor)->LIB(&v); \ + std::vector *v; \ + static_cast(tensor)->getValues(&v); \ ref->basePtr = ref->data = v->data(); \ ref->offset = 0; \ ref->sizes[0] = v->size(); \ ref->strides[0] = 1; \ } +FOREVERY_V(IMPL_SPARSEVALUES) +#undef IMPL_SPARSEVALUES #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \ void _mlir_ciface_##NAME(StridedMemRefType *ref, void *tensor, \ @@ -1438,12 +1420,27 @@ ref->sizes[0] = v->size(); \ ref->strides[0] = 1; \ } +/// Methods that provide direct access to pointers. +IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers) +IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers) +IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers) +IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers) +IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers) -#define IMPL_ADDELT(NAME, TYPE) \ - void *_mlir_ciface_##NAME(void *tensor, TYPE value, \ - StridedMemRefType *iref, \ - StridedMemRefType *pref) { \ - assert(tensor &&iref &&pref); \ +/// Methods that provide direct access to indices. +IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices) +IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices) +IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices) +IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices) +IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices) +#undef IMPL_GETOVERHEAD + +/// Helper to add value to coordinate scheme, one per value type. +#define IMPL_ADDELT(VNAME, V) \ + void *_mlir_ciface_addElt##VNAME(void *coo, V value, \ + StridedMemRefType *iref, \ + StridedMemRefType *pref) { \ + assert(coo &&iref &&pref); \ assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ assert(iref->sizes[0] == pref->sizes[0]); \ const index_type *indx = iref->data + iref->offset; \ @@ -1452,21 +1449,23 @@ std::vector indices(isize); \ for (uint64_t r = 0; r < isize; r++) \ indices[perm[r]] = indx[r]; \ - static_cast *>(tensor)->add(indices, value); \ - return tensor; \ + static_cast *>(coo)->add(indices, value); \ + return coo; \ } +FOREVERY_V(IMPL_ADDELT) +#undef IMPL_ADDELT -#define IMPL_GETNEXT(NAME, V) \ - bool _mlir_ciface_##NAME(void *tensor, \ - StridedMemRefType *iref, \ - StridedMemRefType *vref) { \ - assert(tensor &&iref &&vref); \ +/// Helper to enumerate elements of coordinate scheme, one per value type. +#define IMPL_GETNEXT(VNAME, V) \ + bool _mlir_ciface_getNext##VNAME(void *coo, \ + StridedMemRefType *iref, \ + StridedMemRefType *vref) { \ + assert(coo &&iref &&vref); \ assert(iref->strides[0] == 1); \ index_type *indx = iref->data + iref->offset; \ V *value = vref->data + vref->offset; \ const uint64_t isize = iref->sizes[0]; \ - auto iter = static_cast *>(tensor); \ - const Element *elem = iter->getNext(); \ + const Element *elem = static_cast *>(coo)->getNext();\ if (elem == nullptr) \ return false; \ for (uint64_t r = 0; r < isize; r++) \ @@ -1474,19 +1473,26 @@ *value = elem->value; \ return true; \ } +FOREVERY_V(IMPL_GETNEXT) +#undef IMPL_GETNEXT -#define IMPL_LEXINSERT(NAME, V) \ - void _mlir_ciface_##NAME(void *tensor, \ - StridedMemRefType *cref, V val) { \ +/// Insert elements in lexicographical index order, one per value type. +#define IMPL_LEXINSERT(VNAME, V) \ + void _mlir_ciface_lexInsert##VNAME(void *tensor, \ + StridedMemRefType *cref, \ + V val) { \ assert(tensor &&cref); \ assert(cref->strides[0] == 1); \ index_type *cursor = cref->data + cref->offset; \ assert(cursor); \ static_cast(tensor)->lexInsert(cursor, val); \ } +FOREVERY_V(IMPL_LEXINSERT) +#undef IMPL_LEXINSERT -#define IMPL_EXPINSERT(NAME, V) \ - void _mlir_ciface_##NAME( \ +/// Insert using expansion, one per value type. +#define IMPL_EXPINSERT(VNAME, V) \ + void _mlir_ciface_expInsert##VNAME( \ void *tensor, StridedMemRefType *cref, \ StridedMemRefType *vref, StridedMemRefType *fref, \ StridedMemRefType *aref, index_type count) { \ @@ -1503,6 +1509,8 @@ static_cast(tensor)->expInsert( \ cursor, values, filled, added, count); \ } +FOREVERY_V(IMPL_EXPINSERT) +#undef IMPL_EXPINSERT // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor // can safely rewrite kIndex to kU64. We make this assertion to guarantee @@ -1629,88 +1637,16 @@ fputs("unsupported combination of types\n", stderr); exit(1); } - -/// Methods that provide direct access to pointers. -IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers) -IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers) -IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers) -IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers) -IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers) - -/// Methods that provide direct access to indices. -IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices) -IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices) -IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices) -IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices) -IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices) - -/// Methods that provide direct access to values. -IMPL_SPARSEVALUES(sparseValuesF64, double, getValues) -IMPL_SPARSEVALUES(sparseValuesF32, float, getValues) -IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues) -IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues) -IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues) -IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues) - -/// Helper to add value to coordinate scheme, one per value type. -IMPL_ADDELT(addEltF64, double) -IMPL_ADDELT(addEltF32, float) -IMPL_ADDELT(addEltI64, int64_t) -IMPL_ADDELT(addEltI32, int32_t) -IMPL_ADDELT(addEltI16, int16_t) -IMPL_ADDELT(addEltI8, int8_t) - -/// Helper to enumerate elements of coordinate scheme, one per value type. -IMPL_GETNEXT(getNextF64, double) -IMPL_GETNEXT(getNextF32, float) -IMPL_GETNEXT(getNextI64, int64_t) -IMPL_GETNEXT(getNextI32, int32_t) -IMPL_GETNEXT(getNextI16, int16_t) -IMPL_GETNEXT(getNextI8, int8_t) - -/// Insert elements in lexicographical index order, one per value type. -IMPL_LEXINSERT(lexInsertF64, double) -IMPL_LEXINSERT(lexInsertF32, float) -IMPL_LEXINSERT(lexInsertI64, int64_t) -IMPL_LEXINSERT(lexInsertI32, int32_t) -IMPL_LEXINSERT(lexInsertI16, int16_t) -IMPL_LEXINSERT(lexInsertI8, int8_t) - -/// Insert using expansion, one per value type. -IMPL_EXPINSERT(expInsertF64, double) -IMPL_EXPINSERT(expInsertF32, float) -IMPL_EXPINSERT(expInsertI64, int64_t) -IMPL_EXPINSERT(expInsertI32, int32_t) -IMPL_EXPINSERT(expInsertI16, int16_t) -IMPL_EXPINSERT(expInsertI8, int8_t) - #undef CASE -#undef IMPL_SPARSEVALUES -#undef IMPL_GETOVERHEAD -#undef IMPL_ADDELT -#undef IMPL_GETNEXT -#undef IMPL_LEXINSERT -#undef IMPL_EXPINSERT +#undef CASE_SECSAME /// Output a sparse tensor, one per value type. -void outSparseTensorF64(void *tensor, void *dest, bool sort) { - return outSparseTensor(tensor, dest, sort); -} -void outSparseTensorF32(void *tensor, void *dest, bool sort) { - return outSparseTensor(tensor, dest, sort); -} -void outSparseTensorI64(void *tensor, void *dest, bool sort) { - return outSparseTensor(tensor, dest, sort); -} -void outSparseTensorI32(void *tensor, void *dest, bool sort) { - return outSparseTensor(tensor, dest, sort); -} -void outSparseTensorI16(void *tensor, void *dest, bool sort) { - return outSparseTensor(tensor, dest, sort); -} -void outSparseTensorI8(void *tensor, void *dest, bool sort) { - return outSparseTensor(tensor, dest, sort); -} +#define IMPL_OUTSPARSETENSOR(VNAME, V) \ + void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \ + return outSparseTensor(coo, dest, sort); \ + } +FOREVERY_V(IMPL_OUTSPARSETENSOR) +#undef IMPL_OUTSPARSETENSOR //===----------------------------------------------------------------------===// // @@ -1754,12 +1690,7 @@ void delSparseTensorCOO##VNAME(void *coo) { \ delete static_cast *>(coo); \ } -IMPL_DELCOO(F64, double) -IMPL_DELCOO(F32, float) -IMPL_DELCOO(I64, int64_t) -IMPL_DELCOO(I32, int32_t) -IMPL_DELCOO(I16, int16_t) -IMPL_DELCOO(I8, int8_t) +FOREVERY_V(IMPL_DELCOO) #undef IMPL_DELCOO /// Initializes sparse tensor from a COO-flavored format expressed using C-style @@ -1785,42 +1716,16 @@ // // TODO: generalize beyond 64-bit indices. // -void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape, - double *values, uint64_t *indices, - uint64_t *perm, uint8_t *sparse) { - return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, - sparse); -} -void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape, - float *values, uint64_t *indices, - uint64_t *perm, uint8_t *sparse) { - return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, - sparse); -} -void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape, - int64_t *values, uint64_t *indices, - uint64_t *perm, uint8_t *sparse) { - return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, - sparse); -} -void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape, - int32_t *values, uint64_t *indices, - uint64_t *perm, uint8_t *sparse) { - return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, - sparse); -} -void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape, - int16_t *values, uint64_t *indices, - uint64_t *perm, uint8_t *sparse) { - return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, - sparse); -} -void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape, - int8_t *values, uint64_t *indices, - uint64_t *perm, uint8_t *sparse) { - return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, - sparse); -} +#define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \ + void *convertToMLIRSparseTensor##VNAME(uint64_t rank, uint64_t nse, \ + uint64_t *shape, V *values, \ + uint64_t *indices, uint64_t *perm, \ + uint8_t *sparse) { \ + return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, \ + sparse); \ +} +FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR) +#undef IMPL_CONVERTTOMLIRSPARSETENSOR /// Converts a sparse tensor to COO-flavored format expressed using C-style /// data structures. The expected output parameters are pointers for these @@ -1842,36 +1747,14 @@ // TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions // compressed // -void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank, - uint64_t *pNse, uint64_t **pShape, - double **pValues, uint64_t **pIndices) { - fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, pIndices); -} -void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank, - uint64_t *pNse, uint64_t **pShape, - float **pValues, uint64_t **pIndices) { - fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, pIndices); -} -void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank, - uint64_t *pNse, uint64_t **pShape, - int64_t **pValues, uint64_t **pIndices) { - fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, pIndices); -} -void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank, - uint64_t *pNse, uint64_t **pShape, - int32_t **pValues, uint64_t **pIndices) { - fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, pIndices); -} -void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank, - uint64_t *pNse, uint64_t **pShape, - int16_t **pValues, uint64_t **pIndices) { - fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, pIndices); -} -void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank, - uint64_t *pNse, uint64_t **pShape, - int8_t **pValues, uint64_t **pIndices) { - fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, pIndices); -} +#define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V) \ + void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank, \ + uint64_t *pNse, uint64_t **pShape, \ + V **pValues, uint64_t **pIndices) { \ + fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, pIndices); \ + } +FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR) +#undef IMPL_CONVERTFROMMLIRSPARSETENSOR } // extern "C"