diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -385,6 +385,106 @@ } }; +/// Sparse rewriting rule for sparse-to-sparse reshape operator. +struct TensorReshapeRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ReshapeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value srcTensor = op.getSource(); + const auto srcTp = getSparseTensorType(srcTensor); + const auto dstTp = getSparseTensorType(op.getResult()); + + if (!srcTp.hasEncoding() || !dstTp.hasEncoding() || + !dstTp.hasStaticDimShape()) + return failure(); + + SmallVector srcSizes; + sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); + SmallVector dstSizes; + for (Dimension d : dstTp.getDimShape()) + dstSizes.push_back(constantIndex(rewriter, loc, d)); + + Value nnz = rewriter.create(loc, srcTensor); + // Only need an unordered COO buffer if input and output are not sorted + // in the same way. + Type bufferTp = + srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity() + ? dstTp.getRankedTensorType() + : getUnorderedCOOFromType(dstTp); + SmallVector dynSizes; + Value buffer = rewriter + .create(loc, bufferTp, dynSizes, Value(), + nnz, Attribute()) + .getResult(); + + // Convert src coordinates to dst coordinates by first collapsing it to 1D + // and then expand it to the match the rank of the destination tensor. + // Implemented as follows: + // foreach srcCoords %srcTensor + // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank]) + // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank]) + // insert expandedCoords, %buffer + // + // followed by an optional + // %t = sparse_tensor.cast %tmp + // depending on whether the input/output are sorted in the same way. + const auto encSrc = srcTp.getEncoding(); + ForeachOp foreachOp = rewriter.create( + loc, srcTensor, buffer, + [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, + ValueRange reduc) { + const Dimension srcRank = srcTp.getDimRank(); + SmallVector srcDcvs; + srcDcvs.reserve(srcRank); + for (Dimension d = 0; d < srcRank; d++) { + // FIXME: `toStoredDim` is deprecated + Level lvl = toStoredDim(encSrc, d); + srcDcvs.push_back(srcLcvs[lvl]); + } + + Value collapsed_size = constantIndex(builder, loc, 1); + for (Dimension d = 0; d < srcRank; d++) + collapsed_size = + builder.create(loc, collapsed_size, srcSizes[d]); + SmallVector collapsedSizes = {collapsed_size}; + + ReassociationIndices collapse_indices; + for (Dimension i = 0; i < srcRank; i++) + collapse_indices.push_back(i); + SmallVector collapse_reassociation = { + collapse_indices}; + SmallVector collapsedDcvs; + reshapeCvs(builder, loc, collapse_reassociation, srcSizes, srcDcvs, + collapsedSizes, collapsedDcvs); + + ReassociationIndices expand_indices; + for (Dimension i = 0; i < dstTp.getDimRank(); i++) + expand_indices.push_back(i); + SmallVector expand_reassociation = { + expand_indices}; + SmallVector dstDcvs; + reshapeCvs(builder, loc, expand_reassociation, collapsedSizes, + collapsedDcvs, dstSizes, dstDcvs); + + auto t = builder.create(loc, v, reduc.front(), dstDcvs); + builder.create(loc, t); + }); + + Value t = rewriter.create(loc, foreachOp.getResult(0), true); + if (bufferTp != dstTp) { + auto dstRTT = dstTp.getRankedTensorType(); + Value converted = rewriter.create(loc, dstRTT, t).getResult(); + rewriter.create(loc, t); + t = converted; + } + rewriter.replaceOp(op, t); + return success(); + } +}; + /// Sparse rewriting rule for sparse-to-sparse reshape operator. template struct Sparse2SparseReshapeRewriter : public OpRewritePattern { @@ -1169,7 +1269,8 @@ bool enableForeach, bool enableConvert) { patterns.add, - ReshapeRewriter>(patterns.getContext()); + ReshapeRewriter, TensorReshapeRewriter>( + patterns.getContext()); if (enableForeach) patterns.add(patterns.getContext()); // TODO: If RT not enabled, rewrite concatenate ops, etc here. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir @@ -0,0 +1,98 @@ +// DEFINE: %{option} = enable-runtime-library=true +// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option} +// DEFINE: %{run} = mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_c_runner_utils | \ +// DEFINE: FileCheck %s +// +// RUN: %{compile} | %{run} +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{option} = enable-runtime-library=false +// RUN: %{compile} | %{run} +// +// Do the same run, but now with direct IR generation and vectorization. +// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true" +// RUN: %{compile} | %{run} + +// Do the same run, but now with direct IR generation and, if available, VLA +// vectorization. +// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA" +// REDEFINE: %{run} = %lli_host_or_aarch64_cmd \ +// REDEFINE: --entry-function=entry_lli \ +// REDEFINE: --extra-module=%S/Inputs/main_for_lli.ll \ +// REDEFINE: %VLA_ARCH_ATTR_OPTIONS \ +// REDEFINE: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// REDEFINE: FileCheck %s +// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run} + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["compressed", "compressed"] +}> + +#Sparse3dTensor = #sparse_tensor.encoding<{ + dimLevelType = ["compressed", "compressed", "compressed"] +}> + +module { + + func.func @reshape0(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix> { + %shape = arith.constant dense <[ 2, 6 ]> : tensor<2xi32> + %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<2x6xf64, #SparseMatrix> + return %0 : tensor<2x6xf64, #SparseMatrix> + } + + func.func @reshape1(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> { + %shape = arith.constant dense <[ 12 ]> : tensor<1xi32> + %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<1xi32>) -> tensor<12xf64, #SparseVector> + return %0 : tensor<12xf64, #SparseVector> + } + + func.func @reshape2(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor> { + %shape = arith.constant dense <[ 2, 3, 2 ]> : tensor<3xi32> + %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<3xi32>) -> tensor<2x3x2xf64, #Sparse3dTensor> + return %0 : tensor<2x3x2xf64, #Sparse3dTensor> + } + + + func.func @entry() { + %m = arith.constant dense <[ [ 1.1, 0.0, 1.3, 0.0 ], + [ 2.1, 0.0, 2.3, 0.0 ], + [ 3.1, 0.0, 3.3, 0.0 ]]> : tensor<3x4xf64> + %sm = sparse_tensor.convert %m : tensor<3x4xf64> to tensor<3x4xf64, #SparseMatrix> + + %reshaped0 = call @reshape0(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix> + %reshaped1 = call @reshape1(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> + %reshaped2 = call @reshape2(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor> + + %c0 = arith.constant 0 : index + %df = arith.constant -1.0 : f64 + + // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3 + %b0 = sparse_tensor.values %reshaped0: tensor<2x6xf64, #SparseMatrix> to memref + %v0 = vector.transfer_read %b0[%c0], %df: memref, vector<12xf64> + vector.print %v0 : vector<12xf64> + + // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3 + %b1 = sparse_tensor.values %reshaped1: tensor<12xf64, #SparseVector> to memref + %v1 = vector.transfer_read %b1[%c0], %df: memref, vector<12xf64> + vector.print %v1 : vector<12xf64> + + // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3 + %b2 = sparse_tensor.values %reshaped2: tensor<2x3x2xf64, #Sparse3dTensor> to memref + %v2 = vector.transfer_read %b2[%c0], %df: memref, vector<12xf64> + vector.print %v2: vector<12xf64> + + bufferization.dealloc_tensor %sm : tensor<3x4xf64, #SparseMatrix> + bufferization.dealloc_tensor %reshaped0 : tensor<2x6xf64, #SparseMatrix> + bufferization.dealloc_tensor %reshaped1 : tensor<12xf64, #SparseVector> + bufferization.dealloc_tensor %reshaped2 : tensor<2x3x2xf64, #Sparse3dTensor> + + return + } + +} \ No newline at end of file