diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -63,6 +63,29 @@ highIndices, padValue); } +static mlir::Value makeIntBiasAdd(PatternRewriter &rewriter, Location loc, + ShapedType resultTy, Value bias, Value conv, + Value result, + ArrayRef indexingMaps) { + result = rewriter + .create( + loc, resultTy, ValueRange({bias, conv}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = builder.create(loc, resType, + biasVal); + } + Value added = + builder.create(loc, biasVal, args[1]); + builder.create(loc, added); + }) + .getResult(0); + return result; +} + static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder) { return builder.createOrFold( @@ -290,19 +313,8 @@ loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, ValueRange{zeroTensor}, strideAttr, dilationAttr) ->getResult(0); - - Value result = - rewriter - .create( - loc, resultTy, ValueRange({bias, conv}), biasEmptyTensor, - indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - Value added = nestedBuilder.create( - loc, args[0], args[1]); - nestedBuilder.create(nestedLoc, added); - }) - .getResult(0); + Value result = makeIntBiasAdd(rewriter, loc, resultTy, bias, conv, + biasEmptyTensor, indexingMaps); rewriter.replaceOp(op, result); return success(); } @@ -479,19 +491,8 @@ createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); Value convReshape = rewriter.create( loc, resultTy, conv, reassociationMap); - Value result = - rewriter - .create( - loc, resultTy, ValueRange({bias, convReshape}), - biasEmptyTensor, indexingMaps, - getNParallelLoopsAttrs(resultRank), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - Value added = nestedBuilder.create( - loc, args[0], args[1]); - nestedBuilder.create(nestedLoc, added); - }) - .getResult(0); + Value result = makeIntBiasAdd(rewriter, loc, resultTy, bias, convReshape, + biasEmptyTensor, indexingMaps); rewriter.replaceOp(op, result); } return success(); @@ -624,11 +625,8 @@ Value transposedWeight = rewriter.create( loc, newWeightTy, weight, permutationValue); - auto biasEmptyTensor = - rewriter - .create(loc, outputTy.getShape(), outputETy, - filteredDims) - ->getResults(); + Value biasEmptyTensor = rewriter.create( + loc, outputTy.getShape(), outputETy, filteredDims); if (!op.getQuantizationInfo()) { Value matmul = rewriter @@ -665,18 +663,8 @@ ValueRange{input, transposedWeight, inputZp, outputZp}, zeroTensor) ->getResult(0); - Value result = - rewriter - .create( - loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor, - indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - Value added = nestedBuilder.create( - loc, args[0], args[1]); - nestedBuilder.create(nestedLoc, added); - }) - .getResult(0); + Value result = makeIntBiasAdd(rewriter, loc, outputTy, bias, matmul, + biasEmptyTensor, indexingMaps); rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -361,6 +361,28 @@ // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: @conv2d_i8 +func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () { + // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> + // CHECK: %[[W:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) + // CHECK: %[[M_IN:.+]] = tensor.empty() + // CHECK: %[[CST:.+]] = arith.constant 0 + // CHECK: %[[FILL:.+]] = linalg.fill + // CHECK: %[[B_IN:.+]] = tensor.empty() + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> + // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>) + // CHECK: arith.extsi + // CHECK: arith.addi + // CHECK: linalg.yield + %0 = "tosa.conv2d"(%input, %weights, %bias) {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> (tensor<1x45x40x28xi32>) + return +} + +// ----- + +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + // CHECK-LABEL: @conv2d_f32 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>