diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -541,14 +541,18 @@ /// coo->add(reshape(elem.indices), elem.value) /// } /// s = newSparseTensor(coo) +template static LogicalResult -genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter, - ArrayRef reassociation, Value src, - RankedTensorType dstTp, RankedTensorType srcTp) { - Location loc = op->getLoc(); - auto encDst = getSparseTensorEncoding(dstTp); +genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) { + Location loc = op.getLoc(); + auto srcTp = op.getSrc().getType().template cast(); + auto dstTp = op.getResult().getType().template cast(); auto encSrc = getSparseTensorEncoding(srcTp); - assert(encDst && encSrc); + auto encDst = getSparseTensorEncoding(dstTp); + if (!encDst || !encSrc) + return failure(); + unsigned srcRank = srcTp.getRank(); unsigned dstRank = dstTp.getRank(); Type elemTp = srcTp.getElementType(); @@ -560,14 +564,16 @@ encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); SmallVector sizes; SmallVector params; - sizesFromPtr(rewriter, sizes, loc, noPerm, srcTp, src); + sizesFromSrc(rewriter, sizes, loc, op.getSrc()); newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes, - src); + adaptor.getSrc()); Value iter = genNewCall(rewriter, loc, params); // Start a new COO for the destination tensor. sizes.clear(); params.clear(); - sizesFromPtr(rewriter, sizes, loc, encDst, dstTp, src); + // Fills sizes array using the sizes from destination type. + assert(dstTp.hasStaticShape()); + sizesFromType(rewriter, sizes, loc, dstTp); newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes); Value coo = genNewCall(rewriter, loc, params); Value dstPerm = params[2]; @@ -586,7 +592,8 @@ // not need to store the value in elemPtr, as the value is still there. Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); + translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp, + dstIdx, srcIdx); genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm); rewriter.create(loc); // Final call to construct sparse tensor storage and free temporary resources. @@ -756,15 +763,7 @@ LogicalResult matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type dstType = op.getResult().getType(); - Type srcType = op.getSrc().getType(); - auto encDst = getSparseTensorEncoding(dstType); - auto encSrc = getSparseTensorEncoding(srcType); - if (encDst && encSrc) - return genSparse2SparseReshape( - op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0], - dstType.cast(), srcType.cast()); - return failure(); // handled elsewhere + return genSparse2SparseReshape(op, adaptor, rewriter); } };