diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir @@ -17,12 +17,11 @@ // // Computes C = A x B with all matrices dense. // - func.func @matmul1(%A: tensor<4x8xf64>, - %B: tensor<8x4xf64>) -> tensor<4x4xf64> { - %C = arith.constant dense<0.0> : tensor<4x4xf64> + func.func @matmul1(%A: tensor<4x8xf64>, %B: tensor<8x4xf64>, + %C: tensor<4x4xf64>) -> tensor<4x4xf64> { %D = linalg.matmul ins(%A, %B: tensor<4x8xf64>, tensor<8x4xf64>) - outs(%C: tensor<4x4xf64>) -> tensor<4x4xf64> + outs(%C: tensor<4x4xf64>) -> tensor<4x4xf64> return %D: tensor<4x4xf64> } @@ -30,7 +29,7 @@ // Computes C = A x B with all matrices sparse (SpMSpM) in CSR. // func.func @matmul2(%A: tensor<4x8xf64, #CSR>, - %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { + %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> %D = linalg.matmul ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>) @@ -42,7 +41,7 @@ // Computes C = A x B with all matrices sparse (SpMSpM) in DCSR. // func.func @matmul3(%A: tensor<4x8xf64, #DCSR>, - %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> { + %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> { %C = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR> %D = linalg.matmul ins(%A, %B: tensor<4x8xf64, #DCSR>, tensor<8x4xf64, #DCSR>) @@ -91,6 +90,7 @@ [ 0.0, 0.0, 6.0, 0.0 ], [ 0.0, 0.0, 7.0, 8.0 ] ]> : tensor<8x4xf64> + %zero = arith.constant dense<0.0> : tensor<4x4xf64> // Convert all these matrices to sparse format. %a1 = sparse_tensor.convert %da : tensor<4x8xf64> to tensor<4x8xf64, #CSR> @@ -103,8 +103,8 @@ %b4 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR> // Call kernels with dense. - %0 = call @matmul1(%da, %db) - : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64> + %0 = call @matmul1(%da, %db, %zero) + : (tensor<4x8xf64>, tensor<8x4xf64>, tensor<4x4xf64>) -> tensor<4x4xf64> %1 = call @matmul2(%a1, %b1) : (tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> @@ -113,8 +113,8 @@ tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> // Call kernels with one sparse. - %3 = call @matmul1(%sa, %db) - : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64> + %3 = call @matmul1(%sa, %db, %zero) + : (tensor<4x8xf64>, tensor<8x4xf64>, tensor<4x4xf64>) -> tensor<4x4xf64> %4 = call @matmul2(%a3, %b1) : (tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> @@ -123,8 +123,8 @@ tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> // Call kernels with sparse. - %6 = call @matmul1(%sa, %sb) - : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64> + %6 = call @matmul1(%sa, %sb, %zero) + : (tensor<4x8xf64>, tensor<8x4xf64>, tensor<4x4xf64>) -> tensor<4x4xf64> %7 = call @matmul2(%a3, %b3) : (tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> @@ -239,13 +239,6 @@ sparse_tensor.release %7 : tensor<4x4xf64, #CSR> sparse_tensor.release %8 : tensor<4x4xf64, #DCSR> - // TODO(springerm): needed? - %m0 = bufferization.to_memref %0 : memref<4x4xf64> - memref.dealloc %m0 : memref<4x4xf64> - %m3 = bufferization.to_memref %3 : memref<4x4xf64> - memref.dealloc %m3 : memref<4x4xf64> - %m6 = bufferization.to_memref %6 : memref<4x4xf64> - memref.dealloc %m6 : memref<4x4xf64> return } }