diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -269,6 +269,26 @@ DO(C64, complex64) \ DO(C32, complex32) +// This x-macro calls its argument on every overhead type which has +// fixed-width. It excludes `index_type` because that type is often +// handled specially (e.g., by translating it into the architecture-dependent +// equivalent fixed-width overhead type). +#define FOREVERY_FIXED_O(DO) \ + DO(64, uint64_t) \ + DO(32, uint32_t) \ + DO(16, uint16_t) \ + DO(8, uint8_t) + +// This x-macro calls its argument on every overhead type, including +// `index_type`. Our naming convention uses an empty suffix for +// `index_type`, so the missing first argument when we call `DO` +// gets resolved to the empty token which can be concatenated as +// expected/intended; cf., +// and . +#define FOREVERY_O(DO) \ + FOREVERY_FIXED_O(DO) \ + DO(, index_type) + // Forward. template class SparseTensorEnumeratorBase; @@ -347,30 +367,18 @@ #undef DECL_NEWENUMERATOR /// Overhead storage. - virtual void getPointers(std::vector **, uint64_t) { - FATAL_PIV("p64"); - } - virtual void getPointers(std::vector **, uint64_t) { - FATAL_PIV("p32"); - } - virtual void getPointers(std::vector **, uint64_t) { - FATAL_PIV("p16"); - } - virtual void getPointers(std::vector **, uint64_t) { - FATAL_PIV("p8"); - } - virtual void getIndices(std::vector **, uint64_t) { - FATAL_PIV("i64"); - } - virtual void getIndices(std::vector **, uint64_t) { - FATAL_PIV("i32"); - } - virtual void getIndices(std::vector **, uint64_t) { - FATAL_PIV("i16"); - } - virtual void getIndices(std::vector **, uint64_t) { - FATAL_PIV("i8"); - } +#define DECL_GETPOINTERS(PNAME, P) \ + virtual void getPointers(std::vector

**, uint64_t) { \ + FATAL_PIV("getPointers" #PNAME); \ + } + FOREVERY_FIXED_O(DECL_GETPOINTERS) +#undef DECL_GETPOINTERS +#define DECL_GETINDICES(INAME, I) \ + virtual void getIndices(std::vector **, uint64_t) { \ + FATAL_PIV("getIndices" #INAME); \ + } + FOREVERY_FIXED_O(DECL_GETINDICES) +#undef DECL_GETINDICES /// Primary storage. #define DECL_GETVALUES(VNAME, V) \ @@ -1576,18 +1584,16 @@ ref->strides[0] = 1; \ } /// Methods that provide direct access to pointers. -IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers) -IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers) -IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers) -IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers) -IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers) +#define IMPL_SPARSEPOINTERS(PNAME, P) \ + IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers) +FOREVERY_O(IMPL_SPARSEPOINTERS) +#undef IMPL_SPARSEPOINTERS /// Methods that provide direct access to indices. -IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices) -IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices) -IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices) -IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices) -IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices) +#define IMPL_SPARSEINDICES(INAME, I) \ + IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices) +FOREVERY_O(IMPL_SPARSEINDICES) +#undef IMPL_SPARSEINDICES #undef IMPL_GETOVERHEAD /// Helper to add value to coordinate scheme, one per value type.