diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -351,4 +351,24 @@ let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; } -#endif // SPARSETENSOR_OPS +def SparseTensor_OutOp : SparseTensor_Op<"out", []>, + Arguments<(ins AnyType:$tensor, AnyType:$dest)> { + string summary = "Outputs a sparse tensor to the given destination"; + string description = [{ + Outputs the contents of a sparse tensor to the destination defined by an + opaque pointer provided by `destn`. For targets that have access to a file + system, for example, this pointer specify a filename (or file) for output. + The form of the operation is kept deliberately very general to allow for + alternative implementations in the future, such as sending the contents to + a buffer defined by a pointer. + + Example: + + ```mlir + sparse_tensor.out %t, %dest : tensor<1024x1024xf64, #CSR>, !Dest + ``` + }]; + let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; +} + +#endif // SPARSETENSOR_OPS { diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -329,6 +329,12 @@ return success(); } +static LogicalResult verify(OutOp op) { + if (!getSparseTensorEncoding(op.tensor().getType())) + return op.emitError("expected a sparse tensor for output"); + return success(); +} + //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -13,7 +13,6 @@ // tensor storage schemes. // //===----------------------------------------------------------------------===// - #include "CodegenUtils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -189,8 +188,8 @@ /// computation. static void newParams(ConversionPatternRewriter &rewriter, SmallVector ¶ms, Operation *op, - SparseTensorEncodingAttr &enc, Action action, - ValueRange szs, Value ptr = Value()) { + ShapedType stp, SparseTensorEncodingAttr &enc, + Action action, ValueRange szs, Value ptr = Value()) { Location loc = op->getLoc(); ArrayRef dlt = enc.getDimLevelType(); unsigned sz = dlt.size(); @@ -218,7 +217,7 @@ } params.push_back(genBuffer(rewriter, loc, rev)); // Secondary and primary types encoding. - Type elemTp = op->getResult(0).getType().cast().getElementType(); + Type elemTp = stp.getElementType(); params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); @@ -420,9 +419,10 @@ // inferred from the result type of the new operator. SmallVector sizes; SmallVector params; - sizesFromType(rewriter, sizes, op.getLoc(), resType.cast()); + ShapedType stp = resType.cast(); + sizesFromType(rewriter, sizes, op.getLoc(), stp); Value ptr = adaptor.getOperands()[0]; - newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); + newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr); rewriter.replaceOp(op, genNewCall(rewriter, op, params)); return success(); } @@ -441,7 +441,9 @@ // Generate the call to construct empty tensor. The sizes are // explicitly defined by the arguments to the init operator. SmallVector params; - newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); + ShapedType stp = resType.cast(); + newParams(rewriter, params, op, stp, enc, Action::kEmpty, + adaptor.getOperands()); rewriter.replaceOp(op, genNewCall(rewriter, op, params)); return success(); } @@ -472,15 +474,15 @@ } SmallVector sizes; SmallVector params; - sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast(), - src); + ShapedType stp = resType.cast(); + sizesFromPtr(rewriter, sizes, op, encSrc, stp, src); // Set up encoding with right mix of src and dst so that the two // method calls can share most parameters, while still providing // the correct sparsity information to either of them. auto enc = SparseTensorEncodingAttr::get( op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); + newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src); Value coo = genNewCall(rewriter, op, params); params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); @@ -512,7 +514,8 @@ SmallVector sizes; SmallVector params; sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); - newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); + newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator, + sizes, src); Value iter = genNewCall(rewriter, op, params); Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); @@ -567,7 +570,7 @@ SmallVector sizes; SmallVector params; sizesFromSrc(rewriter, sizes, loc, src); - newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); + newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes); Value ptr = genNewCall(rewriter, op, params); Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); Value perm = params[2]; @@ -771,6 +774,45 @@ } }; +class SparseTensorOutConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ShapedType srcType = op.tensor().getType().cast(); + // Convert to default permuted COO. + Value src = adaptor.getOperands()[0]; + auto encSrc = getSparseTensorEncoding(srcType); + SmallVector sizes; + SmallVector params; + sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src); + auto enc = SparseTensorEncodingAttr::get( + op->getContext(), encSrc.getDimLevelType(), AffineMap(), + encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); + newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src); + Value coo = genNewCall(rewriter, op, params); + // Then output the tensor to external file with indices in the externally + // visible lexicographic index order. A sort is required if the source was + // not in that order yet (note that the sort can be dropped altogether if + // external format does not care about the order at all, but here we assume + // it does). + bool sort = + encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); + params.clear(); + params.push_back(coo); + params.push_back(adaptor.getOperands()[1]); + params.push_back(constantI1(rewriter, loc, sort)); + Type eltType = srcType.getElementType(); + SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; + TypeRange noTp; + replaceOpWithFuncCall(rewriter, op, name, noTp, params, + EmitCInterface::Off); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -787,6 +829,6 @@ SparseTensorReleaseConverter, SparseTensorToPointersConverter, SparseTensorToIndicesConverter, SparseTensorToValuesConverter, SparseTensorLoadConverter, SparseTensorLexInsertConverter, - SparseTensorExpandConverter, SparseTensorCompressConverter>( - typeConverter, patterns.getContext()); + SparseTensorExpandConverter, SparseTensorCompressConverter, + SparseTensorOutConverter>(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include @@ -686,6 +688,31 @@ return tensor; } +/// Writes the sparse tensor to extended FROSTT format. +template +void outSparseTensor(SparseTensorCOO *tensor, char *filename) { + auto &sizes = tensor->getSizes(); + auto &elements = tensor->getElements(); + uint64_t rank = tensor->getRank(); + uint64_t nnz = elements.size(); + std::fstream file; + file.open(filename, std::ios_base::out | std::ios_base::trunc); + assert(file.is_open()); + file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl; + for (uint64_t r = 0; r < rank - 1; r++) + file << sizes[r] << " "; + file << sizes[rank - 1] << std::endl; + for (uint64_t i = 0; i < nnz; i++) { + auto &idx = elements[i].indices; + for (uint64_t r = 0; r < rank; r++) + file << (idx[r] + 1) << " "; + file << elements[i].value << std::endl; + } + file.flush(); + file.close(); + assert(file.good()); +} + } // namespace extern "C" { @@ -818,6 +845,17 @@ cursor, values, filled, added, count); \ } +#define IMPL_OUT(NAME, V) \ + void NAME(void *tensor, void *dest, bool sort) { \ + assert(tensor &&dest); \ + auto coo = static_cast *>(tensor); \ + if (sort) \ + coo->sort(); \ + char *filename = static_cast(dest); \ + outSparseTensor(coo, filename); \ + delete coo; \ + } + // Assume index_t is in fact uint64_t, so that _mlir_ciface_newSparseTensor // can safely rewrite kIndex to kU64. We make this assertion to guarantee // that this file cannot get out of sync with its header. @@ -999,6 +1037,14 @@ IMPL_EXPINSERT(expInsertI16, int16_t) IMPL_EXPINSERT(expInsertI8, int8_t) +/// Helper to output a sparse tensor, one per value type. +IMPL_OUT(outSparseTensorF64, double) +IMPL_OUT(outSparseTensorF32, float) +IMPL_OUT(outSparseTensorI64, int64_t) +IMPL_OUT(outSparseTensorI32, int32_t) +IMPL_OUT(outSparseTensorI16, int16_t) +IMPL_OUT(outSparseTensorI8, int8_t) + #undef CASE #undef IMPL_SPARSEVALUES #undef IMPL_GETOVERHEAD @@ -1006,6 +1052,7 @@ #undef IMPL_GETNEXT #undef IMPL_LEXINSERT #undef IMPL_EXPINSERT +#undef IMPL_OUT //===----------------------------------------------------------------------===// // @@ -1135,6 +1182,7 @@ *pValues = values; *pIndices = indices; } + } // extern "C" #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -468,3 +468,27 @@ : tensor<8x8xf64, #SparseMatrix>, memref, memref, memref, memref, index return } + +// CHECK-LABEL: func @sparse_out1( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr, +// CHECK-SAME: %[[B:.*]]: !llvm.ptr) +// CHECK-DAG: %[[C:.*]] = arith.constant false +// CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]]) +// CHECK: call @outSparseTensorF64(%[[T]], %[[B]], %[[C]]) : (!llvm.ptr, !llvm.ptr, i1) -> () +// CHECK: return +func @sparse_out1(%arg0: tensor, %arg1: !llvm.ptr) { + sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr + return +} + +// CHECK-LABEL: func @sparse_out2( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr, +// CHECK-SAME: %[[B:.*]]: !llvm.ptr) +// CHECK-DAG: %[[C:.*]] = arith.constant true +// CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]]) +// CHECK: call @outSparseTensorF32(%[[T]], %[[B]], %[[C]]) : (!llvm.ptr, !llvm.ptr, i1) -> () +// CHECK: return +func @sparse_out2(%arg0: tensor, %arg1: !llvm.ptr) { + sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr + return +} diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -204,3 +204,11 @@ %0 = sparse_tensor.convert %arg0 : tensor<10x?xf32> to tensor<10x10xf32, #CSR> return %0 : tensor<10x10xf32, #CSR> } + +// ----- + +func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr) { + // expected-error@+1 {{expected a sparse tensor for output}} + sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr + return +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -179,3 +179,17 @@ : tensor<8x8xf64, #SparseMatrix>, memref, memref, memref, memref, index return } + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_out( +// CHECK-SAME: %[[A:.*]]: tensor>, +// CHECK-SAME: %[[B:.*]]: !llvm.ptr) +// CHECK: sparse_tensor.out %[[A]], %[[B]] : tensor>, !llvm.ptr +// CHECK: return +func @sparse_out(%arg0: tensor, %arg1: !llvm.ptr) { + sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr + return +}