diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -467,20 +467,20 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, ArrayRef reassociation, TensorType dstTp, TensorType srcTp, Value dstIdx, - Value srcIdx) { + Value srcIdx, SmallVector dstShape, + SmallVector srcShape) { unsigned dstRank = dstTp.getRank(); unsigned srcRank = srcTp.getRank(); unsigned start = 0; unsigned i = 0; bool isExpand = srcRank > dstRank; - ArrayRef shape = isExpand ? srcTp.getShape() : dstTp.getShape(); + SmallVector shape = isExpand ? srcShape : dstShape; // Iterate over reassociation map. for (const auto &map : llvm::enumerate(reassociation)) { // Prepare strides information in dimension slice. - uint64_t linear = 1; + Value linear = constantIndex(rewriter, loc, 1); for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - assert(!ShapedType::isDynamic(shape[j])); - linear *= shape[j]; + linear = rewriter.create(loc, linear, shape[j]); } // Start collapse. Value idx = constantIndex(rewriter, loc, i++); @@ -489,22 +489,17 @@ val = rewriter.create(loc, srcIdx, idx); // Iterate over dimension slice. for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - linear /= shape[j]; - Value stride = constantIndex(rewriter, loc, linear); + linear = rewriter.create(loc, linear, shape[j]); Value jdx = constantIndex(rewriter, loc, j); if (isExpand) { Value old = rewriter.create(loc, srcIdx, jdx); - Value mul = linear == 1 - ? old - : rewriter.create(loc, old, stride); + Value mul = rewriter.create(loc, old, linear); val = val ? rewriter.create(loc, val, mul) : mul; } else { Value old = val; - if (linear != 1) - val = rewriter.create(loc, val, stride); + val = rewriter.create(loc, val, linear); rewriter.create(loc, val, dstIdx, jdx); - if (linear != 1) - val = rewriter.create(loc, old, stride); + val = rewriter.create(loc, old, linear); } } // Finalize expansion. @@ -516,6 +511,57 @@ assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); } +/// Helper method to compute the shape of destination tensor of a reshape +/// operator. This is only used when operands have dynamic shape. +void genReshapeDstShape(Location loc, ConversionPatternRewriter &rewriter, + SmallVector &dstShape, + const SmallVector &srcShape, + ArrayRef staticDstShape, + ArrayRef reassociation) { + // Collapse shape. + if (reassociation.size() < srcShape.size()) { + unsigned start = 0; + for (const auto &map : llvm::enumerate(reassociation)) { + auto dstDim = constantIndex(rewriter, loc, 1); + for (unsigned i = start; i < start + map.value().size(); i++) { + dstDim = rewriter.create(loc, dstDim, srcShape[i]); + } + dstShape.push_back(dstDim); + start = start + map.value().size(); + } + assert(start == srcShape.size()); + return; + } + + // Expand shape. + assert(reassociation.size() == srcShape.size()); + unsigned start = 0; + for (unsigned i = 0; i < srcShape.size(); i++) { + auto map = reassociation[i]; + auto srcDim = srcShape[i]; + int64_t product = 1; + for (unsigned j = start; j < start + map.size(); j++) { + if (staticDstShape[j] == ShapedType::kDynamicSize) { + // Compute the product of the static sizes. + for (unsigned k = start; k < start + map.size(); k++) { + if (staticDstShape[k] != ShapedType::kDynamicSize) { + product *= staticDstShape[k]; + } + } + // Compute the dynamic dimension size. + Value productVal = constantIndex(rewriter, loc, product); + Value dynamicSize = + rewriter.create(loc, srcDim, productVal); + dstShape.push_back(dynamicSize); + } else { + dstShape.push_back(constantIndex(rewriter, loc, staticDstShape[j])); + } + } + start = start + map.size(); + } + assert(start == staticDstShape.size()); +} + /// Generate code for a general sparse to sparse reshaping operation. /// Note that unlike dense reshaping (which can be done with a "cheap" /// change of view), sparse reshaping is currently done with actual @@ -547,17 +593,23 @@ auto noPerm = SparseTensorEncodingAttr::get( op->getContext(), encSrc.getDimLevelType(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - SmallVector sizes; + SmallVector srcSizes; SmallVector params; - sizesFromPtr(rewriter, sizes, loc, noPerm, srcTp, src); - newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes, + sizesFromPtr(rewriter, srcSizes, loc, noPerm, srcTp, src); + newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, srcSizes, src); Value iter = genNewCall(rewriter, loc, params); // Start a new COO for the destination tensor. - sizes.clear(); + SmallVector dstSizes; params.clear(); - sizesFromPtr(rewriter, sizes, loc, encDst, dstTp, src); - newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes); + if (srcTp.hasStaticShape()) { + sizesFromPtr(rewriter, dstSizes, loc, encDst, dstTp, src); + } else { + ArrayRef dstShape = dstTp.getShape(); + genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape, + reassociation); + } + newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, dstSizes); Value coo = genNewCall(rewriter, loc, params); Value dstPerm = params[2]; // Construct a while loop over the iterator. @@ -575,7 +627,8 @@ // not need to store the value in elemPtr, as the value is still there. Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); + translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx, + dstSizes, srcSizes); genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm); rewriter.create(loc); // Final call to construct sparse tensor storage and free temporary resources. diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -24,11 +24,16 @@ // CHECK-CONV: call @getNextF64 // CHECK-CONV: scf.condition(%13) // CHECK-CONV: } do { +// CHECK-CONV: %[[M1:.*]] = arith.muli %[[C1]], %[[C10]] : index +// CHECK-CONV: %[[M2:.*]] = arith.muli %[[M1]], %[[C10]] : index // CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref -// CHECK-CONV: %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index -// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref -// CHECK-CONV: %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index -// CHECK-CONV: memref.store %[[R]], %{{.*}}[%[[C1]]] : memref +// CHECK-CONV: %[[S1:.*]] = arith.divui %[[M2]], %[[C10]] : index +// CHECK-CONV: %[[D1:.*]] = arith.divui %[[X]], %[[S1]] : index +// CHECK-CONV: memref.store %[[D1]], %{{.*}}[%[[C0]]] : memref +// CHECK-CONV: %[[R:.*]] = arith.remui %[[X]], %[[S1]] : index +// CHECK-CONV: %[[S2:.*]] = arith.divui %[[S1]], %[[C10]] : index +// CHECK-CONV: %[[D2:.*]] = arith.divui %[[R]], %[[S2]] : index +// CHECK-CONV: memref.store %[[D2]], %{{.*}}[%[[C1]]] : memref // CHECK-CONV: call @addEltF64 // CHECK-CONV: scf.yield // CHECK-CONV: } @@ -63,10 +68,15 @@ // CHECK-CONV: call @getNextF64 // CHECK-CONV: scf.condition(%13) // CHECK-CONV: } do { +// CHECK-CONV: %[[M1:.*]] = arith.muli %[[C1]], %[[C10]] : index +// CHECK-CONV: %[[M2:.*]] = arith.muli %[[M1]], %[[C10]] : index +// CHECK-CONV: %[[D1:.*]] = arith.divui %[[M2]], %[[C10]] : index // CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref -// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index +// CHECK-CONV: %[[M3:.*]] = arith.muli %[[X]], %[[D1]] : index +// CHECK-CONV: %[[D2:.*]] = arith.divui %[[D1]], %[[C10]] : index // CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref -// CHECK-CONV: %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index +// CHECK-CONV: %[[M4:.*]] = arith.muli %[[Y]], %[[D2]] : index +// CHECK-CONV: %[[A:.*]] = arith.addi %[[M3]], %[[M4]] : index // CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref // CHECK-CONV: call @addEltF64 // CHECK-CONV: scf.yield diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir @@ -120,6 +120,11 @@ return %0 : tensor } + func.func @expand_sparse2sparse_dyn(%arg0: tensor) -> tensor { + %0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor into tensor + return %0 : tensor + } + func.func @collapse_dense_dyn(%arg0: tensor) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor return %0 : tensor @@ -135,6 +140,11 @@ return %0 : tensor } + func.func @collapse_sparse2sparse_dyn(%arg0: tensor) -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + return %0 : tensor + } + // // Main driver. // @@ -177,6 +187,7 @@ %expand8 = call @expand_dense_dyn(%dm) : (tensor) -> tensor %expand9 = call @expand_from_sparse_dyn(%sdm) : (tensor) -> tensor %expand10 = call @expand_to_sparse_dyn(%dm) : (tensor) -> tensor + %expand11 = call @expand_sparse2sparse_dyn(%sdm) : (tensor) -> tensor %collapse0 = call @collapse_dense(%m) : (tensor<3x4xf64>) -> tensor<12xf64> %collapse1 = call @collapse_from_sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64> @@ -189,6 +200,7 @@ %collapse8 = call @collapse_dense_dyn(%dn) : (tensor) -> tensor %collapse9 = call @collapse_from_sparse_dyn(%sdn) : (tensor) -> tensor %collapse10 = call @collapse_to_sparse_dyn(%dn) : (tensor) -> tensor + %collapse11 = call @collapse_sparse2sparse_dyn(%sdn) : (tensor) -> tensor // // Verify results of expand @@ -204,6 +216,7 @@ // CHECK-NEXT: ( ( ( 1.1, 1.2 ), ( 1.3, 1.4 ) ), ( ( 2.1, 2.2 ), ( 2.3, 2.4 ) ), ( ( 3.1, 3.2 ), ( 3.3, 3.4 ) ) ) // CHECK-NEXT: ( ( ( 1.1, 1.2 ), ( 1.3, 1.4 ) ), ( ( 2.1, 2.2 ), ( 2.3, 2.4 ) ), ( ( 3.1, 3.2 ), ( 3.3, 3.4 ) ) ) // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 ) // %m0 = vector.transfer_read %expand0[%c0, %c0], %df: tensor<3x4xf64>, vector<3x4xf64> @@ -235,6 +248,10 @@ %a10 = sparse_tensor.values %expand10 : tensor to memref %m10 = vector.transfer_read %a10[%c0], %df: memref, vector<16xf64> vector.print %m10 : vector<16xf64> + %a11 = sparse_tensor.values %expand11 : tensor to memref + %m11 = vector.transfer_read %a11[%c0], %df: memref, vector<16xf64> + vector.print %m11 : vector<16xf64> + // // Verify results of collapse @@ -250,6 +267,7 @@ // CHECK-NEXT: ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ), ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ), ( 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 ), ( 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 ), ( 41, 42, 43, 44, 45, 26, 47, 48, 49, 50 ), ( 51, 52, 53, 54, 55, 56, 57, 58, 59, 60 ) ) // CHECK-NEXT: ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ), ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ), ( 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 ), ( 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 ), ( 41, 42, 43, 44, 45, 26, 47, 48, 49, 50 ), ( 51, 52, 53, 54, 55, 56, 57, 58, 59, 60 ) ) // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 26, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 26, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, -1, -1, -1, -1 ) // %v0 = vector.transfer_read %collapse0[%c0], %df: tensor<12xf64>, vector<12xf64> @@ -281,6 +299,10 @@ %b10 = sparse_tensor.values %collapse10 : tensor to memref %v10 = vector.transfer_read %b10[%c0], %df: memref, vector<64xf64> vector.print %v10 : vector<64xf64> + %b11 = sparse_tensor.values %collapse11 : tensor to memref + %v11 = vector.transfer_read %b11[%c0], %df: memref, vector<64xf64> + vector.print %v11 : vector<64xf64> + // Release sparse resources. bufferization.dealloc_tensor %sv : tensor<12xf64, #SparseVector> @@ -293,11 +315,13 @@ bufferization.dealloc_tensor %expand6 : tensor<3x2x2xf64, #Sparse3dTensor> bufferization.dealloc_tensor %expand7 : tensor<3x2x2xf64, #Sparse3dTensor> bufferization.dealloc_tensor %expand10 : tensor + bufferization.dealloc_tensor %expand11 : tensor bufferization.dealloc_tensor %collapse2 : tensor<12xf64, #SparseVector> bufferization.dealloc_tensor %collapse3 : tensor<12xf64, #SparseVector> bufferization.dealloc_tensor %collapse6 : tensor<6x10xf64, #SparseMatrix> bufferization.dealloc_tensor %collapse7 : tensor<6x10xf64, #SparseMatrix> bufferization.dealloc_tensor %collapse10 : tensor + bufferization.dealloc_tensor %collapse11 : tensor // Release dense resources. bufferization.dealloc_tensor %expand1 : tensor<3x4xf64>