diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1414,20 +1414,18 @@ createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim); offset = rewriter.create(loc, offset, curDim); } - if (dstTp.hasEncoding()) { - if (!allDense) { - // In sparse output case, the destination holds the COO. - Value coo = dst; - dst = params.genNewCall(Action::kFromCOO, coo); - // Release resources. - genDelCOOCall(rewriter, loc, elemTp, coo); - } else { - dst = dstTensor; - } - rewriter.replaceOp(op, dst); - } else { + if (!dstTp.hasEncoding()) { rewriter.replaceOpWithNewOp( op, dstTp.getRankedTensorType(), dst); + } else if (allDense) { + rewriter.replaceOp(op, dstTensor); + } else { + // In sparse output case, the destination holds the COO. + Value coo = dst; + dst = params.genNewCall(Action::kFromCOO, coo); + // Release resources. + genDelCOOCall(rewriter, loc, elemTp, coo); + rewriter.replaceOp(op, dst); } return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -513,13 +513,11 @@ } needTmpCOO = !allDense && !allOrdered; + const RankedTensorType tp = needTmpCOO ? getUnorderedCOOFromType(dstTp) + : dstTp.getRankedTensorType(); + encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst; SmallVector dynSizes; getDynamicSizes(dstTp, sizes, dynSizes); - RankedTensorType tp = dstTp; - if (needTmpCOO) { - tp = getUnorderedCOOFromType(dstTp); - encDst = getSparseTensorEncoding(tp); - } dst = rewriter.create(loc, tp, dynSizes).getResult(); if (allDense) { // Create a view of the values buffer to match the unannotated dense @@ -592,21 +590,20 @@ // Temp variable to avoid needing to call `getRankedTensorType` // in the three use-sites below. const RankedTensorType dstRTT = dstTp; - if (encDst) { - if (!allDense) { - dst = rewriter.create(loc, dst, true); - if (needTmpCOO) { - Value tmpCoo = dst; - dst = rewriter.create(loc, dstRTT, tmpCoo).getResult(); - rewriter.create(loc, tmpCoo); - } - } else { - dst = rewriter.create(loc, dstRTT, annotatedDenseDst) - .getResult(); + if (!encDst) { + rewriter.replaceOpWithNewOp(op, dstRTT, dst); + } else if (allDense) { + rewriter.replaceOp( + op, rewriter.create(loc, dstRTT, annotatedDenseDst) + .getResult()); + } else { + dst = rewriter.create(loc, dst, true); + if (needTmpCOO) { + Value tmpCoo = dst; + dst = rewriter.create(loc, dstRTT, tmpCoo).getResult(); + rewriter.create(loc, tmpCoo); } rewriter.replaceOp(op, dst); - } else { - rewriter.replaceOpWithNewOp(op, dstRTT, dst); } return success(); }