diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1099,27 +1099,6 @@ input = applyPad(loc, input, pad, zeroAttr, rewriter); - // Broadcast the initial value to the output tensor before convolving. - SmallVector indexingMaps; - indexingMaps.push_back(AffineMap::get( - /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); - - Value initTensor = - rewriter.create(loc, resultShape, resultETy); - - Value biasBroadcast = - rewriter - .create( - loc, resultTy, bias, initTensor, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); - }) - .getResult(0); - // Extract the attributes for convolution. llvm::SmallVector stride, dilation; getValuesFromIntArrayAttribute(strideTosaAttr, stride); @@ -1135,28 +1114,71 @@ weightShape[2], weightShape[3]}, resultETy); - Value biasReshape = - rewriter.create(loc, linalgConvTy, biasBroadcast); - Value conv; + // Broadcast the initial value to the output tensor before convolving. + SmallVector indexingMaps; + indexingMaps.push_back(AffineMap::get( + /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, + {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); + + Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); + Value initTensor = rewriter.create( + loc, linalgConvTy.getShape(), resultETy); + Value zero = rewriter.create(loc, resultZeroAttr); + Value zeroTensor = + rewriter.create(loc, zero, initTensor).getResult(0); + + Value biasInitTensor = rewriter.create( + loc, resultTy.getShape(), resultETy); if (!isQuantized) { - conv = rewriter - .create( - loc, linalgConvTy, ValueRange{input, weight}, - ValueRange{biasReshape}, strideAttr, dilationAttr) - .getResult(0); + Value conv = rewriter + .create( + loc, linalgConvTy, ValueRange{input, weight}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) + .getResult(0); + Value conv_reshape = + rewriter.create(loc, resultTy, conv); + Value result = + rewriter + .create( + loc, resultTy, ValueRange({bias, conv_reshape}), + biasInitTensor, indexingMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added = + nestedBuilder.create(loc, args[0], args[1]); + nestedBuilder.create(nestedLoc, added); + }) + .getResult(0); + rewriter.replaceOp(op, result); } else { auto iZpVal = rewriter.create(loc, iZp); auto kZpVal = rewriter.create(loc, kZp); - conv = + Value conv = rewriter .create( loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{biasReshape}, strideAttr, dilationAttr) + ValueRange{zeroTensor}, strideAttr, dilationAttr) + .getResult(0); + Value conv_reshape = + rewriter.create(loc, resultTy, conv); + Value result = + rewriter + .create( + loc, resultTy, ValueRange({bias, conv_reshape}), + biasInitTensor, indexingMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added = + nestedBuilder.create(loc, args[0], args[1]); + nestedBuilder.create(nestedLoc, added); + }) .getResult(0); + rewriter.replaceOp(op, result); } - - Value reshape = rewriter.create(loc, resultTy, conv); - rewriter.replaceOp(op, reshape); return success(); } }; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1389,14 +1389,17 @@ // CHECK-LABEL: @depthwise_conv func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 33] - // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<33xf32>) outs([[INIT]] : tensor<1x5x5x33xf32>) { - // CHECK: ^bb0(%arg3: f32, %arg4: f32): // no predecessors - // CHECK: linalg.yield %arg3 : f32 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11] + // CHECK: [[CST0:%.+]] = constant 0 + // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] + // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) + // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { + // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + // CHECK: [[ADD:%.+]] = addf %arg3, %arg4 : f32 + // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> - // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]] - // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>) - // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]] %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>) return } @@ -1408,14 +1411,17 @@ // CHECK-LABEL: @depthwise_conv_strides func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 33] - // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<33xf32>) outs([[INIT]] : tensor<1x5x5x33xf32>) { - // CHECK: ^bb0(%arg3: f32, %arg4: f32): // no predecessors - // CHECK: linalg.yield %arg3 : f32 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11] + // CHECK: [[CST0:%.+]] = constant 0 + // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] + // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) + // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { + // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + // CHECK: [[ADD:%.+]] = addf %arg3, %arg4 : f32 + // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> - // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]] - // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>) - // CHECK: linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1] } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>) return } @@ -1427,20 +1433,23 @@ // CHECK-LABEL: @depthwise_conv_quant func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () { - // CHECK: %[[PADV:.+]] = constant -128 - // CHECK: %[[PAD:.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] - // CHECK: linalg.yield %[[PADV]] - - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 12, 12, 512] - // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<512xi32>) outs([[INIT]] : tensor<1x12x12x512xi32>) { - // CHECK: ^bb0(%arg3: i32, %arg4: i32): // no predecessors - // CHECK: linalg.yield %arg3 : i32 + // CHECK: [[PADV:%.+]] = constant -128 + // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] + // CHECK: linalg.yield [[PADV]] + + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 12, 12, 4, 128] + // CHECK: [[CST0:%.+]] = constant 0 + // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 12, 12, 512] + // CHECK: [[C128:%.+]] = constant -128 + // CHECK: [[C42:%.+]] = constant 42 + // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>) + // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x12x12x512xi32>) outs([[OUT]] : tensor<1x12x12x512xi32>) { + // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // no predecessors + // CHECK: [[ADD:%.+]] = addi %arg3, %arg4 : i32 + // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x12x12x512xi32> - // CHECK: %[[DBIAS:.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]] - // CHECK: %[[C128:.+]] = constant -128 - // CHECK: %[[C42:.+]] = constant 42 - // CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[PAD]], %arg1, %[[C128]], %[[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs(%[[DBIAS]] : tensor<1x12x12x4x128xi32>) - // CHECK: linalg.tensor_collapse_shape %[[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> return } @@ -1452,16 +1461,19 @@ // CHECK-LABEL: @depthwise_conv_quant_dilations func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 512] - // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<512xi32>) outs([[INIT]] : tensor<1x10x10x512xi32>) { - // CHECK: ^bb0(%arg3: i32, %arg4: i32): // no predecessors - // CHECK: linalg.yield %arg3 : i32 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 4, 128] + // CHECK: [[CST0:%.+]] = constant 0 + // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 10, 10, 512] + // CHECK: [[C128:%.+]] = constant -128 + // CHECK: [[C42:%.+]] = constant 42 + // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>) + // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x10x10x512xi32>) outs([[OUT]] : tensor<1x10x10x512xi32>) { + // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // no predecessors + // CHECK: [[ADD:%.+]] = addi %arg3, %arg4 : i32 + // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x10x10x512xi32> - // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]] - // CHECK: %[[C128:.+]] = constant -128 - // CHECK: %[[C42:.+]] = constant 42 - // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[DBIAS]] : tensor<1x10x10x4x128xi32>) - // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]] %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [2, 2] } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> return }