diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -188,6 +188,8 @@ is solely defined by side-effects and not SSA values. The semantics may be refined over time as our sparse abstractions evolve. + Example: + ```mlir sparse_tensor.lex_insert %tensor, %indices, %val : tensor<1024x1024xf64, #CSR>, memref, f64 @@ -385,7 +387,8 @@ would be equivalent to a union operation where non-overlapping values in the inputs are copied to the output unchanged. - Example of isEqual applied to intersecting elements only. + Example of isEqual applied to intersecting elements only: + ```mlir %C = bufferization.alloc_tensor... %0 = linalg.generic #trait @@ -405,8 +408,8 @@ } -> tensor ``` - Example of A+B in upper triangle, A-B in lower triangle - (not working yet, but construct will be available soon). + Example of A+B in upper triangle, A-B in lower triangle: + ```mlir %C = bufferization.alloc_tensor... %1 = linalg.generic #trait @@ -438,7 +441,8 @@ Example of set difference. Returns a copy of A where its sparse structure is *not* overlapped by B. The element type of B can be different than A - because we never use its values, only its sparse structure. + because we never use its values, only its sparse structure: + ```mlir %C = bufferization.alloc_tensor... %2 = linalg.generic #trait @@ -486,6 +490,7 @@ region does not contribute to the output. Example of A+1, restricted to existing elements: + ```mlir %C = bufferization.alloc_tensor... %0 = linalg.generic #trait @@ -546,6 +551,7 @@ Yields a value from within a `binary` or `unary` block. Example: + ``` %0 = sparse_tensor.unary %a : i64 to i64 { ^bb0(%arg0: i64): diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -881,9 +881,8 @@ } /// Generates an index value. -static Value genIndexValue(Merger &merger, CodeGen &codegen, OpBuilder &builder, - unsigned exp, unsigned ldx) { - unsigned idx = merger.exp(exp).index; +static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx, + unsigned ldx) { Value ival = codegen.loops[idx]; Type itype = ival.getType(); // During vectorization, we either encounter: @@ -913,6 +912,25 @@ return ival; } +/// Semi-ring branches are simply inlined by the sparse compiler. Prior +/// analysis has verified that all computations are "local" to the inlined +/// branch or otherwise invariantly defined outside the loop nest, with the +/// exception of index computations, which need to be relinked to actual +/// inlined cloned code. +static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter, + Block *block, Value e, unsigned ldx) { + if (Operation *def = e.getDefiningOp()) { + if (auto indexOp = dyn_cast(def)) + return genIndexValue(codegen, rewriter, indexOp.dim(), ldx); + if (def->getBlock() == block) { + for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) + def->setOperand( + i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx)); + } + } + return e; +} + /// Recursively generates tensor expression. static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op, unsigned exp, unsigned ldx) { @@ -924,12 +942,17 @@ if (merger.exp(exp).kind == Kind::kInvariant) return genInvariantValue(merger, codegen, rewriter, exp); if (merger.exp(exp).kind == Kind::kIndex) - return genIndexValue(merger, codegen, rewriter, exp, ldx); + return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); - return merger.buildExp(rewriter, loc, exp, v0, v1); + Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); + if (ee && (merger.exp(exp).kind == Kind::kUnary || + merger.exp(exp).kind == Kind::kBinary || + merger.exp(exp).kind == Kind::kBinaryBranch)) + ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); + return ee; } /// Determines if affine expression is invariant. diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -798,7 +798,9 @@ } Optional Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { + // Build the linalg semantics backward from yield. Operation *yield = op.region().front().getTerminator(); + assert(isa(yield)); return buildTensorExp(op, yield->getOperand(0)); } @@ -832,6 +834,37 @@ return dtp; } +/// Ensures that sparse compiler can generate code for expression. +static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) { + // Arguments are always admissable. + if (auto arg = v.dyn_cast()) + return true; + // Accept index anywhere. + Operation *def = v.getDefiningOp(); + if (isa(def)) + return true; + // Operation defined outside branch. + if (def->getBlock() != block) { + return def->getBlock() != op->getBlock(); // invariant? + } + // Operation defined within branch. Anything is accepted, + // as long as all subexpressions are admissable. + for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) + if (!isAdmissableBranchExp(op, block, def->getOperand(i))) + return false; + return true; +} + +/// Ensures that sparse compiler can generate code for branch. +static bool isAdmissableBranch(Operation *op, Region ®ion) { + if (region.empty()) + return true; + // Build the semi-ring branch semantics backward from yield. + Operation *yield = region.front().getTerminator(); + assert(isa(yield)); + return isAdmissableBranchExp(op, ®ion.front(), yield->getOperand(0)); +} + Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { if (auto arg = v.dyn_cast()) { unsigned argN = arg.getArgNumber(); @@ -920,8 +953,11 @@ return addExp(kCRe, e); if (isa(def)) return addExp(kBitCast, e, v); - if (isa(def)) - return addExp(kUnary, e, Value(), def); + if (auto unop = dyn_cast(def)) { + if (isAdmissableBranch(unop, unop.presentRegion()) && + isAdmissableBranch(unop, unop.absentRegion())) + return addExp(kUnary, e, Value(), def); + } } } // Construct binary operations if subexpressions can be built. @@ -971,8 +1007,14 @@ return addExp(kShrU, e0, e1); if (isa(def) && isInvariant(e1)) return addExp(kShlI, e0, e1); - if (isa(def)) - return addExp(kBinary, e0, e1, Value(), def); + if (auto binop = dyn_cast(def)) { + if (isAdmissableBranch(binop, binop.overlapRegion()) && + (binop.left_identity() || + isAdmissableBranch(binop, binop.leftRegion())) && + (binop.right_identity() || + isAdmissableBranch(binop, binop.rightRegion()))) + return addExp(kBinary, e0, e1, Value(), def); + } } } // Cannot build. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_triangular_bin.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_triangular_bin.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_triangular_bin.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +#trait_op = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)>, // B + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel","parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +module { + // Performs triangular add/sub operation (using semi-ring binary op). + func.func @triangular(%A: tensor<4x4xf64, #SparseMatrix>, + %B: tensor<4x4xf64, #SparseMatrix>) -> tensor<4x4xf64, #SparseMatrix> { + %C = bufferization.alloc_tensor() : tensor<4x4xf64, #SparseMatrix> + %0 = linalg.generic #trait_op + ins(%A, %B: tensor<4x4xf64, #SparseMatrix>, + tensor<4x4xf64, #SparseMatrix>) + outs(%C: tensor<4x4xf64, #SparseMatrix>) { + ^bb0(%a: f64, %b: f64, %c: f64) : + %row = linalg.index 0 : index + %col = linalg.index 1 : index + %result = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap={ + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpi "uge", %col, %row : index + %upperTriangleResult = arith.addf %x, %y : f64 + %lowerTriangleResult = arith.subf %x, %y : f64 + %ret = arith.select %cmp, %upperTriangleResult, %lowerTriangleResult : f64 + sparse_tensor.yield %ret : f64 + } + left=identity + right={ + ^bb0(%y: f64): + %cmp = arith.cmpi "uge", %col, %row : index + %lowerTriangleResult = arith.negf %y : f64 + %ret = arith.select %cmp, %y, %lowerTriangleResult : f64 + sparse_tensor.yield %ret : f64 + } + linalg.yield %result : f64 + } -> tensor<4x4xf64, #SparseMatrix> + return %0 : tensor<4x4xf64, #SparseMatrix> + } + + // Driver method to call and verify triangular kernel. + func.func @entry() { + %c0 = arith.constant 0 : index + %du = arith.constant -1.0 : f64 + + %am = arith.constant dense< + [ [ 1.0, 0.0, 3.0, 0.0], + [ 0.0, 2.0, 0.0, 0.0], + [ 0.0, 0.0, 0.0, 4.0], + [ 3.0, 4.0, 0.0, 0.0] ]> : tensor<4x4xf64> + %bm = arith.constant dense< + [ [ 1.0, 0.0, 1.0, 1.0], + [ 0.0, 0.5, 0.0, 0.0], + [ 1.0, 5.0, 2.0, 0.0], + [ 2.0, 0.0, 0.0, 0.0] ]> : tensor<4x4xf64> + + %a = sparse_tensor.convert %am : tensor<4x4xf64> to tensor<4x4xf64, #SparseMatrix> + %b = sparse_tensor.convert %bm : tensor<4x4xf64> to tensor<4x4xf64, #SparseMatrix> + %0 = call @triangular(%a, %b) : (tensor<4x4xf64, #SparseMatrix>, + tensor<4x4xf64, #SparseMatrix>) -> tensor<4x4xf64, #SparseMatrix> + + // + // Verify the results. + // + // CHECK: ( ( 2, 0, 4, 1 ), ( 0, 2.5, 0, 0 ), ( -1, -5, 2, 4 ), ( 1, 4, 0, 0 ) ) + // CHECK-NEXST: ( 2, 4, 1, 2.5, -1, -5, 2, 4, 1, 4, -1, -1, -1, -1, -1, -1 ) + // + %c = sparse_tensor.convert %0 : tensor<4x4xf64, #SparseMatrix> to tensor<4x4xf64> + %m = bufferization.to_memref %c : memref<4x4xf64> + %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x4xf64>, vector<4x4xf64> + vector.print %v : vector<4x4xf64> + %1 = sparse_tensor.values %0 : tensor<4x4xf64, #SparseMatrix> to memref + %2 = vector.transfer_read %1[%c0], %du: memref, vector<16xf64> + vector.print %2 : vector<16xf64> + + // Release the resources. + memref.dealloc %m : memref<4x4xf64> + sparse_tensor.release %a : tensor<4x4xf64, #SparseMatrix> + sparse_tensor.release %b : tensor<4x4xf64, #SparseMatrix> + sparse_tensor.release %0 : tensor<4x4xf64, #SparseMatrix> + return + } +}