diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1259,6 +1259,11 @@ auto outputTy = op.output().getType().cast(); unsigned rank = inputTy.getRank(); + // This is an illegal configuration. terminate and log an error + if (op.double_round() && !op.scale32()) + return rewriter.notifyMatchFailure( + op, "tosa.rescale requires scale32 for double_round to be true"); + if (!outputTy.hasStaticShape()) return rewriter.notifyMatchFailure( op, "tosa to linalg conversion expects statically shaped tensors"); diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -103,7 +103,8 @@ Value shiftThirty64 = rewriter.create( loc, rewriter.getI64Type(), shiftThirty32); - // Round value needs to with be added or sbustracted depending on + // Round value needs to with be added or subtracted depending on the sign + // of the input value. Value roundAdd64 = rewriter.create(loc, round64, shiftThirty64); Value roundSub64 =