diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -42,10 +42,13 @@ } // Helper to detect a sparse tensor type operand. -static bool isSparseTensor(OpOperand *op) { - auto enc = getSparseTensorEncoding(op->get().getType()); - return enc && llvm::is_contained(enc.getLvlTypes(), DimLevelType::Compressed); +static bool isSparseTensor(Value v) { + auto enc = getSparseTensorEncoding(v.getType()); + return enc && !llvm::all_of(enc.getLvlTypes(), [](auto dlt) { + return dlt == DimLevelType::Dense; + }); } +static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } // Helper method to find zero/uninitialized allocation. static bool isAlloc(OpOperand *op, bool isZero) { @@ -387,6 +390,139 @@ } }; +/// Rewrites a sequence of operations for sparse tensor selections in to +/// semi-ring operations such that they can be compiled correctly by the sparse +/// compiler. E.g., transforming the following sequence +/// +/// %sel = arith.select %cond, %sp1, %sp2 +/// +/// to +/// +/// %sel = binary %sp1, %sp2: +/// both {yield select %cond, %l, %r} +/// left {yield select %cond, %l, 0} +/// right {yield select %cond, 0, %r} +/// +/// TODO: We require that the tensor used for extracting condtions to be dense +/// to sparsify the code. To support a sparse condition tensor, we need a +/// tri-nary operation. +struct GenSemiRingSelect : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + // Rejects non sparse kernels. + if (!op.hasTensorSemantics() || !hasAnySparseOperand(op)) + return failure(); + + Location loc = op.getLoc(); + SmallVector> semiRings; + for (Operation &inst : *op.getBody()) { + // Matches pattern + auto matched = isRewritablePattern(op, &inst); + if (!matched.has_value()) + continue; + rewriter.setInsertionPoint(&inst); + + auto [c, t, f] = matched.value(); + auto c0 = constantZero(rewriter, loc, t.getType()); + auto binOp = + rewriter.create(loc, t.getType(), t, f); + // Initialize all the blocks. + rewriter.createBlock(&binOp.getOverlapRegion(), {}, + {t.getType(), f.getType()}, + {t.getLoc(), f.getLoc()}); + rewriter.createBlock(&binOp.getRightRegion(), {}, f.getType(), + f.getLoc()); + rewriter.createBlock(&binOp.getLeftRegion(), {}, t.getType(), t.getLoc()); + + for (auto *r : binOp.getRegions()) { + Block *b = &r->front(); + rewriter.setInsertionPointToStart(b); + + IRMapping irMap; + // Clone the cmp operations into the region to make the binary op + // admissible. + Value newC = c; + if (auto *def = c.getDefiningOp()) + newC = rewriter.clone(*def, irMap)->getResult(0); + + irMap.map(c, newC); + if (r == &binOp.getLeftRegion()) { + irMap.map(t, b->getArgument(0)); + irMap.map(f, c0); + } else if (r == &binOp.getRightRegion()) { + irMap.map(t, c0); + irMap.map(f, b->getArgument(0)); + } else { + irMap.map(t, b->getArgument(0)); + irMap.map(f, b->getArgument(1)); + } + auto y = rewriter.clone(inst, irMap)->getResult(0); + rewriter.create(loc, y); + } + + // We successfully rewrited a operation. We can not do replacement here + // becuase it invalidate the iterators for the current loop to traverse + // the instructions. + semiRings.emplace_back(&inst, binOp); + } + + // Finalizes the replacement. + for (auto [sel, semi] : semiRings) + rewriter.replaceOp(sel, semi->getResults()); + + return success(!semiRings.empty()); + } + +private: + static std::optional> + isRewritablePattern(GenericOp op, Operation *v) { + auto sel = dyn_cast(v); + if (!sel) + return std::nullopt; + + auto tVal = sel.getTrueValue().dyn_cast(); + auto fVal = sel.getFalseValue().dyn_cast(); + // TODO: For simplicity, we only handle cases where both true/false value + // are directly loaded the input tensor. We can probably admit more cases + // in theory. + if (!tVal || !fVal) + return std::nullopt; + + // Helper lambda to determine whether the value is loaded from a dense input + // or is a loop invariant. + auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool { + // This is defined outside the loop, thus a loop invariant. + if (auto bArg = v.dyn_cast(); + bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber()))) + return true; + + return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody(); + }; + + // If the condition value is load directly from a dense tensor or + // loop-invariants, we can sparsify the kernel. + auto cond = sel.getCondition(); + if (isValFromDenseInputOrInvariant(cond)) + return std::make_tuple(cond, tVal, fVal); + + Value cmpL, cmpR; + if (matchPattern(cond, m_Op(matchers::m_Any(&cmpL), + matchers::m_Any(&cmpR))) || + matchPattern(cond, m_Op(matchers::m_Any(&cmpL), + matchers::m_Any(&cmpR)))) { + // TODO: we can do it recursively to check whether all the leaf values are + // loaded from dense tensors or are loop invariants. + if (isValFromDenseInputOrInvariant(cmpL) || + isValFromDenseInputOrInvariant(cmpR)) + return std::make_tuple(cond, tVal, fVal); + } + + return std::nullopt; + }; +}; + /// Rewrites a sparse reduction that would not sparsify directly since /// doing so would only iterate over the stored elements, ignoring the /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max @@ -1348,7 +1484,7 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext()); } void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns, diff --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir @@ -8,11 +8,25 @@ lvlTypes = [ "compressed-nu", "singleton" ] }> +#DCSR = #sparse_tensor.encoding<{ + lvlTypes = ["compressed", "compressed"] +}> + #Slice = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ], dimSlices = [ (?, 1, 1), (?, 3, 1) ] }> +#sel_trait = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // C (in) + affine_map<(i,j) -> (i,j)>, // L (in) + affine_map<(i,j) -> (i,j)>, // R (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"] +} + // CHECK-LABEL: func @sparse_nop_cast( // CHECK-SAME: %[[A:.*]]: tensor>) // CHECK: return %[[A]] : tensor> @@ -43,3 +57,46 @@ %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO> return %0 : tensor<1x3xi64, #SortedCOO> } + +// CHECK-LABEL: func.func @sparse_select( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK-NEXT: %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]] +// CHECK-NEXT: ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f64, %[[VAL_8:.*]]: f64, %[[VAL_9:.*]]: f64): +// CHECK-NEXT: %[[VAL_10:.*]] = sparse_tensor.binary %[[VAL_7]], %[[VAL_8]] : f64, f64 to f64 +// CHECK-NEXT: overlap = { +// CHECK-NEXT: ^bb0(%[[VAL_11:.*]]: f64, %[[VAL_12:.*]]: f64): +// CHECK-NEXT: %[[VAL_13:.*]] = arith.select %[[VAL_6]], %[[VAL_11]], %[[VAL_12]] : f64 +// CHECK-NEXT: sparse_tensor.yield %[[VAL_13]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: left = { +// CHECK-NEXT: ^bb0(%[[VAL_14:.*]]: f64): +// CHECK-NEXT: %[[VAL_15:.*]] = arith.select %[[VAL_6]], %[[VAL_14]], %[[VAL_3]] : f64 +// CHECK-NEXT: sparse_tensor.yield %[[VAL_15]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: right = { +// CHECK-NEXT: ^bb0(%[[VAL_16:.*]]: f64): +// CHECK-NEXT: %[[VAL_17:.*]] = arith.select %[[VAL_6]], %[[VAL_3]], %[[VAL_16]] : f64 +// CHECK-NEXT: sparse_tensor.yield %[[VAL_17]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: linalg.yield %[[VAL_10]] : f64 +// CHECK-NEXT: } -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK-NEXT: return %[[VAL_18:.*]] : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK-NEXT: } +func.func @sparse_select(%cond: tensor<4x4xi1>, + %arga: tensor<4x4xf64, #DCSR>, + %argb: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> { + %xv = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR> + %0 = linalg.generic #sel_trait + ins(%cond, %arga, %argb: tensor<4x4xi1>, tensor<4x4xf64, #DCSR>, tensor<4x4xf64, #DCSR>) + outs(%xv: tensor<4x4xf64, #DCSR>) { + ^bb(%c: i1, %a: f64, %b: f64, %x: f64): + %1 = arith.select %c, %a, %b : f64 + linalg.yield %1 : f64 + } -> tensor<4x4xf64, #DCSR> + return %0 : tensor<4x4xf64, #DCSR> +}