diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -295,6 +295,18 @@ return rewriter.create(loc, values, ivs[0]); } +/// Generates code to stack-allocate a `memref` where the `?` +/// is the given `rank`. This array is intended to serve as a reusable +/// buffer for storing the indices of a single tensor element, to avoid +/// allocation in the body of loops. +static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc, + int64_t rank) { + auto indexTp = rewriter.getIndexType(); + auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); + Value arg = rewriter.create(loc, rewriter.getIndexAttr(rank)); + return rewriter.create(loc, memTp, ValueRange{arg}); +} + //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// @@ -413,13 +425,9 @@ // loop is generated by genAddElt(). Location loc = op->getLoc(); ShapedType shape = resType.cast(); - auto memTp = - MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); Value perm; Value ptr = genNewCall(rewriter, op, encDst, 2, perm); - Value arg = rewriter.create( - loc, rewriter.getIndexAttr(shape.getRank())); - Value ind = rewriter.create(loc, memTp, ValueRange{arg}); + Value ind = allocaIndices(rewriter, loc, shape.getRank()); SmallVector lo; SmallVector hi; SmallVector st;