diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -369,10 +369,17 @@ // tosa::ClampOp if (isa(op) && elementTy.isa()) { - auto min = rewriter.create(loc, elementTy, - op->getAttr("min_fp")); - auto max = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); + bool losesInfo = false; + APFloat min_apf = op->getAttr("min_fp").cast().getValue(); + APFloat max_apf = op->getAttr("max_fp").cast().getValue(); + min_apf.convert(elementTy.cast().getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + max_apf.convert(elementTy.cast().getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + auto min = rewriter.create( + loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf)); + auto max = rewriter.create( + loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf)); return clampHelper(loc, args[0], min, max, arith::CmpFPredicate::OLT, rewriter); } @@ -410,8 +417,12 @@ if (isa(op) && elementTy.isa()) { auto zero = rewriter.create(loc, FloatAttr::get(elementTy, 0)); - auto n = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); + bool losesInfo = false; + APFloat max_apf = op->getAttr("max_fp").cast().getValue(); + max_apf.convert(elementTy.cast().getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + auto n = rewriter.create( + loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf)); return clampHelper(loc, args[0], zero, n, arith::CmpFPredicate::OLT, rewriter); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -473,6 +473,22 @@ // ----- +// CHECK-LABEL: @test_clamp_f16 +func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () { + // CHECK: linalg.generic + // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 + // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0 + // CHECK-DAG: %[[CMP1:.+]] = arith.cmpf olt, %arg1, %[[C0]] + // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C0]] + // CHECK-DAG: %[[CMP2:.+]] = arith.cmpf olt, %[[C6]], %arg1 + // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C6]], %[[SEL1]] + %0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16> + + return +} + +// ----- + // CHECK-LABEL: @test_bool func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () { // CHECK: linalg.generic