diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -514,6 +515,97 @@ } }; +/// Sparse rewriting rule for the new operator. +struct NewRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(NewOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto dstTp = op.getResult().getType().template cast(); + SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + if (!encDst) { + return failure(); + } + + // Create a sparse tensor reader. + Value fileName = op.getSource(); + Type opaqueTp = getOpaquePointerType(rewriter); + Value reader = createFuncCall(rewriter, loc, "createSparseTensorReader", + {opaqueTp}, {fileName}, EmitCInterface::Off) + .getResult(0); + + // Allocate a buffer for storing dimension sizes and indices. + Type indexTp = rewriter.getIndexType(); + auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); + uint64_t rank = dstTp.getRank(); + Value dimSizes = rewriter.create( + loc, memTp, ValueRange{constantIndex(rewriter, loc, rank)}); + + // If the result tensor has dynamic dimensions, get the dynamic sizes from + // the sparse tensor reader. + SmallVector dynSizesArray; + if (!dstTp.hasStaticShape()) { + createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", {}, + {reader, dimSizes}, EmitCInterface::On) + .getResult(0); + ArrayRef dstShape = dstTp.getShape(); + for (auto &d : llvm::enumerate(dstShape)) { + if (d.value() == ShapedType::kDynamicSize) { + dynSizesArray.push_back(rewriter.create( + loc, dimSizes, constantIndex(rewriter, loc, d.index()))); + } + } + } + + // Implement the NewOp as follows: + // %tmp = bufferization.alloc_tensor : an unordered COO with identity + // storage ordering + // for i = 0 to nnz + // get the next element from the input file + // insert the element to %tmp + // %t = sparse_tensor.ConvertOp %tmp + RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); + auto cooBuffer = + rewriter.create(loc, cooTp, dynSizesArray).getResult(); + + Value c0 = constantIndex(rewriter, loc, 0); + Value c1 = constantIndex(rewriter, loc, 1); + Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ", + {indexTp}, {reader}, EmitCInterface::Off) + .getResult(0); + scf::ForOp forOp = rewriter.create(loc, c0, nnz, c1); + rewriter.setInsertionPointToStart(forOp.getBody()); + + Type eltTp = dstTp.getElementType(); + SmallString<18> getNextFuncName{"getSparseTensorReaderNext", + primaryTypeFunctionSuffix(eltTp)}; + Value indices = dimSizes; // Reuse the indices memref to store indices. + Value value = createFuncCall(rewriter, loc, getNextFuncName, {eltTp}, + {reader, indices}, EmitCInterface::On) + .getResult(0); + SmallVector indicesArray; + for (int64_t i = 0; i < rank; i++) { + indicesArray.push_back(rewriter.create( + loc, indices, constantIndex(rewriter, loc, i))); + } + rewriter.create(loc, value, cooBuffer, indicesArray); + rewriter.setInsertionPointAfter(forOp); + + // Release the indices buffer and the sparse tensor reader. + rewriter.create(loc, indices); + createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, + EmitCInterface::Off); + + Value newOp = rewriter.replaceOpWithNewOp(op, dstTp, cooBuffer); + + // Release the unordered COO tensor buffer. + rewriter.setInsertionPointAfterValue(newOp); + rewriter.create(loc, cooBuffer); + + return success(); + } +}; + } // namespace //===---------------------------------------------------------------------===// @@ -527,7 +619,7 @@ patterns.getContext()); // TODO: If RT not enabled, rewrite concatenate ops, etc here. if (!enableRT) - patterns.add, Sparse2SparseReshapeRewriter>( patterns.getContext()); diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed"] +}> + +// CHECK-LABEL: func.func @sparse_new( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> tensor> { +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[R:.*]] = call @createSparseTensorReader(%[[A]]) +// CHECK: %[[DS:.*]] = memref.alloc(%[[C2]]) : memref +// CHECK: call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]]) +// CHECK: %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]] +// CHECK: %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] +// CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]]) +// CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]]) +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] { +// CHECK: %[[V:.*]] = func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]]) +// CHECK: %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]] +// CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] +// CHECK: sparse_tensor.insert %[[V]] into %[[T]]{{\[}}%[[E0]], %[[E1]]] +// CHECK: } +// CHECK: memref.dealloc %[[DS]] +// CHECK: call @delSparseTensorReader(%[[R]]) +// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T]] +// CHECK: bufferization.dealloc_tensor %[[T]] +// CHECK: return %[[R]] +// CHECK: } +func.func @sparse_new(%arg0: !llvm.ptr) -> tensor { + %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor + return %0 : tensor +}