diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -889,111 +889,17 @@ if (!didEncounterError) return failure(); - rewriter.replaceOpWithNewOp(op, resultTy, - linalgOp.getResults()); - return success(); -} - -static bool findIntermediateShape(ArrayRef lhsShape, - ArrayRef rhsShape, - SmallVector &intermediateShape, - bool isDynamic) { - if (isDynamic) { - // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 - intermediateShape = {-1}; - return true; - } - - if (lhsShape.empty() || rhsShape.empty()) { - intermediateShape = {}; - return true; - } - - unsigned currLhsDim = 0, currRhsDim = 0; - while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { - int64_t rhsSize = rhsShape[currRhsDim]; - int64_t lhsSize = lhsShape[currLhsDim]; - while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && - currRhsDim < rhsShape.size()) { - if (lhsSize < rhsSize) { - currLhsDim++; - lhsSize *= lhsShape[currLhsDim]; - } else { - currRhsDim++; - rhsSize *= rhsShape[currRhsDim]; - } - } - if (lhsSize == rhsSize) { - intermediateShape.push_back(lhsSize); - } - currRhsDim++; - currLhsDim++; - } - - // If the iterators didn't reach the end and their leftover dimensions are not - // equal to 1 an intermediate shape was not found. - while (currLhsDim < lhsShape.size()) { - if (lhsShape[currLhsDim++] != 1) { - return false; - } - } - - while (currRhsDim < rhsShape.size()) { - if (rhsShape[currRhsDim++] != 1) { - return false; - } - } - - return true; -} - -static bool createReassociationMapsForCollapse( - PatternRewriter &rewriter, ArrayRef srcShape, - ArrayRef dstShape, - SmallVector &reassociationMap, bool isDynamic) { - - // If the shape is dynamic, create a map for collapsing into one dimension. - if (isDynamic) { - SmallVector exprs; - for (int i = 0, s = srcShape.size(); i < s; ++i) - exprs.push_back(rewriter.getAffineDimExpr(i)); - reassociationMap = {exprs}; - return true; + SmallVector reassociationMap; + int reduceRank = reduceShape.size(); + reassociationMap.resize(reduceRank); + for (int i = 0; i < reduceRank; i++) { + reassociationMap[i].push_back(rewriter.getAffineDimExpr(i)); } - - if (dstShape.empty()) { - reassociationMap = {}; - return true; - } - - reassociationMap.resize(dstShape.size()); - unsigned currSrcDim = 0, currDstDim = 0; - while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { - int64_t dstSize = dstShape[currDstDim]; - int64_t srcSize = srcShape[currSrcDim]; - while (srcSize < dstSize && currSrcDim < srcShape.size()) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - srcSize *= srcShape[currSrcDim]; - } - if (srcSize == dstSize) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - // If the next dim in collapsedShape is not 1, treat subsequent dims in - // expandedShape which are 1 to be collapsed. - if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { - while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - } - } - } - currDstDim++; - } - - // If both iterators didn't reach the end, we have leftover dimentions which - // implies that we have a mismatch in shape. - return !(currSrcDim != srcShape.size() || currDstDim != dstShape.size()); + reassociationMap[reduceRank - 1].push_back( + rewriter.getAffineDimExpr(reduceRank)); + rewriter.replaceOpWithNewOp( + op, resultTy, linalgOp.getResult(0), reassociationMap); + return success(); } namespace { @@ -1009,50 +915,7 @@ } }; -class ReshapeConverterCollapse : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.input1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); - - if (isDynamic && resultTy.getRank() != 1) { - return rewriter.notifyMatchFailure( - reshape, "Cannot collapse dynamic dims to more than one dimension"); - } - - if (operandTy == resultTy) { - rewriter.replaceOp(reshape, adaptor.getOperands()[0]); - return success(); - } - - SmallVector reassociationMap; - if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), - resultTy.getShape(), - reassociationMap, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, - "tosa.reshape Attempting to collapse into an incompatible shape"); - } - - SmallVector intermediateShape; - if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), - intermediateShape, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot collapse into given shape"); - } - - rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); - return success(); - } -}; - -class ReshapeConverterExpand : public OpConversionPattern { +class ReshapeConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1061,74 +924,29 @@ ConversionPatternRewriter &rewriter) const final { ShapedType operandTy = adaptor.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); + int64_t resultRank = resultTy.getRank(); + Location loc = reshape.getLoc(); if (operandTy == resultTy) { rewriter.replaceOp(reshape, adaptor.getOperands()[0]); return success(); } - if (isDynamic && operandTy.getRank() != 1) { - return rewriter.notifyMatchFailure( - reshape, "Cannot expand dynamic dims from more than one dimension"); - } - - SmallVector reassociationMap; - if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), - operandTy.getShape(), - reassociationMap, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, - "tosa.reshape Attempting to expand into an incompatible shape"); - } - - SmallVector intermediateShape; - if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), - intermediateShape, isDynamic) || - intermediateShape != operandTy.getShape()) { - return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot expand into given shape"); - } - rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); - return success(); - } -}; + SmallVector outShape; + getValuesFromIntArrayAttribute(reshape.new_shape(), outShape); -class ReshapeConverterCollapseExpand - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.input1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); - - if (operandTy == resultTy) { - rewriter.replaceOp(reshape, adaptor.getOperands()[0]); - return success(); - } - - SmallVector intermediateShape; - if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(), - intermediateShape, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot identify an intermediate shape between " - "the given two shapes"); + SmallVector shapeValues; + for (const int64_t dim : outShape) { + shapeValues.push_back(rewriter.create(loc, dim)); } - Value collapse = rewriter.create( + Value outputShape = rewriter.create( reshape.getLoc(), - RankedTensorType::get(intermediateShape, - reshape.getType().getElementType()), - adaptor.input1()); - Value expand = - rewriter.create(reshape.getLoc(), resultTy, collapse); - rewriter.replaceOp(reshape, expand); + RankedTensorType::get({resultRank}, rewriter.getIndexType()), + shapeValues); + rewriter.replaceOpWithNewOp( + reshape, resultTy, adaptor.input1(), outputShape); return success(); } }; @@ -2308,9 +2126,7 @@ ConcatConverter, GatherConverter, PadConverter, - ReshapeConverterCollapse, - ReshapeConverterExpand, - ReshapeConverterCollapseExpand, + ReshapeConverter, RescaleConverter, ResizeConverter, ReverseConverter, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -144,6 +144,19 @@ return filteredDims; } +// Creates a map to collapse the last dimension of the Depthwise convolution op +// due to a shape mismatch +static void createDepthwiseConvCollapseMap( + int64_t outputRank, SmallVector &reassociationMap, + OpBuilder &rewriter) { + reassociationMap.resize(outputRank); + for (int i = 0; i < outputRank; i++) { + reassociationMap[i].push_back(rewriter.getAffineDimExpr(i)); + } + reassociationMap[outputRank - 1].push_back( + rewriter.getAffineDimExpr(outputRank)); +} + namespace { class ConvConverter : public OpConversionPattern { @@ -331,6 +344,7 @@ ShapedType weightTy = weight.getType().cast(); ShapedType biasTy = bias.getType().cast(); ShapedType resultTy = op->getResult(0).getType().cast(); + int64_t resultRank = resultTy.getRank(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -410,10 +424,10 @@ // Broadcast the initial value to the output tensor before convolving. SmallVector indexingMaps; indexingMaps.push_back(AffineMap::get( - /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, + /*dimCount=*/resultRank, /*symbolCount=*/0, {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( @@ -432,14 +446,18 @@ loc, linalgConvTy, ValueRange{input, weight}, ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); - Value convReshape = rewriter.create( - loc, resultTy, conv, rewriter.getI64ArrayAttr(resultTy.getShape())); + + SmallVector reassociationMap; + createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); + Value convReshape = rewriter.create( + loc, resultTy, conv, reassociationMap); + Value result = rewriter .create( loc, resultTy, ValueRange({bias, convReshape}), biasInitTensor, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), + getNParallelLoopsAttrs(resultRank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { Value added = nestedBuilder.create( @@ -457,14 +475,16 @@ loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); - Value convReshape = rewriter.create( - loc, resultTy, conv, rewriter.getI64ArrayAttr(resultTy.getShape())); + SmallVector reassociationMap; + createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); + Value convReshape = rewriter.create( + loc, resultTy, conv, reassociationMap); Value result = rewriter .create( loc, resultTy, ValueRange({bias, convReshape}), biasInitTensor, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), + getNParallelLoopsAttrs(resultRank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { Value added = nestedBuilder.create( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -477,7 +477,7 @@ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) - // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]} + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // CHECK: [[ADD:%.+]] = arith.addf %arg3, %arg4 : f32 @@ -501,7 +501,7 @@ // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[OUT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 33] // CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<3x1x3x11xf32>) outs(%[[FILL]] : tensor) - // CHECK: %[[COLLAPSED:.+]] = "tosa.reshape"(%[[DEPTH]]) {new_shape = [-1, 5, 5, 33]} + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: %[[BIAS:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[COLLAPSED]] : tensor<33xf32>, tensor) outs(%[[OUT]] : tensor) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // CHECK: %[[ADD:.+]] = arith.addf %arg3, %arg4 : f32 @@ -523,7 +523,7 @@ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) - // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]} + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // CHECK: [[ADD:%.+]] = arith.addf %arg3, %arg4 : f32 @@ -551,7 +551,7 @@ // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>) - // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 12, 12, 512]} + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x12x12x512xi32>) outs([[OUT]] : tensor<1x12x12x512xi32>) { // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 @@ -575,7 +575,7 @@ // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>) - // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 10, 10, 512]} + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x10x10x512xi32>) outs([[OUT]] : tensor<1x10x10x512xi32>) { // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 @@ -596,7 +596,7 @@ // CHECK: tensor.yield %cst : f32 // CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32> // CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%22 : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32> - // CHECK: %[[RESHAPED:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, -1, -1, 15]} : (tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x15xf32> + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[CONV]] {{\[}}[0], [1], [2], [3, 4]] %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 2, 3, 4], dilation = [2, 1], stride = [1, 2]} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> return } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -89,7 +89,8 @@ // CHECK-LABEL: @test_broadcast func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg0 + // CHECK: [[TENSOR:%.+]] = tensor.from_elements : tensor<0xindex> + // CHECK: [[RESHAPE:%.+]] = tensor.reshape %arg0([[TENSOR]]) // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %arg1 : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %arg2, %arg3 : f32 @@ -107,7 +108,8 @@ // CHECK-LABEL: @test_broadcast_swapped_args func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg1 + // CHECK: [[TENSOR:%.+]] = tensor.from_elements : tensor<0xindex> + // CHECK: [[RESHAPE:%.+]] = tensor.reshape %arg1([[TENSOR]]) // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %arg2, %arg3 : f32 @@ -126,8 +128,12 @@ // CHECK-LABEL: @test_multibroadcast func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> - // CHECK: [[RESHAPE1:%.+]] = tensor.collapse_shape %arg0 {{\[}}[0, 1]] - // CHECK: [[RESHAPE2:%.+]] = tensor.collapse_shape %arg1 {{\[}}[0, 1]] + // CHECK: [[C3:%.+]] = arith.constant 3 : index + // CHECK: [[TENSOR1:%.+]] = tensor.from_elements %c3 : tensor<1xindex> + // CHECK: [[RESHAPE1:%.+]] = tensor.reshape %arg0(%1) : (tensor<1x3xf32>, tensor<1xindex>) -> tensor<3xf32> + // CHECK: [[C2:%.+]] = arith.constant 2 : index + // CHECK: [[TENSOR2:%.+]] = tensor.from_elements %c2 : tensor<1xindex> + // CHECK: [[RESHAPE2:%.+]] = tensor.reshape %arg1(%3) : (tensor<2x1xf32>, tensor<1xindex>) -> tensor<2xf32> // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %arg2, %arg3 : f32 @@ -531,31 +537,14 @@ // ----- -// CHECK-LABEL: @test_reshape_downrank -func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { - // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg0 {{\[}}[0, 1]] - %0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32> - // CHECK: return [[RESHAPE]] - return %0 : tensor<6xf32> -} - -// ----- - -// CHECK-LABEL: @test_reshape_downrank_dyn -func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor { - // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %arg0 {{\[}}[0, 1]] - %0 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<2x?xf32>) -> tensor - // CHECK: return [[RESHAPE]] - return %0 : tensor -} - -// ----- - // CHECK-LABEL: @test_reshape_uprank func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { - // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %arg0 {{\[}}[0, 1]] + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[C3:.+]] = arith.constant 3 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[C2]], %[[C3]] : tensor<2xindex> + // CHECK: %[[RESHAPE:.+]] = tensor.reshape %arg0(%[[TENSOR]]) %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32> - // CHECK: return [[RESHAPE]] + // CHECK: return %[[RESHAPE]] return %0 : tensor<2x3xf32> } @@ -563,9 +552,12 @@ // CHECK-LABEL: @test_reshape_uprank_dyn func.func @test_reshape_uprank_dyn(%arg0: tensor) -> tensor<2x?xf32> { - // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %arg0 {{\[}}[0, 1]] + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[CM1:.+]] = arith.constant -1 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[C2]], %[[CM1]] : tensor<2xindex> + // CHECK: %[[RESHAPE:.+]] = tensor.reshape %arg0(%[[TENSOR]]) %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor) -> tensor<2x?xf32> - // CHECK: return [[RESHAPE]] + // CHECK: return %[[RESHAPE]] return %0 : tensor<2x?xf32> } @@ -573,11 +565,12 @@ // CHECK-LABEL: @test_reshape_samerank func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { - // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>) - // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[C3:.+]] = arith.constant 3 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[C2]], %[[C3]] : tensor<2xindex> + // CHECK: %[[RESHAPE:.+]] = tensor.reshape %arg0(%[[TENSOR]]) %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32> - // CHECK-NEXT: return %[[RESHAPE2]] + // CHECK-NEXT: return %[[RESHAPE]] return %0 : tensor<2x3xf32> } @@ -585,35 +578,76 @@ // CHECK-LABEL: @test_reshape_samerank_dyn func.func @test_reshape_samerank_dyn(%arg0: tensor) -> tensor<2x?xf32> { - // CHECK-SAME: (%[[ARG0:.*]]: tensor) - // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[CM1:.+]] = arith.constant -1 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[C2]], %[[CM1]] : tensor<2xindex> + // CHECK: %[[RESHAPE:.+]] = tensor.reshape %arg0(%[[TENSOR]]) %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor) -> tensor<2x?xf32> - // CHECK-NEXT: return %[[RESHAPE2]] + // CHECK-NEXT: return %[[RESHAPE]] return %0 : tensor<2x?xf32> } // ----- -// CHECK-LABEL: @test_reshape_downrank_6D -func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { - // CHECK: tensor.collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]] +// CHECK-LABEL: @test_reshape_downrank +func.func @test_reshape_downrank(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { + // CHECK: %[[C6:.+]] = arith.constant 6 : index + // CHECK: %[[C5:.+]] = arith.constant 5 : index + // CHECK: %[[C77:.+]] = arith.constant 77 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[C6]], %[[C5]], %[[C77]] : tensor<3xindex> + // CHECK: %[[RESHAPE:.+]] = tensor.reshape %arg0(%[[TENSOR]]) %0 = "tosa.reshape"(%arg0) {new_shape = [6, 5, 77]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> + // CHECK-NEXT: return %[[RESHAPE]] return %0 : tensor<6x5x77xf32> } // ----- -// CHECK-LABEL: @test_reshape_downrank_6D_dyn -func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { - // CHECK: tensor.collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]] - // CHECK: tensor.expand_shape %0 {{\[}}[0, 1, 2]] +// CHECK-LABEL: @test_reshape_downrank_dyn +func.func @test_reshape_downrank_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { + // CHECK: %[[CM1:.+]] = arith.constant -1 : index + // CHECK: %[[C5:.+]] = arith.constant 5 : index + // CHECK: %[[C77:.+]] = arith.constant 77 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[CM1]], %[[C5]], %[[C77]] : tensor<3xindex> + // CHECK: %[[RESHAPE:.+]] = tensor.reshape %arg0(%[[TENSOR]]) %0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor + // CHECK-NEXT: return %[[RESHAPE]] return %0 : tensor } // ----- +// CHECK-LABEL: @test_reshape_collapse_multiple_dyn_dims +func @test_reshape_collapse_multiple_dyn_dims(%arg0: tensor<1x?x?x256x1xf32>) -> (tensor<1x?x?x256xf32>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + // CHECK: %[[CM1:.+]] = arith.constant -1 : index + // CHECK: %[[CM1_0:.+]] = arith.constant -1 : index + // CHECK: %[[C256:.+]] = arith.constant 256 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[C1]], %[[CM1]], %[[CM1_0]], %[[C256]] : tensor<4xindex> + // CHECK: %[[RESHAPE:.+]] = tensor.reshape %arg0(%[[TENSOR]]) + %0 = "tosa.reshape"(%arg0) {new_shape = [1, -1, -1, 256]} : (tensor<1x?x?x256x1xf32>) -> tensor<1x?x?x256xf32> + // CHECK-NEXT: return %[[RESHAPE]] + return %0 : tensor<1x?x?x256xf32> +} + +// ----- + +//CHECK-LABEL: @test_reshape_collapse_expand_multiple_dyn_dims +func @test_reshape_collapse_expand_multiple_dyn_dims(%arg0: tensor<2x3x8x?x?xf32>) -> (tensor<6x2x?x4x?xf32>) { + // CHECK: %[[C6:.+]] = arith.constant 6 : index + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[CM1:.+]] = arith.constant -1 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[CM1_0:.+]] = arith.constant -1 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[C6]], %[[C2]], %[[CM1]], %[[C4]], %[[CM1_0]] : tensor<5xindex> + // CHECK: %[[RESHAPE:.+]] = tensor.reshape %arg0(%[[TENSOR]]) + %0 = "tosa.reshape"(%arg0) {new_shape = [6, 2,-1, 4, -1]} : (tensor<2x3x8x?x?xf32>) -> tensor<6x2x?x4x?xf32> + // CHECK-NEXT: return %[[RESHAPE]] + return %0 : tensor<6x2x?x4x?xf32> +} + +// ----- + // CHECK-LABEL: @test_identity func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) { %0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> @@ -750,8 +784,7 @@ // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[RES:.+]] = arith.addf %arg1, %arg2 : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor into tensor + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor into tensor %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return } @@ -772,8 +805,7 @@ // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<5x?xf32> into tensor - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor into tensor<5x?x1xf32> + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32> %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32> return } @@ -1174,19 +1206,19 @@ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]] + // CHECK: tensor.reshape %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] + // CHECK: tensor.reshape %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] + // CHECK: tensor.reshape %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>) return @@ -1204,8 +1236,10 @@ // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DYN]], 1, 3] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor) outs(%[[INIT]] : tensor<2x?x1x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]] - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]] + // CHECK: %[[CM1:.+]] = arith.constant -1 : index + // CHECK: %[[C3:.+]] = arith.constant 3 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[CM1]], %[[C3]] + // CHECK: tensor.reshape %[[GENERIC]](%[[TENSOR]]) : (tensor<2x?x1x3xi8>, tensor<2xindex>) -> tensor %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor) -> (tensor) return @@ -1223,8 +1257,10 @@ // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 2, %[[DYN]], 3] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]] - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]] + // CHECK: %[[C2:.+]] = arith.constant 2 : index + // CHECK: %[[CM1:.+]] = arith.constant -1 : index + // CHECK: %[[TENSOR:.+]] = tensor.from_elements %[[C2]], %[[CM1]] + // CHECK: tensor.reshape %[[GENERIC]](%[[TENSOR]]) : (tensor<2x2x?x3xi8>, tensor<2xindex>) -> tensor<2x?xi8> %0 = "tosa.tile"(%arg0) {multiples = [2, -1]} : (tensor<2x3xi8>) -> (tensor<2x?xi8>) return