diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -437,6 +437,56 @@ assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); } +/// Helper method to translate indices during a reshaping operation. +/// TODO: provide as general utility to MLIR at large? +static void translateIndicesDyn(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef reassociation, + TensorType dstTp, TensorType srcTp, Value dstIdx, + Value srcIdx, + SmallVector src_shape, + SmallVector dst_shape) { + unsigned dstRank = dstTp.getRank(); + unsigned srcRank = srcTp.getRank(); + unsigned start = 0; + unsigned i = 0; + bool isExpand = srcRank < dstRank; + SmallVector shape = isExpand ? dst_shape : src_shape; + // Iterate over reassociation map. + for (const auto &map : llvm::enumerate(reassociation)) { + // Prepare strides information in dimension slice. + Value linear = constantIndex(rewriter, loc, 1); + for (unsigned j = start, end = start + map.value().size(); j < end; j++) { + linear = rewriter.create(loc, linear, shape[j]); + } + // Start collapse. + Value idx = constantIndex(rewriter, loc, i++); + Value val; + if (isExpand) + val = rewriter.create(loc, srcIdx, idx); + // Iterate over dimension slice. + for (unsigned j = start, end = start + map.value().size(); j < end; j++) { + linear = rewriter.create(loc, linear, shape[j]); + Value jdx = constantIndex(rewriter, loc, j); + if (!isExpand) { + Value old = rewriter.create(loc, srcIdx, jdx); + Value mul = rewriter.create(loc, old, linear); + val = val ? rewriter.create(loc, val, mul) : mul; + } else { + Value old = val; + val = rewriter.create(loc, val, linear); + rewriter.create(loc, val, dstIdx, jdx); + val = rewriter.create(loc, old, linear); + } + } + // Finalize expansion. + if (!isExpand) + rewriter.create(loc, val, dstIdx, idx); + start += map.value().size(); + } + // Sanity. + assert((!isExpand && i == dstRank) || (isExpand && i == srcRank)); +} + /// Generate code for a general sparse to sparse reshaping operation. /// Note that unlike dense reshaping (which can be done with a "cheap" /// change of view), sparse reshaping is currently done with actual @@ -468,17 +518,17 @@ auto noPerm = SparseTensorEncodingAttr::get( op->getContext(), encSrc.getDimLevelType(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - SmallVector sizes; + SmallVector src_sizes; SmallVector params; - sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src); - newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes, + sizesFromPtr(rewriter, src_sizes, op, noPerm, srcTp, src); + newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, src_sizes, src); Value iter = genNewCall(rewriter, op, params); // Start a new COO for the destination tensor. - sizes.clear(); + SmallVector dst_sizes; params.clear(); - sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src); - newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes); + sizesFromPtr(rewriter, dst_sizes, op, encDst, dstTp, src); + newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, dst_sizes); Value coo = genNewCall(rewriter, op, params); Value dstPerm = params[2]; // Construct a while loop over the iterator. @@ -496,7 +546,12 @@ // not need to store the value in elemPtr, as the value is still there. Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); + if (srcTp.hasStaticShape()) { + translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); + } else { + translateIndicesDyn(loc, rewriter, reassociation, dstTp, srcTp, + dstIdx, srcIdx, src_sizes, dst_sizes); + } genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm); rewriter.create(loc); // Final call to construct sparse tensor storage and free temporary resources. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape_dyn.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape_dyn.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape_dyn.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s --sparse-compiler -mlir-print-ir-before=cse | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext + +#Sparse3dTensor = #sparse_tensor.encoding<{ + dimLevelType = ["compressed", "compressed"] +}> + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + + +module { + func.func @collapse_dyn(%arg0 : tensor) -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor into tensor + return %0 : tensor + } + + func.func @entry() { + %m = arith.constant dense<[ + [1.0, 3.0] + ]> : tensor<1x2xf64> + %sm = sparse_tensor.convert %m : tensor<1x2xf64> to tensor<1x2xf64, #Sparse3dTensor> + %dsm = tensor.cast %sm : tensor<1x2xf64, #Sparse3dTensor> to tensor + %c = call @collapse_dyn(%dsm) : (tensor) -> tensor + + bufferization.dealloc_tensor %sm: tensor<1x2xf64, #Sparse3dTensor> + bufferization.dealloc_tensor %c: tensor + return + } +} \ No newline at end of file