diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -159,7 +159,7 @@ /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. template -struct CollapseReshapeOps : public OpRewritePattern { +struct ComposeReassociativeReshapeOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override { @@ -180,21 +180,101 @@ } }; -/// Pattern to collapse producer/consumer reshape ops that are both collapsing -/// dimensions or are both expanding dimensions. -template -struct CollapseMixedReshapeOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, +/// Pattern to compose +/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`. +/// In that case both `srcType` and `resultType` can be expressed as a function +/// of `intermediateType`. +/// In order to demonstrate the approach, let's assume that `rank(srcType) > +/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`. +/// In that case, we can iterate over every set of indices in `reassociation_2` +/// and try to find ids of sets of indices in `reassociation_1` that cover it +/// completely. +/// +/// Example: +/// +/// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] +/// : tensor into tensor +/// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] +/// : tensor into tensor +/// +/// can be canonicalized into +/// +/// %0 = tensor.collapse_shape %arg [[0, 1], [2]] +/// : tensor into tensor +/// +/// because [0] and [1] from `expand_shape` reassociation cover completely +/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of +/// indices, then we fail. +// +/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1` +/// `reassociation_2` and produce `expand_shape`. +template +struct ComposeCollapseOfExpandOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override { - auto srcReshapeOp = - reshapeOp.src().template getDefiningOp(); - if (!srcReshapeOp) + auto expandOp = collapseOp.src().template getDefiningOp(); + if (!expandOp) return failure(); - ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); - ShapedType intermediateType = reshapeOp.getSrcType(); - ShapedType resultType = reshapeOp.getResultType(); + ShapedType srcType = expandOp.getSrcType(); + ShapedType intermediateType = collapseOp.getSrcType(); + ShapedType resultType = collapseOp.getResultType(); + + int64_t srcRank = srcType.getRank(); + int64_t resultRank = resultType.getRank(); + if (srcType == resultType) + return failure(); + + SmallVector higherRankReassociation, + lowerRankReassociation; + + bool isResultCollapsed = srcRank > resultRank; + if (isResultCollapsed) { + higherRankReassociation = expandOp.getReassociationIndices(); + lowerRankReassociation = collapseOp.getReassociationIndices(); + } else { + higherRankReassociation = collapseOp.getReassociationIndices(); + lowerRankReassociation = expandOp.getReassociationIndices(); + } + + int higherRankIndicesID = 0; + SmallVector composedReassociation; + for (const auto &lowerRankIndices : lowerRankReassociation) { + ReassociationIndices composedIndices; + while (higherRankIndicesID < higherRankReassociation.size()) { + auto rightmostIndex = + higherRankReassociation[higherRankIndicesID].back(); + if (rightmostIndex > lowerRankIndices.back()) + return failure(); + composedIndices.push_back(higherRankIndicesID++); + if (rightmostIndex == lowerRankIndices.back()) + break; + } + composedReassociation.push_back(composedIndices); + } + if (isResultCollapsed) + rewriter.replaceOpWithNewOp( + collapseOp, resultType, expandOp.src(), composedReassociation); + else + rewriter.replaceOpWithNewOp( + collapseOp, resultType, expandOp.src(), composedReassociation); + return success(); + } +}; + +template +struct ComposeExpandOfCollapseOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ExpandOpTy expandOp, + PatternRewriter &rewriter) const override { + auto collapseOp = expandOp.src().template getDefiningOp(); + if (!collapseOp) + return failure(); + + ShapedType srcType = collapseOp.getSrcType(); + ShapedType intermediateType = expandOp.getSrcType(); + ShapedType resultType = expandOp.getResultType(); // If the source reshape can be collapsed/expanded into the target reshape // they can still be folded. This can only be reasoned about statically @@ -203,19 +283,17 @@ // - The number of dynamic dimensions matches in the source of source and // result with all other dimensions being 1. Optional> reassociationIndices = - getReassociationIndicesForReshape(srcReshapeSrcType, resultType); + getReassociationIndicesForReshape(srcType, resultType); if (!reassociationIndices) return failure(); - bool originalOpExpands = - intermediateType.getRank() > srcReshapeSrcType.getRank(); - bool resultingOpExpands = - resultType.getRank() > srcReshapeSrcType.getRank(); + bool originalOpExpands = intermediateType.getRank() > srcType.getRank(); + bool resultingOpExpands = resultType.getRank() > srcType.getRank(); if (!(resultingOpExpands ^ originalOpExpands)) - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + rewriter.replaceOpWithNewOp( + expandOp, resultType, collapseOp.src(), *reassociationIndices); else - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + rewriter.replaceOpWithNewOp( + expandOp, resultType, collapseOp.src(), *reassociationIndices); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1708,8 +1708,9 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - CollapseMixedReshapeOps>(context); + results.add, + ComposeExpandOfCollapseOp>( + context); } LogicalResult CollapseShapeOp::verify() { @@ -1748,8 +1749,8 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - CollapseMixedReshapeOps, + results.add, + ComposeCollapseOfExpandOp, CollapseShapeOpMemRefCastFolder>(context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -890,16 +890,16 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - CollapseMixedReshapeOps, + results.add, + ComposeExpandOfCollapseOp, FoldReshapeWithConstant, FoldReshapeWithFromElements>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - CollapseMixedReshapeOps, + results.add, + ComposeCollapseOfExpandOp, FoldReshapeWithConstant, FoldReshapeWithFromElements>(context); } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -302,20 +302,20 @@ // ----- -func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>) - -> memref { +func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>) + -> memref { %0 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x1x1xf32> into memref<1xf32> %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref return %1 : memref } -// CHECK-LABEL: collapsing_memref_reshapes_to_zero +// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim // CHECK: memref.collapse_shape %{{.*}} [] // CHECK-SAME: memref<1x1x1xf32> into memref // ----- -func @collapsing_memref_reshapes(%arg0 : memref) +func @compose_collapse_of_collapse(%arg0 : memref) -> memref { %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : memref into memref @@ -323,13 +323,13 @@ : memref into memref return %1 : memref } -// CHECK-LABEL: collapsing_memref_reshapes +// CHECK-LABEL: func @compose_collapse_of_collapse // CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: memref.collapse_shape // ----- -func @expanding_memref_reshapes(%arg0 : memref) +func @compose_expand_of_expand(%arg0 : memref) -> memref { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref into memref @@ -337,45 +337,46 @@ : memref into memref return %1 : memref } -// CHECK-LABEL: expanding_memref_reshapes +// CHECK-LABEL: func @compose_expand_of_expand // CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: memref.expand_shape // ----- -func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref) - -> memref<1x1x1xf32> { +func @compose_expand_of_expand_of_zero_dim(%arg0 : memref) + -> memref<1x1x1xf32> { %0 = memref.expand_shape %arg0 [] : memref into memref<1xf32> %1 = memref.expand_shape %0 [[0, 1, 2]] : memref<1xf32> into memref<1x1x1xf32> return %1 : memref<1x1x1xf32> } -// CHECK-LABEL: expanding_memref_reshapes_to_zero +// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim // CHECK: memref.expand_shape %{{.*}} [] // CHECK-SAME: memref into memref<1x1x1xf32> // ----- -func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { +func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x4xf32> into memref<3x4x4xf32> %1 = memref.collapse_shape %0 [[0, 1], [2]] : memref<3x4x4xf32> into memref<12x4xf32> return %1 : memref<12x4xf32> } -// CHECK-LABEL: @fold_memref_reshape +// CHECK-LABEL: func @fold_collapse_of_expand // CHECK-NOT: linalg.{{.*}}_shape // ----- -func @fold_memref_reshape_dynamic(%arg0 : memref) -> memref { +func @fold_collapse_collapse_of_expand(%arg0 : memref) + -> memref { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref into memref %1 = memref.collapse_shape %0 [[0, 1], [2]] : memref into memref return %1 : memref } -// CHECK-LABEL: @fold_memref_reshape_dynamic +// CHECK-LABEL: @fold_collapse_collapse_of_expand // CHECK-NOT: linalg.{{.*}}_shape // ----- diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -646,7 +646,7 @@ // ----- -func @expanding_tensor_reshapes(%arg0 : tensor) +func @compose_expand_of_expand(%arg0 : tensor) -> tensor { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor into tensor @@ -654,49 +654,51 @@ : tensor into tensor return %1 : tensor } -// CHECK-LABEL: expanding_tensor_reshapes +// CHECK-LABEL: compose_expand_of_expand // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: tensor.expand_shape // ----- -func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor) +func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor) -> tensor<1x1x1xf32> { %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> %1 = tensor.expand_shape %0 [[0, 1, 2]] : tensor<1xf32> into tensor<1x1x1xf32> return %1 : tensor<1x1x1xf32> } -// CHECK-LABEL: expanding_tensor_reshapes_to_zero +// CHECK-LABEL: compose_expand_of_expand_of_zero_dim // CHECK: tensor.expand_shape %{{.*}} [] // CHECK-SAME: tensor into tensor<1x1x1xf32> // ----- -func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { +func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<12x4xf32> into tensor<3x4x4xf32> %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<3x4x4xf32> into tensor<12x4xf32> return %1 : tensor<12x4xf32> } -// CHECK-LABEL: @fold_tensor_reshape +// CHECK-LABEL: @fold_collapse_of_expand // CHECK-NOT: linalg.{{.*}}shape // ----- -func @fold_tensor_reshape_dynamic(%arg0 : tensor) -> tensor { +func @fold_collapse_of_expand_dynamic(%arg0 : tensor) + -> tensor { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor into tensor return %1 : tensor } -// CHECK-LABEL: @fold_tensor_reshape_dynamic +// CHECK-LABEL: @fold_collapse_of_expand_dynamic // CHECK-NOT: linalg.{{.*}}_shape // ----- -func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) + +func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> @@ -704,7 +706,7 @@ : tensor<40320xf32> into tensor<24x5x42x8xf32> return %1 : tensor<24x5x42x8xf32> } -// CHECK: func @reshape_collapse +// CHECK: func @compose_expand_of_collapse // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] @@ -712,7 +714,7 @@ // ----- -func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) +func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<24x5x42x8xf32> into tensor<40320xf32> @@ -720,7 +722,7 @@ : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> return %1 : tensor<2x3x4x5x6x7x8xf32> } -// CHECK: func @reshape_expand +// CHECK: func @compose_expand_of_collapse_7D // CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] @@ -728,20 +730,37 @@ // ----- -func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> { +func @compose_collapse_of_expand(%arg : tensor) + -> tensor { + %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @compose_collapse_of_expand +// CHECK: (%[[ARG:.*]]: tensor) +// CHECK-NEXT: tensor.collapse_shape %[[ARG]] +// CHECK-SAME: [0, 1], [2] +// CHECK-SAME: : tensor into tensor + +// ----- + +func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>) + -> tensor<4x512xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] : tensor<2048xf32> into tensor<1x4x1x512xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<1x4x1x512xf32> into tensor<4x512xf32> return %1 : tensor<4x512xf32> } -// CHECK: func @expand_reshape_1D +// CHECK: func @compose_collapse_of_expand_1D // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> // ----- -// CHECK-LABEL: zero_rank_reshape_multi +// CHECK-LABEL: func @zero_rank_reshape_multi func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0 %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> @@ -752,7 +771,7 @@ // ----- -func @collapsing_tensor_reshapes(%arg0 : tensor) +func @compose_collapse_of_collapse(%arg0 : tensor) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor into tensor @@ -760,39 +779,39 @@ : tensor into tensor return %1 : tensor } -// CHECK-LABEL: collapsing_tensor_reshapes +// CHECK-LABEL: func @compose_collapse_of_collapse // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: tensor.collapse_shape // ----- -func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) +func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x1xf32> into tensor<1xf32> %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor return %1 : tensor } -// CHECK-LABEL: collapsing_tensor_reshapes_to_zero +// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim // CHECK: tensor.collapse_shape %{{.*}} [] // CHECK-SAME: tensor<1x1x1xf32> into tensor // ----- -func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { +func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] : tensor<4x512xf32> into tensor<1x4x1x512xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] : tensor<1x4x1x512xf32> into tensor<2048xf32> return %1 : tensor<2048xf32> } -// CHECK: func @fold_reshape_1D +// CHECK: func @fold_collapse_of_expand_1D // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> // ----- -func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) +func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> @@ -800,13 +819,13 @@ : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> return %1 : tensor<4x512x1x1xf32> } -// CHECK: func @fold_reshape_unit_dims +// CHECK: func @fold_collapse_of_expand_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] // CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> // ----- -func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) +func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> @@ -814,69 +833,70 @@ : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> return %1 : tensor<4x512x1x512x4xf32> } -// CHECK: func @expand_reshape_unit_dims +// CHECK: func @compose_collapse_of_expand_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] // CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> // ----- -func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> { +func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) + -> tensor<2x1xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2]] : tensor<2xf32> into tensor<2x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } -// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: func @compose_collapse_of_expand_trailing_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- -func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor) - -> tensor { +func @compose_collapse_of_collapse_unit_dims_dynamic( + %arg0 : tensor) -> tensor { %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]] : tensor into tensor return %1 : tensor } -// CHECK: func @collapse_reshape_unit_dims_dynamic +// CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic // CHECK: tensor.collapse_shape // CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8] // CHECK-SAME: tensor into tensor // ----- -func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> -{ +func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) + -> tensor<2x1xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2]] : tensor<2xf32> into tensor<2x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } -// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: func @fold_collapse_of_expand_trailing_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- -func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) - -> tensor { +func @fold_collapse_of_collapse_trailing_unit_dims_dynamic( + %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]] : tensor<1x1x?x1x1x1xf32> into tensor %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] : tensor into tensor return %1 : tensor } -// CHECK: func @fold_reshape_trailing_unit_dims_dynamic +// CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] // CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor // ----- -func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) +func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) -> tensor<12x42xf32> { %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> @@ -884,27 +904,28 @@ : tensor<12x42x1x1x1xf32> into tensor<12x42xf32> return %1 : tensor<12x42xf32> } -// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: func @fold_collapse_of_expand_trailing_unit_dims // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] // CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32> // ----- -func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor) -> tensor { +func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor) + -> tensor { %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } -// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle +// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle // CHECK-SAME: (%[[ARG:.*]]: tensor // CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]] // CHECK-SAME: tensor into tensor // ----- -func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) +func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32> { %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> @@ -912,20 +933,21 @@ : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> return %1 : tensor<2x6x16xf32> } -// CHECK-LABEL: func @no_fold_reshape_incompatible +// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible // CHECK: tensor.expand_shape // CHECK: tensor.collapse_shape // ----- -func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { +func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) + -> tensor<12x1xf32> { %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<3x2x2x1xf32> into tensor<12x1xf32> return %1 : tensor<12x1xf32> } -// CHECK: func @no_fold_reshape_empty_expr +// CHECK: func @no_fold_collapse_of_expand_empty_expr // CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> // CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0], [1], [2, 3] @@ -1002,11 +1024,11 @@ // ----- -// CHECK-LABEL: func @pad_tensor_same_static_shape( +// CHECK-LABEL: func @pad_same_static_shape( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK-NOT: tensor.pad // CHECK: return %[[ARG0]] -func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) +func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) -> tensor<5x6xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] { @@ -1018,11 +1040,11 @@ // ----- -// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape( +// CHECK-LABEL: func @pad_nofold_same_static_shape( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK: %[[PAD:.*]] = tensor.pad // CHECK: return %[[PAD]] -func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) +func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) -> tensor<5x6xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] { @@ -1034,7 +1056,7 @@ // ----- -// CHECK-LABEL: func @pad_tensor_after_cast_different_shape( +// CHECK-LABEL: func @pad_after_cast_different_shape( // CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] @@ -1046,7 +1068,7 @@ // CHECK-SAME: tensor to tensor // CHECK: return %[[DYNAMIC]] : tensor // CHECK: } -func @pad_tensor_after_cast_different_shape(%arg0: tensor) +func @pad_after_cast_different_shape(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %dynamic = tensor.cast %arg0 : tensor to tensor @@ -1059,7 +1081,7 @@ // ----- -// CHECK-LABEL: func @pad_tensor_after_cast_same_shape( +// CHECK-LABEL: func @pad_after_cast_same_shape( // CHECK-SAME: %[[INPUT:.*]]: tensor, // CHECK-SAME: %[[PADDING:.*]]: index) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 @@ -1070,7 +1092,7 @@ // CHECK: } : tensor to tensor // CHECK: return %[[PADDED:.*]] : tensor // CHECK: } -func @pad_tensor_after_cast_same_shape(%arg0: tensor, %padding : index) +func @pad_after_cast_same_shape(%arg0: tensor, %padding : index) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %dynamic = tensor.cast %arg0 : tensor to tensor @@ -1083,11 +1105,11 @@ // ----- -// CHECK-LABEL: func @pad_tensor_of_cast( +// CHECK-LABEL: func @pad_of_cast( // CHECK-NOT: tensor.cast // CHECK: tensor.pad // CHECK: tensor<8x?xf32> to tensor<8x32xf32> -func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> { +func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.cast %t : tensor<8x?xf32> to tensor @@ -1133,7 +1155,7 @@ // ----- -func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { +func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor @@ -1143,17 +1165,17 @@ } : tensor to tensor<4x4xf32> return %1 : tensor<4x4xf32> } -// CHECK-LABEL: @tensor_pad_cast +// CHECK-LABEL: @pad_cast // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32> // CHECK: return %[[ARG0]] // ----- -// CHECK-LABEL: func @fold_pad_tensor_source_cast( +// CHECK-LABEL: func @fold_pad_source_cast( // CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32> // CHECK-NOT: tensor.cast // CHECK: %[[RESULT:.*]] = tensor.pad %[[ARG0]] -func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { +func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { %cst = arith.constant 0.0 : f32 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor %1 = tensor.pad %0 low[0, 0] high[0, 1] {