diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -68,6 +68,12 @@ Optional> getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); +/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if +/// possible. +Optional> +getReassociationIndicesForCollapse(ArrayRef sourceShape, + ArrayRef targetShape); + /// Return true if the reassociation specification is valid, false otherwise. /// When false, the `invalidIndex` integer pointer is optionally filled with the /// index of the offending reassociation map. @@ -156,10 +162,13 @@ op.getReassociationIndices(), isExpandingReshape); } +/// Returns true iff the type is a MemRefType and has a non-identity layout. +bool hasNonIdentityLayout(Type type); + /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. template -struct CollapseReshapeOps : public OpRewritePattern { +struct ComposeReassociativeReshapeOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override { @@ -167,7 +176,15 @@ if (!srcReshapeOp) return failure(); + ShapedType srcType = srcReshapeOp.getSrcType(); + ShapedType intermediateType = reshapeOp.getSrcType(); ShapedType resultType = reshapeOp.getResultType(); + + if (hasNonIdentityLayout(srcReshapeOp.src().getType()) || + hasNonIdentityLayout(reshapeOp.src().getType()) || + hasNonIdentityLayout(reshapeOp.result().getType())) + return failure(); + Optional> reassociationIndices = composeReassociationIndices(srcReshapeOp.getReassociationIndices(), reshapeOp.getReassociationIndices(), @@ -180,44 +197,180 @@ } }; -/// Pattern to collapse producer/consumer reshape ops that are both collapsing -/// dimensions or are both expanding dimensions. -template -struct CollapseMixedReshapeOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, +/// Pattern to compose +/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`. +/// In that case both `srcType` and `resultType` can be expressed as a function +/// of `intermediateType`. +/// In order to demonstrate the approach, let's assume that `rank(srcType) > +/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`. +/// In that case, we can iterate over every set of indices in `reassociation_2` +/// and try to find ids of sets of indices in `reassociation_1` that cover it +/// completely. +/// +/// Example: +/// +/// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] +/// : tensor into tensor +/// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] +/// : tensor into tensor +/// +/// can be canonicalized into +/// +/// %0 = tensor.collapse_shape %arg [[0, 1], [2]] +/// : tensor into tensor +/// +/// because [0] and [1] from `expand_shape` reassociation cover completely +/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of +/// indices, then we fail. +// +/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1` +/// `reassociation_2` and produce `expand_shape`. +template +struct ComposeCollapseOfExpandOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override { - auto srcReshapeOp = - reshapeOp.src().template getDefiningOp(); - if (!srcReshapeOp) + auto expandOp = collapseOp.src().template getDefiningOp(); + if (!expandOp) return failure(); - ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); - ShapedType intermediateType = reshapeOp.getSrcType(); - ShapedType resultType = reshapeOp.getResultType(); + ShapedType srcType = expandOp.getSrcType(); + ShapedType intermediateType = collapseOp.getSrcType(); + ShapedType resultType = collapseOp.getResultType(); - // If the source reshape can be collapsed/expanded into the target reshape - // they can still be folded. This can only be reasoned about statically - // for cases where - // - either all shapes are static, or - // - The number of dynamic dimensions matches in the source of source and - // result with all other dimensions being 1. - Optional> reassociationIndices = - getReassociationIndicesForReshape(srcReshapeSrcType, resultType); - if (!reassociationIndices) + if (hasNonIdentityLayout(collapseOp.src().getType()) || + hasNonIdentityLayout(expandOp.src().getType()) || + hasNonIdentityLayout(expandOp.result().getType())) + return failure(); + + int64_t srcRank = srcType.getRank(); + int64_t resultRank = resultType.getRank(); + if (srcType == resultType) return failure(); - bool originalOpExpands = - intermediateType.getRank() > srcReshapeSrcType.getRank(); - bool resultingOpExpands = - resultType.getRank() > srcReshapeSrcType.getRank(); - if (!(resultingOpExpands ^ originalOpExpands)) - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + + SmallVector higherRankReassociation, + lowerRankReassociation; + + bool isResultCollapsed = srcRank > resultRank; + if (isResultCollapsed) { + higherRankReassociation = expandOp.getReassociationIndices(); + lowerRankReassociation = collapseOp.getReassociationIndices(); + } else { + higherRankReassociation = collapseOp.getReassociationIndices(); + lowerRankReassociation = expandOp.getReassociationIndices(); + } + + int higherRankIndicesID = 0; + SmallVector composedReassociation; + for (const auto &lowerRankIndices : lowerRankReassociation) { + ReassociationIndices composedIndices; + while (higherRankIndicesID < higherRankReassociation.size()) { + auto rightmostIndex = + higherRankReassociation[higherRankIndicesID].back(); + if (rightmostIndex > lowerRankIndices.back()) + return failure(); + composedIndices.push_back(higherRankIndicesID++); + if (rightmostIndex == lowerRankIndices.back()) + break; + } + composedReassociation.push_back(composedIndices); + } + if (isResultCollapsed) + rewriter.replaceOpWithNewOp( + collapseOp, resultType, expandOp.src(), composedReassociation); else - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + rewriter.replaceOpWithNewOp( + collapseOp, resultType, expandOp.src(), composedReassociation); + return success(); + } +}; + +template +struct ComposeExpandOfCollapseOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ExpandOpTy expandOp, + PatternRewriter &rewriter) const override { + auto collapseOp = expandOp.src().template getDefiningOp(); + if (!collapseOp) + return failure(); + + ShapedType srcType = collapseOp.getSrcType(); + ShapedType intermediateType = expandOp.getSrcType(); + ShapedType resultType = expandOp.getResultType(); + + if (hasNonIdentityLayout(expandOp.src().getType()) || + hasNonIdentityLayout(collapseOp.src().getType()) || + hasNonIdentityLayout(collapseOp.result().getType())) + return failure(); + + int64_t srcRank = srcType.getRank(); + int64_t resultRank = resultType.getRank(); + if (srcType == resultType) + return failure(); + + auto srcReassociation = collapseOp.getReassociationIndices(); + auto resultReassociation = expandOp.getReassociationIndices(); + if (srcRank > resultRank) { + auto composedReassociation = findCollapsingReassociation( + srcReassociation, resultReassociation, srcType.getShape(), + resultType.getShape()); + if (!composedReassociation.hasValue()) + return failure(); + + rewriter.replaceOpWithNewOp( + expandOp, resultType, collapseOp.src(), *composedReassociation); + return success(); + } + auto composedReassociation = + findCollapsingReassociation(resultReassociation, srcReassociation, + resultType.getShape(), srcType.getShape()); + if (!composedReassociation.hasValue()) + return failure(); + + rewriter.replaceOpWithNewOp( + expandOp, resultType, collapseOp.src(), *composedReassociation); return success(); } + +private: + // Attempts to find a way to collapse `srcShape` to `resultShape` by + // collapsing subshapes defined by the reassociation indices. + Optional> findCollapsingReassociation( + ArrayRef srcReassociation, + ArrayRef resultReassociation, + ArrayRef srcShape, ArrayRef resultShape) const { + SmallVector composedReassociation; + + for (auto item : llvm::zip(srcReassociation, resultReassociation)) { + auto &srcIndices = std::get<0>(item); + auto &resultIndices = std::get<1>(item); + auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size()); + auto resultSubShape = + resultShape.slice(resultIndices.front(), resultIndices.size()); + + if (srcSubShape.size() == resultSubShape.size()) { + if (srcSubShape == resultSubShape) + composedReassociation.push_back(srcIndices); + else + return llvm::None; + } + + // Find reassociation to collapse `srcSubShape` into `resultSubShape`. + auto subShapeReassociation = + getReassociationIndicesForCollapse(srcSubShape, resultSubShape); + if (!subShapeReassociation.hasValue()) + return llvm::None; + + // Remap the subshape indices back to the original srcShape. + for (auto &subshape_indices : *subShapeReassociation) { + ReassociationIndices shape_indices; + for (int64_t index : subshape_indices) + shape_indices.push_back(srcIndices.front() + index); + composedReassociation.push_back(shape_indices); + } + } + return {std::move(composedReassociation)}; + } }; } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1793,8 +1793,9 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - CollapseMixedReshapeOps>(context); + results.add, + ComposeExpandOfCollapseOp>( + context); } /// Compute the layout map after collapsing a given source MemRef type with the @@ -1999,8 +2000,8 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - CollapseMixedReshapeOps, + results.add, + ComposeCollapseOfExpandOp, CollapseShapeOpMemRefCastFolder>(context); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -890,16 +890,16 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - CollapseMixedReshapeOps, + results.add, + ComposeExpandOfCollapseOp, FoldReshapeWithConstant, FoldReshapeWithFromElements>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - CollapseMixedReshapeOps, + results.add, + ComposeCollapseOfExpandOp, FoldReshapeWithConstant, FoldReshapeWithFromElements>(context); } diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -18,18 +18,23 @@ Optional> mlir::getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType) { - // Make the sourceType greater rank than the targetType. If they are same - // rank, then its an unsupported reshape op. - if (sourceType.getRank() == targetType.getRank()) - return llvm::None; + if (sourceType.getRank() > targetType.getRank()) + return getReassociationIndicesForCollapse(sourceType.getShape(), + targetType.getShape()); if (sourceType.getRank() < targetType.getRank()) - std::swap(sourceType, targetType); + return getReassociationIndicesForCollapse(targetType.getShape(), + sourceType.getShape()); + return llvm::None; +} - ArrayRef sourceShape = sourceType.getShape(); - ArrayRef targetShape = targetType.getShape(); +Optional> +mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, + ArrayRef targetShape) { + if (sourceShape.size() <= targetShape.size()) + return llvm::None; unsigned sourceDim = 0; SmallVector reassociationMap; - reassociationMap.reserve(targetType.getRank()); + reassociationMap.reserve(targetShape.size()); ReassociationIndices currIndices; int64_t prodOfCollapsedDims = 1; @@ -37,7 +42,7 @@ unsigned targetDim = reassociationMap.size(); // If we have mapped all the target dimensions stop and handle the remaining // tail of size-1 dimensions explictly. - if (targetDim == targetType.getRank()) + if (targetDim == targetShape.size()) break; int64_t currTargetShape = targetShape[targetDim]; @@ -187,6 +192,7 @@ } return maps; } + bool mlir::isReassociationValid(ArrayRef reassociation, int *invalidIndex) { if (reassociation.empty()) @@ -258,3 +264,9 @@ } return success(); } + +bool mlir::hasNonIdentityLayout(Type type) { + if (auto memrefType = type.dyn_cast()) + return !memrefType.getLayout().isIdentity(); + return false; +} diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -302,20 +302,20 @@ // ----- -func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>) - -> memref { +func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>) + -> memref { %0 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x1x1xf32> into memref<1xf32> %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref return %1 : memref } -// CHECK-LABEL: collapsing_memref_reshapes_to_zero +// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim // CHECK: memref.collapse_shape %{{.*}} [] // CHECK-SAME: memref<1x1x1xf32> into memref // ----- -func @collapsing_memref_reshapes(%arg0 : memref) +func @compose_collapse_of_collapse(%arg0 : memref) -> memref { %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : memref into memref @@ -323,13 +323,30 @@ : memref into memref return %1 : memref } -// CHECK-LABEL: collapsing_memref_reshapes +// CHECK-LABEL: func @compose_collapse_of_collapse // CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: memref.collapse_shape // ----- -func @expanding_memref_reshapes(%arg0 : memref) +func @do_not_compose_collapse_of_expand_non_identity_layout( + %arg0: memref) + -> memref { + %1 = memref.expand_shape %arg0 [[0, 1], [2]] : + memref into + memref + %2 = memref.collapse_shape %1 [[0, 1, 2]] : + memref into + memref + return %2 : memref +} +// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout +// CHECK: expand +// CHECK: collapse + +// ----- + +func @compose_expand_of_expand(%arg0 : memref) -> memref { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref into memref @@ -337,45 +354,46 @@ : memref into memref return %1 : memref } -// CHECK-LABEL: expanding_memref_reshapes +// CHECK-LABEL: func @compose_expand_of_expand // CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: memref.expand_shape // ----- -func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref) - -> memref<1x1x1xf32> { +func @compose_expand_of_expand_of_zero_dim(%arg0 : memref) + -> memref<1x1x1xf32> { %0 = memref.expand_shape %arg0 [] : memref into memref<1xf32> %1 = memref.expand_shape %0 [[0, 1, 2]] : memref<1xf32> into memref<1x1x1xf32> return %1 : memref<1x1x1xf32> } -// CHECK-LABEL: expanding_memref_reshapes_to_zero +// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim // CHECK: memref.expand_shape %{{.*}} [] // CHECK-SAME: memref into memref<1x1x1xf32> // ----- -func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { +func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x4xf32> into memref<3x4x4xf32> %1 = memref.collapse_shape %0 [[0, 1], [2]] : memref<3x4x4xf32> into memref<12x4xf32> return %1 : memref<12x4xf32> } -// CHECK-LABEL: @fold_memref_reshape +// CHECK-LABEL: func @fold_collapse_of_expand // CHECK-NOT: linalg.{{.*}}_shape // ----- -func @fold_memref_reshape_dynamic(%arg0 : memref) -> memref { +func @fold_collapse_collapse_of_expand(%arg0 : memref) + -> memref { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref into memref %1 = memref.collapse_shape %0 [[0, 1], [2]] : memref into memref return %1 : memref } -// CHECK-LABEL: @fold_memref_reshape_dynamic +// CHECK-LABEL: @fold_collapse_collapse_of_expand // CHECK-NOT: linalg.{{.*}}_shape // ----- diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -646,7 +646,7 @@ // ----- -func @expanding_tensor_reshapes(%arg0 : tensor) +func @compose_expand_of_expand(%arg0 : tensor) -> tensor { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor into tensor @@ -654,49 +654,51 @@ : tensor into tensor return %1 : tensor } -// CHECK-LABEL: expanding_tensor_reshapes +// CHECK-LABEL: compose_expand_of_expand // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: tensor.expand_shape // ----- -func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor) +func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor) -> tensor<1x1x1xf32> { %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> %1 = tensor.expand_shape %0 [[0, 1, 2]] : tensor<1xf32> into tensor<1x1x1xf32> return %1 : tensor<1x1x1xf32> } -// CHECK-LABEL: expanding_tensor_reshapes_to_zero +// CHECK-LABEL: compose_expand_of_expand_of_zero_dim // CHECK: tensor.expand_shape %{{.*}} [] // CHECK-SAME: tensor into tensor<1x1x1xf32> // ----- -func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { +func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<12x4xf32> into tensor<3x4x4xf32> %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<3x4x4xf32> into tensor<12x4xf32> return %1 : tensor<12x4xf32> } -// CHECK-LABEL: @fold_tensor_reshape +// CHECK-LABEL: @fold_collapse_of_expand // CHECK-NOT: linalg.{{.*}}shape // ----- -func @fold_tensor_reshape_dynamic(%arg0 : tensor) -> tensor { +func @fold_collapse_of_expand_dynamic(%arg0 : tensor) + -> tensor { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor into tensor return %1 : tensor } -// CHECK-LABEL: @fold_tensor_reshape_dynamic +// CHECK-LABEL: @fold_collapse_of_expand_dynamic // CHECK-NOT: linalg.{{.*}}_shape // ----- -func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) + +func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> @@ -704,7 +706,7 @@ : tensor<40320xf32> into tensor<24x5x42x8xf32> return %1 : tensor<24x5x42x8xf32> } -// CHECK: func @reshape_collapse +// CHECK: func @compose_expand_of_collapse // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] @@ -712,7 +714,7 @@ // ----- -func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) +func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<24x5x42x8xf32> into tensor<40320xf32> @@ -720,7 +722,7 @@ : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> return %1 : tensor<2x3x4x5x6x7x8xf32> } -// CHECK: func @reshape_expand +// CHECK: func @compose_expand_of_collapse_7D // CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] @@ -728,20 +730,37 @@ // ----- -func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> { +func @compose_collapse_of_expand(%arg : tensor) + -> tensor { + %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @compose_collapse_of_expand +// CHECK: (%[[ARG:.*]]: tensor) +// CHECK-NEXT: tensor.collapse_shape %[[ARG]] +// CHECK-SAME: [0, 1], [2] +// CHECK-SAME: : tensor into tensor + +// ----- + +func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>) + -> tensor<4x512xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] : tensor<2048xf32> into tensor<1x4x1x512xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<1x4x1x512xf32> into tensor<4x512xf32> return %1 : tensor<4x512xf32> } -// CHECK: func @expand_reshape_1D +// CHECK: func @compose_collapse_of_expand_1D // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> // ----- -// CHECK-LABEL: zero_rank_reshape_multi +// CHECK-LABEL: func @zero_rank_reshape_multi func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0 %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> @@ -752,7 +771,7 @@ // ----- -func @collapsing_tensor_reshapes(%arg0 : tensor) +func @compose_collapse_of_collapse(%arg0 : tensor) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor into tensor @@ -760,39 +779,39 @@ : tensor into tensor return %1 : tensor } -// CHECK-LABEL: collapsing_tensor_reshapes +// CHECK-LABEL: func @compose_collapse_of_collapse // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: tensor.collapse_shape // ----- -func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) +func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x1xf32> into tensor<1xf32> %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor return %1 : tensor } -// CHECK-LABEL: collapsing_tensor_reshapes_to_zero +// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim // CHECK: tensor.collapse_shape %{{.*}} [] // CHECK-SAME: tensor<1x1x1xf32> into tensor // ----- -func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { +func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] : tensor<4x512xf32> into tensor<1x4x1x512xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] : tensor<1x4x1x512xf32> into tensor<2048xf32> return %1 : tensor<2048xf32> } -// CHECK: func @fold_reshape_1D +// CHECK: func @fold_collapse_of_expand_1D // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> // ----- -func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) +func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> @@ -800,13 +819,13 @@ : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> return %1 : tensor<4x512x1x1xf32> } -// CHECK: func @fold_reshape_unit_dims +// CHECK: func @fold_collapse_of_expand_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] // CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> // ----- -func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) +func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> @@ -814,69 +833,70 @@ : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> return %1 : tensor<4x512x1x512x4xf32> } -// CHECK: func @expand_reshape_unit_dims +// CHECK: func @compose_collapse_of_expand_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] // CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> // ----- -func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> { +func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) + -> tensor<2x1xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2]] : tensor<2xf32> into tensor<2x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } -// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: func @compose_collapse_of_expand_trailing_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- -func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor) - -> tensor { +func @compose_collapse_of_collapse_unit_dims_dynamic( + %arg0 : tensor) -> tensor { %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]] : tensor into tensor return %1 : tensor } -// CHECK: func @collapse_reshape_unit_dims_dynamic +// CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic // CHECK: tensor.collapse_shape // CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8] // CHECK-SAME: tensor into tensor // ----- -func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> -{ +func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) + -> tensor<2x1xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2]] : tensor<2xf32> into tensor<2x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } -// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: func @fold_collapse_of_expand_trailing_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- -func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) - -> tensor { +func @fold_collapse_of_collapse_trailing_unit_dims_dynamic( + %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]] : tensor<1x1x?x1x1x1xf32> into tensor %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] : tensor into tensor return %1 : tensor } -// CHECK: func @fold_reshape_trailing_unit_dims_dynamic +// CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] // CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor // ----- -func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) +func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) -> tensor<12x42xf32> { %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> @@ -884,27 +904,28 @@ : tensor<12x42x1x1x1xf32> into tensor<12x42xf32> return %1 : tensor<12x42xf32> } -// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: func @fold_collapse_of_expand_trailing_unit_dims // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] // CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32> // ----- -func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor) -> tensor { +func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor) + -> tensor { %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } -// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle +// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle // CHECK-SAME: (%[[ARG:.*]]: tensor // CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]] // CHECK-SAME: tensor into tensor // ----- -func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) +func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32> { %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> @@ -912,20 +933,21 @@ : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> return %1 : tensor<2x6x16xf32> } -// CHECK-LABEL: func @no_fold_reshape_incompatible +// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible // CHECK: tensor.expand_shape // CHECK: tensor.collapse_shape // ----- -func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { +func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) + -> tensor<12x1xf32> { %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<3x2x2x1xf32> into tensor<12x1xf32> return %1 : tensor<12x1xf32> } -// CHECK: func @no_fold_reshape_empty_expr +// CHECK: func @no_fold_collapse_of_expand_empty_expr // CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> // CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0], [1], [2, 3] @@ -1002,11 +1024,11 @@ // ----- -// CHECK-LABEL: func @pad_tensor_same_static_shape( +// CHECK-LABEL: func @pad_same_static_shape( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK-NOT: tensor.pad // CHECK: return %[[ARG0]] -func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) +func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) -> tensor<5x6xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] { @@ -1018,11 +1040,11 @@ // ----- -// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape( +// CHECK-LABEL: func @pad_nofold_same_static_shape( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK: %[[PAD:.*]] = tensor.pad // CHECK: return %[[PAD]] -func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) +func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) -> tensor<5x6xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] { @@ -1034,7 +1056,7 @@ // ----- -// CHECK-LABEL: func @pad_tensor_after_cast_different_shape( +// CHECK-LABEL: func @pad_after_cast_different_shape( // CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] @@ -1046,7 +1068,7 @@ // CHECK-SAME: tensor to tensor // CHECK: return %[[DYNAMIC]] : tensor // CHECK: } -func @pad_tensor_after_cast_different_shape(%arg0: tensor) +func @pad_after_cast_different_shape(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %dynamic = tensor.cast %arg0 : tensor to tensor @@ -1059,7 +1081,7 @@ // ----- -// CHECK-LABEL: func @pad_tensor_after_cast_same_shape( +// CHECK-LABEL: func @pad_after_cast_same_shape( // CHECK-SAME: %[[INPUT:.*]]: tensor, // CHECK-SAME: %[[PADDING:.*]]: index) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 @@ -1070,7 +1092,7 @@ // CHECK: } : tensor to tensor // CHECK: return %[[PADDED:.*]] : tensor // CHECK: } -func @pad_tensor_after_cast_same_shape(%arg0: tensor, %padding : index) +func @pad_after_cast_same_shape(%arg0: tensor, %padding : index) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %dynamic = tensor.cast %arg0 : tensor to tensor @@ -1083,11 +1105,11 @@ // ----- -// CHECK-LABEL: func @pad_tensor_of_cast( +// CHECK-LABEL: func @pad_of_cast( // CHECK-NOT: tensor.cast // CHECK: tensor.pad // CHECK: tensor<8x?xf32> to tensor<8x32xf32> -func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> { +func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.cast %t : tensor<8x?xf32> to tensor @@ -1133,7 +1155,7 @@ // ----- -func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { +func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor @@ -1143,17 +1165,17 @@ } : tensor to tensor<4x4xf32> return %1 : tensor<4x4xf32> } -// CHECK-LABEL: @tensor_pad_cast +// CHECK-LABEL: @pad_cast // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32> // CHECK: return %[[ARG0]] // ----- -// CHECK-LABEL: func @fold_pad_tensor_source_cast( +// CHECK-LABEL: func @fold_pad_source_cast( // CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32> // CHECK-NOT: tensor.cast // CHECK: %[[RESULT:.*]] = tensor.pad %[[ARG0]] -func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { +func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { %cst = arith.constant 0.0 : f32 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor %1 = tensor.pad %0 low[0, 0] high[0, 1] {