diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -160,6 +160,16 @@ return rewriter.create(loc, memTp, ValueRange{sz}); } +/// Generates an uninitialized buffer of the given size and type, +/// but returns it as type `memref` (rather than as type +/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, +/// this buffer must be explicitly deallocated by client. +static Value genAlloc(ConversionPatternRewriter &rewriter, Location loc, + Value sz, Type tp) { + auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); + return rewriter.create(loc, memTp, ValueRange{sz}); +} + /// Generates an uninitialized temporary buffer of the given size and /// type, but returns it as type `memref` (rather than as type /// `memref<$sz x $tp>`). @@ -761,15 +771,18 @@ auto enc = getSparseTensorEncoding(srcType); Value src = adaptor.getOperands()[0]; Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); - // Allocate temporary stack buffers for values, filled-switch, and indices. - Value values = genAlloca(rewriter, loc, sz, eltType); - Value filled = genAlloca(rewriter, loc, sz, boolType); - Value indices = genAlloca(rewriter, loc, sz, idxType); + // Allocate temporary buffers for values, filled-switch, and indices. + // We do not use stack buffers for this, since the expanded size may + // be rather large (as it envelops a single expanded dense dimension). + Value values = genAlloc(rewriter, loc, sz, eltType); + Value filled = genAlloc(rewriter, loc, sz, boolType); + Value indices = genAlloc(rewriter, loc, sz, idxType); Value zero = constantZero(rewriter, loc, idxType); // Reset the values/filled-switch to all-zero/false. Note that this // introduces an O(N) operation into the computation, but this reset // operation is amortized over the innermost loops for the access - // pattern expansion. + // pattern expansion. As noted in the operation doc, we would like + // to amortize this setup cost even between kernels. rewriter.create( loc, ValueRange{constantZero(rewriter, loc, eltType)}, ValueRange{values}); @@ -789,6 +802,7 @@ LogicalResult matchAndRewrite(CompressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); // Note that this method call resets the values/filled-switch back to // all-zero/false by only iterating over the set elements, so the // complexity remains proportional to the sparsity of the expanded @@ -798,6 +812,18 @@ TypeRange noTp; replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), EmitCInterface::On); + // Deallocate the buffers on exit of the loop nest. + Operation *parent = op; + for (; isa(parent->getParentOp()) || + isa(parent->getParentOp()) || + isa(parent->getParentOp()) || + isa(parent->getParentOp()); + parent = parent->getParentOp()) + ; + rewriter.setInsertionPointAfter(parent); + rewriter.create(loc, adaptor.getOperands()[2]); + rewriter.create(loc, adaptor.getOperands()[3]); + rewriter.create(loc, adaptor.getOperands()[4]); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -461,24 +461,31 @@ } // CHECK-LABEL: func @sparse_expansion() -// %[[S:.*]] = call @sparseDimSize -// %[[V:.*]] = memref.alloca(%[[S]]) : memref -// %[[F:.*]] = memref.alloca(%[[S]]) : memref -// %[[A:.*]] = memref.alloca(%[[S]]) : memref -// linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref) -// linalg.fill ins(%{{.*}} : i1) outs(%[[F]] : memref) -// CHECK: return -func @sparse_expansion() { +// CHECK: %[[S:.*]] = call @sparseDimSize +// CHECK: %[[A:.*]] = memref.alloc(%[[S]]) : memref +// CHECK: %[[B:.*]] = memref.alloc(%[[S]]) : memref +// CHECK: %[[C:.*]] = memref.alloc(%[[S]]) : memref +// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref) +// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) +// CHECK: return %[[C]] : memref +func @sparse_expansion() -> memref { %c = arith.constant 8 : index %0 = sparse_tensor.init [%c, %c] : tensor<8x8xf64, #SparseMatrix> %values, %filled, %added, %count = sparse_tensor.expand %0 : tensor<8x8xf64, #SparseMatrix> to memref, memref, memref, index - return + return %added : memref } // CHECK-LABEL: func @sparse_compression( -// CHECK-SAME: %[[A:.*]]: !llvm.ptr, +// CHECK-SAME: %[[A:.*0]]: !llvm.ptr, +// CHECK-SAME: %[[B:.*1]]: memref, +// CHECK-SAME: %[[C:.*2]]: memref, +// CHECK-SAME: %[[D:.*3]]: memref, +// CHECK-SAME: %[[E:.*4]]: memref, // CHECK: call @expInsertF64(%[[A]], +// CHECK-DAG: memref.dealloc %[[C]] : memref +// CHECK-DAG: memref.dealloc %[[D]] : memref +// CHECK-DAG: memref.dealloc %[[E]] : memref // CHECK: return func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>, %arg1: memref, %arg2: memref, %arg3: memref, diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s -sparsification | \ +// RUN: FileCheck %s --check-prefix=CHECK-SPARSE +// RUN: mlir-opt %s -sparsification -sparse-tensor-conversion | \ +// RUN: FileCheck %s --check-prefix=CHECK-CONVERT + +#DCSC = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +#SV = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed" ] +}> + +#rowsum = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i)> // x (out) + ], + iterator_types = ["parallel", "reduction"], + doc = "X(i) = SUM A(i,j)" +} + +// +// CHECK-SPARSE-LABEL: func @kernel( +// CHECK-SPARSE: %[[A:.*]], %[[B:.*]], %[[C:.*]], %{{.*}} = sparse_tensor.expand +// CHECK-SPARSE: scf.for +// CHECK-SPARSE: scf.for +// CHECK-SPARSE: sparse_tensor.compress %{{.*}}, %{{.*}}, %[[A]], %[[B]], %[[C]] +// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.load %{{.*}} hasInserts +// CHECK-SPARSE: return %[[RET]] +// +// CHECK-CONVERT-LABEL: func @kernel( +// CHECK-CONVERT: %{{.*}} = call @sparseDimSize +// CHECK-CONVERT: %[[S:.*]] = call @sparseDimSize +// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref +// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref +// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref +// CHECK-CONVERT: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref) +// CHECK-CONVERT: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) +// CHECK-CONVERT: scf.for +// CHECK-CONVERT: scf.for +// CHECK-CONVERT: call @expInsertF64 +// CHECK-CONVERT: memref.dealloc %[[A]] : memref +// CHECK-CONVERT: memref.dealloc %[[B]] : memref +// CHECK-CONVERT: memref.dealloc %[[C]] : memref +// CHECK-CONVERT: call @endInsert +// +func @kernel(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %n = tensor.dim %arga, %c0 : tensor + %v = sparse_tensor.init [%n] : tensor + %0 = linalg.generic #rowsum + ins(%arga: tensor) + outs(%v: tensor) { + ^bb(%a: f64, %x: f64): + %1 = arith.addf %x, %a : f64 + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor +}