diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir @@ -119,12 +119,31 @@ return %r : tensor<3x4xi64, #SparseMatrix> } + func @add_outer_2d(%arg0: tensor<2x3xf32, #SparseMatrix>) + -> tensor<2x3xf32, #SparseMatrix> { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0 = sparse_tensor.init[%c2, %c3] : tensor<2x3xf32, #SparseMatrix> + %1 = linalg.generic #trait_2d + ins(%arg0 : tensor<2x3xf32, #SparseMatrix>) + outs(%0 : tensor<2x3xf32, #SparseMatrix>) { + ^bb0(%arg1: f32, %arg2: f32): + %2 = linalg.index 0 : index + %3 = arith.index_cast %2 : index to i64 + %4 = arith.uitofp %3 : i64 to f32 + %5 = arith.addf %arg1, %4 : f32 + linalg.yield %5 : f32 + } -> tensor<2x3xf32, #SparseMatrix> + return %1 : tensor<2x3xf32, #SparseMatrix> + } + // // Main driver. // func @entry() { %c0 = arith.constant 0 : index %du = arith.constant -1 : i64 + %df = arith.constant -1.0 : f32 // Setup input sparse vector. %v1 = arith.constant sparse<[[2], [4]], [ 10, 20]> : tensor<8xi64> @@ -144,6 +163,10 @@ [ 1, 1, 3, 4 ] ]> : tensor<3x4xi64> %dm = sparse_tensor.convert %m2 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix> + // Setup input sparse f32 matrix. + %mf32 = arith.constant sparse<[[0,1], [1,2]], [10.0, 41.0]> : tensor<2x3xf32> + %sf32 = sparse_tensor.convert %mf32 : tensor<2x3xf32> to tensor<2x3xf32, #SparseMatrix> + // Call the kernels. %0 = call @sparse_index_1d_conj(%sv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64, #SparseVector> @@ -213,6 +236,19 @@ sparse_tensor.release %6 : tensor<3x4xi64, #SparseMatrix> sparse_tensor.release %7 : tensor<3x4xi64, #SparseMatrix> + // + // Call the f32 kernel, verify the result, release the resources. + // + // CHECK-NEXT: ( 0, 10, 0, 1, 1, 42 ) + // + %100 = call @add_outer_2d(%sf32) : (tensor<2x3xf32, #SparseMatrix>) + -> tensor<2x3xf32, #SparseMatrix> + %101 = sparse_tensor.values %100 : tensor<2x3xf32, #SparseMatrix> to memref + %102 = vector.transfer_read %101[%c0], %df: memref, vector<6xf32> + vector.print %102 : vector<6xf32> + sparse_tensor.release %sf32 : tensor<2x3xf32, #SparseMatrix> + sparse_tensor.release %100 : tensor<2x3xf32, #SparseMatrix> + return } }