diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -74,6 +74,7 @@ kCIm, // complex.im kCRe, // complex.re kBitCast, + kSelect, // custom selection criteria kBinaryBranch, // semiring unary branch created from a binary op kUnary, // semiring unary op // Binary operations. @@ -129,8 +130,8 @@ /// this field may be used to cache "hoisted" loop invariant tensor loads. Value val; - /// Code blocks used by semirings. For the case of kUnary, kBinary, and - /// kReduce, this holds the original operation with all regions. For + /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce, + /// and kSelect, this holds the original operation with all regions. For /// kBinaryBranch, this holds the YieldOp for the left or right half /// to be merged into a nested scf loop. Operation *op; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -883,7 +883,18 @@ // to indicate missing output. assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary); } else { - genInsertionStore(codegen, builder, op, t, rhs); + if (merger.exp(exp).kind == kSelect) { + scf::IfOp ifOp = builder.create(loc, rhs); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Existing value was saved previously. + Value v0 = merger.exp(exp).val; + genInsertionStore(codegen, builder, op, t, v0); + // Drop saved value now that it has been used. + merger.exp(exp).val = Value(); + builder.setInsertionPointAfter(ifOp); + } else { + genInsertionStore(codegen, builder, op, t, rhs); + } } return; } @@ -1041,7 +1052,8 @@ if (ee && (merger.exp(exp).kind == Kind::kUnary || merger.exp(exp).kind == Kind::kBinary || merger.exp(exp).kind == Kind::kBinaryBranch || - merger.exp(exp).kind == Kind::kReduce)) + merger.exp(exp).kind == Kind::kReduce || + merger.exp(exp).kind == Kind::kSelect)) ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); if (merger.exp(exp).kind == Kind::kReduce) { diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -77,6 +77,7 @@ children.e0 = x; children.e1 = y; break; + case kSelect: case kBinaryBranch: assert(x != -1u && y == -1u && !v && o); children.e0 = x; @@ -265,9 +266,8 @@ BitVector simple = latPoints[p0].bits; bool reset = isSingleton && hasAnySparse(simple); for (unsigned b = 0, be = simple.size(); b < be; b++) { - if (simple[b] && - (!isDimLevelType(b, DimLvlType::kCompressed) && - !isDimLevelType(b, DimLvlType::kSingleton))) { + if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) && + !isDimLevelType(b, DimLvlType::kSingleton))) { if (reset) simple.reset(b); reset = true; @@ -336,6 +336,7 @@ case kCRe: case kBitCast: return isSingleCondition(t, tensorExps[e].children.e0); + case kSelect: case kBinaryBranch: case kUnary: return false; @@ -445,6 +446,8 @@ return "complex.re"; case kBitCast: return "cast"; + case kSelect: + return "select"; case kBinaryBranch: return "binary_branch"; case kUnary: @@ -535,6 +538,7 @@ case kCIm: case kCRe: case kBitCast: + case kSelect: case kBinaryBranch: case kUnary: llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; @@ -683,6 +687,7 @@ // | 0 |-y | return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), tensorExps[e].val); + case kSelect: case kBinaryBranch: // The left or right half of a binary operation which has already // been split into separate operations for each region. @@ -978,6 +983,10 @@ isAdmissableBranch(unop, unop.getAbsentRegion())) return addExp(kUnary, e, Value(), def); } + if (auto selop = dyn_cast(def)) { + if (isAdmissableBranch(selop, selop.getRegion())) + return addExp(kSelect, e, Value(), def); + } } } // Construct binary operations if subexpressions can be built. @@ -1223,6 +1232,10 @@ return rewriter.create(loc, v0, v1); case kShlI: return rewriter.create(loc, v0, v1); + case kSelect: + tensorExps[e].val = v0; // save value to be used if select criteria succeeds + return insertYieldOp(rewriter, loc, + cast(tensorExps[e].op).getRegion(), {v0}); case kBinaryBranch: // semi-ring ops with custom logic. return insertYieldOp(rewriter, loc, *tensorExps[e].op->getBlock()->getParent(), {v0}); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir @@ -0,0 +1,148 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +#CSC = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +// +// Traits for tensor operations. +// +#trait_vec_select = { + indexing_maps = [ + affine_map<(i) -> (i)>, // A + affine_map<(i) -> (i)> // C (out) + ], + iterator_types = ["parallel"] +} + +#trait_mat_select = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"] +} + +module { + func.func @vecSelect(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cf1 = arith.constant 1.0 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %xv = bufferization.alloc_tensor(%d0): tensor + %0 = linalg.generic #trait_vec_select + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64): + %1 = sparse_tensor.select %a : f64 { + ^bb0(%x: f64): + %keep = arith.cmpf "oge", %x, %cf1 : f64 + sparse_tensor.yield %keep : i1 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + func.func @matUpperTriangle(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %arga, %c1 : tensor + %xv = bufferization.alloc_tensor(%d0, %d1): tensor + %0 = linalg.generic #trait_mat_select + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64): + %row = linalg.index 0 : index + %col = linalg.index 1 : index + %1 = sparse_tensor.select %a : f64 { + ^bb0(%x: f64): + %keep = arith.cmpi "ugt", %col, %row : index + sparse_tensor.yield %keep : i1 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func.func @dump_vec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<8xf64> + vector.print %1 : vector<8xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dv[%c0], %d0: tensor, vector<16xf64> + vector.print %2 : vector<16xf64> + return + } + + // Dump a sparse matrix. + func.func @dump_mat(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dm[%c0, %c0], %d0: tensor, vector<5x5xf64> + vector.print %2 : vector<5x5xf64> + return + } + + // Driver method to call and verify vector kernels. + func.func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse matrices. + %v1 = arith.constant sparse< + [ [1], [3], [5], [7], [9] ], + [ 1.0, 2.0, -4.0, 0.0, 5.0 ] + > : tensor<10xf64> + %m1 = arith.constant sparse< + [ [0, 3], [1, 4], [2, 1], [2, 3], [3, 3], [3, 4], [4, 2] ], + [ 1., 2., 3., 4., 5., 6., 7.] + > : tensor<5x5xf64> + %sv1 = sparse_tensor.convert %v1 : tensor<10xf64> to tensor + %sm1 = sparse_tensor.convert %m1 : tensor<5x5xf64> to tensor + + // Call sparse matrix kernels. + %1 = call @vecSelect(%sv1) : (tensor) -> tensor + %2 = call @matUpperTriangle(%sm1) : (tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, -4, 0, 5, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 1, 0, 2, 0, -4, 0, 0, 0, 5, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 0, 0, 0, 1, 0 ), ( 0, 0, 0, 0, 2 ), ( 0, 3, 0, 4, 0 ), ( 0, 0, 0, 5, 6 ), ( 0, 0, 7, 0, 0 ) ) + // CHECK-NEXT: ( 1, 2, 5, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 1, 0, 2, 0, 0, 0, 0, 0, 5, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 0, 0, 0, 1, 0 ), ( 0, 0, 0, 0, 2 ), ( 0, 0, 0, 4, 0 ), ( 0, 0, 0, 0, 6 ), ( 0, 0, 0, 0, 0 ) ) + // + call @dump_vec(%sv1) : (tensor) -> () + call @dump_mat(%sm1) : (tensor) -> () + call @dump_vec(%1) : (tensor) -> () + call @dump_mat(%2) : (tensor) -> () + + // Release the resources. + bufferization.dealloc_tensor %sv1 : tensor + bufferization.dealloc_tensor %sm1 : tensor + bufferization.dealloc_tensor %1 : tensor + bufferization.dealloc_tensor %2 : tensor + return + } +} diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -259,6 +259,7 @@ case kCIm: case kCRe: case kBitCast: + case kSelect: case kBinaryBranch: case kUnary: return compareExpression(tensorExp.children.e0, pattern->e0);