diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1242,10 +1242,11 @@ }); MutSparseTensorDescriptor desc(stt, fields); + Value c0 = constantIndex(rewriter, loc, 0); Value c1 = constantIndex(rewriter, loc, 1); Value c2 = constantIndex(rewriter, loc, 2); - Value posBack = c1; // index to the last value in the postion array - Value memSize = c2; // memory size for current array + Value posBack = c0; // index to the last value in the postion array + Value memSize = c1; // memory size for current array Level trailCOOStart = getCOOStart(stt.getEncoding()); Level trailCOORank = stt.getLvlRank() - trailCOOStart; @@ -1266,7 +1267,7 @@ DimLevelType dlt = stt.getLvlType(lvl); // Simply forwards the position index when this is a dense level. if (isDenseDLT(dlt)) { - memSize = rewriter.create(loc, lvlSize, posBack); + memSize = rewriter.create(loc, lvlSize, memSize); posBack = rewriter.create(loc, memSize, c1); continue; } @@ -1276,6 +1277,10 @@ if (isCompressedWithHiDLT(dlt)) { memSize = rewriter.create(loc, memSize, c2); posBack = rewriter.create(loc, memSize, c1); + } else { + assert(isCompressedDLT(dlt)); + posBack = memSize; + memSize = rewriter.create(loc, memSize, c1); } desc.setPosMemSize(rewriter, loc, lvl, memSize); // The last value in position array is the memory size for next level. diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse | FileCheck %s +// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s #COO = #sparse_tensor.encoding<{ lvlTypes = ["compressed-nu", "singleton"], @@ -9,25 +9,25 @@ // CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf64>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<2xindex>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<6x2xi32>) -// CHECK-DAG: %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex> -// CHECK-DAG: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref -// CHECK-DAG: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32> -// CHECK-DAG: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32> -// CHECK-DAG: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref -// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64> -// CHECK-DAG: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref -// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 100 : index -// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] lvl_sz at 0 with %[[VAL_13]] -// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_12]] -// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref -// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_12]] : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 100 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex> +// CHECK-DAG: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<2xindex> to memref +// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32> +// CHECK-DAG: %[[VAL_9:.*]] = memref.collapse_shape %[[VAL_8]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32> +// CHECK-DAG: %[[VAL_10:.*]] = memref.cast %[[VAL_9]] : memref<12xi32> to memref +// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64> +// CHECK-DAG: %[[VAL_12:.*]] = memref.cast %[[VAL_11]] : memref<6xf64> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init +// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] lvl_sz at 0 with %[[VAL_4]] +// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_3]] +// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex> +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index // CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_17]] -// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_13]] +// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_4]] // CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] val_mem_sz with %[[VAL_16]] -// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_20]] +// CHECK: return %[[VAL_7]], %[[VAL_10]], %[[VAL_12]], %[[VAL_20]] // CHECK: } func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinates: tensor<6x2xi32>) -> tensor<100x100xf64, #COO> { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir @@ -81,6 +81,10 @@ %s5= sparse_tensor.pack %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32> to tensor<10x10xf64, #SortedCOOI32> + %csr_data = arith.constant dense< + [ 1.0, 2.0, 3.0, 4.0] + > : tensor<4xf64> + %csr_pos32 = arith.constant dense< [0, 1, 3] > : tensor<3xi32> @@ -88,7 +92,7 @@ %csr_index32 = arith.constant dense< [1, 0, 1] > : tensor<3xi32> - %csr= sparse_tensor.pack %data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32> + %csr= sparse_tensor.pack %csr_data, %csr_pos32, %csr_index32 : tensor<4xf64>, tensor<3xi32>, tensor<3xi32> to tensor<2x2xf64, #CSR> %bdata = arith.constant dense< @@ -164,6 +168,16 @@ vector.print %v: f64 } + %d_csr = tensor.empty() : tensor<4xf64> + %p_csr = tensor.empty() : tensor<3xi32> + %i_csr = tensor.empty() : tensor<3xi32> + %rd_csr, %rp_csr, %ri_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR> + outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>) + -> tensor<4xf64>, tensor<3xi32>, tensor<3xi32> + + // CHECK-NEXT: ( 1, 2, 3, {{.*}} ) + %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64> + vector.print %vd_csr : vector<4xf64> // CHECK-NEXT:1 // CHECK-NEXT:2