diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir @@ -3,6 +3,8 @@ // RUN: --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \ // RUN: --sparsification | FileCheck %s +#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> + #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> // CHECK-LABEL: func @matmul1( @@ -255,3 +257,64 @@ outs(%output : tensor<5x6xi64>) -> tensor<5x6xi64> return %0: tensor<5x6xi64> } + +// CHECK-LABEL: func @sparse_dot( +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0:.*]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1:.*]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK-DAG: %[[VAL_11:.*]] = memref.alloc() : memref +// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2:.*]] : memref +// CHECK-DAG: memref.copy %[[VAL_12]], %[[VAL_11]] : memref to memref +// CHECK-DAG: %[[VAL_13:.*]] = memref.load %[[VAL_11]][] : memref +// CHECK-DAG: %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK-DAG: %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK-DAG: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref +// CHECK-DAG: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_18:.*]]:3 = scf.while (%[[VAL_19:.*]] = %[[VAL_14]], %[[VAL_20:.*]] = %[[VAL_16]], %[[VAL_21:.*]] = %[[VAL_13]]) : (index, index, f32) -> (index, index, f32) { +// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_19]], %[[VAL_15]] : index +// CHECK: %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_17]] : index +// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_22]], %[[VAL_23]] : i1 +// CHECK: scf.condition(%[[VAL_24]]) %[[VAL_19]], %[[VAL_20]], %[[VAL_21]] : index, index, f32 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index, %[[VAL_27:.*]]: f32): +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_34:.*]] = arith.andi %[[VAL_32]], %[[VAL_33]] : i1 +// CHECK: %[[VAL_35:.*]] = scf.if %[[VAL_34]] -> (f32) { +// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_38:.*]] = arith.mulf %[[VAL_36]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_39:.*]] = arith.addf %[[VAL_27]], %[[VAL_38]] : f32 +// CHECK: scf.yield %[[VAL_39]] : f32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_27]] : f32 +// CHECK: } +// CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_25]], %[[VAL_4]] : index +// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_40]], %[[VAL_41]], %[[VAL_25]] : index +// CHECK: %[[VAL_43:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_45:.*]] = arith.select %[[VAL_43]], %[[VAL_44]], %[[VAL_26]] : index +// CHECK: scf.yield %[[VAL_42]], %[[VAL_45]], %[[VAL_46:.*]] : index, index, f32 +// CHECK: } +// CHECK: memref.store %[[VAL_47:.*]]#2, %[[VAL_11]][] : memref +// CHECK: %[[VAL_48:.*]] = bufferization.to_tensor %[[VAL_11]] : memref +// CHECK: return %[[VAL_48]] : tensor +// CHECK: } +func @sparse_dot(%a: tensor<1024xf32, #SparseVector>, + %b: tensor<1024xf32, #SparseVector>, + %x: tensor) -> tensor { + %dot = linalg.dot ins(%a, %b: tensor<1024xf32, #SparseVector>, + tensor<1024xf32, #SparseVector>) + outs(%x: tensor) -> tensor + return %dot : tensor +}