diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -276,18 +276,20 @@ } Value applyfn__max(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) - return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT); + return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) - return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt); + return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__min(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) - return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT); + return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) - return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt); + return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } @@ -324,17 +326,6 @@ MLIRContext *context; Block █ - Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) { - OpBuilder builder = getBuilder(); - Value condition = builder.create(lhs.getLoc(), predicate, lhs, rhs); - return builder.create(lhs.getLoc(), condition, lhs, rhs); - } - Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) { - OpBuilder builder = getBuilder(); - Value condition = builder.create(lhs.getLoc(), predicate, lhs, rhs); - return builder.create(lhs.getLoc(), condition, lhs, rhs); - } - bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -319,20 +319,16 @@ def _eval_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) - return _emit_cmpf_and_select(lhs, rhs, ogt_attr) + return std.MaxFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) - return _emit_cmpi_and_select(lhs, rhs, sgt_attr) + return std.MaxSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) - return _emit_cmpf_and_select(lhs, rhs, olt_attr) + return std.MinFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) - return _emit_cmpi_and_select(lhs, rhs, slt_attr) + return std.MinSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") @@ -413,13 +409,3 @@ if BF16Type.isinstance(t): return 16 raise NotImplementedError(f"Unhandled floating point type switch {t}") - - -def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: - cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result - - -def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: - cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -242,8 +242,7 @@ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32) # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32 - # CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT]], %[[IN_CAST:.+]] : i32 - # CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN_CAST:.+]] : i32 + # CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT]], %[[IN_CAST:.+]] : i32 # CHECK-NEXT: linalg.yield %[[MAX]] : i32 # CHECK-NEXT: -> tensor<2x4xi32> @builtin.FuncOp.from_py_func( @@ -258,8 +257,7 @@ # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]] # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT]], %[[IN:.+]] : f32 - # CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN:.+]] : f32 + # CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT]], %[[IN:.+]] : f32 # CHECK-NEXT: linalg.yield %[[MAX]] : f32 # CHECK-NEXT: -> tensor<2x4xf32> @builtin.FuncOp.from_py_func( @@ -270,7 +268,7 @@ input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_f32i32_min_pooling - # CHECK: = cmpi slt, + # CHECK: = minsi @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), i32)) @@ -279,7 +277,7 @@ input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_f32f32_min_pooling - # CHECK: = cmpf olt, + # CHECK: = minf @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), f32)) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -118,7 +118,7 @@ def transform(module, boilerplate): import mlir.conversions - import mlir.dialects.linalg.passes + import mlir.all_passes_registration import mlir.transforms # TODO: Allow cloning functions from one module to another. @@ -128,8 +128,8 @@ boilerplate) pm = PassManager.parse( "builtin.func(convert-linalg-to-loops, lower-affine, " + - "convert-scf-to-std), convert-vector-to-llvm," + - "convert-memref-to-llvm,convert-std-to-llvm," + + "convert-scf-to-std, std-expand), convert-vector-to-llvm," + + "convert-memref-to-llvm, convert-std-to-llvm," + "reconcile-unrealized-casts") pm.run(mod) return mod