diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir @@ -28,9 +28,6 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// Interop between linalg/sparse leaves some issues to be revolved: -// UNSUPPORTED: asan - #SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> #trait_sampled_dense_dense = { @@ -72,10 +69,7 @@ func @sampled_dd(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { - %d = arith.constant 0.0 : f64 - - %0 = linalg.init_tensor [8, 8] : tensor<8x8xf64> - %1 = linalg.fill(%d, %0) : f64, tensor<8x8xf64> -> tensor<8x8xf64> + %1 = arith.constant dense<0.0> : tensor<8x8xf64> %2 = linalg.generic #trait_sampled_dense_dense ins(%args, %arga, %argb: tensor<8x8xf64, #SM>, tensor<8x8xf64>, tensor<8x8xf64>) @@ -94,11 +88,8 @@ // func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, - %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { - %d = arith.constant 0.0 : f64 - - %0 = linalg.init_tensor [8, 8] : tensor<8x8xf64> - %1 = linalg.fill(%d, %0) : f64, tensor<8x8xf64> -> tensor<8x8xf64> + %argb: tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8x8xf64>) { + %1 = arith.constant dense<0.0> : tensor<8x8xf64> %2 = linalg.generic #trait_matmul ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>) outs(%1 : tensor<8x8xf64>) { @@ -108,17 +99,16 @@ linalg.yield %q : f64 } -> tensor<8x8xf64> - %3 = linalg.init_tensor [8, 8] : tensor<8x8xf64> - %4 = linalg.fill(%d, %3) : f64, tensor<8x8xf64> -> tensor<8x8xf64> - %5 = linalg.generic #trait_scale + %3 = arith.constant dense<0.0> : tensor<8x8xf64> + %4 = linalg.generic #trait_scale ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>) - outs(%4 : tensor<8x8xf64>) { + outs(%3 : tensor<8x8xf64>) { ^bb0(%t: f64, %s: f64, %x: f64): %r = arith.mulf %t, %s : f64 linalg.yield %r : f64 } -> tensor<8x8xf64> - return %5 : tensor<8x8xf64> + return %4, %2 : tensor<8x8xf64>, tensor<8x8xf64> } // @@ -140,9 +130,9 @@ %0 = call @sampled_dd(%s, %a, %b) : (tensor<8x8xf64, #SM>, tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64> - %1 = call @sampled_dd_unfused(%s, %a, %b) + %1, %2 = call @sampled_dd_unfused(%s, %a, %b) : (tensor<8x8xf64, #SM>, - tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64> + tensor<8x8xf64>, tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8x8xf64>) // Verify the outputs. // @@ -158,6 +148,7 @@ // %m0 = bufferization.to_memref %0 : memref<8x8xf64> %m1 = bufferization.to_memref %1 : memref<8x8xf64> + %m2 = bufferization.to_memref %2 : memref<8x8xf64> %v0 = vector.transfer_read %m0[%c0, %c0], %d0 : memref<8x8xf64>, vector<8x8xf64> %v1 = vector.transfer_read %m1[%c0, %c0], %d0 @@ -169,6 +160,7 @@ sparse_tensor.release %s : tensor<8x8xf64, #SM> memref.dealloc %m0 : memref<8x8xf64> memref.dealloc %m1 : memref<8x8xf64> + memref.dealloc %m2 : memref<8x8xf64> return }