diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1060,8 +1060,11 @@ loc, tensorType, DenseElementsAttr::get( tensorType, - {APInt(64, 0), - APInt(64, op.getData().getType().getShape()[0])})); + ArrayRef{ + IntegerAttr::get(enc.getPointerType(), 0), + IntegerAttr::get( + enc.getPointerType(), + op.getData().getType().getShape()[0])})); field = rewriter.create(loc, memrefType, cstPtr); break; diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir @@ -14,6 +14,12 @@ dimLevelType = [ "compressed-nu", "singleton" ] }> +#SortedCOOI32 = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ], + pointerBitWidth = 32, + indexBitWidth = 32 +}> + module { // // Main driver. @@ -32,6 +38,12 @@ [ 7, 8]] > : tensor<3x2xindex> + %index32 = arith.constant dense< + [[ 1, 2], + [ 5, 6], + [ 7, 8]] + > : tensor<3x2xi32> + %s4 = sparse_tensor.pack %data, %index : tensor<3xf64>, tensor<3x2xindex> to tensor<10x10xf64, #SortedCOO> // CHECK:1 @@ -51,6 +63,27 @@ vector.print %2: index vector.print %v: f64 } + + %s5= sparse_tensor.pack %data, %index32 : tensor<3xf64>, tensor<3x2xi32> + to tensor<10x10xf64, #SortedCOOI32> + // CHECK-NEXT:1 + // CHECK-NEXT:2 + // CHECK-NEXT:1 + // + // CHECK-NEXT:5 + // CHECK-NEXT:6 + // CHECK-NEXT:2 + // + // CHECK-NEXT:7 + // CHECK-NEXT:8 + // CHECK-NEXT:3 + sparse_tensor.foreach in %s5 : tensor<10x10xf64, #SortedCOOI32> do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return } }