diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -358,6 +358,14 @@ } }; +static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc, + int64_t rank) { + auto indexTp = rewriter.getIndexType(); + auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); + Value arg = rewriter.create(loc, rewriter.getIndexAttr(rank)); + return rewriter.create(loc, memTp, ValueRange{arg}); +} + /// Sparse conversion rule for the convert operator. class SparseTensorConvertConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -412,13 +420,9 @@ // loop is generated by genAddElt(). Location loc = op->getLoc(); ShapedType shape = resType.cast(); - auto memTp = - MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); Value perm; Value ptr = genNewCall(rewriter, op, encDst, 2, perm); - Value arg = rewriter.create( - loc, rewriter.getIndexAttr(shape.getRank())); - Value ind = rewriter.create(loc, memTp, ValueRange{arg}); + Value ind = allocaIndices(rewriter, loc, shape.getRank()); SmallVector lo; SmallVector hi; SmallVector st;