diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -56,6 +56,12 @@ using ReassociationIndicesRef = ArrayRef; using ReassociationExprs = SmallVector; +/// Return the reassociations maps to use to reshape given the source type and +/// the target type when possible. Return llvm::None when this computation +/// failed. +Optional> +getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); + /// Returns the name mangled library call name to disambiguate between different /// overloads at the C level. The name mangling scheme is basic and uses MLIR /// type names: diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1050,6 +1050,78 @@ // ReshapeOp //===----------------------------------------------------------------------===// +Optional> +mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType, + ShapedType targetType) { + // Make the sourceType greater rank than the targetType. If they are same + // rank, then its an unsupported reshape op. + if (sourceType.getRank() == targetType.getRank()) + return llvm::None; + if (sourceType.getRank() < targetType.getRank()) + std::swap(sourceType, targetType); + + ArrayRef sourceShape = sourceType.getShape(); + ArrayRef targetShape = targetType.getShape(); + unsigned sourceDim = 0; + SmallVector reassociationMap; + reassociationMap.reserve(targetType.getRank()); + + ReassociationIndices currIndices; + int64_t prodOfCollapsedDims = 1; + while (sourceDim < sourceShape.size()) { + unsigned targetDim = reassociationMap.size(); + + // If all the dimensions of the targetShape are exhausted, then the + // remaining dims in the source shape must be all 1s. So for such cases, set + // 1 as the target shape. The actual reassociation indices will be handled + // later. + int64_t currTargetShape = + (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); + while (sourceShape[sourceDim] != ShapedType::kDynamicSize && + prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && + sourceDim < sourceShape.size()) { + prodOfCollapsedDims *= sourceShape[sourceDim]; + currIndices.push_back(sourceDim++); + } + + // If the current expanded dimension is dynamic, then the collapsed + // dimensions should also be dynamic and product of all previous unprocessed + // dimensions of the expanded shape should be 1. + if (sourceShape[sourceDim] == ShapedType::kDynamicSize && + (currTargetShape != ShapedType::kDynamicSize || + prodOfCollapsedDims != 1)) + return llvm::None; + + // If the collapsed dim is dynamic, the current expanded dim should also + // be dynamic. + if (currTargetShape == ShapedType::kDynamicSize && + sourceShape[sourceDim] != ShapedType::kDynamicSize) + return llvm::None; + + // For static shapes, if the product of dimensions of the expanded shape + // should match the collapsed dimension shape. + if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) + return llvm::None; + + currIndices.push_back(sourceDim++); + // If there are no dimensions in the target to match, then append the + // `currIndices` to the last element of the reassociationMap. + if (targetDim == targetShape.size()) { + reassociationMap.back().append(currIndices.begin(), currIndices.end()); + // Break out of the loops. We should be done here. + break; + } + reassociationMap.emplace_back(ReassociationIndices{}); + std::swap(reassociationMap.back(), currIndices); + prodOfCollapsedDims = 1; + } + // All the dimensions in the two shapes must have been processed. + if (reassociationMap.size() != targetShape.size() || + sourceDim != sourceShape.size()) + return llvm::None; + return reassociationMap; +} + /// Collapse reassociation maps that are used in pair of reshape ops where one /// is a producer and other is the consumer. Only valid to use this method when /// both the producer and consumer are collapsing dimensions or both are @@ -1066,34 +1138,39 @@ /// /// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] -static ArrayAttr collapseReassociationMaps(ArrayRef mapsProducer, - ArrayRef mapsConsumer, - MLIRContext *context) { +static Optional> +collapseReassociationMaps(ArrayRef mapsProducer, + ArrayRef mapsConsumer, + MLIRContext *context) { + // Make the producer the larger sized vector. If they are of same size, the + // resulting reshape is not a supported reshape op. + if (mapsProducer.size() == mapsConsumer.size()) + return llvm::None; + if (mapsProducer.size() < mapsConsumer.size()) + std::swap(mapsProducer, mapsConsumer); + // Handle the corner case of the result being a rank 0 shaped type. Return an - // emtpy ArrayAttr. - if (mapsConsumer.empty() && !mapsProducer.empty()) - return ArrayAttr::get(context, ArrayRef()); - if (mapsProducer.empty() || mapsConsumer.empty() || - mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() || - mapsProducer.size() != mapsConsumer[0].getNumDims()) - return nullptr; - unsigned numLhsDims = mapsProducer[0].getNumDims(); + // empty reassociation. + if (mapsConsumer.empty()) + return SmallVector{}; + if (mapsProducer.size() != mapsConsumer[0].getNumDims()) + return llvm::None; + unsigned currDim = 0; - SmallVector reassociations; - SmallVector reassociationMaps; + ReassociationIndices reassociations; + SmallVector reassociationMaps; for (AffineMap rhs : mapsConsumer) { for (AffineExpr rhsExpr : rhs.getResults()) { AffineDimExpr dimExpr = rhsExpr.cast(); for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); i < e; ++i) { - reassociations.push_back(getAffineDimExpr(currDim++, context)); + reassociations.push_back(currDim++); } } - reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( - numLhsDims, /*numSymbols =*/0, reassociations, context))); - reassociations.clear(); + reassociationMaps.emplace_back(ReassociationIndices{}); + std::swap(reassociationMaps.back(), reassociations); } - return ArrayAttr::get(context, reassociationMaps); + return reassociationMaps; } namespace { @@ -1108,33 +1185,43 @@ if (!srcReshapeOp) return failure(); + ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); + ShapedType intermediateType = reshapeOp.getSrcType(); + ShapedType resultType = reshapeOp.getResultType(); + auto areReshapeOpsFoldable = [](ShapedType largerType, ShapedType intermediateType, ShapedType smallerType) -> bool { return largerType.getRank() > intermediateType.getRank() && intermediateType.getRank() > smallerType.getRank(); }; - // Check if producer and consumer are both expanding dims. - if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(), - srcReshapeOp.getSrcType())) { - rewriter.replaceOpWithNewOp( - reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(), - collapseReassociationMaps(reshapeOp.getReassociationMaps(), - srcReshapeOp.getReassociationMaps(), - rewriter.getContext())); - return success(); + Optional> reassociationMaps = llvm::None; + // Check if producer and consumer are both expanding dims or both collapsing + // dims. In this case, try to compose the affine maps. This works for + // dynamic shapes too. + if (areReshapeOpsFoldable(resultType, intermediateType, + srcReshapeSrcType) || + areReshapeOpsFoldable(srcReshapeSrcType, intermediateType, + resultType)) { + reassociationMaps = collapseReassociationMaps( + srcReshapeOp.getReassociationMaps(), reshapeOp.getReassociationMaps(), + rewriter.getContext()); } - // Check if producer and consumer are both collapsing dims. - if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), reshapeOp.getSrcType(), - reshapeOp.getResultType())) { - rewriter.replaceOpWithNewOp( - reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(), - collapseReassociationMaps(srcReshapeOp.getReassociationMaps(), - reshapeOp.getReassociationMaps(), - rewriter.getContext())); - return success(); + if (!reassociationMaps) { + // If the source reshape can be collapsed/expanded into the target reshape + // they can still be folded. This can only be reasoned about statically + // for cases where + // - either all shapes are static, or + // - The number of dynamic dimensions matches in the source of source and + // result with all other dimensions being 1. + reassociationMaps = + getReassociationIndicesForReshape(srcReshapeSrcType, resultType); } - return failure(); + if (!reassociationMaps) + return failure(); + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.src(), *reassociationMaps); + return success(); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -427,123 +427,6 @@ return success(); } }; - -/// Pattern to fold pair of reshape ops where the intermediate has unit-dims for -/// example: -/// -/// %0 = linalg.tensor_reshape %arg0 -/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] -/// : tensor<2048xf32> into tensor<1x4x1x512xf32> -/// %1 = linalg.tensor_reshape %0 -/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, -/// affine_map<(d0, d1, d2, d3) -> (d3)>] -/// : tensor<1x4x1x512xf32> into tensor<4x512xf32> -/// -/// can be replaced with -/// -/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] -/// : tensor<2048xf32> into tensor<4x512xf32> -/// -/// Similarly, -/// -/// %0 = linalg.tensor_reshape %arg0 -/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, -/// affine_map<(d0, d1, d2, d3) -> (d3)>] -/// : tensor<4x512xf32> into tensor<1x4x1x512xf32> -/// %1 = linalg.tensor_reshape %0 -/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] -/// : tensor<1x4x1x512xf32> into tensor<2048xf32> -/// -/// can be replaced with -/// -/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] -/// : tensor<4x512xf32> into tensor<2048xf32> -struct FoldReshapeOpWithUnitExtent : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - // Check that the source operand is created from a reshape as well. - TensorReshapeOp parentReshapeOp = - reshapeOp.src().getDefiningOp(); - if (!parentReshapeOp) - return failure(); - - RankedTensorType srcType = reshapeOp.getSrcType(), - dstType = reshapeOp.getResultType(), - parentSrcType = parentReshapeOp.getSrcType(); - if (!srcType.hasStaticShape() || !dstType.hasStaticShape() || - !parentSrcType.hasStaticShape() || - srcType.getRank() < dstType.getRank() || - parentSrcType.getRank() == dstType.getRank()) - return failure(); - - // Check if the result tensor_reshape is folding or expanding after folding - // the reshapeOp and parentReshapeOp are combined. If the final - // tensor_reshape is folding, the parentReshapeOp is introducing unit-dims, - // and the reshapeOp does an actual reshape. If the final tensor_reshape op - // is expanding, the reshapeOp is introducing unit-dims, and the - // parentReshapeOp does an actual reshape. - bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); - ArrayRef expandedShape = - isFoldingPattern ? parentSrcType.getShape() : dstType.getShape(); - ArrayRef foldedShape = - isFoldingPattern ? dstType.getShape() : parentSrcType.getShape(); - - unsigned expandedDim = 0, foldedDim = 0; - SmallVector, 4> reassociationExprs( - foldedShape.size()); - while (expandedDim < expandedShape.size() && - foldedDim < foldedShape.size()) { - int64_t dstSize = foldedShape[foldedDim]; - int64_t srcSize = expandedShape[expandedDim]; - while (srcSize < dstSize && expandedDim < expandedShape.size()) { - reassociationExprs[foldedDim].push_back( - rewriter.getAffineDimExpr(expandedDim++)); - srcSize *= expandedShape[expandedDim]; - } - if (srcSize == dstSize) { - reassociationExprs[foldedDim].push_back( - rewriter.getAffineDimExpr(expandedDim++)); - // If the next dim in foldedShape is not 1, treat subsequent dims in - // expandedShape which are 1 to be collapsed. - if (foldedDim == foldedShape.size() - 1 || - foldedShape[foldedDim + 1] != 1) { - while (expandedDim < expandedShape.size() && - expandedShape[expandedDim] == 1) { - reassociationExprs[foldedDim].push_back( - rewriter.getAffineDimExpr(expandedDim++)); - } - } - } else { - return failure(); - } - - foldedDim++; - // If inner most dims are folded there shouldn't be any leading 1 dims. - // otherwise these dims are not mapped and will lead into an illegal - // reshape. - if (expandedDim == expandedShape.size()) { - if (foldedDim < foldedShape.size() && foldedShape[foldedDim] == 1) { - return failure(); - } - } - } - if (expandedDim != expandedShape.size()) - return failure(); - - SmallVector reassociationMaps = - llvm::to_vector<4>(llvm::map_range( - reassociationExprs, [&](ArrayRef exprs) -> AffineMap { - return AffineMap::get(expandedShape.size(), 0, exprs, - rewriter.getContext()); - })); - rewriter.replaceOpWithNewOp( - reshapeOp, dstType, parentReshapeOp.src(), - rewriter.getAffineMapArrayAttr(reassociationMaps)); - return success(); - } -}; } // namespace /// Get the reassociation maps to convert a `type` to its rank-reduced version. @@ -627,7 +510,6 @@ UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>( context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); - patterns.add(context); } namespace { diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -223,6 +223,272 @@ // ----- +func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>] + : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] + : tensor<40320xf32> into tensor<24x5x42x8xf32> + return %1 : tensor<24x5x42x8xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d6)> +// CHECK: func @reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> +// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK: return %[[RESULT]] + +// ----- + +func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] + : tensor<24x5x42x8xf32> into tensor<40320xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>] + : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> + return %1 : tensor<2x3x4x5x6x7x8xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d6)> +// CHECK: func @reshape_expand +// CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> +// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK: return %[[RESULT]] + +// ----- + +func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] + : tensor<2048xf32> into tensor<1x4x1x512xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d3)>] + : tensor<1x4x1x512xf32> into tensor<4x512xf32> + return %1 : tensor<4x512xf32> +} +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @expand_reshape_1D +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] +// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> + +// ----- + +func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d3)>] + : tensor<4x512xf32> into tensor<1x4x1x512xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] + : tensor<1x4x1x512xf32> into tensor<2048xf32> + return %1 : tensor<2048xf32> +} +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @fold_reshape_1D +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] +// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> + +// ----- + +func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>] + : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>] + : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> + return %1 : tensor<4x512x1x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: func @fold_reshape_unit_dims +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> + +// ----- + +func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8)>] + : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d8)>] + : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> + return %1 : tensor<4x512x1x512x4xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +// CHECK: func @expand_reshape_unit_dims +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> + +// ----- + +func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)> + ] : tensor<2x1x1xf32> into tensor<2x1xf32> + return %1 : tensor<2x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]] +// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> + +// ----- + +func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor) -> tensor +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d3)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d4)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8)>] + : tensor into tensor + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>] + : tensor into tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d3, d4, d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8)> +// CHECK: func @collapse_reshape_unit_dims_dynamic +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: tensor into tensor + +// ----- + +func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)> + ] : tensor<2x1x1xf32> into tensor<2x1xf32> + return %1 : tensor<2x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]] +// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> + +// ----- + +func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) -> tensor +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>] + : tensor<1x1x?x1x1x1xf32> into tensor + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] + : tensor into tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: func @fold_reshape_trailing_unit_dims_dynamic +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]] +// CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor + +// ----- + +func @no_fold_reshapes(%arg0 : tensor) -> tensor +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d1)>, + affine_map<(d0, d1, d2, d3) -> (d2, d3)>] + : tensor into tensor + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @no_fold_reshapes +// CHECK: linalg.tensor_reshape +// CHECK: linalg.tensor_reshape + +// ----- + +func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d4)>] + : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] + : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> + return %1 : tensor<2x6x16xf32> +} +// CHECK-LABEL: func @no_fold_reshape_incompatible +// CHECK: linalg.tensor_reshape +// CHECK: linalg.tensor_reshape + +// ----- + +func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> + %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<3x2x2x1xf32> into tensor<12x1xf32> + return %1 : tensor<12x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: func @no_fold_reshape_empty_expr +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> +// CHECK: %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0:.+]] [#[[MAP0]], #[[MAP1]], #[[MAP2]] +// CHECK: %[[RES:.+]] = linalg.tensor_reshape %[[RARG0:.+]] [#[[MAP3]], #[[MAP4]]] +// CHECK: return %[[RES:.+]] : tensor<12x1xf32> + +// ----- + #accesses = [ affine_map<(i) -> (i)> ] diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -317,104 +317,6 @@ // ----- -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @fold_reshape -// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] -// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> -func @fold_reshape(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> -{ - %0 = linalg.tensor_reshape %arg0 - [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] - : tensor<2048xf32> into tensor<1x4x1x512xf32> - %1 = linalg.tensor_reshape %0 - [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2, d3) -> (d3)>] - : tensor<1x4x1x512xf32> into tensor<4x512xf32> - return %1 : tensor<4x512xf32> -} - -// ----- - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @fold_reshape -// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] -// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> -func @fold_reshape(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> -{ - %0 = linalg.tensor_reshape %arg0 - [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2, d3) -> (d3)>] - : tensor<4x512xf32> into tensor<1x4x1x512xf32> - %1 = linalg.tensor_reshape %0 - [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] - : tensor<1x4x1x512xf32> into tensor<2048xf32> - return %1 : tensor<2048xf32> -} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: func @fold_reshape -// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: tensor<2048x1xf32> into tensor<4x512x1xf32> -func @fold_reshape(%arg0 : tensor<2048x1xf32>) -> tensor<4x512x1xf32> -{ - %0 = linalg.tensor_reshape %arg0 - [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, - affine_map<(d0, d1, d2, d3, d4) -> (d4)>] - : tensor<2048x1xf32> into tensor<1x4x1x512x1xf32> - %1 = linalg.tensor_reshape %0 - [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2, d3, d4) -> (d3)>, - affine_map<(d0, d1, d2, d3, d4) -> (d4)>] - : tensor<1x4x1x512x1xf32> into tensor<4x512x1xf32> - return %1 : tensor<4x512x1xf32> -} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> -// CHECK: func @fold_reshape -// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> -func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> -{ - %0 = linalg.tensor_reshape %arg0 - [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>, - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>, - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8)>] - : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> - %1 = linalg.tensor_reshape %0 - [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d3, d4)>, - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>, - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7)>, - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d8)>] - : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> - return %1 : tensor<4x512x1x512x4xf32> -} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @fold_reshape -// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]] -// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> -func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32> -{ - %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32> - %1 = linalg.tensor_reshape %0 - [affine_map<(d0, d1, d2) -> (d0)>, - affine_map<(d0, d1, d2) -> (d1, d2)> - ] : tensor<2x1x1xf32> into tensor<2x1xf32> - return %1 : tensor<2x1xf32> -} - -// ----- - #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -612,21 +514,3 @@ // CHECK-SAME: outs(%[[FILL]] : tensor) // CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]] // CHECK: return %[[RESULT_RESHAPE]] - -// ----- - -func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { - %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> - %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<3x2x2x1xf32> into tensor<12x1xf32> - return %1 : tensor<12x1xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> -// CHECK: func @no_fold_reshape_empty_expr -// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> -// CHECK: %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0:.+]] [#[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK: %[[RES:.+]] = linalg.tensor_reshape %[[RARG0:.+]] [#[[MAP3]], #[[MAP4]]] -// CHECK: return %[[RES:.+]] : tensor<12x1xf32>