diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -42,10 +42,13 @@ } // Helper to detect a sparse tensor type operand. -static bool isSparseTensor(OpOperand *op) { - auto enc = getSparseTensorEncoding(op->get().getType()); - return enc && llvm::is_contained(enc.getLvlTypes(), DimLevelType::Compressed); +static bool isSparseTensor(Value v) { + auto enc = getSparseTensorEncoding(v.getType()); + return enc && !llvm::all_of(enc.getLvlTypes(), [](auto dlt) { + return dlt == DimLevelType::Dense; + }); } +static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } // Helper method to find zero/uninitialized allocation. static bool isAlloc(OpOperand *op, bool isZero) { @@ -387,6 +390,137 @@ } }; +/// Rewrites a sequence of operations for sparse tensor selections in to +/// semi-ring operations such that they can be compiled correctly by the sparse +/// compiler. E.g., transforming the following sequence +/// +/// %sel = arith.select %cond, %sp1, %sp2 +/// +/// to +/// +/// %sel = binary %sp1, %sp2: +/// both (%l, %r) {yield select %cond, %l, %r} +/// left (%l) {yield select %cond, %l, 0} +/// right (%r) {yield select %cond, 0, %r} +/// +/// TODO: We require that the tensor used for extracting conditions to be dense +/// to sparsify the code. To support a sparse condition tensor, we need a +/// tri-nary operation. +struct GenSemiRingSelect : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + // Rejects non sparse kernels. + if (!op.hasTensorSemantics() || !hasAnySparseOperand(op)) + return failure(); + + Location loc = op.getLoc(); + SmallVector> semiRings; + for (Operation &inst : *op.getBody()) { + // Matches pattern. + auto matched = isRewritablePattern(op, &inst); + if (!matched.has_value()) + continue; + + rewriter.setInsertionPoint(&inst); + auto [c, t, f] = matched.value(); + assert(t.getType() == f.getType()); + auto selTp = t.getType(); + auto c0 = constantZero(rewriter, loc, selTp); + auto binOp = rewriter.create(loc, selTp, t, f); + // Initializes all the blocks. + rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp}, + {t.getLoc(), f.getLoc()}); + rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc()); + rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc()); + + for (auto *r : binOp.getRegions()) { + Block *b = &r->front(); + rewriter.setInsertionPointToStart(b); + + IRMapping irMap; + // Clones the cmp operations into the region to make the binary op + // admissible. + Value newC = c; + if (auto *def = c.getDefiningOp()) + newC = rewriter.clone(*def, irMap)->getResult(0); + + irMap.map(c, newC); + if (r == &binOp.getLeftRegion()) { + irMap.map(t, b->getArgument(0)); + irMap.map(f, c0); + } else if (r == &binOp.getRightRegion()) { + irMap.map(t, c0); + irMap.map(f, b->getArgument(0)); + } else { + irMap.map(t, b->getArgument(0)); + irMap.map(f, b->getArgument(1)); + } + auto y = rewriter.clone(inst, irMap)->getResult(0); + rewriter.create(loc, y); + } + + // We successfully rewrited a operation. We can not do replacement here + // becuase it invalidate the iterator for the current loop to traverse + // the instructions. + semiRings.emplace_back(&inst, binOp); + } + + // Finalizes the replacement. + for (auto [sel, semi] : semiRings) + rewriter.replaceOp(sel, semi->getResults()); + + return success(!semiRings.empty()); + } + +private: + static std::optional> + isRewritablePattern(GenericOp op, Operation *v) { + auto sel = dyn_cast(v); + if (!sel) + return std::nullopt; + + auto tVal = sel.getTrueValue().dyn_cast(); + auto fVal = sel.getFalseValue().dyn_cast(); + // TODO: For simplicity, we only handle cases where both true/false value + // are directly loaded the input tensor. We can probably admit more cases + // in theory. + if (!tVal || !fVal) + return std::nullopt; + + // Helper lambda to determine whether the value is loaded from a dense input + // or is a loop invariant. + auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool { + if (auto bArg = v.dyn_cast(); + bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber()))) + return true; + // If the value is defined outside the loop, it is a loop invariant. + return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody(); + }; + + // If the condition value is load directly from a dense tensor or + // loop-invariants, we can sparsify the kernel. + auto cond = sel.getCondition(); + if (isValFromDenseInputOrInvariant(cond)) + return std::make_tuple(cond, tVal, fVal); + + Value cmpL, cmpR; + if (matchPattern(cond, m_Op(matchers::m_Any(&cmpL), + matchers::m_Any(&cmpR))) || + matchPattern(cond, m_Op(matchers::m_Any(&cmpL), + matchers::m_Any(&cmpR)))) { + // TODO: we can do it recursively to check whether all the leaf values are + // loaded from dense tensors or are loop invariants. + if (isValFromDenseInputOrInvariant(cmpL) || + isValFromDenseInputOrInvariant(cmpR)) + return std::make_tuple(cond, tVal, fVal); + } + + return std::nullopt; + }; +}; + /// Rewrites a sparse reduction that would not sparsify directly since /// doing so would only iterate over the stored elements, ignoring the /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max @@ -1348,7 +1482,7 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext()); } void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1211,7 +1211,8 @@ if (ee && (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || kind == TensorExp::Kind::kBinaryBranch || - kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect)) { + kind == TensorExp::Kind::kReduce || + kind == TensorExp::Kind::kSelect)) { OpBuilder::InsertionGuard guard(rewriter); ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx); } diff --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir @@ -8,11 +8,25 @@ lvlTypes = [ "compressed-nu", "singleton" ] }> +#DCSR = #sparse_tensor.encoding<{ + lvlTypes = ["compressed", "compressed"] +}> + #Slice = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ], dimSlices = [ (?, 1, 1), (?, 3, 1) ] }> +#sel_trait = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // C (in) + affine_map<(i,j) -> (i,j)>, // L (in) + affine_map<(i,j) -> (i,j)>, // R (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"] +} + // CHECK-LABEL: func @sparse_nop_cast( // CHECK-SAME: %[[A:.*]]: tensor>) // CHECK: return %[[A]] : tensor> @@ -43,3 +57,46 @@ %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO> return %0 : tensor<1x3xi64, #SortedCOO> } + +// CHECK-LABEL: func.func @sparse_select( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xi1>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK-NEXT: %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]] +// CHECK-NEXT: ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f64, %[[VAL_8:.*]]: f64, %[[VAL_9:.*]]: f64): +// CHECK-NEXT: %[[VAL_10:.*]] = sparse_tensor.binary %[[VAL_7]], %[[VAL_8]] : f64, f64 to f64 +// CHECK-NEXT: overlap = { +// CHECK-NEXT: ^bb0(%[[VAL_11:.*]]: f64, %[[VAL_12:.*]]: f64): +// CHECK-NEXT: %[[VAL_13:.*]] = arith.select %[[VAL_6]], %[[VAL_11]], %[[VAL_12]] : f64 +// CHECK-NEXT: sparse_tensor.yield %[[VAL_13]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: left = { +// CHECK-NEXT: ^bb0(%[[VAL_14:.*]]: f64): +// CHECK-NEXT: %[[VAL_15:.*]] = arith.select %[[VAL_6]], %[[VAL_14]], %[[VAL_3]] : f64 +// CHECK-NEXT: sparse_tensor.yield %[[VAL_15]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: right = { +// CHECK-NEXT: ^bb0(%[[VAL_16:.*]]: f64): +// CHECK-NEXT: %[[VAL_17:.*]] = arith.select %[[VAL_6]], %[[VAL_3]], %[[VAL_16]] : f64 +// CHECK-NEXT: sparse_tensor.yield %[[VAL_17]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: linalg.yield %[[VAL_10]] : f64 +// CHECK-NEXT: } -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK-NEXT: return %[[VAL_18:.*]] : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK-NEXT: } +func.func @sparse_select(%cond: tensor<4x4xi1>, + %arga: tensor<4x4xf64, #DCSR>, + %argb: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> { + %xv = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR> + %0 = linalg.generic #sel_trait + ins(%cond, %arga, %argb: tensor<4x4xi1>, tensor<4x4xf64, #DCSR>, tensor<4x4xf64, #DCSR>) + outs(%xv: tensor<4x4xf64, #DCSR>) { + ^bb(%c: i1, %a: f64, %b: f64, %x: f64): + %1 = arith.select %c, %a, %b : f64 + linalg.yield %1 : f64 + } -> tensor<4x4xf64, #DCSR> + return %0 : tensor<4x4xf64, #DCSR> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir @@ -0,0 +1,97 @@ +// DEFINE: %{option} = enable-runtime-library=true +// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option} +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_c_runner_utils | \ +// DEFINE: FileCheck %s +// +// RUN: %{compile} | %{run} +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true" +// RUN: %{compile} | %{run} + +// Do the same run, but now with direct IR generation and, if available, VLA +// vectorization. +// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=4 enable-arm-sve=%ENABLE_VLA" +// REDEFINE: %{run} = %lli_host_or_aarch64_cmd \ +// REDEFINE: --entry-function=entry_lli \ +// REDEFINE: --extra-module=%S/Inputs/main_for_lli.ll \ +// REDEFINE: %VLA_ARCH_ATTR_OPTIONS \ +// REDEFINE: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// REDEFINE: FileCheck %s +// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run} + +#DCSR = #sparse_tensor.encoding<{ + lvlTypes = ["compressed", "compressed"] +}> + +#sel_trait = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // C (in) + affine_map<(i,j) -> (i,j)>, // L (in) + affine_map<(i,j) -> (i,j)>, // R (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"] +} + +module { + func.func @sparse_select(%cond: tensor<5x5xi1>, + %arga: tensor<5x5xf64, #DCSR>, + %argb: tensor<5x5xf64, #DCSR>) -> tensor<5x5xf64, #DCSR> { + %xv = bufferization.alloc_tensor() : tensor<5x5xf64, #DCSR> + %0 = linalg.generic #sel_trait + ins(%cond, %arga, %argb: tensor<5x5xi1>, tensor<5x5xf64, #DCSR>, tensor<5x5xf64, #DCSR>) + outs(%xv: tensor<5x5xf64, #DCSR>) { + ^bb(%c: i1, %a: f64, %b: f64, %x: f64): + %1 = arith.select %c, %a, %b : f64 + linalg.yield %1 : f64 + } -> tensor<5x5xf64, #DCSR> + return %0 : tensor<5x5xf64, #DCSR> + } + + // Driver method to call and verify vector kernels. + func.func @entry() { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f64 + + %cond = arith.constant sparse< + [ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4] ], + [ 1, 1, 1, 1, 1 ] + > : tensor<5x5xi1> + %lhs = arith.constant sparse< + [ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4] ], + [ 0.1, 1.1, 2.1, 3.1, 4.1 ] + > : tensor<5x5xf64> + %rhs = arith.constant sparse< + [ [0, 1], [1, 2], [2, 3], [3, 4], [4, 4]], + [ 1.1, 2.2, 3.3, 4.4 , 5.5 ] + > : tensor<5x5xf64> + + %sl = sparse_tensor.convert %lhs : tensor<5x5xf64> to tensor<5x5xf64, #DCSR> + %sr = sparse_tensor.convert %rhs : tensor<5x5xf64> to tensor<5x5xf64, #DCSR> + + // Call sparse matrix kernels. + %1 = call @sparse_select(%cond, %sl, %sr) : (tensor<5x5xi1>, + tensor<5x5xf64, #DCSR>, + tensor<5x5xf64, #DCSR>) -> tensor<5x5xf64, #DCSR> + + + // CHECK: ( ( 0.1, 1.1, 0, 0, 0 ), + // CHECK-SAME: ( 0, 1.1, 2.2, 0, 0 ), + // CHECK-SAME: ( 0, 0, 2.1, 3.3, 0 ), + // CHECK-SAME: ( 0, 0, 0, 3.1, 4.4 ), + // CHECK-SAME: ( 0, 0, 0, 0, 4.1 ) ) + %r = sparse_tensor.convert %1 : tensor<5x5xf64, #DCSR> to tensor<5x5xf64> + %v2 = vector.transfer_read %r[%c0, %c0], %f0 : tensor<5x5xf64>, vector<5x5xf64> + vector.print %v2 : vector<5x5xf64> + + // Release the resources. + bufferization.dealloc_tensor %sl: tensor<5x5xf64, #DCSR> + bufferization.dealloc_tensor %sr: tensor<5x5xf64, #DCSR> + bufferization.dealloc_tensor %1: tensor<5x5xf64, #DCSR> + + return + } +}