diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -946,6 +946,112 @@ return success(); } +static bool findIntermediateShape(ArrayRef lhsShape, + ArrayRef rhsShape, + SmallVector &intermediateShape, + bool isDynamic) { + if (isDynamic) { + // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 + intermediateShape = {-1}; + return true; + } + + if (lhsShape.empty() || rhsShape.empty()) { + intermediateShape = {}; + return true; + } + + unsigned currLhsDim = 0, currRhsDim = 0; + while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { + int64_t rhsSize = rhsShape[currRhsDim]; + int64_t lhsSize = lhsShape[currLhsDim]; + while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && + currRhsDim < rhsShape.size()) { + if (lhsSize < rhsSize) { + currLhsDim++; + lhsSize *= lhsShape[currLhsDim]; + } else { + currRhsDim++; + rhsSize *= rhsShape[currRhsDim]; + } + } + if (lhsSize == rhsSize) { + intermediateShape.push_back(lhsSize); + } + currRhsDim++; + currLhsDim++; + } + + // If the iterators didn't reach the end and their leftover dimensions are not + // equal to 1 an intermediate shape was not found. + while (currLhsDim < lhsShape.size()) { + if (lhsShape[currLhsDim++] != 1) { + return false; + } + } + + while (currRhsDim < rhsShape.size()) { + if (rhsShape[currRhsDim++] != 1) { + return false; + } + } + + return true; +} + +static bool createReassociationMapsForCollapse( + PatternRewriter &rewriter, ArrayRef srcShape, + ArrayRef dstShape, + SmallVector &reassociationMap, bool isDynamic) { + + // If the shape is dynamic, create a map for collapsing into one dimension. + if (isDynamic) { + SmallVector exprs; + for (int i = 0, s = srcShape.size(); i < s; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + reassociationMap = {exprs}; + return true; + } + + if (dstShape.empty()) { + reassociationMap = {}; + return true; + } + + reassociationMap.resize(dstShape.size()); + unsigned currSrcDim = 0, currDstDim = 0; + while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { + int64_t dstSize = dstShape[currDstDim]; + int64_t srcSize = srcShape[currSrcDim]; + while (srcSize < dstSize && currSrcDim < srcShape.size()) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + srcSize *= srcShape[currSrcDim]; + } + if (srcSize == dstSize) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + // If the next dim in collapsedShape is not 1, treat subsequent dims in + // expandedShape which are 1 to be collapsed. + if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { + while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + } + } + } + currDstDim++; + } + + // If both iterators didn't reach the end, we have leftover dimentions which + // implies that we have a mismatch in shape. + if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) { + return false; + } + + return true; +} + namespace { template @@ -1534,7 +1640,7 @@ } }; -class ReshapeConverter : public OpConversionPattern { +class ReshapeConverterCollapse : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1543,103 +1649,116 @@ ConversionPatternRewriter &rewriter) const final { ShapedType operandTy = adaptor.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); + bool isDynamic = !operandTy.hasStaticShape(); + + if (isDynamic && resultTy.getRank() != 1) { + return rewriter.notifyMatchFailure( + reshape, "Cannot collapse dynamic dims to more than one dimension"); + } if (operandTy == resultTy) { rewriter.replaceOp(reshape, adaptor.getOperands()[0]); return success(); } - if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) - return failure(); + SmallVector reassociationMap; + if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), + resultTy.getShape(), + reassociationMap, isDynamic)) { + return rewriter.notifyMatchFailure( + reshape, + "tosa.reshape Attempting to collapse into an incompatible shape"); + } - // Compute the reassociation maps for the linalg operation. - ArrayRef expandedShape = - (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape() - : resultTy.getShape()); - ArrayRef collapsedShape = - (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape() - : operandTy.getShape()); - unsigned currSrcDim = 0, currDstDim = 0; - SmallVector reassociationMap(collapsedShape.size()); - - // First scan all dimensions in the source shapes to see whether we have a - // perfect case where consecutive dimensions in source are collapsed. For - // such case we can just generate one single linalg.reshape. - bool isCollapsingSource = true; - while (currSrcDim < expandedShape.size() && - currDstDim < collapsedShape.size()) { - int64_t dstSize = collapsedShape[currDstDim]; - int64_t srcSize = expandedShape[currSrcDim]; - while (srcSize < dstSize && currSrcDim < expandedShape.size()) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - srcSize *= expandedShape[currSrcDim]; - } - if (srcSize == dstSize) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - // If the next dim in collapsedShape is not 1, treat subsequent dims in - // expandedShape which are 1 to be collapsed. - if (currDstDim == collapsedShape.size() - 1 || - collapsedShape[currDstDim + 1] != 1) { - while (currSrcDim < expandedShape.size() && - expandedShape[currSrcDim] == 1) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - } - } - } else { - isCollapsingSource = false; - break; - } - currDstDim++; + SmallVector intermediateShape; + if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), + intermediateShape, isDynamic)) { + return rewriter.notifyMatchFailure( + reshape, "tosa.reshape Cannot collapse into given shape"); } - // Check if any remaining dimensions exist. If either is rank-0 we only - // require the directly lowering. - if (currSrcDim != expandedShape.size() || - currDstDim != collapsedShape.size()) - isCollapsingSource = collapsedShape.empty() || expandedShape.empty(); - - // Otherwise, we need to first reduce all source dimensions into one and - // then expand to the destination dimensions. - if (!isCollapsingSource) { - auto getIdentityExprs = [&rewriter](int n) { - SmallVector exprs; - for (int i = 0; i < n; ++i) - exprs.push_back(rewriter.getAffineDimExpr(i)); - return exprs; - }; - Location loc = reshape.getLoc(); - int64_t totalElems = - std::accumulate(expandedShape.begin(), expandedShape.end(), 1, - std::multiplies()); - auto elemTy = operandTy.getElementType(); - SmallVector collapsingMap = { - // Use operandTy here because we need to collapse all operands - // dimensions. - getIdentityExprs(operandTy.getShape().size())}; - SmallVector expandingMap = { - // Use resultTy here because we need to expand to all result - // dimensions. - getIdentityExprs(resultTy.getShape().size())}; - - auto collapsedTy = RankedTensorType::get({totalElems}, elemTy); - Value collapsedOp = rewriter.create( - loc, collapsedTy, adaptor.getOperands()[0], collapsingMap); - rewriter.replaceOpWithNewOp( - reshape, resultTy, collapsedOp, expandingMap); + rewriter.replaceOpWithNewOp( + reshape, resultTy, adaptor.getOperands()[0], reassociationMap); + return success(); + } +}; + +class ReshapeConverterExpand : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + ShapedType operandTy = adaptor.input1().getType().cast(); + ShapedType resultTy = reshape.getType().template cast(); + bool isDynamic = !operandTy.hasStaticShape(); + if (operandTy == resultTy) { + rewriter.replaceOp(reshape, adaptor.getOperands()[0]); return success(); } - if (resultTy.getRank() < - adaptor.getOperands()[0].getType().cast().getRank()) - rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); - else - rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); + if (isDynamic && operandTy.getRank() != 1) { + return rewriter.notifyMatchFailure( + reshape, "Cannot expand dynamic dims from more than one dimension"); + } + + SmallVector reassociationMap; + if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), + operandTy.getShape(), + reassociationMap, isDynamic)) { + return rewriter.notifyMatchFailure( + reshape, + "tosa.reshape Attempting to expand into an incompatible shape"); + } + + SmallVector intermediateShape; + if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), + intermediateShape, isDynamic) || + intermediateShape != operandTy.getShape()) { + return rewriter.notifyMatchFailure( + reshape, "tosa.reshape Cannot expand into given shape"); + } + rewriter.replaceOpWithNewOp( + reshape, resultTy, adaptor.getOperands()[0], reassociationMap); + return success(); + } +}; + +class ReshapeConverterCollapseExpand + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + ShapedType operandTy = adaptor.input1().getType().cast(); + ShapedType resultTy = reshape.getType().template cast(); + bool isDynamic = !operandTy.hasStaticShape(); + + if (operandTy == resultTy) { + rewriter.replaceOp(reshape, adaptor.getOperands()[0]); + return success(); + } + + SmallVector intermediateShape; + if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(), + intermediateShape, isDynamic)) { + return rewriter.notifyMatchFailure( + reshape, "tosa.reshape Cannot identify an intermediate shape between " + "the given two shapes"); + } + + Value collapse = rewriter.create( + reshape.getLoc(), + RankedTensorType::get(intermediateShape, + reshape.getType().getElementType()), + adaptor.input1()); + Value expand = + rewriter.create(reshape.getLoc(), resultTy, collapse); + rewriter.replaceOp(reshape, expand); return success(); } @@ -3072,7 +3191,9 @@ TransposeConvConverter, GatherConverter, PadConverter, - ReshapeConverter, + ReshapeConverterCollapse, + ReshapeConverterExpand, + ReshapeConverterCollapseExpand, RescaleConverter, ResizeConverter, ReverseConverter, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -541,6 +541,16 @@ // ----- +// CHECK-LABEL: @test_reshape_downrank_dyn +func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<2x?xf32>) -> tensor + // CHECK: return [[RESHAPE]] + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @test_reshape_uprank func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]] @@ -551,6 +561,16 @@ // ----- +// CHECK-LABEL: @test_reshape_uprank_dyn +func @test_reshape_uprank_dyn(%arg0: tensor) -> tensor<2x?xf32> { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor) -> tensor<2x?xf32> + // CHECK: return [[RESHAPE]] + return %0 : tensor<2x?xf32> +} + +// ----- + // CHECK-LABEL: @test_reshape_samerank func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>) @@ -563,6 +583,18 @@ // ----- +// CHECK-LABEL: @test_reshape_samerank_dyn +func @test_reshape_samerank_dyn(%arg0: tensor) -> tensor<2x?xf32> { + // CHECK-SAME: (%[[ARG0:.*]]: tensor) + // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor) -> tensor<2x?xf32> + // CHECK-NEXT: return %[[RESHAPE2]] + return %0 : tensor<2x?xf32> +} + +// ----- + // CHECK-LABEL: @test_reshape_downrank_6D func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]] @@ -572,6 +604,16 @@ // ----- +// CHECK-LABEL: @test_reshape_downrank_6D_dyn +func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { + // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]] + // CHECK: linalg.tensor_expand_shape %0 {{\[}}[0, 1, 2]] + %0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @test_identity func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) { %0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>