diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -102,18 +102,21 @@ let hasVerifier = 1; } -def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>, +def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultSize]>, Arguments<(ins AnySparseTensor:$tensor, TensorOf<[AnyType]>:$out_values, Variadic>:$out_levels)>, Results<(outs TensorOf<[AnyType]>:$ret_values, - Variadic>:$ret_levels)> { + Variadic>:$ret_levels, + Index:$val_len, + Variadic:$lvl_lens)> { let summary = "Returns the (values, coordinates) pair unpacked from the input tensor"; let description = [{ The unpack operation is the inverse of `sparse_tensor::pack`. It returns the values and per-level position and coordinate array to the user - from the sparse tensor. This operation can be used for returning an + from the sparse tensor along with the actual length of the memory used in + each returned buffer. This operation can be used for returning an unpacked MLIR sparse tensor to frontend; e.g., returning two numpy arrays to Python. Disclaimer: This is the user's responsibility to allocate large enough buffers @@ -128,18 +131,22 @@ // input COO format |1.1, 0.0, 0.0, 0.0| // of 3x4 matrix |0.0, 0.0, 2.2, 3.3| // |0.0, 0.0, 0.0, 0.0| - %values, %pos, %coords = sparse_tensor.unpack %sp : tensor<3x4xf64, #SparseVector> - outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>) - -> tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex> - // %values = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64> - // %pos = arith.constant dense<[ 0, 3 ]> : tensor<2xindex> - // %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex> + %v, %p, %c, %v_len, %p_len, %c_len = sparse_tensor.unpack %sp : tensor<3x4xf64, #SparseVector> + outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>) + -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index) + // %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64> + // %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex> + // %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex> + // %v_len = 3 + // %p_len = 2 + // %c_len = 6 (3x2) ``` }]; let assemblyFormat = - "$tensor `:` type($tensor) `outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)`" - "attr-dict `->` type($ret_values) `,` type($ret_levels)"; + "$tensor `:` type($tensor) " + "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict" + "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`"; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -166,13 +166,13 @@ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // We write into the output operand. - assert(op->getNumOperands() == op->getNumResults() + 1); + assert(2 * (op->getNumOperands() - 1) == op->getNumResults()); return opOperand.getOperandNumber() > 0; } AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(op->getNumOperands() == op->getNumResults() + 1); + assert(2 * (op->getNumOperands() - 1) == op->getNumResults()); if (opOperand.getOperandNumber() == 0) return {}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1311,7 +1311,8 @@ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); Location loc = op.getLoc(); SmallVector retMem; - desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem]( + SmallVector retLen; + desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, &retLen]( FieldIndex fid, SparseTensorFieldKind fKind, Level lvl, DimLevelType dlt) -> bool { @@ -1329,6 +1330,7 @@ // TODO: maybe change unpack/pack operation instead to be // consistent. retMem.insert(retMem.begin(), dst); + retLen.insert(retLen.begin(), sz); } else { assert(fKind == SparseTensorFieldKind::PosMemRef || fKind == SparseTensorFieldKind::CrdMemRef); @@ -1339,6 +1341,7 @@ src = desc.getMemRefField(fid); dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]); retMem.push_back(dst); + retLen.push_back(sz); } Value flatOut = dst; if (dst.getType().getRank() != 1) { @@ -1352,12 +1355,13 @@ }); // Converts MemRefs back to Tensors. - SmallVector retTensor = llvm::to_vector( + SmallVector retValues = llvm::to_vector( llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value { return rewriter.create(loc, v); })); - - rewriter.replaceOp(op, retTensor); + // Appends the actual memory length used in each buffer returned. + retValues.append(retLen.begin(), retLen.end()); + rewriter.replaceOp(op, retValues); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -60,9 +60,9 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) { // expected-error@+1 {{input/output element-types don't match}} - %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector> + %rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector> outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>) - -> tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> + -> tensor<6xf64>, (tensor<2xi32>, tensor<6x1xi32>), index, (index, index) return } @@ -72,9 +72,9 @@ func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) { // expected-error@+1 {{input/output trailing COO level-ranks don't match}} - %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector> + %rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector> outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>) - -> tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32> + -> tensor<6xf64>, (tensor<2xi32>, tensor<6x3xi32>), index, (index, index) return } @@ -84,9 +84,9 @@ func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) { // expected-error@+1 {{inconsistent number of fields between input/output}} - %rv, %rc = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR> + %rv, %rc, %vl, %pl = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR> outs(%values, %coordinates : tensor<6xf64>, tensor<6xi32>) - -> tensor<6xf64>, tensor<6xi32> + -> tensor<6xf64>, (tensor<6xi32>), index, (index) return } diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -36,16 +36,16 @@ // CHECK-SAME: %[[OD:.*]]: tensor<6xf64> // CHECK-SAME: %[[OP:.*]]: tensor<2xindex> // CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32> -// CHECK: %[[D:.*]], %[[P:.*]]:2 = sparse_tensor.unpack %[[T]] +// CHECK: %[[D:.*]], %[[P:.*]]:2, %[[DL:.*]], %[[PL:.*]]:2 = sparse_tensor.unpack %[[T]] // CHECK: return %[[D]], %[[P]]#0, %[[P]]#1 func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>, %od : tensor<6xf64>, %op : tensor<2xindex>, %oi : tensor<6x1xi32>) -> (tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) { - %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector> + %rd, %rp, %ri, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector> outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) - -> tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32> + -> tensor<6xf64>, (tensor<2xindex>, tensor<6x1xi32>), index, (index, index) return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32> } diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir @@ -70,8 +70,8 @@ %op : tensor<2xindex>, %oi : tensor<6x2xi32>) -> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) { - %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO> - outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) - -> tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32> + %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO> + outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) + -> tensor<6xf64>, (tensor<2xindex>, tensor<6x2xi32>), index, (index, index) return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32> } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir @@ -171,9 +171,9 @@ %d_csr = tensor.empty() : tensor<4xf64> %p_csr = tensor.empty() : tensor<3xi32> %i_csr = tensor.empty() : tensor<3xi32> - %rd_csr, %rp_csr, %ri_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR> + %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR> outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>) - -> tensor<4xf64>, tensor<3xi32>, tensor<3xi32> + -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (index, index) // CHECK-NEXT: ( 1, 2, 3, {{.*}} ) %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64> @@ -196,9 +196,9 @@ %od = tensor.empty() : tensor<3xf64> %op = tensor.empty() : tensor<2xi32> %oi = tensor.empty() : tensor<3x2xi32> - %d, %p, %i = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32> + %d, %p, %i, %dl, %pl, %il = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32> outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>) - -> tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32> + -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (index, index) // CHECK-NEXT: ( 1, 2, 3 ) %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64> @@ -212,17 +212,21 @@ %bod = tensor.empty() : tensor<6xf64> %bop = tensor.empty() : tensor<4xindex> %boi = tensor.empty() : tensor<6x2xindex> - %bd, %bp, %bi = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO> + %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO> outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>) - -> tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex> + -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (index, index) // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} ) %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64> vector.print %vbd : vector<6xf64> + // CHECK-NEXT: 5 + vector.print %ld : index // CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ), ( 2, 3 ), ( 4, 2 ), ( {{.*}}, {{.*}} ) ) %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex> vector.print %vbi : vector<6x2xindex> + // CHECK-NEXT: 10 + vector.print %li : index return }