diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -478,20 +478,20 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, ArrayRef reassociation, TensorType dstTp, TensorType srcTp, Value dstIdx, - Value srcIdx) { + Value srcIdx, ArrayRef dstShape, + ArrayRef srcShape) { unsigned dstRank = dstTp.getRank(); unsigned srcRank = srcTp.getRank(); unsigned start = 0; unsigned i = 0; bool isExpand = srcRank > dstRank; - ArrayRef shape = isExpand ? srcTp.getShape() : dstTp.getShape(); + ArrayRef shape = isExpand ? srcShape : dstShape; // Iterate over reassociation map. for (const auto &map : llvm::enumerate(reassociation)) { // Prepare strides information in dimension slice. - uint64_t linear = 1; + Value linear = constantIndex(rewriter, loc, 1); for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - assert(!ShapedType::isDynamic(shape[j])); - linear *= shape[j]; + linear = rewriter.create(loc, linear, shape[j]); } // Start collapse. Value idx = constantIndex(rewriter, loc, i++); @@ -500,22 +500,17 @@ val = rewriter.create(loc, srcIdx, idx); // Iterate over dimension slice. for (unsigned j = start, end = start + map.value().size(); j < end; j++) { - linear /= shape[j]; - Value stride = constantIndex(rewriter, loc, linear); + linear = rewriter.create(loc, linear, shape[j]); Value jdx = constantIndex(rewriter, loc, j); if (isExpand) { Value old = rewriter.create(loc, srcIdx, jdx); - Value mul = linear == 1 - ? old - : rewriter.create(loc, old, stride); + Value mul = rewriter.create(loc, old, linear); val = val ? rewriter.create(loc, val, mul) : mul; } else { Value old = val; - if (linear != 1) - val = rewriter.create(loc, val, stride); + val = rewriter.create(loc, val, linear); rewriter.create(loc, val, dstIdx, jdx); - if (linear != 1) - val = rewriter.create(loc, old, stride); + val = rewriter.create(loc, old, linear); } } // Finalize expansion. @@ -527,6 +522,65 @@ assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); } +/// Helper method to compute the shape of destination tensor of a reshape +/// operator. This is only used when operands have dynamic shape. The shape of +/// the destination is stored into dstShape. +void genReshapeDstShape(Location loc, ConversionPatternRewriter &rewriter, + SmallVector &dstShape, + ArrayRef srcShape, + ArrayRef staticDstShape, + ArrayRef reassociation) { + // Collapse shape. + if (reassociation.size() < srcShape.size()) { + unsigned start = 0; + for (const auto &map : llvm::enumerate(reassociation)) { + auto dstDim = constantIndex(rewriter, loc, 1); + for (unsigned i = start; i < start + map.value().size(); i++) { + dstDim = rewriter.create(loc, dstDim, srcShape[i]); + } + dstShape.push_back(dstDim); + start = start + map.value().size(); + } + assert(start == srcShape.size()); + return; + } + + // Expand shape. + assert(reassociation.size() == srcShape.size()); + unsigned start = 0; + // Expand the i-th dimension in srcShape. + for (unsigned i = 0, size = srcShape.size(); i < size; i++) { + auto map = reassociation[i]; + auto srcDim = srcShape[i]; + // Iterate through dimensions expanded from the i-th dimension. + for (unsigned j = start; j < start + map.size(); j++) { + // There can be only one dynamic sized dimension among dimensions expanded + // from the i-th dimension in srcShape. For example, if srcDim = 8, then + // the expanded shape could be <2x?x2>, but not <2x?x?>. + if (staticDstShape[j] == ShapedType::kDynamicSize) { + // The expanded dimension has dynamic size. We compute the dimension + // by dividing srcDim by the product of the static dimensions. + int64_t product = 1; + for (unsigned k = start; k < start + map.size(); k++) { + if (staticDstShape[k] != ShapedType::kDynamicSize) { + product *= staticDstShape[k]; + } + } + // Compute the dynamic dimension size. + Value productVal = constantIndex(rewriter, loc, product); + Value dynamicSize = + rewriter.create(loc, srcDim, productVal); + dstShape.push_back(dynamicSize); + } else { + // The expanded dimension is statically known. + dstShape.push_back(constantIndex(rewriter, loc, staticDstShape[j])); + } + } + start = start + map.size(); + } + assert(start == staticDstShape.size()); +} + /// Generate code for a general sparse to sparse reshaping operation. /// Note that unlike dense reshaping (which can be done with a "cheap" /// change of view), sparse reshaping is currently done with actual @@ -562,19 +616,23 @@ auto noPerm = SparseTensorEncodingAttr::get( op->getContext(), encSrc.getDimLevelType(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - SmallVector sizes; + SmallVector srcSizes; SmallVector params; - sizesFromPtr(rewriter, sizes, loc, encSrc, srcTp, adaptor.getSrc()); - newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes, + sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc()); + newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, srcSizes, adaptor.getSrc()); Value iter = genNewCall(rewriter, loc, params); // Start a new COO for the destination tensor. - sizes.clear(); + SmallVector dstSizes; params.clear(); - // Fills sizes array using the sizes from destination type. - assert(dstTp.hasStaticShape()); - sizesFromType(rewriter, sizes, loc, dstTp); - newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes); + if (dstTp.hasStaticShape()) { + sizesFromType(rewriter, dstSizes, loc, dstTp); + } else { + ArrayRef dstShape = dstTp.getShape(); + genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape, + op.getReassociationIndices()); + } + newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, dstSizes); Value coo = genNewCall(rewriter, loc, params); Value dstPerm = params[2]; // Construct a while loop over the iterator. @@ -593,7 +651,7 @@ Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp, - dstIdx, srcIdx); + dstIdx, srcIdx, dstSizes, srcSizes); genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm); rewriter.create(loc); // Final call to construct sparse tensor storage and free temporary resources. diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND -// RUN: mlir-opt %s --sparse-tensor-conversion --cse | FileCheck %s --check-prefix=CHECK-CONV +// RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> @@ -22,13 +22,13 @@ // CHECK-CONV-DAG: call @newSparseTensor // CHECK-CONV: scf.while : () -> () { // CHECK-CONV: call @getNextF64 -// CHECK-CONV: scf.condition(%13) +// CHECK-CONV: scf.condition(%21) // CHECK-CONV: } do { -// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref +// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex> // CHECK-CONV: %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index -// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref +// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex> // CHECK-CONV: %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index -// CHECK-CONV: memref.store %[[R]], %{{.*}}[%[[C1]]] : memref +// CHECK-CONV: memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<2xindex> // CHECK-CONV: call @addEltF64 // CHECK-CONV: scf.yield // CHECK-CONV: } @@ -61,13 +61,13 @@ // CHECK-CONV-DAG: call @newSparseTensor // CHECK-CONV: scf.while : () -> () { // CHECK-CONV: call @getNextF64 -// CHECK-CONV: scf.condition(%13) +// CHECK-CONV: scf.condition(%21) // CHECK-CONV: } do { -// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref +// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex> // CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index -// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref +// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex> // CHECK-CONV: %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index -// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref +// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex> // CHECK-CONV: call @addEltF64 // CHECK-CONV: scf.yield // CHECK-CONV: } @@ -81,3 +81,89 @@ tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector> return %0 : tensor<100xf64, #SparseVector> } + +// roundtrip: +// +// CHECK-ROUND-LABEL: func.func @dynamic_sparse_expand( +// CHECK-ROUND-SAME: %[[A:.*]]: tensor>) -> tensor> +// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor> into tensor> +// CHECK-ROUND: return %[[E]] : tensor> +// +// conversion: +// +// CHECK-CONV-LABEL: func.func @dynamic_sparse_expand( +// CHECK-CONV-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-CONV-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-CONV-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-CONV-DAG: %[[D1:.*]] = arith.divui %{{.*}}, %[[C10]] : index +// CHECK-CONV-DAG: call @newSparseTensor +// CHECK-CONV-DAG: call @newSparseTensor +// CHECK-CONV: scf.while : () -> () { +// CHECK-CONV: call @getNextF64 +// CHECK-CONV: scf.condition(%23) +// CHECK-CONV: } do { +// CHECK-CONV: %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index +// CHECK-CONV: %[[L:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex> +// CHECK-CONV: %[[D2:.*]] = arith.divui %[[M]], %[[D1]] : index +// CHECK-CONV: %[[D3:.*]] = arith.divui %[[L]], %[[D2]] : index +// CHECK-CONV: memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex> +// CHECK-CONV: %[[R:.*]] = arith.remui %[[L]], %[[D2]] : index +// CHECK-CONV: %[[D4:.*]] = arith.divui %[[D2]], %[[C10]] : index +// CHECK-CONV: %[[D5:.*]] = arith.divui %[[R]], %[[D4]] : index +// CHECK-CONV: memref.store %[[D5]], %{{.*}}[%[[C1]]] : memref<2xindex> +// CHECK-CONV: call @addEltF64 +// CHECK-CONV: scf.yield +// CHECK-CONV: } +// CHECK-CONV: %[[N:.*]] = call @newSparseTensor +// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: return %[[N]] : !llvm.ptr +// +func.func @dynamic_sparse_expand(%arg0: tensor) -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1]] : + tensor into tensor + return %0 : tensor +} + +// +// roundtrip: +// +// CHECK-ROUND-LABEL: func.func @dynamic_sparse_collapse( +// CHECK-ROUND-SAME: %[[A:.*]]: tensor<10x?xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor> +// CHECK-ROUND: %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x?xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor> +// CHECK-ROUND: return %[[C]] : tensor> +// +// conversion: +// +// CHECK-CONV-LABEL: func.func @dynamic_sparse_collapse( +// CHECK-CONV-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-CONV-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-CONV-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-CONV-DAG: %[[M1:.*]] = arith.muli %{{.*}}, %[[C10]] : index +// CHECK-CONV-DAG: call @newSparseTensor +// CHECK-CONV-DAG: call @newSparseTensor +// CHECK-CONV: scf.while : () -> () { +// CHECK-CONV: call @getNextF64 +// CHECK-CONV: scf.condition(%23) +// CHECK-CONV: } do { +// CHECK-CONV: %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index +// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex> +// CHECK-CONV: %[[M2:.*]] = arith.muli %[[X]], %[[D1]] : index +// CHECK-CONV: %[[D2:.*]] = arith.divui %[[D1]], %{{.*}} : index +// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex> +// CHECK-CONV: %[[M3:.*]] = arith.muli %[[Y]], %[[D2]] : index +// CHECK-CONV: %[[A:.*]] = arith.addi %[[M2]], %[[M3]] : index +// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex> +// CHECK-CONV: call @addEltF64 +// CHECK-CONV: scf.yield +// CHECK-CONV: } +// CHECK-CONV: %[[N:.*]] = call @newSparseTensor +// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: return %[[N]] : !llvm.ptr +// +func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : + tensor<10x?xf64, #SparseMatrix> into tensor + return %0 : tensor +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir @@ -120,6 +120,11 @@ return %0 : tensor } + func.func @expand_sparse2sparse_dyn(%arg0: tensor) -> tensor { + %0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor into tensor + return %0 : tensor + } + func.func @collapse_dense_dyn(%arg0: tensor) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor return %0 : tensor @@ -135,6 +140,11 @@ return %0 : tensor } + func.func @collapse_sparse2sparse_dyn(%arg0: tensor) -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + return %0 : tensor + } + // // Main driver. // @@ -177,6 +187,7 @@ %expand8 = call @expand_dense_dyn(%dm) : (tensor) -> tensor %expand9 = call @expand_from_sparse_dyn(%sdm) : (tensor) -> tensor %expand10 = call @expand_to_sparse_dyn(%dm) : (tensor) -> tensor + %expand11 = call @expand_sparse2sparse_dyn(%sdm) : (tensor) -> tensor %collapse0 = call @collapse_dense(%m) : (tensor<3x4xf64>) -> tensor<12xf64> %collapse1 = call @collapse_from_sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64> @@ -189,6 +200,7 @@ %collapse8 = call @collapse_dense_dyn(%dn) : (tensor) -> tensor %collapse9 = call @collapse_from_sparse_dyn(%sdn) : (tensor) -> tensor %collapse10 = call @collapse_to_sparse_dyn(%dn) : (tensor) -> tensor + %collapse11 = call @collapse_sparse2sparse_dyn(%sdn) : (tensor) -> tensor // // Verify results of expand @@ -204,6 +216,7 @@ // CHECK-NEXT: ( ( ( 1.1, 1.2 ), ( 1.3, 1.4 ) ), ( ( 2.1, 2.2 ), ( 2.3, 2.4 ) ), ( ( 3.1, 3.2 ), ( 3.3, 3.4 ) ) ) // CHECK-NEXT: ( ( ( 1.1, 1.2 ), ( 1.3, 1.4 ) ), ( ( 2.1, 2.2 ), ( 2.3, 2.4 ) ), ( ( 3.1, 3.2 ), ( 3.3, 3.4 ) ) ) // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 ) // %m0 = vector.transfer_read %expand0[%c0, %c0], %df: tensor<3x4xf64>, vector<3x4xf64> @@ -235,6 +248,10 @@ %a10 = sparse_tensor.values %expand10 : tensor to memref %m10 = vector.transfer_read %a10[%c0], %df: memref, vector<16xf64> vector.print %m10 : vector<16xf64> + %a11 = sparse_tensor.values %expand11 : tensor to memref + %m11 = vector.transfer_read %a11[%c0], %df: memref, vector<16xf64> + vector.print %m11 : vector<16xf64> + // // Verify results of collapse @@ -250,6 +267,7 @@ // CHECK-NEXT: ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ), ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ), ( 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 ), ( 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 ), ( 41, 42, 43, 44, 45, 26, 47, 48, 49, 50 ), ( 51, 52, 53, 54, 55, 56, 57, 58, 59, 60 ) ) // CHECK-NEXT: ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ), ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ), ( 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 ), ( 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 ), ( 41, 42, 43, 44, 45, 26, 47, 48, 49, 50 ), ( 51, 52, 53, 54, 55, 56, 57, 58, 59, 60 ) ) // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 26, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 26, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, -1, -1, -1, -1 ) // %v0 = vector.transfer_read %collapse0[%c0], %df: tensor<12xf64>, vector<12xf64> @@ -281,6 +299,10 @@ %b10 = sparse_tensor.values %collapse10 : tensor to memref %v10 = vector.transfer_read %b10[%c0], %df: memref, vector<64xf64> vector.print %v10 : vector<64xf64> + %b11 = sparse_tensor.values %collapse11 : tensor to memref + %v11 = vector.transfer_read %b11[%c0], %df: memref, vector<64xf64> + vector.print %v11 : vector<64xf64> + // Release sparse resources. bufferization.dealloc_tensor %sv : tensor<12xf64, #SparseVector> @@ -293,11 +315,13 @@ bufferization.dealloc_tensor %expand6 : tensor<3x2x2xf64, #Sparse3dTensor> bufferization.dealloc_tensor %expand7 : tensor<3x2x2xf64, #Sparse3dTensor> bufferization.dealloc_tensor %expand10 : tensor + bufferization.dealloc_tensor %expand11 : tensor bufferization.dealloc_tensor %collapse2 : tensor<12xf64, #SparseVector> bufferization.dealloc_tensor %collapse3 : tensor<12xf64, #SparseVector> bufferization.dealloc_tensor %collapse6 : tensor<6x10xf64, #SparseMatrix> bufferization.dealloc_tensor %collapse7 : tensor<6x10xf64, #SparseMatrix> bufferization.dealloc_tensor %collapse10 : tensor + bufferization.dealloc_tensor %collapse11 : tensor // Release dense resources. bufferization.dealloc_tensor %expand1 : tensor<3x4xf64>