diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -536,12 +536,10 @@ {opaqueTp}, {fileName}, EmitCInterface::Off) .getResult(0); - // Allocate a buffer for storing dimension sizes and indices. + // Allocate a temporary buffer for storing dimension sizes and indices. Type indexTp = rewriter.getIndexType(); - auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); uint64_t rank = dstTp.getRank(); - Value dimSizes = rewriter.create( - loc, memTp, ValueRange{constantIndex(rewriter, loc, rank)}); + Value dimSizes = genAlloca(rewriter, loc, rank, indexTp); // If the result tensor has dynamic dimensions, get the dynamic sizes from // the sparse tensor reader. @@ -575,26 +573,27 @@ Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ", {indexTp}, {reader}, EmitCInterface::Off) .getResult(0); + Type eltTp = dstTp.getElementType(); + Value value = genAllocaScalar(rewriter, loc, eltTp); scf::ForOp forOp = rewriter.create(loc, c0, nnz, c1); rewriter.setInsertionPointToStart(forOp.getBody()); - Type eltTp = dstTp.getElementType(); SmallString<18> getNextFuncName{"getSparseTensorReaderNext", primaryTypeFunctionSuffix(eltTp)}; Value indices = dimSizes; // Reuse the indices memref to store indices. - Value value = createFuncCall(rewriter, loc, getNextFuncName, {eltTp}, - {reader, indices}, EmitCInterface::On) - .getResult(0); + createFuncCall(rewriter, loc, getNextFuncName, {eltTp}, + {reader, indices, value}, EmitCInterface::On) + .getResult(0); SmallVector indicesArray; for (uint64_t i = 0; i < rank; i++) { indicesArray.push_back(rewriter.create( loc, indices, constantIndex(rewriter, loc, i))); } - rewriter.create(loc, value, cooBuffer, indicesArray); + Value v = rewriter.create(loc, value); + rewriter.create(loc, v, cooBuffer, indicesArray); rewriter.setInsertionPointAfter(forOp); - // Release the indices buffer and the sparse tensor reader. - rewriter.create(loc, indices); + // Release the sparse tensor reader. createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, EmitCInterface::Off); @@ -608,6 +607,70 @@ } }; +struct OutRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OutOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // Calculate NNZ. + Value src = op.getTensor(); + Value nnz = rewriter.create(loc, src); + + // Allocate a temporary buffer for storing dimension sizes and indices. + auto srcTp = src.getType().template cast(); + uint64_t rank = srcTp.getRank(); + Type indexTp = rewriter.getIndexType(); + Value dimSizes = genAlloca(rewriter, loc, rank, indexTp); + + // Generate code to calculate dimension size values and store the values to + // the buffer. + SmallVector dims; + sizesForTensor(rewriter, dims, loc, srcTp, src); + for (int64_t i = 0; i < rank; i++) { + rewriter.create(loc, dims[i], dimSizes, + constantIndex(rewriter, loc, i)); + } + + // Create a sparse tensor writer and output meta data. + Type opaqueTp = getOpaquePointerType(rewriter); + Value writer = + createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp}, + {op.getDest()}, EmitCInterface::Off) + .getResult(0); + Value rankValue = constantIndex(rewriter, loc, rank); + createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {}, + {writer, rankValue, nnz, dimSizes}, EmitCInterface::On); + + Value indices = dimSizes; // Reuse the dimSizes buffer for indices. + Type eltTp = srcTp.getElementType(); + SmallString<18> outNextFuncName{"outSparseTensorWriterNext", + primaryTypeFunctionSuffix(eltTp)}; + Value value = genAllocaScalar(rewriter, loc, eltTp); + ModuleOp module = op->getParentOfType(); + // For each element in the source tensor, output the element. + rewriter.create( + loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) { + for (int64_t i = 0; i < rank; i++) { + rewriter.create(loc, args[i], indices, + constantIndex(builder, loc, i)); + } + rewriter.create(loc, args.back(), value); + SmallVector operands{writer, rankValue, indices, value}; + FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands, + EmitCInterface::On); + builder.create(loc, TypeRange(), fn, operands); + builder.create(loc); + }); + + // Release the writer. + createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer}, + EmitCInterface::Off); + + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace //===---------------------------------------------------------------------===// @@ -624,7 +687,7 @@ // TODO: If RT not enabled, rewrite concatenate ops, etc here. if (!enableRT) - patterns.add, Sparse2SparseReshapeRewriter>( patterns.getContext()); diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false | FileCheck %s +// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false | FileCheck %s #CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] @@ -10,19 +10,20 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[R:.*]] = call @createSparseTensorReader(%[[A]]) -// CHECK: %[[DS:.*]] = memref.alloc(%[[C2]]) : memref +// CHECK: %[[DS:.*]] = memref.alloca(%[[C2]]) : memref // CHECK: call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]]) // CHECK: %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]] // CHECK: %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] // CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]]) // CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]]) +// CHECK: %[[VB:.*]] = memref.alloca() // CHECK: scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] { -// CHECK: %[[V:.*]] = func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]]) +// CHECK: func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]]) // CHECK: %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]] // CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] +// CHECK: %[[V:.*]] = memref.load %[[VB]][] // CHECK: sparse_tensor.insert %[[V]] into %[[T]]{{\[}}%[[E0]], %[[E1]]] // CHECK: } -// CHECK: memref.dealloc %[[DS]] // CHECK: call @delSparseTensorReader(%[[R]]) // CHECK: %[[R:.*]] = sparse_tensor.convert %[[T]] // CHECK: bufferization.dealloc_tensor %[[T]] @@ -32,3 +33,31 @@ %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor return %0 : tensor } + +// CHECK-LABEL: func.func @sparse_out( +// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>, +// CHECK-SAME: %[[B:.*]]: !llvm.ptr) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index +// CHECK: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[A]] +// CHECK: %[[DS:.*]] = memref.alloca(%[[C2]]) : memref +// CHECK: memref.store %[[C10]], %[[DS]]{{\[}}%[[C0]]] : memref +// CHECK: memref.store %[[C20]], %[[DS]]{{\[}}%[[C1]]] : memref +// CHECK: %[[W:.*]] = call @createSparseTensorWriter(%[[B]]) +// CHECK: call @outSparseTensorWriterMetaData(%[[W]], %[[C2]], %[[NNZ]], %[[DS]]) +// CHECK: %[[V:.*]] = memref.alloca() : memref +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C10]] step %[[C1]] { +// CHECK: scf.for {{.*}} { +// CHECK: func.call @outSparseTensorWriterNextF32(%[[W]], %[[C2]], %[[DS]], %[[V]]) +// CHECK: } +// CHECK: } +// CHECK: call @delSparseTensorWriter(%[[W]]) +// CHECK: return +// CHECK: } +func.func @sparse_out( %arg0: tensor<10x20xf32, #CSR>, %arg1: !llvm.ptr) -> () { + sparse_tensor.out %arg0, %arg1 : tensor<10x20xf32, #CSR>, !llvm.ptr + return +}