diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -428,12 +428,32 @@ } if (isa(op) && elementTy.isa()) { - auto min = createConstFromIntAttribute(op, "min_int", elementTy, - rewriter); - auto max = createConstFromIntAttribute(op, "max_int", elementTy, - rewriter); - return clampHelper(loc, args[0], min, max, CmpIPredicate::slt, - rewriter); + auto intTy = elementTy.cast(); + int32_t min = static_cast( + op->getAttr("min_int").cast().getValue().getSExtValue()); + int32_t max = static_cast( + op->getAttr("max_int").cast().getValue().getSExtValue()); + + if (intTy.isUnsignedInteger()) { + min = std::max(min, 0); + max = std::min( + max, + APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue()); + } else { + min = std::max( + min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); + max = std::min( + max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); + } + + auto minVal = + rewriter.create(loc, min, intTy.getIntOrFloatBitWidth()); + auto maxVal = + rewriter.create(loc, max, intTy.getIntOrFloatBitWidth()); + return clampHelper(loc, args[0], minVal, maxVal, + CmpIPredicate::slt, rewriter); } // tosa::ReluNOp