diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1242,10 +1242,11 @@ }); MutSparseTensorDescriptor desc(stt, fields); + Value c0 = constantIndex(rewriter, loc, 0); Value c1 = constantIndex(rewriter, loc, 1); Value c2 = constantIndex(rewriter, loc, 2); - Value posBack = c1; // index to the last value in the postion array - Value memSize = c2; // memory size for current array + Value posBack = c0; // index to the last value in the postion array + Value memSize = c1; // memory size for current array Level trailCOOStart = getCOOStart(stt.getEncoding()); Level trailCOORank = stt.getLvlRank() - trailCOOStart; @@ -1266,7 +1267,7 @@ DimLevelType dlt = stt.getLvlType(lvl); // Simply forwards the position index when this is a dense level. if (isDenseDLT(dlt)) { - memSize = rewriter.create(loc, lvlSize, posBack); + memSize = rewriter.create(loc, lvlSize, memSize); posBack = rewriter.create(loc, memSize, c1); continue; } @@ -1276,6 +1277,10 @@ if (isCompressedWithHiDLT(dlt)) { memSize = rewriter.create(loc, memSize, c2); posBack = rewriter.create(loc, memSize, c1); + } else { + assert(isCompressedDLT(dlt)); + posBack = memSize; + memSize = rewriter.create(loc, memSize, c1); } desc.setPosMemSize(rewriter, loc, lvl, memSize); // The last value in position array is the memory size for next level. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir @@ -164,6 +164,16 @@ vector.print %v: f64 } + %d_csr = tensor.empty() : tensor<3xf64> + %p_csr = tensor.empty() : tensor<3xi32> + %i_csr = tensor.empty() : tensor<3xi32> + %rd_csr, %rp_csr, %ri_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR> + outs(%d_csr, %p_csr, %i_csr : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>) + -> tensor<3xf64>, tensor<3xi32>, tensor<3xi32> + + // CHECK-NEXT: ( 1, 2, 3 ) + %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<3xf64>, vector<3xf64> + vector.print %vd_csr : vector<3xf64> // CHECK-NEXT:1 // CHECK-NEXT:2