diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -540,6 +540,7 @@ reshapeOp.getResultType().hasStaticShape() && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) return reshapeSrcOp.src(); + // Reshape of a constant can be replaced with a new constant. if (auto elements = operands.front().dyn_cast_or_null()) { return elements.reshape( reshapeOp.getResult().getType().template cast()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -353,12 +353,126 @@ }; } // namespace +namespace { +/// Pattern to fold pair of reshape ops where the intermediate has unit-dims for +/// example: +/// +/// %0 = linalg.tensor_reshape %arg0 +/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] +/// : tensor<2048xf32> into tensor<1x4x1x512xf32> +/// %1 = linalg.tensor_reshape %0 +/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2, d3) -> (d3)>] +/// : tensor<1x4x1x512xf32> into tensor<4x512xf32> +/// +/// can be replaced with +/// +/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] +/// : tensor<2048xf32> into tensor<4x512xf32> +/// +/// Similarly, +/// +/// %0 = linalg.tensor_reshape %arg0 +/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2, d3) -> (d3)>] +/// : tensor<4x512xf32> into tensor<1x4x1x512xf32> +/// %1 = linalg.tensor_reshape %0 +/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] +/// : tensor<1x4x1x512xf32> into tensor<2048xf32> +/// +/// can be replaced with +/// +/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] +/// : tensor<4x512xf32> into tensor<2048xf32> +struct FoldReshapeOpWithUnitExtent : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + // Check that the source operand is created from a reshape as well. + TensorReshapeOp parentReshapeOp = + reshapeOp.src().getDefiningOp(); + if (!parentReshapeOp) + return failure(); + + RankedTensorType srcType = reshapeOp.getSrcType(), + dstType = reshapeOp.getResultType(), + parentSrcType = parentReshapeOp.getSrcType(); + if (!srcType.hasStaticShape() || !dstType.hasStaticShape() || + !parentSrcType.hasStaticShape() || + srcType.getRank() < dstType.getRank() || + parentSrcType.getRank() == dstType.getRank()) + return failure(); + // Check if the result tensor_reshape after folding the reshapeOp and + // parentReshapeOp are combined. + // If the final tensor_reshape is folding, the parentReshapeOp is + // introducing unit-dims, and the reshapeOp does an actual reshape. + // If the final tensor_reshape op is expanding, the reshapeOp is introducing + // unit-dims, and the parentReshapeOp does an actual reshape. + bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); + auto reassociationMaps = isFoldingPattern + ? reshapeOp.getReassociationMaps() + : parentReshapeOp.getReassociationMaps(); + DenseSet conservedDimensions; + for (auto &map : reassociationMaps) { + if (map.getNumResults() == 1) { + conservedDimensions.insert( + map.getResult(0).cast().getPosition()); + } + } + + // Find positions at which the unit-dims exist. + int64_t nonUnitDimPos = 0; + DenseMap nonUnitSrcDims; + ArrayRef nonUnitShape = + isFoldingPattern ? parentSrcType.getShape() : dstType.getShape(); + for (auto shape : enumerate(srcType.getShape())) { + // Case 1 : It is a conserved dimension. + if (conservedDimensions.count(shape.index())) { + nonUnitSrcDims[shape.index()] = nonUnitDimPos++; + continue; + } + // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim. + if (shape.value() == 1) + continue; + // Case 3 : Dimensions match, treat it as a non-unit src dim. + if (nonUnitDimPos < static_cast(nonUnitShape.size()) && + nonUnitShape[nonUnitDimPos] == shape.value()) { + nonUnitSrcDims[shape.index()] = nonUnitDimPos++; + continue; + } + return failure(); + } + + // Compute reassociation maps for the final operation. Use the reassociation + // maps that is actually doing a reshape (and not just introducing + // unit-dims). From these maps, prune the unit-extent dimensions. + for (AffineMap &map : reassociationMaps) { + SmallVector exprs; + exprs.reserve(nonUnitSrcDims.size()); + for (auto result : map.getResults()) { + unsigned dim = result.cast().getPosition(); + if (nonUnitSrcDims.count(dim)) + exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim])); + } + map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs, + rewriter.getContext()); + } + rewriter.replaceOpWithNewOp( + reshapeOp, dstType, parentReshapeOp.src(), + rewriter.getAffineMapArrayAttr(reassociationMaps)); + return success(); + } +}; +} // namespace + /// Patterns that are used to canonicalize the use of unit-extent dims for /// broadcasting. void mlir::populateLinalgFoldUnitExtentDimsPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert(context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); + patterns.insert(context); } namespace { diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -158,3 +158,85 @@ // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel"] // CHECK-SAME: %[[A]] + +// ----- + +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @fold_reshape +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] +// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> +func @fold_reshape(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] + : tensor<2048xf32> into tensor<1x4x1x512xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d3)>] + : tensor<1x4x1x512xf32> into tensor<4x512xf32> + return %1 : tensor<4x512xf32> +} + +// ----- + +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @fold_reshape +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] +// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> +func @fold_reshape(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d3)>] + : tensor<4x512xf32> into tensor<1x4x1x512xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] + : tensor<1x4x1x512xf32> into tensor<2048xf32> + return %1 : tensor<2048xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: func @fold_reshape +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: tensor<2048x1xf32> into tensor<4x512x1xf32> +func @fold_reshape(%arg0 : tensor<2048x1xf32>) -> tensor<4x512x1xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d4)>] + : tensor<2048x1xf32> into tensor<1x4x1x512x1xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d4)>] + : tensor<1x4x1x512x1xf32> into tensor<4x512x1xf32> + return %1 : tensor<4x512x1xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +// CHECK: func @fold_reshape +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> +func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8)>] + : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d8)>] + : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> + return %1 : tensor<4x512x1x512x4xf32> +}