diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -428,12 +428,32 @@ } if (isa(op) && elementTy.isa()) { - auto min = createConstFromIntAttribute(op, "min_int", elementTy, - rewriter); - auto max = createConstFromIntAttribute(op, "max_int", elementTy, - rewriter); - return clampHelper(loc, args[0], min, max, CmpIPredicate::slt, - rewriter); + auto intTy = elementTy.cast(); + int32_t min = static_cast( + op->getAttr("min_int").cast().getValue().getSExtValue()); + int32_t max = static_cast( + op->getAttr("max_int").cast().getValue().getSExtValue()); + + if (intTy.isUnsignedInteger()) { + min = std::max(min, 0); + max = std::min( + max, + APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue()); + } else { + min = std::max( + min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); + max = std::min( + max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); + } + + auto minVal = + rewriter.create(loc, min, intTy.getIntOrFloatBitWidth()); + auto maxVal = + rewriter.create(loc, max, intTy.getIntOrFloatBitWidth()); + return clampHelper(loc, args[0], minVal, maxVal, + CmpIPredicate::slt, rewriter); } // tosa::ReluNOp diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -404,6 +404,31 @@ // ----- +// CHECK-LABEL: @test_i8 +func @test_i8(%arg0: tensor<1xi8>) -> () { + // CHECK: linalg.generic + // CHECK-DAG: %[[C127:.+]] = constant -127 + // CHECK-DAG: %[[C126:.+]] = constant 126 + // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C127]] + // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C127]] + // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C126]], %arg1 + // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C126]], %[[SEL1]] + %0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> + + // CHECK: linalg.generic + // CHECK-DAG: %[[C128:.+]] = constant -128 + // CHECK-DAG: %[[C127:.+]] = constant 127 + // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C128]] + // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C128]] + // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C127]], %arg1 + // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C127]], %[[SEL1]] + %1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> + + return +} + +// ----- + // CHECK-LABEL: @test_bool func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () { // CHECK: linalg.generic