diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -369,10 +369,17 @@ // tosa::ClampOp if (isa(op) && elementTy.isa()) { - auto min = rewriter.create(loc, elementTy, - op->getAttr("min_fp")); - auto max = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); + bool losesInfo = false; + APFloat min_apf = op->getAttr("min_fp").cast().getValue(); + APFloat max_apf = op->getAttr("max_fp").cast().getValue(); + min_apf.convert(elementTy.cast().getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + max_apf.convert(elementTy.cast().getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + auto min = rewriter.create( + loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf)); + auto max = rewriter.create( + loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf)); return clampHelper(loc, args[0], min, max, arith::CmpFPredicate::OLT, rewriter); } @@ -410,8 +417,12 @@ if (isa(op) && elementTy.isa()) { auto zero = rewriter.create(loc, FloatAttr::get(elementTy, 0)); - auto n = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); + bool losesInfo = false; + APFloat max_apf = op->getAttr("max_fp").cast().getValue(); + max_apf.convert(elementTy.cast().getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + auto n = rewriter.create( + loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf)); return clampHelper(loc, args[0], zero, n, arith::CmpFPredicate::OLT, rewriter); }