diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -388,20 +388,31 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, ArrayRef reassociation, TensorType dstTp, TensorType srcTp, Value dstIdx, - Value srcIdx) { + Value srcIdx, SmallVector dstShape, + SmallVector srcShape) { unsigned dstRank = dstTp.getRank(); unsigned srcRank = srcTp.getRank(); unsigned start = 0; unsigned i = 0; bool isExpand = srcRank > dstRank; + bool isStatic = srcTp.hasStaticShape(); ArrayRef shape = isExpand ? srcTp.getShape() : dstTp.getShape(); + SmallVector shape_dyn = isExpand ? srcShape : dstShape; // Iterate over reassociation map. for (const auto &map : llvm::enumerate(reassociation)) { // Prepare strides information in dimension slice. uint64_t linear = 1; - for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - assert(!ShapedType::isDynamic(shape[j])); - linear *= shape[j]; + Value linear_dyn; + if (isStatic) { + for (unsigned j = start, end = start + map.value().size(); j < end; j++) { + linear *= shape[j]; + } + } else { + linear_dyn = constantIndex(rewriter, loc, 1); + for (unsigned j = start, end = start + map.value().size(); j < end; j++) { + linear_dyn = rewriter.create( + loc, linear_dyn, shape_dyn[j]); + } } // Start collapse. Value idx = constantIndex(rewriter, loc, i++); @@ -410,21 +421,27 @@ val = rewriter.create(loc, srcIdx, idx); // Iterate over dimension slice. for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - linear /= shape[j]; - Value stride = constantIndex(rewriter, loc, linear); + if (isStatic) { + linear /= shape[j]; + } else { + linear_dyn = rewriter.create( + loc, linear_dyn, shape_dyn[j]); + } + Value stride = isStatic ? constantIndex(rewriter, loc, linear) + : linear_dyn; Value jdx = constantIndex(rewriter, loc, j); if (isExpand) { Value old = rewriter.create(loc, srcIdx, jdx); - Value mul = linear == 1 + Value mul = isStatic && linear == 1 ? old : rewriter.create(loc, old, stride); val = val ? rewriter.create(loc, val, mul) : mul; } else { Value old = val; - if (linear != 1) + if (!isStatic || linear != 1) val = rewriter.create(loc, val, stride); rewriter.create(loc, val, dstIdx, jdx); - if (linear != 1) + if (!isStatic || linear != 1) val = rewriter.create(loc, old, stride); } } @@ -437,6 +454,30 @@ assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); } +/// Helper method to compute the shape of destination tensor of a reshape +/// operator. This is only used when operands have dynamic shape. +static SmallVector genReshapeDstShape( + Location loc, + ConversionPatternRewriter &rewriter, + SmallVector srcShape, + ArrayRef reassociation) { + // TODO: Support computing dst shape for expand_shape operator. + assert(reassociation.size() <= srcShape.size()); + SmallVector dstShape; + unsigned start = 0; + for (const auto &map : llvm::enumerate(reassociation)) { + auto dstDim = constantIndex(rewriter, loc, 1); + for (unsigned i = start; i < start + map.value().size(); i++) { + dstDim = rewriter.create(loc, dstDim, srcShape[i]); + } + dstShape.push_back(dstDim); + start = start + map.value().size(); + } + // Sanity. + assert(start == srcShape.size()); + return dstShape; +} + /// Generate code for a general sparse to sparse reshaping operation. /// Note that unlike dense reshaping (which can be done with a "cheap" /// change of view), sparse reshaping is currently done with actual @@ -468,17 +509,21 @@ auto noPerm = SparseTensorEncodingAttr::get( op->getContext(), encSrc.getDimLevelType(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - SmallVector sizes; + SmallVector srcSizes; SmallVector params; - sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src); - newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes, + sizesFromPtr(rewriter, srcSizes, op, noPerm, srcTp, src); + newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, srcSizes, src); Value iter = genNewCall(rewriter, op, params); // Start a new COO for the destination tensor. - sizes.clear(); + SmallVector dstSizes; params.clear(); - sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src); - newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes); + if (srcTp.hasStaticShape()) { + sizesFromPtr(rewriter, dstSizes, op, encDst, dstTp, src); + } else { + dstSizes = genReshapeDstShape(loc, rewriter, srcSizes, reassociation); + } + newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, dstSizes); Value coo = genNewCall(rewriter, op, params); Value dstPerm = params[2]; // Construct a while loop over the iterator. @@ -496,7 +541,8 @@ // not need to store the value in elemPtr, as the value is still there. Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); + translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx, + dstSizes, srcSizes); genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm); rewriter.create(loc); // Final call to construct sparse tensor storage and free temporary resources. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir @@ -135,6 +135,11 @@ return %0 : tensor } + func.func @collapse_sparse2sparse_dyn(%arg0: tensor) -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + return %0 : tensor + } + // // Main driver. // @@ -189,6 +194,7 @@ %collapse8 = call @collapse_dense_dyn(%dn) : (tensor) -> tensor %collapse9 = call @collapse_from_sparse_dyn(%sdn) : (tensor) -> tensor %collapse10 = call @collapse_to_sparse_dyn(%dn) : (tensor) -> tensor + %collapse11 = call @collapse_sparse2sparse_dyn(%sdn) : (tensor) -> tensor // // Verify results of expand @@ -250,6 +256,7 @@ // CHECK-NEXT: ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ), ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ), ( 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 ), ( 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 ), ( 41, 42, 43, 44, 45, 26, 47, 48, 49, 50 ), ( 51, 52, 53, 54, 55, 56, 57, 58, 59, 60 ) ) // CHECK-NEXT: ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ), ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ), ( 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 ), ( 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 ), ( 41, 42, 43, 44, 45, 26, 47, 48, 49, 50 ), ( 51, 52, 53, 54, 55, 56, 57, 58, 59, 60 ) ) // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 26, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 26, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, -1, -1, -1, -1 ) // %v0 = vector.transfer_read %collapse0[%c0], %df: tensor<12xf64>, vector<12xf64> @@ -281,6 +288,10 @@ %b10 = sparse_tensor.values %collapse10 : tensor to memref %v10 = vector.transfer_read %b10[%c0], %df: memref, vector<64xf64> vector.print %v10 : vector<64xf64> + %b11 = sparse_tensor.values %collapse11 : tensor to memref + %v11 = vector.transfer_read %b11[%c0], %df: memref, vector<64xf64> + vector.print %v11 : vector<64xf64> + // Release sparse resources. bufferization.dealloc_tensor %sv : tensor<12xf64, #SparseVector> @@ -298,6 +309,7 @@ bufferization.dealloc_tensor %collapse6 : tensor<6x10xf64, #SparseMatrix> bufferization.dealloc_tensor %collapse7 : tensor<6x10xf64, #SparseMatrix> bufferization.dealloc_tensor %collapse10 : tensor + bufferization.dealloc_tensor %collapse11 : tensor // Release dense resources. bufferization.dealloc_tensor %expand1 : tensor<3x4xf64>