diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -54,8 +54,8 @@ // Multiply a sparse matrix A with a dense vector b into a dense vector x. func.func @kernel_matvec(%arga: tensor, - %argb: tensor, - %argx: tensor) -> tensor { + %argb: tensor, + %argx: tensor) -> tensor { %0 = linalg.generic #matvec ins(%arga, %argb: tensor, tensor) outs(%argx: tensor) { @@ -142,17 +142,19 @@ ```mlir Before: - %c1 = arith.constant 1 : index - %0 = sparse_tensor.pointers %arg0, %c1 - : tensor<8x8xf32, #sparse_tensor.encoding<{ - dimLevelType = [ "dense", "compressed" ], - pointerBitWidth = 0, - indexBitWidth = 0 - }>> to memref + func.func @foo(%arg0: tensor<8x8xf32, #CSR>) -> memref { + %0 = sparse_tensor.pointers %arg0 {dimension = 1 : index} + : tensor<8x8xf32, #CSR> to memref + return %0 : memref + } After: - %c1 = arith.constant 1 : index - %0 = call @sparsePointers(%arg0, %c1) : (!llvm.ptr, index) -> memref + func.func @foo(%arg0: !llvm.ptr) -> memref { + %c1 = arith.constant 1 : index + %0 = call @sparsePointers0(%arg0, %c1) + : (!llvm.ptr, index) -> memref + return %0 : memref + } ``` }]; let constructor = "mlir::createSparseTensorConversionPass()"; @@ -186,7 +188,21 @@ Example of the conversion: ```mlir - TBD + Before: + func.func @foo(%arg0: tensor<8x8xf32, #CSR>) -> memref { + %0 = sparse_tensor.pointers %arg0 {dimension = 1 : index} + : tensor<8x8xf32, #CSR> to memref + return %0 : memref + } + + After: + func.func @foo(%arg0: memref<2xindex>, + %arg1: memref<3xindex>, + %arg2: memref, + %arg3: memref, + %arg4: memref) -> memref { + return %arg2 : memref + } ``` }]; let constructor = "mlir::createSparseTensorCodegenPass()";