diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1413,20 +1413,18 @@ createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim); offset = rewriter.create(loc, offset, curDim); } - if (dstTp.hasEncoding()) { - if (!allDense) { - // In sparse output case, the destination holds the COO. - Value coo = dst; - dst = params.genNewCall(Action::kFromCOO, coo); - // Release resources. - genDelCOOCall(rewriter, loc, elemTp, coo); - } else { - dst = dstTensor; - } + if (!dstTp.hasEncoding()) { + rewriter.replaceOpWithNewOp( + op, dstTp.getRankedTensorType(), dst); + } else if (allDense) { + rewriter.replaceOp(op, dstTensor); + } else { + // In sparse output case, the destination holds the COO. + Value coo = dst; + dst = params.genNewCall(Action::kFromCOO, coo); + // Release resources. + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.replaceOp(op, dst); - } else { - rewriter.replaceOpWithNewOp( - op, dstTp.getRankedTensorType(), dst); } return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -519,13 +519,14 @@ } needTmpCOO = !allDense && !allOrdered; + const RankedTensorType tp = needTmpCOO + ? getUnorderedCOOFromType(dstTp) + : static_cast(dstTp); + encDst = needTmpCOO + ? getSparseTensorEncoding(tp) + : encDst; SmallVector dynSizes; getDynamicSizes(dstTp, sizes, dynSizes); - RankedTensorType tp = dstTp; - if (needTmpCOO) { - tp = getUnorderedCOOFromType(dstTp); - encDst = getSparseTensorEncoding(tp); - } dst = rewriter.create(loc, tp, dynSizes).getResult(); if (allDense) { // Create a view of the values buffer to match the unannotated dense @@ -598,21 +599,18 @@ // Temp variable to avoid needing to call `getRankedTensorType` // in the three use-sites below. const RankedTensorType dstRTT = dstTp; - if (encDst) { - if (!allDense) { - dst = rewriter.create(loc, dst, true); - if (needTmpCOO) { - Value tmpCoo = dst; - dst = rewriter.create(loc, dstRTT, tmpCoo).getResult(); - rewriter.create(loc, tmpCoo); - } - } else { - dst = rewriter.create(loc, dstRTT, annotatedDenseDst) - .getResult(); - } - rewriter.replaceOp(op, dst); - } else { + if (!encDst) { rewriter.replaceOpWithNewOp(op, dstRTT, dst); + } else if (allDense) { + rewriter.replaceOp(op, rewriter.create(loc, dstRTT, annotatedDenseDst).getResult()); + } else { + dst = rewriter.create(loc, dst, true); + if (needTmpCOO) { + Value tmpCoo = dst; + dst = rewriter.create(loc, dstRTT, tmpCoo).getResult(); + rewriter.create(loc, tmpCoo); + } + rewriter.replaceOp(op, dst); } return success(); }