diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -897,8 +897,8 @@ static bool findIntermediateShape(ArrayRef lhsShape, ArrayRef rhsShape, SmallVector &intermediateShape, - bool isDynamic) { - if (isDynamic) { + int64_t dynamicDims) { + if (dynamicDims == 1) { // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 intermediateShape = {-1}; return true; @@ -915,6 +915,14 @@ int64_t lhsSize = lhsShape[currLhsDim]; while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { + if ((lhsSize == -1) != (rhsSize == -1)) + return false; + if (lhsSize == -1 && rhsSize == -1) { + intermediateShape.push_back(lhsSize); + currLhsDim++; + currRhsDim++; + continue; + } if (lhsSize < rhsSize) { currLhsDim++; lhsSize *= lhsShape[currLhsDim]; @@ -950,10 +958,16 @@ static bool createReassociationMapsForCollapse( PatternRewriter &rewriter, ArrayRef srcShape, ArrayRef dstShape, - SmallVector &reassociationMap, bool isDynamic) { + SmallVector &reassociationMap, int64_t dynamicDims) { + + if (dstShape.empty()) { + reassociationMap = {}; + return true; + } - // If the shape is dynamic, create a map for collapsing into one dimension. - if (isDynamic) { + // If the shape is dynamic with one dynamic dim, create a map for collapsing + // into one dimension. + if (dynamicDims == 1) { SmallVector exprs; for (int i = 0, s = srcShape.size(); i < s; ++i) exprs.push_back(rewriter.getAffineDimExpr(i)); @@ -961,26 +975,35 @@ return true; } - if (dstShape.empty()) { - reassociationMap = {}; - return true; - } - + // If there are multiple dynamic dims, only reshape if all the other dims + // match up reassociationMap.resize(dstShape.size()); + unsigned currSrcDim = 0, currDstDim = 0; while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { int64_t dstSize = dstShape[currDstDim]; int64_t srcSize = srcShape[currSrcDim]; + if ((srcSize == -1) != (dstSize == -1)) + return false; + + if (srcSize == -1) { + reassociationMap[currDstDim++].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + if (currSrcDim) + continue; + } + while (srcSize < dstSize && currSrcDim < srcShape.size()) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); srcSize *= srcShape[currSrcDim]; } - if (srcSize == dstSize) { + + if (srcSize == dstSize && srcSize != -1) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); - // If the next dim in collapsedShape is not 1, treat subsequent dims in - // expandedShape which are 1 to be collapsed. + // If the next dim in collapsedShape is not 1, treat subsequent dims + // in expandedShape which are 1 to be collapsed. if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { reassociationMap[currDstDim].push_back( @@ -1018,22 +1041,25 @@ ConversionPatternRewriter &rewriter) const final { ShapedType operandTy = adaptor.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); - - if (isDynamic && resultTy.getRank() != 1) { - return rewriter.notifyMatchFailure( - reshape, "Cannot collapse dynamic dims to more than one dimension"); - } + int64_t dynamicDims = 0; if (operandTy == resultTy) { rewriter.replaceOp(reshape, adaptor.getOperands()[0]); return success(); } + if (operandTy.getNumDynamicDims() == resultTy.getNumDynamicDims()) { + dynamicDims = operandTy.getNumDynamicDims(); + } else { + return rewriter.notifyMatchFailure( + reshape, "tosa.reshape must have the same number of dynamic " + "dimentions for input and output"); + } + SmallVector reassociationMap; if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), resultTy.getShape(), - reassociationMap, isDynamic)) { + reassociationMap, dynamicDims)) { return rewriter.notifyMatchFailure( reshape, "tosa.reshape Attempting to collapse into an incompatible shape"); @@ -1041,7 +1067,7 @@ SmallVector intermediateShape; if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), - intermediateShape, isDynamic)) { + intermediateShape, dynamicDims)) { return rewriter.notifyMatchFailure( reshape, "tosa.reshape Cannot collapse into given shape"); } @@ -1061,22 +1087,25 @@ ConversionPatternRewriter &rewriter) const final { ShapedType operandTy = adaptor.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); + int64_t dynamicDims = 0; if (operandTy == resultTy) { rewriter.replaceOp(reshape, adaptor.getOperands()[0]); return success(); } - if (isDynamic && operandTy.getRank() != 1) { + if (operandTy.getNumDynamicDims() == resultTy.getNumDynamicDims()) { + dynamicDims = operandTy.getNumDynamicDims(); + } else { return rewriter.notifyMatchFailure( - reshape, "Cannot expand dynamic dims from more than one dimension"); + reshape, "tosa.reshape must have the same number of dynamic " + "dimentions for input and output"); } SmallVector reassociationMap; if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), operandTy.getShape(), - reassociationMap, isDynamic)) { + reassociationMap, dynamicDims)) { return rewriter.notifyMatchFailure( reshape, "tosa.reshape Attempting to expand into an incompatible shape"); @@ -1084,7 +1113,7 @@ SmallVector intermediateShape; if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), - intermediateShape, isDynamic) || + intermediateShape, dynamicDims) || intermediateShape != operandTy.getShape()) { return rewriter.notifyMatchFailure( reshape, "tosa.reshape Cannot expand into given shape"); @@ -1105,19 +1134,27 @@ ConversionPatternRewriter &rewriter) const final { ShapedType operandTy = adaptor.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); + int64_t dynamicDims = 0; if (operandTy == resultTy) { rewriter.replaceOp(reshape, adaptor.getOperands()[0]); return success(); } + if (operandTy.getNumDynamicDims() == resultTy.getNumDynamicDims()) { + dynamicDims = operandTy.getNumDynamicDims(); + } else { + return rewriter.notifyMatchFailure( + reshape, "tosa.reshape must have the same number of dynamic " + "dimentions for input and output"); + } + SmallVector intermediateShape; if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(), - intermediateShape, isDynamic)) { + intermediateShape, dynamicDims)) { return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot identify an intermediate shape between " - "the given two shapes"); + reshape, "tosa.reshape Cannot identify an intermediate shape" + "between the given two shapes"); } Value collapse = rewriter.create( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -614,6 +614,35 @@ // ----- +//CHECK-LABEL: @test_reshape_collapse_multiple_dyn_dims +func.func @test_reshape_collapse_multiple_dyn_dims(%arg0: tensor<1x?x?x256x1xf32>) -> (tensor<1x?x?x256xf32>) { + // CHECK: tensor.collapse_shape %arg0 {{\[}}[0], [1], [2], [3, 4]] + %0 = "tosa.reshape"(%arg0) {new_shape = [1, -1, -1, 256]} : (tensor<1x?x?x256x1xf32>) -> tensor<1x?x?x256xf32> + return %0 : tensor<1x?x?x256xf32> +} + +// ----- + +//CHECK-LABEL: @test_reshape_expand_multiple_dyn_dims +func.func @test_reshape_expand_multiple_dyn_dims(%arg0: tensor) -> (tensor) { + // CHECK: tensor.expand_shape %arg0 {{\[}}[0], [1], [2], [3, 4], [5, 6]] + %0 = "tosa.reshape"(%arg0) {new_shape = [-1, -1, -1, 2, 3, 2, 4]} : (tensor) -> tensor + return %0 : tensor +} + + +// ----- + +//CHECK-LABEL: @test_reshape_collapse_expand_multiple_dyn_dims +func.func @test_reshape_collapse_expand_multiple_dyn_dims(%arg0: tensor<2x3x8x?x?xf32>) -> (tensor<6x2x4x?x?xf32>) { + // CHECK: tensor.collapse_shape %arg0 {{\[}}[0, 1], [2], [3], [4]] + // CHECK: tensor.expand_shape %0 {{\[}}[0], [1, 2], [3], [4]] + %0 = "tosa.reshape"(%arg0) {new_shape = [6, 2, 4, -1, -1]} : (tensor<2x3x8x?x?xf32>) -> tensor<6x2x4x?x?xf32> + return %0 : tensor<6x2x4x?x?xf32> +} + +// ----- + // CHECK-LABEL: @test_identity func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) { %0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> @@ -1241,7 +1270,7 @@ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { - // CHECK: ^bb0(%arg1: index, %arg2: index): + // CHECK: ^bb0(%arg1: index, %arg2: index): // CHECK: tensor.yield [[CST]] // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>) @@ -1277,7 +1306,7 @@ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32 // CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { - // CHECK: ^bb0(%arg1: index, %arg2: index): + // CHECK: ^bb0(%arg1: index, %arg2: index): // CHECK: tensor.yield [[CST]] // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> %1 = arith.constant dense<42.0> : tensor