diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -195,6 +195,15 @@ static_cast(dimLevelTypeEncoding(dlt))); } +/// Computes the shape of destination tensor of a reshape operator. This is only +/// used when operands have dynamic shape. The shape of the destination is +/// stored into dstShape. +void genReshapeDstShape(Location loc, PatternRewriter &rewriter, + SmallVector &dstShape, + ArrayRef srcShape, + ArrayRef staticDstShape, + ArrayRef reassociation); + /// Helper method to translate indices during a reshaping operation. void translateIndicesArray(OpBuilder &builder, Location loc, ArrayRef reassociation, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -200,6 +200,61 @@ llvm_unreachable("Non-numeric type"); } +void mlir::sparse_tensor::genReshapeDstShape( + Location loc, PatternRewriter &rewriter, SmallVector &dstShape, + ArrayRef srcShape, ArrayRef staticDstShape, + ArrayRef reassociation) { + // Collapse shape. + if (reassociation.size() < srcShape.size()) { + unsigned start = 0; + for (const auto &map : llvm::enumerate(reassociation)) { + auto dstDim = constantIndex(rewriter, loc, 1); + for (unsigned i = start; i < start + map.value().size(); i++) { + dstDim = rewriter.create(loc, dstDim, srcShape[i]); + } + dstShape.push_back(dstDim); + start = start + map.value().size(); + } + assert(start == srcShape.size()); + return; + } + + // Expand shape. + assert(reassociation.size() == srcShape.size()); + unsigned start = 0; + // Expand the i-th dimension in srcShape. + for (unsigned i = 0, size = srcShape.size(); i < size; i++) { + auto map = reassociation[i]; + auto srcDim = srcShape[i]; + // Iterate through dimensions expanded from the i-th dimension. + for (unsigned j = start; j < start + map.size(); j++) { + // There can be only one dynamic sized dimension among dimensions expanded + // from the i-th dimension in srcShape. For example, if srcDim = 8, then + // the expanded shape could be <2x?x2>, but not <2x?x?>. + if (staticDstShape[j] == ShapedType::kDynamicSize) { + // The expanded dimension has dynamic size. We compute the dimension + // by dividing srcDim by the product of the static dimensions. + int64_t product = 1; + for (unsigned k = start; k < start + map.size(); k++) { + if (staticDstShape[k] != ShapedType::kDynamicSize) { + product *= staticDstShape[k]; + } + } + // Compute the dynamic dimension size. + Value productVal = constantIndex(rewriter, loc, product); + Value dynamicSize = + rewriter.create(loc, srcDim, productVal); + dstShape.push_back(dynamicSize); + } else { + // The expanded dimension is statically known. + dstShape.push_back(constantIndex(rewriter, loc, staticDstShape[j])); + } + } + start = start + map.size(); + } + assert(start == staticDstShape.size()); +} + void mlir::sparse_tensor::translateIndicesArray( OpBuilder &builder, Location loc, ArrayRef reassociation, ValueRange srcIndices, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -492,65 +492,6 @@ constantIndex(rewriter, loc, i)); } -/// Helper method to compute the shape of destination tensor of a reshape -/// operator. This is only used when operands have dynamic shape. The shape of -/// the destination is stored into dstShape. -void genReshapeDstShape(Location loc, ConversionPatternRewriter &rewriter, - SmallVector &dstShape, - ArrayRef srcShape, - ArrayRef staticDstShape, - ArrayRef reassociation) { - // Collapse shape. - if (reassociation.size() < srcShape.size()) { - unsigned start = 0; - for (const auto &map : llvm::enumerate(reassociation)) { - auto dstDim = constantIndex(rewriter, loc, 1); - for (unsigned i = start; i < start + map.value().size(); i++) { - dstDim = rewriter.create(loc, dstDim, srcShape[i]); - } - dstShape.push_back(dstDim); - start = start + map.value().size(); - } - assert(start == srcShape.size()); - return; - } - - // Expand shape. - assert(reassociation.size() == srcShape.size()); - unsigned start = 0; - // Expand the i-th dimension in srcShape. - for (unsigned i = 0, size = srcShape.size(); i < size; i++) { - auto map = reassociation[i]; - auto srcDim = srcShape[i]; - // Iterate through dimensions expanded from the i-th dimension. - for (unsigned j = start; j < start + map.size(); j++) { - // There can be only one dynamic sized dimension among dimensions expanded - // from the i-th dimension in srcShape. For example, if srcDim = 8, then - // the expanded shape could be <2x?x2>, but not <2x?x?>. - if (staticDstShape[j] == ShapedType::kDynamicSize) { - // The expanded dimension has dynamic size. We compute the dimension - // by dividing srcDim by the product of the static dimensions. - int64_t product = 1; - for (unsigned k = start; k < start + map.size(); k++) { - if (staticDstShape[k] != ShapedType::kDynamicSize) { - product *= staticDstShape[k]; - } - } - // Compute the dynamic dimension size. - Value productVal = constantIndex(rewriter, loc, product); - Value dynamicSize = - rewriter.create(loc, srcDim, productVal); - dstShape.push_back(dynamicSize); - } else { - // The expanded dimension is statically known. - dstShape.push_back(constantIndex(rewriter, loc, staticDstShape[j])); - } - } - start = start + map.size(); - } - assert(start == staticDstShape.size()); -} - /// Generate code for a general sparse to sparse reshaping operation. /// Note that unlike dense reshaping (which can be done with a "cheap" /// change of view), sparse reshaping is currently done with actual