diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -162,12 +162,14 @@ primary = 1; else if (eltType.isF32()) primary = 2; - else if (eltType.isInteger(32)) + else if (eltType.isInteger(64)) primary = 3; - else if (eltType.isInteger(16)) + else if (eltType.isInteger(32)) primary = 4; - else if (eltType.isInteger(8)) + else if (eltType.isInteger(16)) primary = 5; + else if (eltType.isInteger(8)) + primary = 6; else return failure(); params.push_back( @@ -256,6 +258,8 @@ name = "sparseValuesF64"; else if (eltType.isF32()) name = "sparseValuesF32"; + else if (eltType.isInteger(64)) + name = "sparseValuesI64"; else if (eltType.isInteger(32)) name = "sparseValuesI32"; else if (eltType.isInteger(16)) diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -129,6 +129,7 @@ // Primary storage. virtual void getValues(std::vector **) { fatal("valf64"); } virtual void getValues(std::vector **) { fatal("valf32"); } + virtual void getValues(std::vector **) { fatal("vali64"); } virtual void getValues(std::vector **) { fatal("vali32"); } virtual void getValues(std::vector **) { fatal("vali16"); } virtual void getValues(std::vector **) { fatal("vali8"); } @@ -437,6 +438,7 @@ TEMPLATE(MemRef1DU32, uint32_t); TEMPLATE(MemRef1DU16, uint16_t); TEMPLATE(MemRef1DU8, uint8_t); +TEMPLATE(MemRef1DI64, int64_t); TEMPLATE(MemRef1DI32, int32_t); TEMPLATE(MemRef1DI16, int16_t); TEMPLATE(MemRef1DI8, int8_t); @@ -448,9 +450,10 @@ enum PrimaryTypeEnum : uint64_t { kF64 = 1, kF32 = 2, - kI32 = 3, - kI16 = 4, - kI8 = 5 + kI64 = 3, + kI32 = 4, + kI16 = 5, + kI8 = 6 }; void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata, @@ -499,6 +502,7 @@ CASE(kU8, kU8, kF32, uint8_t, uint8_t, float); // Integral matrices with same overhead storage. + CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t); CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t); CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t); CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t); @@ -535,6 +539,7 @@ IMPL2(MemRef1DU8, sparseIndices8, uint8_t, getIndices) IMPL1(MemRef1DF64, sparseValuesF64, double, getValues) IMPL1(MemRef1DF32, sparseValuesF32, float, getValues) +IMPL1(MemRef1DI64, sparseValuesI64, int64_t, getValues) IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues) IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues) IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)