diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -117,6 +117,9 @@ void updateReduc(Value val); Value getReduc() const { return redVal; } Value endReduc(); + void setValidLexInsert(Value val); + void clearValidLexInsert(); + Value getValidLexInsert() const { return redValidLexInsert; } void startCustomReduc(unsigned exp); bool isCustomReduc() const { return redCustom != -1u; } @@ -156,6 +159,11 @@ Value redVal; unsigned redExp; unsigned redCustom; + + // Bookkeeping for lex insertion during reductions. Holds the runtime boolean + // value of whether any reduction occurred. This is only set during a + // reduction and cleared once the reduction is finished. + Value redValidLexInsert; }; } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -23,7 +23,7 @@ latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(), redExp(-1u), - redCustom(-1u) {} + redCustom(-1u), redValidLexInsert() {} void CodegenEnv::startEmit(OpOperand *so, unsigned lv) { assert(sparseOut == nullptr && insChain == nullptr && @@ -49,16 +49,24 @@ function_ref(MutableArrayRef parameters)> callback) { SmallVector params; - if (isReduc()) + if (isReduc()) { params.push_back(redVal); + if (redValidLexInsert) + params.push_back(redValidLexInsert); + } else { + assert(!redValidLexInsert); + } if (isExpand()) params.push_back(expCount); if (insChain != nullptr) params.push_back(insChain); auto r = callback(params); // may update parameters unsigned i = 0; - if (isReduc()) + if (isReduc()) { updateReduc(params[i++]); + if (redValidLexInsert) + setValidLexInsert(params[i++]); + } if (isExpand()) updateExpandCount(params[i++]); if (insChain != nullptr) @@ -139,6 +147,16 @@ return val; } +void CodegenEnv::setValidLexInsert(Value val) { + assert(isReduc() && val); + redValidLexInsert = val; +} + +void CodegenEnv::clearValidLexInsert() { + assert(!isReduc()); + redValidLexInsert = Value(); +} + void CodegenEnv::startCustomReduc(unsigned exp) { assert(redCustom == -1u && exp != -1u); redCustom = exp; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -726,8 +726,31 @@ indices.push_back(env.emitter().getLoopIV(i)); } Value chain = env.getInsertionChain(); - env.updateInsertionChain( - builder.create(loc, rhs, chain, indices)); + if (!env.getValidLexInsert()) { + env.updateInsertionChain( + builder.create(loc, rhs, chain, indices)); + } else { + // Generates runtime check for a valid lex during reduction, + // to avoid inserting the identity value for empty reductions. + // if (validLexInsert) then + // insert(rhs) into chain + // return updated chain + // else + // return unmodified chain + scf::IfOp ifValidLexInsert = builder.create( + loc, chain.getType(), env.getValidLexInsert(), + /*else=*/true); + // True branch. + builder.setInsertionPointToStart(ifValidLexInsert.thenBlock()); + Value res = builder.create(loc, rhs, chain, indices); + builder.create(loc, res); + // False branch. + builder.setInsertionPointToStart(ifValidLexInsert.elseBlock()); + builder.create(loc, chain); + // Value assignment. + builder.setInsertionPointAfter(ifValidLexInsert); + env.updateInsertionChain(ifValidLexInsert.getResult(0)); + } return; } // Generates insertion code along expanded access pattern. @@ -932,13 +955,16 @@ return; OpOperand *lhs = op.getDpsInitOperand(0); if (lhs == &t) { - // Start or end a scalarized reduction + // Start or end a scalarized reduction. if (atStart) { Value load = env.isCustomReduc() ? env.getCustomRedId() : genTensorLoad(env, builder, exp); env.startReduc(exp, load); + if (env.hasSparseOutput()) + env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false)); } else { genTensorStore(env, builder, exp, env.endReduc()); + env.clearValidLexInsert(); } } else { // Start or end loop invariant hoisting of a tensor load. @@ -1106,6 +1132,10 @@ if (env.isReduc()) { yields.push_back(env.getReduc()); env.updateReduc(ifOp.getResult(y++)); + if (env.getValidLexInsert()) { + yields.push_back(env.getValidLexInsert()); + env.setValidLexInsert(ifOp.getResult(y++)); + } } if (env.isExpand()) { yields.push_back(env.getExpandCount()); @@ -1148,8 +1178,11 @@ } cond = cond ? builder.create(loc, cond, clause) : clause; } - if (env.isReduc()) + if (env.isReduc()) { types.push_back(env.getReduc().getType()); + if (env.getValidLexInsert()) + types.push_back(env.getValidLexInsert().getType()); + } if (env.isExpand()) types.push_back(builder.getIndexType()); if (env.getInsertionChain()) @@ -1167,6 +1200,9 @@ if (env.isReduc()) { operands.push_back(env.getReduc()); env.updateReduc(redInput); + if (env.getValidLexInsert()) + // Any overlapping indices during a reduction creates a valid lex insert. + operands.push_back(constantI1(builder, env.op().getLoc(), true)); } if (env.isExpand()) { operands.push_back(env.getExpandCount()); @@ -1392,6 +1428,10 @@ // End a while-loop. if (auto whileOp = dyn_cast(loop)) { finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp); + } else if (auto forOp = dyn_cast(loop)) { + // Any iteration of a reduction for-loop creates a valid lex insert. + if (env.isReduc() && env.getValidLexInsert()) + env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); } else { needsUniv = false; } diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir @@ -306,12 +306,14 @@ // CHECK-LABEL: func.func @mul_affine_sparse2d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<34x19xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK: %[[VAL_2:.*]] = arith.constant 32 : index -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[VAL_7:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[VAL_TRUE:.*]] = arith.constant true +// CHECK-DAG: %[[VAL_FALSE:.*]] = arith.constant false // CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref @@ -330,22 +332,27 @@ // CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]]] : memref // CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_4]] : index // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref -// CHECK: %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_4]] iter_args(%[[VAL_32:.*]] = %[[VAL_6]], %[[VAL_33:.*]] = %[[VAL_24]]) -> (f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { +// CHECK: %[[VAL_30:.*]]:3 = scf.for %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_4]] iter_args(%[[VAL_32:.*]] = %[[VAL_6]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_33:.*]] = %[[VAL_24]]) -> (f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { // CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_31]]] : memref // CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_25]], %[[VAL_7]] : index // CHECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_35]] : index -// CHECK: %[[VAL_37:.*]]:2 = scf.if %[[VAL_36]] -> (f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { +// CHECK: %[[VAL_37:.*]]:3 = scf.if %[[VAL_36]] -> (f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { // CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_31]]] : memref // CHECK: %[[VAL_39:.*]] = arith.mulf %[[VAL_26]], %[[VAL_38]] : f64 // CHECK: %[[VAL_40:.*]] = arith.addf %[[VAL_32]], %[[VAL_39]] : f64 -// CHECK: scf.yield %[[VAL_40]], %[[VAL_33]] : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_40]], %[[VAL_TRUE]], %[[VAL_33]] : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } else { -// CHECK: scf.yield %[[VAL_32]], %[[VAL_33]] : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_32]], %[[VAL_200]], %[[VAL_33]] : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } -// CHECK: scf.yield %[[VAL_41:.*]]#0, %[[VAL_41]]#1 : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_41:.*]]#0, %[[VAL_41]]#1, %[[VAL_41]]#2 : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: } +// CHECK: %[[VAL_201:.*]] = scf.if %[[VAL_30]]#1 -> (tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { +// CHECK: %[[VAL_42:.*]] = sparse_tensor.insert %[[VAL_30]]#0 into %[[VAL_30]]#2{{\[}}%[[VAL_16]], %[[VAL_25]]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_42]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_30]]#2 : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } -// CHECK: %[[VAL_42:.*]] = sparse_tensor.insert %[[VAL_43:.*]]#0 into %[[VAL_43]]#1{{\[}}%[[VAL_16]], %[[VAL_25]]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: scf.yield %[[VAL_42]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_201]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } // CHECK: scf.yield %[[VAL_44:.*]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -153,6 +153,8 @@ // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[VAL_FALSE:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_TRUE:.*]] = arith.constant true // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> // CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> // CHECK: %[[VAL_7:.*]] = bufferization.alloc_tensor(%[[VAL_5]], %[[VAL_6]]) : tensor> @@ -216,13 +218,13 @@ // CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_58]]] : memref // CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index // CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_72]]] : memref -// CHECK: %[[VAL_74:.*]]:4 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, tensor>) -> (index, index, i32, tensor>) { +// CHECK: %[[VAL_74:.*]]:5 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, i1, tensor>) -> (index, index, i32, i1, tensor>) { // CHECK: %[[VAL_79:.*]] = arith.cmpi ult, %[[VAL_75]], %[[VAL_70]] : index // CHECK: %[[VAL_80:.*]] = arith.cmpi ult, %[[VAL_76]], %[[VAL_73]] : index // CHECK: %[[VAL_81:.*]] = arith.andi %[[VAL_79]], %[[VAL_80]] : i1 -// CHECK: scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_78]] : index, index, i32, tensor> +// CHECK: scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_200]], %[[VAL_78]] : index, index, i32, i1, tensor> // CHECK: } do { -// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_85:.*]]: tensor>): +// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1, %[[VAL_85:.*]]: tensor>): // CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_82]]] : memref // CHECK: %[[VAL_87:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_83]]] : memref // CHECK: %[[VAL_88:.*]] = arith.cmpi ult, %[[VAL_87]], %[[VAL_86]] : index @@ -230,14 +232,14 @@ // CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index // CHECK: %[[VAL_91:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index // CHECK: %[[VAL_92:.*]] = arith.andi %[[VAL_90]], %[[VAL_91]] : i1 -// CHECK: %[[VAL_93:.*]]:2 = scf.if %[[VAL_92]] -> (i32, tensor>) { +// CHECK: %[[VAL_93:.*]]:3 = scf.if %[[VAL_92]] -> (i32, i1, tensor>) { // CHECK: %[[VAL_94:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_82]]] : memref // CHECK: %[[VAL_95:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_83]]] : memref // CHECK: %[[VAL_96:.*]] = arith.muli %[[VAL_94]], %[[VAL_95]] : i32 // CHECK: %[[VAL_97:.*]] = arith.addi %[[VAL_84]], %[[VAL_96]] : i32 -// CHECK: scf.yield %[[VAL_97]], %[[VAL_85]] : i32, tensor> +// CHECK: scf.yield %[[VAL_97]], %[[VAL_TRUE]], %[[VAL_85]] : i32, i1, tensor> // CHECK: } else { -// CHECK: scf.yield %[[VAL_84]], %[[VAL_85]] : i32, tensor> +// CHECK: scf.yield %[[VAL_84]], %[[VAL_201]], %[[VAL_85]] : i32, i1, tensor> // CHECK: } // CHECK: %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index // CHECK: %[[VAL_99:.*]] = arith.addi %[[VAL_82]], %[[VAL_3]] : index @@ -245,10 +247,15 @@ // CHECK: %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index // CHECK: %[[VAL_102:.*]] = arith.addi %[[VAL_83]], %[[VAL_3]] : index // CHECK: %[[VAL_103:.*]] = arith.select %[[VAL_101]], %[[VAL_102]], %[[VAL_83]] : index -// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1 : index, index, i32, tensor> +// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1, %[[VAL_104]]#2 : index, index, i32, i1, tensor> // CHECK: } -// CHECK: %[[VAL_105:.*]] = sparse_tensor.insert %[[VAL_106:.*]]#2 into %[[VAL_106]]#3{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor> -// CHECK: scf.yield %[[VAL_105]] : tensor> +// CHECK: %[[VAL_202:.*]] = scf.if %[[VAL_74]]#3 -> (tensor>) { +// CHECK: %[[VAL_105:.*]] = sparse_tensor.insert %[[VAL_74]]#2 into %[[VAL_74]]#4{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor> +// CHECK: scf.yield %[[VAL_105]] : tensor> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_74]]#4 : tensor> +// CHECK: } +// CHECK: scf.yield %[[VAL_202]] : tensor> // CHECK: } else { // CHECK: scf.yield %[[VAL_59]] : tensor> // CHECK: } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir @@ -221,9 +221,8 @@ // CHECK-NEXT: ( 6, 5, 12, 2, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( ( 7, 0, 0, 7, 0 ), ( 9, 0, 0, 0, 0 ), ( 8, 7, 0, 7, 0 ), ( 12, 11, 0, 11, 0 ), ( 0, 0, 0, 0, 0 ) ) - // TODO: Update once identity values are no longer inserted for non-overlapping dot product - // CHECK-NEXT: ( 7, inf, inf, 7, 9, inf, inf, inf, 8, 7, inf, 7, 12, 11, inf, 11 ) - // CHECK-NEXT: ( ( 7, inf, inf, 7, 0 ), ( 9, inf, inf, inf, 0 ), ( 8, 7, inf, 7, 0 ), ( 12, 11, inf, 11, 0 ), ( 0, 0, 0, 0, 0 ) ) + // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( ( 7, 0, 0, 7, 0 ), ( 9, 0, 0, 0, 0 ), ( 8, 7, 0, 7, 0 ), ( 12, 11, 0, 11, 0 ), ( 0, 0, 0, 0, 0 ) ) // call @dump_mat(%sm1) : (tensor) -> () call @dump_mat(%sm2r) : (tensor) -> ()