diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h @@ -204,6 +204,66 @@ SmallVector tiledSizes; }; +/// Tries to simplify a `tensor.collapse_shape` operation by inserting a single +/// rank-reducing `tensor.extract_slice` operation. The `extract_slice` op will +/// either take the place of the source, allowing for a new, simpler +/// `collapse_shape` op to replace `op`, or the `collapse_shape` op will be +/// completely replaced by the `extract_slice` result. Either way, `op` is +/// replaced and new new op is returned. +/// +/// ### Example: +/// ``` +/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]] +/// : tensor to tensor +/// ``` +/// can be transformed to +/// +/// ``` +/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0] +/// [0, %dim1, 30, 30] +/// [1, 1, 1 1] +/// : tensor to tensor +/// %result = tensor.collapse_shape %tmp [[0], [1, 2]] +/// : tensor to tensor +/// ``` +/// +/// ### Example: +/// +/// ``` +/// %result = tensor.collapse_shape %1 [[0, 1], [2]] +/// : tensor to tensor +/// ``` +/// can be transformed to +/// ``` +/// %result = tensor.extract_slice %1 [0, 0, 0] +/// [%dim2, 1, 30] +/// [1, 1, 1] +/// : tensor to tensor +/// ``` +/// +/// ### Unsupported cases: +/// +/// This transform doesn't yet support reducing the rank of the reassociation +/// indices, which would require inserting a `tensor.expand_shape` op similar to +/// the following example: +/// ``` +/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]] +/// : tensor<1x1x30x10xf32> to tensor<1x300xf32> +/// ``` +/// can be transformed to +/// ``` +/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0] +/// [0, 1, 30, 30] +/// [1, 1, 1 1] +/// : tensor<1x1x30x10xf32> to tensor<30x10xf32> +/// %result0 = tensor.collapse_shape %tmp [[0, 1]] +/// : tensor<30x10xf32> to tensor<300xf32> +/// %result1 = tensor.expand_shape %tmp [[0, 1], [2]] :... tensor<1x300xf32> +/// ``` +/// +FailureOr +simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op, + RewriterBase &rewriter); } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -460,6 +460,58 @@ llvm::SmallBitVector linearizedDimensions; llvm::SmallBitVector slicedDimensions; }; + +/// Parameters required to simplify a collapsing reshape op with a rank-reducing +/// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`. +struct CollapseShapeRankReducingSliceSimplificationInfo { + /// The shape of the output of the rank-reducing slice. + RankedTensorType sliceResultType; + /// The reassociation indices for the new collapse shape op, if required. If + /// `None`, the slice should replace the collapse shape op. + Optional> newReassociationIndices; +}; + +/// A collapsing reshape operation can sometimes be simplified or eliminated by +/// inserting a single rank-reducing slice operation between it and the source +/// tensor. The slice op will either take the place of the source, allowing for +/// a new, simpler reshape op to replace the original, or the reshape op will be +/// completely replaced by the slice result. +/// +/// This function returns the parameters required to implement this pattern. If +/// the pattern is not applicable, then failure is returned. +/// +/// ### Example: +/// ``` +/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]] +/// : tensor to tensor +/// ``` +/// can be transformed to +/// ``` +/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0] +/// [0, %dim1, 30, 30] +/// [1, 1, 1 1] +/// : tensor to tensor +/// %result = tensor.collapse_shape %tmp [[0], [1, 2]] +/// : tensor to tensor +/// ``` +/// +/// ### Example: +/// ``` +/// %result = tensor.collapse_shape %1 [[0, 1], [2]] +/// : tensor to tensor +/// ``` +/// can be transformed to +/// ``` +/// %result = tensor.extract_slice %1 [0, 0, 0] +/// [%dim2, 1, 30] +/// [1, 1, 1] +/// : tensor to tensor +/// ``` +FailureOr +getSimplifyCollapseShapeWithRankReducingSliceInfo( + RankedTensorType sourceType, + ArrayRef reassociationIndices); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -26,8 +26,8 @@ using namespace mlir::tensor; /// Get the dimension size of a value of RankedTensor type at the -OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor, - int64_t dimIdx) { +static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, + Value rankedTensor, int64_t dimIdx) { RankedTensorType tensorType = rankedTensor.getType().cast(); if (!tensorType.isDynamicDim(dimIdx)) { return b.getIndexAttr(tensorType.getDimSize(dimIdx)); @@ -103,6 +103,11 @@ tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef sliceParams) { + // Don't perform this pattern if the collapse op can be simplified by + // a rank-reducing extract slice. + if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo( + op.getSrcType(), op.getReassociationIndices()))) + return failure(); // Materialize the output shape of the collapse_shape operation. This will // create IR describing the output shape in terms of the input shape. @@ -125,9 +130,6 @@ auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc()); - SmallVector srcShape = - getShapeDimSizes(b, op->getLoc(), op.getSrc()); - SmallVector tileSizes; for (unsigned i = 0; i < sliceParams.size(); i++) { if (slicedDimensions[i] && linearizedDimensions[i]) @@ -178,3 +180,36 @@ loc, subTileResult, reassociationIndices); return std::make_pair(collapsedResult, insertParams); } + +FailureOr +tensor::simplifyCollapseShapeWithRankReducingExtractSlice( + tensor::CollapseShapeOp op, RewriterBase &rewriter) { + SmallVector reassociationIndices = + op.getReassociationIndices(); + RankedTensorType sourceType = op.getSrcType(); + FailureOr info = + getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType, + reassociationIndices); + if (failed(info)) + return failure(); + + // Create the rank-reducing extract slice op. + auto zero = rewriter.getIndexAttr(0); + auto one = rewriter.getIndexAttr(1); + SmallVector offsets(sourceType.getRank(), zero); + SmallVector sizes = + getShapeDimSizes(rewriter, op.getLoc(), op.getSrc()); + SmallVector strides(sourceType.getRank(), one); + auto sliceOp = rewriter.create( + op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides); + + if (!info->newReassociationIndices.has_value()) { + rewriter.replaceOp(op, sliceOp.getResult()); + return sliceOp.getOperation(); + } + + return rewriter + .replaceOpWithNewOp( + op, sliceOp.getResult(), info->newReassociationIndices.value()) + .getOperation(); +} diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -352,3 +352,99 @@ } return insertParams; } + +/// Returns the index of the only non-unit dimension among `indices` of `shape`, +/// if such a dimension exists and `indices` has more than one element. +/// Otherwise, return none. +static Optional getUniqueNonUnitDim(ArrayRef indices, + ArrayRef shape) { + // Return false if more than one of the dimensions in this group are not 1. + Optional dimIndex = None; + if (indices.size() < 2) + return None; + for (int64_t idx : indices) { + if (shape[idx] != 1) { + if (dimIndex != None) + return None; + dimIndex = idx; + } + } + return dimIndex; +} + +// For each segment in the reassociation indices, check whether we can +// simplify that segment with a rank-reducing extract slice. We can do this if +// all but (exactly) one of the corresponding source dims is 1. +static SmallVector> getCollapseShapeTrivialSegments( + RankedTensorType sourceType, + ArrayRef reassociationIndices) { + SmallVector> trivialSegments; + for (const auto &indices : reassociationIndices) + trivialSegments.push_back( + getUniqueNonUnitDim(indices, sourceType.getShape())); + return trivialSegments; +} + +/// Returns true if any of the segments of the reassociation indices for a +/// collapsing reshape can be simplified using a rank-reducing slice. +static FailureOr>> +canCollapseShapeBeSimplifiedByRankReducingSlice( + RankedTensorType sourceType, + ArrayRef reassociationIndices) { + SmallVector> trivialSegments = + getCollapseShapeTrivialSegments(sourceType, reassociationIndices); + if (!llvm::any_of(trivialSegments, [](const Optional &idx) { + return idx.has_value(); + })) + return failure(); + return trivialSegments; +} + +FailureOr +mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo( + RankedTensorType sourceType, + ArrayRef reassociationIndices) { + FailureOr>> trivialSegments = + canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType, + reassociationIndices); + if (failed(trivialSegments)) + return failure(); + + // Create the expected result shape of the rank-reducing slice. + SmallVector sliceShape; + for (const auto &[nonUnitDim, indices] : + llvm::zip(*trivialSegments, reassociationIndices)) { + if (nonUnitDim) { + sliceShape.push_back(sourceType.getDimSize(nonUnitDim.value())); + continue; + } + llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) { + return sourceType.getDimSize(idx); + })); + } + auto sliceType = + RankedTensorType::get(sliceShape, sourceType.getElementType()); + + // If the rank-reducing slice simplified every segment, then we are done. + if (sliceShape.size() == reassociationIndices.size()) + return CollapseShapeRankReducingSliceSimplificationInfo{sliceType, None}; + + // Otherwise, we need to create a new collapse_shape op for the segments that + // weren't covered by the slice. By design, the new reassociation indices has + // the same number of groups as the old reassociation indices. + SmallVector newReassociationIndices; + SmallVector reassociation; + int64_t groupIdx = 0; + for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) { + reassociation.push_back(dimIdx); + if ((*trivialSegments)[groupIdx] || + reassociation.size() == reassociationIndices[groupIdx].size()) { + newReassociationIndices.emplace_back(); + std::swap(newReassociationIndices.back(), reassociation); + groupIdx++; + } + } + + return CollapseShapeRankReducingSliceSimplificationInfo{ + sliceType, newReassociationIndices}; +} diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir --- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -177,3 +177,65 @@ // CHECK: return %[[res]] return %slice : tensor<330x?xf32> } + +// ----- + +// The below tests verify that a dimension which is the result of collapsing at +// most one non-unit dim is handled properly. + +// CHECK: @collapse_and_slice_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index +func.func @collapse_and_slice_unit_dim(%input: tensor<1x11x100xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<1x11x100xf32> into tensor<11x100xf32> + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<11x100xf32> to tensor + // CHECK-NOT: scf.for + // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0] [1, 11, 100] [1, 1, 1] + // CHECK-SAME: tensor<1x11x100xf32> to tensor<11x100xf32> + // CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1] + // CHECK-SAME: tensor<11x100xf32> to tensor + return %slice : tensor +} + +// CHECK: @collapse_and_slice_multiple_unit_dim_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index +func.func @collapse_and_slice_multiple_unit_dim_dynamic(%input: tensor<1x?x1x100xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x?x1x100xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor to tensor + // CHECK-NOT: scf.for + // CHECK: %[[c1:.+]] = arith.constant 1 : index + // CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] : + // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0] [1, %[[dim]], 1, 100] [1, 1, 1, 1] + // CHECK-SAME: tensor<1x?x1x100xf32> to tensor + // CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1] + // CHECK-SAME: tensor to tensor + return %slice : tensor +} + +// CHECK: @collapse_and_slice_multiple_unit_dim_mixed(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index +func.func @collapse_and_slice_multiple_unit_dim_mixed(%input: tensor<1x?x1x100x10xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<1x?x1x100x10xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, %offt] [%size, %size] [1, 1] : tensor to tensor + // CHECK-DAG: %[[c0]] = arith.constant 0 : index + // CHECK-DAG: %[[c1]] = arith.constant 1 : index + // CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] + // CHECK: %[[rank_reduced:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0, 0] [1, %[[dim]], 1, 100, 10] [1, 1, 1, 1, 1] + // CHECK: %[[empty:.+]] = tensor.empty + // CHECK: %[[result:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[arg2]] step %[[c1]] iter_args(%[[ia:.+]] = %[[empty]]) + // CHECK: %[[idx:.+]] = affine.apply + // CHECK: %[[multi_index:.+]] = affine.delinearize_index %[[idx]] into + // CHECK: %[[collapsed:.+]] = tensor.collapse_shape + // CHECK: %[[updated:.+]] = tensor.insert_slice + // CHECK: scf.yield %[[updated]] + // CHECK: return %[[result]] + return %slice : tensor +} + +// Edge case where all collapsed dims are unit dims. This pattern can't eliminate the collapse shape, +// that should be handled by `linalg-fold-unit-extent-dims`. + +// CHECK: @collapse_and_slice_multiple_all_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>) +func.func @collapse_and_slice_multiple_all_unit_dim(%input: tensor<1x1x1x100xf32>) -> tensor<1x100xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32> + %slice = tensor.extract_slice %collapsed [0, 0] [1, 100] [1, 1] : tensor<1x100xf32> to tensor<1x100xf32> + return %slice : tensor<1x100xf32> + // CHECK: %[[collapse:.+]] = tensor.collapse_shape %[[arg0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32> + // CHECK: return %[[collapse]] +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -128,7 +128,22 @@ return rewriter.notifyMatchFailure( op, "producer is not a tensor.collapse_shape op"); - // Materialize the output shape values of the slice operation.a + // Try to simplify the collapse shape using a rank-reducing slice, if + // possible. + FailureOr simplifiedCollapseShapeResult = + tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp, + rewriter); + if (succeeded(simplifiedCollapseShapeResult)) { + auto newCollapseOp = + dyn_cast(*simplifiedCollapseShapeResult); + // The collapse shape op might have been simplified away, so we can just + // return. + if (!newCollapseOp) + return success(); + collapseOp = newCollapseOp; + } + + // Materialize the output shape values of the slice operation. ReifiedRankedShapedTypeDims reifiedShapes; if (failed(op.reifyResultShapes(rewriter, reifiedShapes))) return rewriter.notifyMatchFailure(op, "failed to reify result shapes");