diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h @@ -27,17 +27,15 @@ // Takes a vector of values and condenses them to a vector with no gaps. SmallVector condenseValues(const SmallVector &values); -// Takes the parameters for a clamp and turns it into a series of ops. -template -arith::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min, - arith::ConstantOp max, P pred, - OpBuilder &rewriter) { - auto smallerThanMin = rewriter.create(loc, pred, arg, min); - auto minOrArg = - rewriter.create(loc, smallerThanMin, min, arg); - auto largerThanMax = rewriter.create(loc, pred, max, arg); - return rewriter.create(loc, largerThanMax, max, minOrArg); -} +// Takes the parameters for a clamp and turns it into a series of ops for float +// inputs. +Value clampFloatHelper(Location loc, Value arg, arith::ConstantOp min, + arith::ConstantOp max, OpBuilder &rewriter); + +// Takes the parameters for a clamp and turns it into a series of ops for +// integer inputs. +Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min, + arith::ConstantOp max, OpBuilder &rewriter); // Returns the values in an attribute as an array of values. template diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -182,8 +182,7 @@ auto max = rewriter.create( loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), intermediateType); - auto clamp = clampHelper( - loc, sub, min, max, arith::CmpIPredicate::slt, rewriter); + auto clamp = clampIntHelper(loc, sub, min, max, rewriter); // Truncate to the final value. return rewriter.create(loc, elementTy, clamp); @@ -335,9 +334,7 @@ // tosa::MaximumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create( - loc, arith::CmpFPredicate::OGT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { @@ -348,9 +345,7 @@ // tosa::MinimumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create( - loc, arith::CmpFPredicate::OLT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { @@ -380,8 +375,7 @@ loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf)); auto max = rewriter.create( loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf)); - return clampHelper(loc, args[0], min, max, - arith::CmpFPredicate::OLT, rewriter); + return clampFloatHelper(loc, args[0], min, max, rewriter); } if (isa(op) && elementTy.isa()) { @@ -409,8 +403,7 @@ loc, min, intTy.getIntOrFloatBitWidth()); auto maxVal = rewriter.create( loc, max, intTy.getIntOrFloatBitWidth()); - return clampHelper(loc, args[0], minVal, maxVal, - arith::CmpIPredicate::slt, rewriter); + return clampIntHelper(loc, args[0], minVal, maxVal, rewriter); } // tosa::ReluNOp @@ -423,8 +416,7 @@ APFloat::rmNearestTiesToEven, &losesInfo); auto n = rewriter.create( loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf)); - return clampHelper(loc, args[0], zero, n, - arith::CmpFPredicate::OLT, rewriter); + return clampFloatHelper(loc, args[0], zero, n, rewriter); } if (isa(op) && elementTy.isa()) { @@ -432,8 +424,7 @@ rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto n = createConstFromIntAttribute(op, "max_int", elementTy, rewriter); - return clampHelper(loc, args[0], zero, n, - arith::CmpIPredicate::slt, rewriter); + return clampIntHelper(loc, args[0], zero, n, rewriter); } // tosa::SigmoidOp @@ -521,8 +512,7 @@ auto rounded = rewriter.create(loc, negative, subbed, added); - auto clamped = clampHelper( - loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter); + auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter); return rewriter.create(loc, dstTy, clamped); } @@ -553,8 +543,7 @@ .getSExtValue(), srcTy.getIntOrFloatBitWidth()); - auto clamped = clampHelper( - loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter); + auto clamped = clampIntHelper(loc, args[0], intMin, intMax, rewriter); return rewriter.create(loc, dstTy, clamped); } } @@ -751,9 +740,7 @@ } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create( - loc, arith::CmpFPredicate::OLT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isa()) { @@ -763,9 +750,7 @@ } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create( - loc, arith::CmpFPredicate::OGT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isa()) { @@ -1314,9 +1299,8 @@ auto intMaxVal = nestedBuilder.create( loc, nestedBuilder.getI32IntegerAttr(intMax)); - value = clampHelper( - nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt, - nestedBuilder); + value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal, + nestedBuilder); if (outIntType.getWidth() < 32) { value = nestedBuilder.create( @@ -1497,10 +1481,8 @@ // Clamp the to be within the bounds of the input image. - iy = clampHelper(loc, iy, hwMin, hMax, - arith::CmpIPredicate::slt, rewriter); - ix = clampHelper(loc, ix, hwMin, wMax, - arith::CmpIPredicate::slt, rewriter); + iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter); + ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter); // Read the value from the input array. iy = @@ -1525,15 +1507,11 @@ Value y1 = rewriter.create(loc, y0, oneVal); Value x1 = rewriter.create(loc, x0, oneVal); - y0 = clampHelper(loc, y0, hwMin, hMax, - arith::CmpIPredicate::slt, rewriter); - y1 = clampHelper(loc, y1, hwMin, hMax, - arith::CmpIPredicate::slt, rewriter); + y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter); + y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter); - x0 = clampHelper(loc, x0, hwMin, wMax, - arith::CmpIPredicate::slt, rewriter); - x1 = clampHelper(loc, x1, hwMin, wMax, - arith::CmpIPredicate::slt, rewriter); + x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter); + x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter); y0 = rewriter.create(loc, rewriter.getIndexType(), y0); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -943,8 +943,7 @@ auto max = rewriter.create( loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(), accETy); - auto clamp = clampHelper( - loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter); + auto clamp = clampIntHelper(loc, scaled, min, max, rewriter); poolVal = clamp; // Convert type. diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -28,3 +28,21 @@ condensedValues.push_back(value); return condensedValues; } + +Value mlir::tosa::clampFloatHelper(Location loc, Value arg, + arith::ConstantOp min, arith::ConstantOp max, + OpBuilder &rewriter) { + Value minValue = rewriter.create(loc, arg, min); + return rewriter.create(loc, minValue, max); +} + +Value mlir::tosa::clampIntHelper(Location loc, Value arg, arith::ConstantOp min, + arith::ConstantOp max, OpBuilder &rewriter) { + auto smallerThanMin = + rewriter.create(loc, arith::CmpIPredicate::slt, arg, min); + auto minOrArg = + rewriter.create(loc, smallerThanMin, min, arg); + auto largerThanMax = + rewriter.create(loc, arith::CmpIPredicate::slt, max, arg); + return rewriter.create(loc, largerThanMax, max, minOrArg); +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -198,13 +198,11 @@ %13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic - // CHECK: arith.cmpf - // CHECK: select + // CHECK: arith.maxf %14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic - // CHECK: arith.cmpf - // CHECK: select + // CHECK: arith.minf %15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic @@ -216,13 +214,13 @@ %17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic - // CHECK: arith.cmpf - // CHECK: select + // CHECK: arith.minf + // CHECK: arith.maxf %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic - // CHECK: arith.cmpf - // CHECK: select + // CHECK: arith.minf + // CHECK: arith.maxf %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic @@ -241,10 +239,8 @@ // CHECK: arith.subf // CHECK: arith.cmpf olt // CHECK: select - // CHECK: arith.cmpf olt - // CHECK: select - // CHECK: arith.cmpf olt - // CHECK: select + // CHECK: arith.minf + // CHECK: arith.maxf // CHECK: arith.fptosi %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32> @@ -451,20 +447,22 @@ // CHECK-LABEL: @test_i8 func.func @test_i8(%arg0: tensor<1xi8>) -> () { // CHECK: linalg.generic + // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK-DAG: %[[C127:.+]] = arith.constant -127 // CHECK-DAG: %[[C126:.+]] = arith.constant 126 - // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C127]] + // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C127]] // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C127]] - // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %arg1 + // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %[[ARG1]] // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C126]], %[[SEL1]] %0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic + // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK-DAG: %[[C128:.+]] = arith.constant -128 // CHECK-DAG: %[[C127:.+]] = arith.constant 127 - // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C128]] + // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C128]] // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C128]] - // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %arg1 + // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %[[ARG1]] // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C127]], %[[SEL1]] %1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> @@ -476,12 +474,11 @@ // CHECK-LABEL: @test_clamp_f16 func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () { // CHECK: linalg.generic + // CHECK: ^bb0(%[[ARG1:.+]]: f16, // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0 - // CHECK-DAG: %[[CMP1:.+]] = arith.cmpf olt, %arg1, %[[C0]] - // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C0]] - // CHECK-DAG: %[[CMP2:.+]] = arith.cmpf olt, %[[C6]], %arg1 - // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C6]], %[[SEL1]] + // CHECK-DAG: %[[MIN:.+]] = arith.minf %[[ARG1]], %[[C0]] + // CHECK-DAG: %[[MAX:.+]] = arith.maxf %[[MIN]], %[[C6]] %0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16> return @@ -732,15 +729,13 @@ // CHECK: arith.constant 3.40282347E+38 : f32 // CHECK: linalg.fill // CHECK: linalg.generic - // CHECK: arith.cmpf olt - // CHECK: select + // CHECK: arith.minf %3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> // CHECK: arith.constant -3.40282347E+38 : f32 // CHECK: linalg.fill // CHECK: linalg.generic - // CHECK: arith.cmpf ogt - // CHECK: select + // CHECK: arith.maxf %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> return } @@ -803,9 +798,8 @@ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) // CHECK: ^bb0(%arg1: f32, %arg2: f32) - // CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32 - // CHECK: %[[RES:.+]] = arith.select %[[CMP]], %arg1, %arg2 : f32 - // CHECK: linalg.yield %[[RES]] : f32 + // CHECK: %[[MAX:.+]] = arith.maxf %arg1, %arg2 : f32 + // CHECK: linalg.yield %[[MAX]] : f32 // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return