diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1678,13 +1678,24 @@ dy = rewriter.create(loc, resultElementTy, dy); } + Value xScaleNExt = xScaleN; + Value yScaleNExt = yScaleN; + + if (xScaleN.getType() != resultElementTy) + xScaleNExt = + rewriter.create(loc, resultElementTy, xScaleN); + + if (yScaleN.getType() != resultElementTy) + yScaleNExt = + rewriter.create(loc, resultElementTy, yScaleN); + Value topAcc, bottomAcc; if (imageW == 1) { - topAcc = rewriter.create(loc, y0x0, xScaleN); - bottomAcc = rewriter.create(loc, y1x0, xScaleN); + topAcc = rewriter.create(loc, y0x0, xScaleNExt); + bottomAcc = rewriter.create(loc, y1x0, xScaleNExt); } else { Value rightPart = dx; - Value leftPart = rewriter.create(loc, xScaleN, dx); + Value leftPart = rewriter.create(loc, xScaleNExt, dx); y0x0 = rewriter.create(loc, y0x0, leftPart); y0x1 = rewriter.create(loc, y0x1, rightPart); @@ -1697,10 +1708,10 @@ Value result; if (imageH == 1) { - result = rewriter.create(loc, topAcc, yScaleN); + result = rewriter.create(loc, topAcc, yScaleNExt); } else { Value bottomPart = dy; - Value topPart = rewriter.create(loc, yScaleN, dy); + Value topPart = rewriter.create(loc, yScaleNExt, dy); topAcc = rewriter.create(loc, topAcc, topPart); bottomAcc = rewriter.create(loc, bottomAcc, bottomPart);