diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -144,6 +144,19 @@ return filteredDims; } +// Creates a map to collapse the last dimension of the Depthwise convolution op +// due to a shape mismatch +static void createDepthwiseConvCollapseMap( + int64_t outputRank, SmallVector &reassociationMap, + OpBuilder &rewriter) { + reassociationMap.resize(outputRank); + for (int i = 0; i < outputRank; i++) { + reassociationMap[i].push_back(rewriter.getAffineDimExpr(i)); + } + reassociationMap[outputRank - 1].push_back( + rewriter.getAffineDimExpr(outputRank)); +} + namespace { class ConvConverter : public OpConversionPattern { @@ -331,6 +344,7 @@ ShapedType weightTy = weight.getType().cast(); ShapedType biasTy = bias.getType().cast(); ShapedType resultTy = op->getResult(0).getType().cast(); + int64_t resultRank = resultTy.getRank(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -410,10 +424,10 @@ // Broadcast the initial value to the output tensor before convolving. SmallVector indexingMaps; indexingMaps.push_back(AffineMap::get( - /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, + /*dimCount=*/resultRank, /*symbolCount=*/0, {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( @@ -432,14 +446,18 @@ loc, linalgConvTy, ValueRange{input, weight}, ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); - Value convReshape = rewriter.create( - loc, resultTy, conv, rewriter.getI64ArrayAttr(resultTy.getShape())); + + SmallVector reassociationMap; + createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); + Value convReshape = rewriter.create( + loc, resultTy, conv, reassociationMap); + Value result = rewriter .create( loc, resultTy, ValueRange({bias, convReshape}), biasInitTensor, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), + getNParallelLoopsAttrs(resultRank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { Value added = nestedBuilder.create( @@ -457,14 +475,16 @@ loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); - Value convReshape = rewriter.create( - loc, resultTy, conv, rewriter.getI64ArrayAttr(resultTy.getShape())); + SmallVector reassociationMap; + createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); + Value convReshape = rewriter.create( + loc, resultTy, conv, reassociationMap); Value result = rewriter .create( loc, resultTy, ValueRange({bias, convReshape}), biasInitTensor, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), + getNParallelLoopsAttrs(resultRank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { Value added = nestedBuilder.create( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -477,7 +477,7 @@ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) - // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]} + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // CHECK: [[ADD:%.+]] = arith.addf %arg3, %arg4 : f32 @@ -501,7 +501,7 @@ // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[OUT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 33] // CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<3x1x3x11xf32>) outs(%[[FILL]] : tensor) - // CHECK: %[[COLLAPSED:.+]] = "tosa.reshape"(%[[DEPTH]]) {new_shape = [-1, 5, 5, 33]} + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: %[[BIAS:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[COLLAPSED]] : tensor<33xf32>, tensor) outs(%[[OUT]] : tensor) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // CHECK: %[[ADD:.+]] = arith.addf %arg3, %arg4 : f32 @@ -523,7 +523,7 @@ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) - // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]} + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // CHECK: [[ADD:%.+]] = arith.addf %arg3, %arg4 : f32 @@ -551,7 +551,7 @@ // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>) - // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 12, 12, 512]} + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x12x12x512xi32>) outs([[OUT]] : tensor<1x12x12x512xi32>) { // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 @@ -575,7 +575,7 @@ // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>) - // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 10, 10, 512]} + // CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x10x10x512xi32>) outs([[OUT]] : tensor<1x10x10x512xi32>) { // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // CHECK: [[ADD:%.+]] = arith.addi %arg3, %arg4 : i32 @@ -596,7 +596,7 @@ // CHECK: tensor.yield %cst : f32 // CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32> // CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%22 : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32> - // CHECK: %[[RESHAPED:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, -1, -1, 15]} : (tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x15xf32> + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[CONV]] {{\[}}[0], [1], [2], [3, 4]] %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 2, 3, 4], dilation = [2, 1], stride = [1, 2]} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> return }