diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/SparseTensor/Pipelines/Passes.h" #include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" @@ -73,6 +74,8 @@ pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions())); pm.addPass(createMemRefToLLVMPass()); pm.addNestedPass(createConvertComplexToStandardPass()); + pm.addNestedPass( + mlir::arith::createArithmeticExpandOpsPass()); pm.addNestedPass(createConvertMathToLLVMPass()); pm.addPass(createConvertMathToLibmPass()); pm.addPass(createConvertComplexToLibmPass()); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir @@ -43,27 +43,26 @@ module { // Creates a new sparse vector using the minimum values from two input sparse vectors. // When there is no overlap, include the present value in the output. - func.func @vector_min(%arga: tensor, - %argb: tensor) -> tensor { + func.func @vector_min(%arga: tensor, + %argb: tensor) -> tensor { %c = arith.constant 0 : index - %d = tensor.dim %arga, %c : tensor - %xv = bufferization.alloc_tensor(%d) : tensor + %d = tensor.dim %arga, %c : tensor + %xv = bufferization.alloc_tensor(%d) : tensor %0 = linalg.generic #trait_vec_op - ins(%arga, %argb: tensor, tensor) - outs(%xv: tensor) { - ^bb(%a: f64, %b: f64, %x: f64): - %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + ins(%arga, %argb: tensor, tensor) + outs(%xv: tensor) { + ^bb(%a: i32, %b: i32, %x: i32): + %1 = sparse_tensor.binary %a, %b : i32, i32 to i32 overlap={ - ^bb0(%a0: f64, %b0: f64): - %cmp = arith.cmpf "olt", %a0, %b0 : f64 - %2 = arith.select %cmp, %a0, %b0: f64 - sparse_tensor.yield %2 : f64 + ^bb0(%a0: i32, %b0: i32): + %2 = arith.minsi %a0, %b0: i32 + sparse_tensor.yield %2 : i32 } left=identity right=identity - linalg.yield %1 : f64 - } -> tensor - return %0 : tensor + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor } // Creates a new sparse vector by multiplying a sparse vector with a dense vector. @@ -428,8 +427,13 @@ [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1.] > : tensor<32xf64> + %v1_si = arith.fptosi %v1 : tensor<32xf64> to tensor<32xi32> + %v2_si = arith.fptosi %v2 : tensor<32xf64> to tensor<32xi32> + %sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor %sv2 = sparse_tensor.convert %v2 : tensor<32xf64> to tensor + %sv1_si = sparse_tensor.convert %v1_si : tensor<32xi32> to tensor + %sv2_si = sparse_tensor.convert %v2_si : tensor<32xi32> to tensor %dv3 = tensor.cast %v3 : tensor<32xf64> to tensor // Setup sparse matrices. @@ -459,9 +463,9 @@ %sm4 = sparse_tensor.convert %m4 : tensor<4x4xf64> to tensor<4x4xf64, #DCSR> // Call sparse vector kernels. - %0 = call @vector_min(%sv1, %sv2) - : (tensor, - tensor) -> tensor + %0 = call @vector_min(%sv1_si, %sv2_si) + : (tensor, + tensor) -> tensor %1 = call @vector_mul(%sv1, %dv3) : (tensor, tensor) -> tensor @@ -494,7 +498,7 @@ // CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 ) // CHECK-NEXT: ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -1, -1, -1, -1, -1, -1 ) // CHECK-NEXT: ( 0, 11, 0, 12, 13, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 18, 19, 0, 20 ) - // CHECK-NEXT: ( 1, 11, 2, 13, 14, 3, 15, 4, 16, 5, 6, 7, 8, 9, -1, -1 ) + // CHECK-NEXT: ( 1, 11, 2, 13, 14, 3, 15, 4, 16, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) // CHECK-NEXT: ( 1, 11, 0, 2, 13, 0, 0, 0, 0, 0, 14, 3, 0, 0, 0, 0, 15, 4, 16, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 ) // CHECK-NEXT: ( 0, 6, 3, 28, 0, 6, 56, 72, 9, -1, -1, -1, -1, -1, -1, -1 ) // CHECK-NEXT: ( 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 28, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 56, 72, 0, 9 ) @@ -518,7 +522,7 @@ // call @dump_vec(%sv1) : (tensor) -> () call @dump_vec(%sv2) : (tensor) -> () - call @dump_vec(%0) : (tensor) -> () + call @dump_vec_i32(%0) : (tensor) -> () call @dump_vec(%1) : (tensor) -> () call @dump_vec(%2) : (tensor) -> () call @dump_vec_i32(%3) : (tensor) -> () @@ -537,7 +541,7 @@ bufferization.dealloc_tensor %sm2 : tensor bufferization.dealloc_tensor %sm3 : tensor<4x4xf64, #DCSR> bufferization.dealloc_tensor %sm4 : tensor<4x4xf64, #DCSR> - bufferization.dealloc_tensor %0 : tensor + bufferization.dealloc_tensor %0 : tensor bufferization.dealloc_tensor %1 : tensor bufferization.dealloc_tensor %2 : tensor bufferization.dealloc_tensor %3 : tensor