Index: mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp =================================================================== --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -335,9 +335,7 @@ // tosa::MaximumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create( - loc, arith::CmpFPredicate::OGT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { @@ -348,9 +346,7 @@ // tosa::MinimumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create( - loc, arith::CmpFPredicate::OLT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { @@ -751,9 +747,7 @@ } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create( - loc, arith::CmpFPredicate::OLT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isa()) { @@ -763,9 +757,7 @@ } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create( - loc, arith::CmpFPredicate::OGT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isa()) { Index: mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir =================================================================== --- mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -198,13 +198,11 @@ %13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic - // CHECK: arith.cmpf - // CHECK: select + // CHECK: arith.maxf %14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic - // CHECK: arith.cmpf - // CHECK: select + // CHECK: arith.minf %15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic @@ -732,15 +730,13 @@ // CHECK: arith.constant 3.40282347E+38 : f32 // CHECK: linalg.fill // CHECK: linalg.generic - // CHECK: arith.cmpf olt - // CHECK: select + // CHECK: arith.minf %3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> // CHECK: arith.constant -3.40282347E+38 : f32 // CHECK: linalg.fill // CHECK: linalg.generic - // CHECK: arith.cmpf ogt - // CHECK: select + // CHECK: arith.maxf %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> return } @@ -803,9 +799,8 @@ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) // CHECK: ^bb0(%arg1: f32, %arg2: f32) - // CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32 - // CHECK: %[[RES:.+]] = arith.select %[[CMP]], %arg1, %arg2 : f32 - // CHECK: linalg.yield %[[RES]] : f32 + // CHECK: %[[MAX:.+]] = arith.maxf %arg1, %arg2 : f32 + // CHECK: linalg.yield %[[MAX]] : f32 // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return