diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -68,12 +68,6 @@ Optional> getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); -/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if -/// possible. -Optional> -getReassociationIndicesForCollapse(ArrayRef sourceShape, - ArrayRef targetShape); - /// Return true if the reassociation specification is valid, false otherwise. /// When false, the `invalidIndex` integer pointer is optionally filled with the /// index of the offending reassociation map. @@ -162,13 +156,10 @@ op.getReassociationIndices(), isExpandingReshape); } -/// Returns true iff the type is a MemRefType and has a non-identity layout. -bool hasNonIdentityLayout(Type type); - /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. template -struct ComposeReassociativeReshapeOps : public OpRewritePattern { +struct CollapseReshapeOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override { @@ -177,12 +168,6 @@ return failure(); ShapedType resultType = reshapeOp.getResultType(); - - if (hasNonIdentityLayout(srcReshapeOp.src().getType()) || - hasNonIdentityLayout(reshapeOp.src().getType()) || - hasNonIdentityLayout(reshapeOp.result().getType())) - return failure(); - Optional> reassociationIndices = composeReassociationIndices(srcReshapeOp.getReassociationIndices(), reshapeOp.getReassociationIndices(), @@ -195,180 +180,46 @@ } }; -/// Pattern to compose -/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`. -/// In that case both `srcType` and `resultType` can be expressed as a function -/// of `intermediateType`. -/// In order to demonstrate the approach, let's assume that `rank(srcType) > -/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`. -/// In that case, we can iterate over every set of indices in `reassociation_2` -/// and try to find ids of sets of indices in `reassociation_1` that cover it -/// completely. -/// -/// Example: -/// -/// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] -/// : tensor into tensor -/// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] -/// : tensor into tensor -/// -/// can be canonicalized into -/// -/// %0 = tensor.collapse_shape %arg [[0, 1], [2]] -/// : tensor into tensor -/// -/// because [0] and [1] from `expand_shape` reassociation cover completely -/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of -/// indices, then we fail. -// -/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1` -/// `reassociation_2` and produce `expand_shape`. -template -struct ComposeCollapseOfExpandOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CollapseOpTy collapseOp, +/// Pattern to collapse producer/consumer reshape ops that are both collapsing +/// dimensions or are both expanding dimensions. +template +struct CollapseMixedReshapeOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override { - auto expandOp = collapseOp.src().template getDefiningOp(); - if (!expandOp) + auto srcReshapeOp = + reshapeOp.src().template getDefiningOp(); + if (!srcReshapeOp) return failure(); - ShapedType srcType = expandOp.getSrcType(); - ShapedType resultType = collapseOp.getResultType(); - - if (hasNonIdentityLayout(collapseOp.src().getType()) || - hasNonIdentityLayout(expandOp.src().getType()) || - hasNonIdentityLayout(expandOp.result().getType())) - return failure(); + ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); + ShapedType intermediateType = reshapeOp.getSrcType(); + ShapedType resultType = reshapeOp.getResultType(); - int64_t srcRank = srcType.getRank(); - int64_t resultRank = resultType.getRank(); - if (srcType == resultType) + // If the source reshape can be collapsed/expanded into the target reshape + // they can still be folded. This can only be reasoned about statically + // for cases where + // - either all shapes are static, or + // - The number of dynamic dimensions matches in the source of source and + // result with all other dimensions being 1. + Optional> reassociationIndices = + getReassociationIndicesForReshape(srcReshapeSrcType, resultType); + if (!reassociationIndices) return failure(); - - SmallVector higherRankReassociation, - lowerRankReassociation; - - bool isResultCollapsed = srcRank > resultRank; - if (isResultCollapsed) { - higherRankReassociation = expandOp.getReassociationIndices(); - lowerRankReassociation = collapseOp.getReassociationIndices(); - } else { - higherRankReassociation = collapseOp.getReassociationIndices(); - lowerRankReassociation = expandOp.getReassociationIndices(); - } - - size_t higherRankIndicesID = 0; - SmallVector composedReassociation; - for (const auto &lowerRankIndices : lowerRankReassociation) { - ReassociationIndices composedIndices; - while (higherRankIndicesID < higherRankReassociation.size()) { - auto rightmostIndex = - higherRankReassociation[higherRankIndicesID].back(); - if (rightmostIndex > lowerRankIndices.back()) - return failure(); - composedIndices.push_back(higherRankIndicesID++); - if (rightmostIndex == lowerRankIndices.back()) - break; - } - composedReassociation.push_back(composedIndices); - } - if (isResultCollapsed) - rewriter.replaceOpWithNewOp( - collapseOp, resultType, expandOp.src(), composedReassociation); + bool originalOpExpands = + intermediateType.getRank() > srcReshapeSrcType.getRank(); + bool resultingOpExpands = + resultType.getRank() > srcReshapeSrcType.getRank(); + if (!(resultingOpExpands ^ originalOpExpands)) + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); else - rewriter.replaceOpWithNewOp( - collapseOp, resultType, expandOp.src(), composedReassociation); + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); return success(); } }; -template -struct ComposeExpandOfCollapseOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ExpandOpTy expandOp, - PatternRewriter &rewriter) const override { - auto collapseOp = expandOp.src().template getDefiningOp(); - if (!collapseOp) - return failure(); - - ShapedType srcType = collapseOp.getSrcType(); - ShapedType resultType = expandOp.getResultType(); - - if (hasNonIdentityLayout(expandOp.src().getType()) || - hasNonIdentityLayout(collapseOp.src().getType()) || - hasNonIdentityLayout(collapseOp.result().getType())) - return failure(); - - int64_t srcRank = srcType.getRank(); - int64_t resultRank = resultType.getRank(); - if (srcType == resultType) - return failure(); - - auto srcReassociation = collapseOp.getReassociationIndices(); - auto resultReassociation = expandOp.getReassociationIndices(); - if (srcRank > resultRank) { - auto composedReassociation = findCollapsingReassociation( - srcReassociation, resultReassociation, srcType.getShape(), - resultType.getShape()); - if (!composedReassociation.hasValue()) - return failure(); - - rewriter.replaceOpWithNewOp( - expandOp, resultType, collapseOp.src(), *composedReassociation); - return success(); - } - auto composedReassociation = - findCollapsingReassociation(resultReassociation, srcReassociation, - resultType.getShape(), srcType.getShape()); - if (!composedReassociation.hasValue()) - return failure(); - - rewriter.replaceOpWithNewOp( - expandOp, resultType, collapseOp.src(), *composedReassociation); - return success(); - } - -private: - // Attempts to find a way to collapse `srcShape` to `resultShape` by - // collapsing subshapes defined by the reassociation indices. - Optional> findCollapsingReassociation( - ArrayRef srcReassociation, - ArrayRef resultReassociation, - ArrayRef srcShape, ArrayRef resultShape) const { - SmallVector composedReassociation; - - for (auto item : llvm::zip(srcReassociation, resultReassociation)) { - auto &srcIndices = std::get<0>(item); - auto &resultIndices = std::get<1>(item); - auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size()); - auto resultSubShape = - resultShape.slice(resultIndices.front(), resultIndices.size()); - - if (srcSubShape.size() == resultSubShape.size()) { - if (srcSubShape == resultSubShape) - composedReassociation.push_back(srcIndices); - else - return llvm::None; - } - - // Find reassociation to collapse `srcSubShape` into `resultSubShape`. - auto subShapeReassociation = - getReassociationIndicesForCollapse(srcSubShape, resultSubShape); - if (!subShapeReassociation.hasValue()) - return llvm::None; - - // Remap the subshape indices back to the original srcShape. - for (auto &subshape_indices : *subShapeReassociation) { - ReassociationIndices shape_indices; - for (int64_t index : subshape_indices) - shape_indices.push_back(srcIndices.front() + index); - composedReassociation.push_back(shape_indices); - } - } - return {std::move(composedReassociation)}; - } -}; - } // namespace mlir #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1793,9 +1793,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeExpandOfCollapseOp>( - context); + results.add, + CollapseMixedReshapeOps>(context); } /// Compute the layout map after collapsing a given source MemRef type with the @@ -2000,8 +1999,8 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeCollapseOfExpandOp, + results.add, + CollapseMixedReshapeOps, CollapseShapeOpMemRefCastFolder>(context); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -890,16 +890,16 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeExpandOfCollapseOp, + results.add, + CollapseMixedReshapeOps, FoldReshapeWithConstant, FoldReshapeWithFromElements>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeCollapseOfExpandOp, + results.add, + CollapseMixedReshapeOps, FoldReshapeWithConstant, FoldReshapeWithFromElements>(context); } diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -18,23 +18,18 @@ Optional> mlir::getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType) { - if (sourceType.getRank() > targetType.getRank()) - return getReassociationIndicesForCollapse(sourceType.getShape(), - targetType.getShape()); + // Make the sourceType greater rank than the targetType. If they are same + // rank, then its an unsupported reshape op. + if (sourceType.getRank() == targetType.getRank()) + return llvm::None; if (sourceType.getRank() < targetType.getRank()) - return getReassociationIndicesForCollapse(targetType.getShape(), - sourceType.getShape()); - return llvm::None; -} + std::swap(sourceType, targetType); -Optional> -mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, - ArrayRef targetShape) { - if (sourceShape.size() <= targetShape.size()) - return llvm::None; + ArrayRef sourceShape = sourceType.getShape(); + ArrayRef targetShape = targetType.getShape(); unsigned sourceDim = 0; SmallVector reassociationMap; - reassociationMap.reserve(targetShape.size()); + reassociationMap.reserve(targetType.getRank()); ReassociationIndices currIndices; int64_t prodOfCollapsedDims = 1; @@ -42,7 +37,7 @@ unsigned targetDim = reassociationMap.size(); // If we have mapped all the target dimensions stop and handle the remaining // tail of size-1 dimensions explictly. - if (targetDim == targetShape.size()) + if (targetDim == targetType.getRank()) break; int64_t currTargetShape = targetShape[targetDim]; @@ -192,7 +187,6 @@ } return maps; } - bool mlir::isReassociationValid(ArrayRef reassociation, int *invalidIndex) { if (reassociation.empty()) @@ -264,9 +258,3 @@ } return success(); } - -bool mlir::hasNonIdentityLayout(Type type) { - if (auto memrefType = type.dyn_cast()) - return !memrefType.getLayout().isIdentity(); - return false; -} diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -302,20 +302,20 @@ // ----- -func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>) - -> memref { +func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>) + -> memref { %0 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<1x1x1xf32> into memref<1xf32> %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref return %1 : memref } -// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim +// CHECK-LABEL: collapsing_memref_reshapes_to_zero // CHECK: memref.collapse_shape %{{.*}} [] // CHECK-SAME: memref<1x1x1xf32> into memref // ----- -func @compose_collapse_of_collapse(%arg0 : memref) +func @collapsing_memref_reshapes(%arg0 : memref) -> memref { %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : memref into memref @@ -323,30 +323,13 @@ : memref into memref return %1 : memref } -// CHECK-LABEL: func @compose_collapse_of_collapse +// CHECK-LABEL: collapsing_memref_reshapes // CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: memref.collapse_shape // ----- -func @do_not_compose_collapse_of_expand_non_identity_layout( - %arg0: memref) - -> memref { - %1 = memref.expand_shape %arg0 [[0, 1], [2]] : - memref into - memref - %2 = memref.collapse_shape %1 [[0, 1, 2]] : - memref into - memref - return %2 : memref -} -// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout -// CHECK: expand -// CHECK: collapse - -// ----- - -func @compose_expand_of_expand(%arg0 : memref) +func @expanding_memref_reshapes(%arg0 : memref) -> memref { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref into memref @@ -354,46 +337,45 @@ : memref into memref return %1 : memref } -// CHECK-LABEL: func @compose_expand_of_expand +// CHECK-LABEL: expanding_memref_reshapes // CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: memref.expand_shape // ----- -func @compose_expand_of_expand_of_zero_dim(%arg0 : memref) - -> memref<1x1x1xf32> { +func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref) + -> memref<1x1x1xf32> { %0 = memref.expand_shape %arg0 [] : memref into memref<1xf32> %1 = memref.expand_shape %0 [[0, 1, 2]] : memref<1xf32> into memref<1x1x1xf32> return %1 : memref<1x1x1xf32> } -// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim +// CHECK-LABEL: expanding_memref_reshapes_to_zero // CHECK: memref.expand_shape %{{.*}} [] // CHECK-SAME: memref into memref<1x1x1xf32> // ----- -func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { +func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x4xf32> into memref<3x4x4xf32> %1 = memref.collapse_shape %0 [[0, 1], [2]] : memref<3x4x4xf32> into memref<12x4xf32> return %1 : memref<12x4xf32> } -// CHECK-LABEL: func @fold_collapse_of_expand +// CHECK-LABEL: @fold_memref_reshape // CHECK-NOT: linalg.{{.*}}_shape // ----- -func @fold_collapse_collapse_of_expand(%arg0 : memref) - -> memref { +func @fold_memref_reshape_dynamic(%arg0 : memref) -> memref { %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref into memref %1 = memref.collapse_shape %0 [[0, 1], [2]] : memref into memref return %1 : memref } -// CHECK-LABEL: @fold_collapse_collapse_of_expand +// CHECK-LABEL: @fold_memref_reshape_dynamic // CHECK-NOT: linalg.{{.*}}_shape // ----- diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -646,7 +646,7 @@ // ----- -func @compose_expand_of_expand(%arg0 : tensor) +func @expanding_tensor_reshapes(%arg0 : tensor) -> tensor { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor into tensor @@ -654,51 +654,49 @@ : tensor into tensor return %1 : tensor } -// CHECK-LABEL: compose_expand_of_expand +// CHECK-LABEL: expanding_tensor_reshapes // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: tensor.expand_shape // ----- -func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor) +func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor) -> tensor<1x1x1xf32> { %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> %1 = tensor.expand_shape %0 [[0, 1, 2]] : tensor<1xf32> into tensor<1x1x1xf32> return %1 : tensor<1x1x1xf32> } -// CHECK-LABEL: compose_expand_of_expand_of_zero_dim +// CHECK-LABEL: expanding_tensor_reshapes_to_zero // CHECK: tensor.expand_shape %{{.*}} [] // CHECK-SAME: tensor into tensor<1x1x1xf32> // ----- -func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { +func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<12x4xf32> into tensor<3x4x4xf32> %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<3x4x4xf32> into tensor<12x4xf32> return %1 : tensor<12x4xf32> } -// CHECK-LABEL: @fold_collapse_of_expand +// CHECK-LABEL: @fold_tensor_reshape // CHECK-NOT: linalg.{{.*}}shape // ----- -func @fold_collapse_of_expand_dynamic(%arg0 : tensor) - -> tensor { +func @fold_tensor_reshape_dynamic(%arg0 : tensor) -> tensor { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor into tensor return %1 : tensor } -// CHECK-LABEL: @fold_collapse_of_expand_dynamic +// CHECK-LABEL: @fold_tensor_reshape_dynamic // CHECK-NOT: linalg.{{.*}}_shape // ----- - -func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) +func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> @@ -706,7 +704,7 @@ : tensor<40320xf32> into tensor<24x5x42x8xf32> return %1 : tensor<24x5x42x8xf32> } -// CHECK: func @compose_expand_of_collapse +// CHECK: func @reshape_collapse // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] @@ -714,7 +712,7 @@ // ----- -func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>) +func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<24x5x42x8xf32> into tensor<40320xf32> @@ -722,7 +720,7 @@ : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> return %1 : tensor<2x3x4x5x6x7x8xf32> } -// CHECK: func @compose_expand_of_collapse_7D +// CHECK: func @reshape_expand // CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] @@ -730,37 +728,20 @@ // ----- -func @compose_collapse_of_expand(%arg : tensor) - -> tensor { - %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] - : tensor into tensor - %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] - : tensor into tensor - return %1 : tensor -} -// CHECK-LABEL: func @compose_collapse_of_expand -// CHECK: (%[[ARG:.*]]: tensor) -// CHECK-NEXT: tensor.collapse_shape %[[ARG]] -// CHECK-SAME: [0, 1], [2] -// CHECK-SAME: : tensor into tensor - -// ----- - -func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>) - -> tensor<4x512xf32> { +func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] : tensor<2048xf32> into tensor<1x4x1x512xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<1x4x1x512xf32> into tensor<4x512xf32> return %1 : tensor<4x512xf32> } -// CHECK: func @compose_collapse_of_expand_1D +// CHECK: func @expand_reshape_1D // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> // ----- -// CHECK-LABEL: func @zero_rank_reshape_multi +// CHECK-LABEL: zero_rank_reshape_multi func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0 %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> @@ -771,7 +752,7 @@ // ----- -func @compose_collapse_of_collapse(%arg0 : tensor) +func @collapsing_tensor_reshapes(%arg0 : tensor) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor into tensor @@ -779,39 +760,39 @@ : tensor into tensor return %1 : tensor } -// CHECK-LABEL: func @compose_collapse_of_collapse +// CHECK-LABEL: collapsing_tensor_reshapes // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] // CHECK-NOT: tensor.collapse_shape // ----- -func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>) +func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x1xf32> into tensor<1xf32> %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor return %1 : tensor } -// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim +// CHECK-LABEL: collapsing_tensor_reshapes_to_zero // CHECK: tensor.collapse_shape %{{.*}} [] // CHECK-SAME: tensor<1x1x1xf32> into tensor // ----- -func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { +func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] : tensor<4x512xf32> into tensor<1x4x1x512xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] : tensor<1x4x1x512xf32> into tensor<2048xf32> return %1 : tensor<2048xf32> } -// CHECK: func @fold_collapse_of_expand_1D +// CHECK: func @fold_reshape_1D // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> // ----- -func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>) +func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> @@ -819,13 +800,13 @@ : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> return %1 : tensor<4x512x1x1xf32> } -// CHECK: func @fold_collapse_of_expand_unit_dims +// CHECK: func @fold_reshape_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] // CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> // ----- -func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>) +func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> @@ -833,70 +814,69 @@ : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> return %1 : tensor<4x512x1x512x4xf32> } -// CHECK: func @compose_collapse_of_expand_unit_dims +// CHECK: func @expand_reshape_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] // CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> // ----- -func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) - -> tensor<2x1xf32> { +func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> { %0 = tensor.expand_shape %arg0 [[0, 1, 2]] : tensor<2xf32> into tensor<2x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } -// CHECK: func @compose_collapse_of_expand_trailing_unit_dims +// CHECK: func @fold_reshape_trailing_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- -func @compose_collapse_of_collapse_unit_dims_dynamic( - %arg0 : tensor) -> tensor { +func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor) + -> tensor { %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]] : tensor into tensor return %1 : tensor } -// CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic +// CHECK: func @collapse_reshape_unit_dims_dynamic // CHECK: tensor.collapse_shape // CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8] // CHECK-SAME: tensor into tensor // ----- -func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) - -> tensor<2x1xf32> { +func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> +{ %0 = tensor.expand_shape %arg0 [[0, 1, 2]] : tensor<2xf32> into tensor<2x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } -// CHECK: func @fold_collapse_of_expand_trailing_unit_dims +// CHECK: func @fold_reshape_trailing_unit_dims // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- -func @fold_collapse_of_collapse_trailing_unit_dims_dynamic( - %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor { +func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) + -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]] : tensor<1x1x?x1x1x1xf32> into tensor %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] : tensor into tensor return %1 : tensor } -// CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic +// CHECK: func @fold_reshape_trailing_unit_dims_dynamic // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] // CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor // ----- -func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) +func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) -> tensor<12x42xf32> { %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> @@ -904,28 +884,27 @@ : tensor<12x42x1x1x1xf32> into tensor<12x42xf32> return %1 : tensor<12x42xf32> } -// CHECK: func @fold_collapse_of_expand_trailing_unit_dims +// CHECK: func @fold_reshape_trailing_unit_dims // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] // CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32> // ----- -func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor) - -> tensor { +func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor) -> tensor { %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } -// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle +// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle // CHECK-SAME: (%[[ARG:.*]]: tensor // CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]] // CHECK-SAME: tensor into tensor // ----- -func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>) +func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32> { %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> @@ -933,21 +912,20 @@ : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> return %1 : tensor<2x6x16xf32> } -// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible +// CHECK-LABEL: func @no_fold_reshape_incompatible // CHECK: tensor.expand_shape // CHECK: tensor.collapse_shape // ----- -func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) - -> tensor<12x1xf32> { +func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<3x2x2x1xf32> into tensor<12x1xf32> return %1 : tensor<12x1xf32> } -// CHECK: func @no_fold_collapse_of_expand_empty_expr +// CHECK: func @no_fold_reshape_empty_expr // CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> // CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0], [1], [2, 3] @@ -1024,11 +1002,11 @@ // ----- -// CHECK-LABEL: func @pad_same_static_shape( +// CHECK-LABEL: func @pad_tensor_same_static_shape( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK-NOT: tensor.pad // CHECK: return %[[ARG0]] -func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) +func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) -> tensor<5x6xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] { @@ -1040,11 +1018,11 @@ // ----- -// CHECK-LABEL: func @pad_nofold_same_static_shape( +// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK: %[[PAD:.*]] = tensor.pad // CHECK: return %[[PAD]] -func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) +func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) -> tensor<5x6xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] { @@ -1056,7 +1034,7 @@ // ----- -// CHECK-LABEL: func @pad_after_cast_different_shape( +// CHECK-LABEL: func @pad_tensor_after_cast_different_shape( // CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] @@ -1068,7 +1046,7 @@ // CHECK-SAME: tensor to tensor // CHECK: return %[[DYNAMIC]] : tensor // CHECK: } -func @pad_after_cast_different_shape(%arg0: tensor) +func @pad_tensor_after_cast_different_shape(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %dynamic = tensor.cast %arg0 : tensor to tensor @@ -1081,7 +1059,7 @@ // ----- -// CHECK-LABEL: func @pad_after_cast_same_shape( +// CHECK-LABEL: func @pad_tensor_after_cast_same_shape( // CHECK-SAME: %[[INPUT:.*]]: tensor, // CHECK-SAME: %[[PADDING:.*]]: index) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 @@ -1092,7 +1070,7 @@ // CHECK: } : tensor to tensor // CHECK: return %[[PADDED:.*]] : tensor // CHECK: } -func @pad_after_cast_same_shape(%arg0: tensor, %padding : index) +func @pad_tensor_after_cast_same_shape(%arg0: tensor, %padding : index) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %dynamic = tensor.cast %arg0 : tensor to tensor @@ -1105,11 +1083,11 @@ // ----- -// CHECK-LABEL: func @pad_of_cast( +// CHECK-LABEL: func @pad_tensor_of_cast( // CHECK-NOT: tensor.cast // CHECK: tensor.pad // CHECK: tensor<8x?xf32> to tensor<8x32xf32> -func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> { +func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.cast %t : tensor<8x?xf32> to tensor @@ -1155,7 +1133,7 @@ // ----- -func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { +func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor @@ -1165,17 +1143,17 @@ } : tensor to tensor<4x4xf32> return %1 : tensor<4x4xf32> } -// CHECK-LABEL: @pad_cast +// CHECK-LABEL: @tensor_pad_cast // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32> // CHECK: return %[[ARG0]] // ----- -// CHECK-LABEL: func @fold_pad_source_cast( +// CHECK-LABEL: func @fold_pad_tensor_source_cast( // CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32> // CHECK-NOT: tensor.cast // CHECK: %[[RESULT:.*]] = tensor.pad %[[ARG0]] -func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { +func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { %cst = arith.constant 0.0 : f32 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor %1 = tensor.pad %0 low[0, 0] high[0, 1] {