diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -31,9 +31,7 @@ namespace { -// TODO: start using these when insertions are implemented -// static constexpr uint64_t DimSizesIdx = 0; -// static constexpr uint64_t DimCursorIdx = 1; +static constexpr uint64_t DimSizesIdx = 0; static constexpr uint64_t MemSizesIdx = 2; static constexpr uint64_t FieldsIdx = 3; @@ -88,11 +86,12 @@ if (!ShapedType::isDynamic(shape[dim])) return constantIndex(rewriter, loc, shape[dim]); - // Any other query can consult the dimSizes array at field 0 using, + // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. auto tuple = getTuple(adaptedValue); Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim)); - return rewriter.create(loc, tuple.getInputs().front(), idx) + return rewriter + .create(loc, tuple.getInputs()[DimSizesIdx], idx) .getResult(); }