diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1411,7 +1411,8 @@ DimLevelType dlt, bool /*unused*/) { assert(ldx == env.merger().loop(b)); Value clause; - if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt)) { + if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || + isCompressedWithHiDLT(dlt)) { assert(lvl.has_value()); const Value crd = env.emitter().getCoords()[tid][*lvl]; const Value lvar = env.getLoopVar(ldx); @@ -1487,7 +1488,8 @@ assert(env.merger().loop(b) == idx); if (isDenseDLT(dlt) || isUndefDLT(dlt)) needsUniv = true; - if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isIdxReduc) { + if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || + isCompressedWithHiDLT(dlt) || isIdxReduc) { // Only when this is a index reduction loop, can the dlt be undefined. assert(!isUndefDLT(dlt) || isIdxReduc); // sparse/singleton levels, or a dense/sparse index reduction loop. diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir @@ -599,6 +599,110 @@ return %0 : tensor<32x16xf32> } +#BatchedVector = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed-hi" ], +}> +// CHECK-LABEL: func.func @sub_ss_batched( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_5]]) -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_4]] : index +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_4]] : index +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_23:.*]]:3 = scf.while (%[[VAL_24:.*]] = %[[VAL_16]], %[[VAL_25:.*]] = %[[VAL_20]], %[[VAL_26:.*]] = %[[VAL_14]]) +// CHECK: %[[VAL_27:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_18]] : index +// CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_22]] : index +// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1 +// CHECK: scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>): +// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_30]]] : memref +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_33]] : index +// CHECK: %[[VAL_36:.*]] = arith.select %[[VAL_35]], %[[VAL_34]], %[[VAL_33]] : index +// CHECK: %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index +// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index +// CHECK: %[[VAL_39:.*]] = arith.andi %[[VAL_37]], %[[VAL_38]] : i1 +// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref +// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_43:.*]] = arith.subf %[[VAL_41]], %[[VAL_42]] : f64 +// CHECK: %[[VAL_44:.*]] = sparse_tensor.insert %[[VAL_43]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_44]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } else { +// CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index +// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref +// CHECK: %[[VAL_48:.*]] = sparse_tensor.insert %[[VAL_47]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_48]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } else { +// CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index +// CHECK: %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_52:.*]] = arith.negf %[[VAL_51]] : f64 +// CHECK: %[[VAL_53:.*]] = sparse_tensor.insert %[[VAL_52]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_53]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_32]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +// CHECK: scf.yield %[[VAL_54:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +// CHECK: scf.yield %[[VAL_55:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +// CHECK: %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index +// CHECK: %[[VAL_57:.*]] = arith.addi %[[VAL_30]], %[[VAL_4]] : index +// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_57]], %[[VAL_30]] : index +// CHECK: %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index +// CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_31]], %[[VAL_4]] : index +// CHECK: %[[VAL_61:.*]] = arith.select %[[VAL_59]], %[[VAL_60]], %[[VAL_31]] : index +// CHECK: scf.yield %[[VAL_58]], %[[VAL_61]], %[[VAL_62:.*]] : index, index, tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } attributes {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_3]] to %[[VAL_18]] step %[[VAL_4]] iter_args(%[[VAL_65:.*]] = %[[VAL_66:.*]]#2) +// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_64]]] : memref +// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_64]]] : memref +// CHECK: %[[VAL_69:.*]] = sparse_tensor.insert %[[VAL_68]] into %[[VAL_65]]{{\[}}%[[VAL_13]], %[[VAL_67]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_69]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_70:.*]] = scf.for %[[VAL_71:.*]] = %[[VAL_3]] to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_72:.*]] = %[[VAL_73:.*]]) +// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_71]]] : memref +// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_71]]] : memref +// CHECK: %[[VAL_76:.*]] = arith.negf %[[VAL_75]] : f64 +// CHECK: %[[VAL_77:.*]] = sparse_tensor.insert %[[VAL_76]] into %[[VAL_72]]{{\[}}%[[VAL_13]], %[[VAL_74]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_77]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: scf.yield %[[VAL_78:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_79:.*]] = sparse_tensor.load %[[VAL_80:.*]] hasInserts : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: return %[[VAL_79]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +func.func @sub_ss_batched(%0: tensor<2x3xf64, #BatchedVector>, %1: tensor<2x3xf64, #BatchedVector>) + -> tensor<2x3xf64, #BatchedVector> { + %2 = bufferization.alloc_tensor() : tensor<2x3xf64, #BatchedVector> + %3 = linalg.generic #trait2 + ins(%0, %1 : tensor<2x3xf64, #BatchedVector>, tensor<2x3xf64, #BatchedVector>) + outs(%2 : tensor<2x3xf64, #BatchedVector>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %7 = arith.subf %in, %in_0 : f64 + linalg.yield %7 : f64 + } -> tensor<2x3xf64, #BatchedVector> + return %3 : tensor<2x3xf64, #BatchedVector> +} + // CHECK-LABEL: func @mul_ss_ss( // CHECK-SAME: %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, // CHECK-SAME: %[[VAL_1:.*1]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>,