diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -63,27 +63,25 @@ highIndices, padValue); } -static mlir::Value makeIntBiasAdd(PatternRewriter &rewriter, Location loc, - ShapedType resultTy, Value bias, Value conv, - Value result, - ArrayRef indexingMaps) { - result = rewriter - .create( - loc, resultTy, ValueRange({bias, conv}), result, - indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), - [](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = builder.create(loc, resType, - biasVal); - } - Value added = - builder.create(loc, biasVal, args[1]); - builder.create(loc, added); - }) - .getResult(0); - return result; +static mlir::Value +linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, + Value conv, Value result, + ArrayRef indexingMaps) { + ShapedType resultTy = conv.getType().cast(); + return rewriter + .create( + loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = builder.create(loc, resType, biasVal); + } + Value added = builder.create(loc, biasVal, args[1]); + builder.create(loc, added); + }) + .getResult(0); } static mlir::Value reifyConstantDim(int64_t attr, @@ -313,8 +311,8 @@ loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, ValueRange{zeroTensor}, strideAttr, dilationAttr) ->getResult(0); - Value result = makeIntBiasAdd(rewriter, loc, resultTy, bias, conv, - biasEmptyTensor, indexingMaps); + Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, conv, + biasEmptyTensor, indexingMaps); rewriter.replaceOp(op, result); return success(); } @@ -491,8 +489,8 @@ createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); Value convReshape = rewriter.create( loc, resultTy, conv, reassociationMap); - Value result = makeIntBiasAdd(rewriter, loc, resultTy, bias, convReshape, - biasEmptyTensor, indexingMaps); + Value result = linalgIntBroadcastExtSIAdd( + rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps); rewriter.replaceOp(op, result); } return success(); @@ -663,8 +661,8 @@ ValueRange{input, transposedWeight, inputZp, outputZp}, zeroTensor) ->getResult(0); - Value result = makeIntBiasAdd(rewriter, loc, outputTy, bias, matmul, - biasEmptyTensor, indexingMaps); + Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, matmul, + biasEmptyTensor, indexingMaps); rewriter.replaceOp(op, result); return success(); }