diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -664,6 +664,77 @@ - !ScalarExpression scalar_arg: I --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nhwc_min_poly + cpp_class_name: PoolingNhwcMinPolyOp + doc: |- + Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s0, s1, s2, s3)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s4, s5)> + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s0, s6, s7, s3)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] + -> (s8, s9)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] + -> (s10, s11)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d3, d4)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d0, d1, d2, d5)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + - reduction + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: min + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d cpp_class_name: FillRng2DOp diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -275,17 +275,18 @@ } Value applyfn__max(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) { - Value condition = - builder.create(lhs.getLoc(), CmpFPredicate::OGT, lhs, rhs); - return builder.create(lhs.getLoc(), condition, lhs, rhs); - } - if (isInteger(lhs)) { - Value condition = - builder.create(lhs.getLoc(), CmpIPredicate::sgt, lhs, rhs); - return builder.create(lhs.getLoc(), condition, lhs, rhs); - } + if (isFloatingPoint(lhs)) + return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT); + if (isInteger(lhs)) + return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__min(Value lhs, Value rhs) { + if (isFloatingPoint(lhs)) + return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT); + if (isInteger(lhs)) + return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt); llvm_unreachable("unsupported non numeric type"); } @@ -322,6 +323,17 @@ MLIRContext *context; Block █ + Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) { + OpBuilder builder = getBuilder(); + Value condition = builder.create(lhs.getLoc(), predicate, lhs, rhs); + return builder.create(lhs.getLoc(), condition, lhs, rhs); + } + Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) { + OpBuilder builder = getBuilder(); + Value condition = builder.create(lhs.getLoc(), predicate, lhs, rhs); + return builder.create(lhs.getLoc(), condition, lhs, rhs); + } + bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -339,6 +339,7 @@ log = PrimFnType("log") mul = PrimFnType("mul") max = PrimFnType("max") + min = PrimFnType("min") sub = PrimFnType("sub") @@ -364,6 +365,7 @@ add = PrimFn.add.reduce mul = PrimFn.mul.reduce max = PrimFn.max.reduce + min = PrimFn.min.reduce class PrimApply(TensorExpression): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -308,17 +308,23 @@ raise NotImplementedError("Unsupported 'mul' operand: {lhs}") def _eval_max(self, lhs: Value, rhs: Value) -> Value: - i1 = IntegerType.get_signless(1) if _is_floating_point_type(lhs.type): ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) - cond = std.CmpFOp(i1, ogt_attr, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result + return _emit_cmpf_and_select(lhs, rhs, ogt_attr) if _is_integer_type(lhs.type) or _is_index_type(lhs.type): sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) - cond = std.CmpIOp(i1, sgt_attr, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result + return _emit_cmpi_and_select(lhs, rhs, sgt_attr) raise NotImplementedError("Unsupported 'max' operand: {lhs}") + def _eval_min(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) + return _emit_cmpf_and_select(lhs, rhs, olt_attr) + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) + return _emit_cmpi_and_select(lhs, rhs, slt_attr) + raise NotImplementedError("Unsupported 'min' operand: {lhs}") + def _infer_structured_outs(op_config: LinalgStructuredOpConfig, in_arg_defs: Sequence[OperandDefConfig], @@ -397,3 +403,13 @@ if BF16Type.isinstance(t): return 16 raise NotImplementedError(f"Unhandled floating point type switch {t}") + + +def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: + cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result + return std.SelectOp(lhs.type, cond, lhs, rhs).result + + +def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: + cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result + return std.SelectOp(lhs.type, cond, lhs, rhs).result diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -166,6 +166,24 @@ D.c])) +@linalg_structured_op +def pooling_nhwc_min_poly( + I=TensorDef(T1, S.N, S.H, S.W, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -90,6 +90,36 @@ // ----- +func @generalize_pooling_nhwc_min_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { + %0 = linalg.pooling_nhwc_min_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} + ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> + return %0: tensor<1x2x4x1xf32> +} + +// CHECK-LABEL: @generalize_pooling_nhwc_min_poly_f32 +// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) +// CHECK-NEXT: %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: linalg.yield %[[MAX]] : f32 +// CHECK-NEXT: -> tensor<1x2x4x1xf32> + +// ----- + +func @generalize_pooling_nhwc_min_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { + %0 = linalg.pooling_nhwc_min_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} + ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> + return %0: tensor<1x2x4x1xi32> +} + +// CHECK-LABEL: @generalize_pooling_nhwc_min_poly_i32 +// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) +// CHECK-NEXT: %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32 +// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32 +// CHECK-NEXT: linalg.yield %[[MAX]] : i32 +// CHECK-NEXT: -> tensor<1x2x4x1xi32> + +// ----- + func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { %0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -43,7 +43,7 @@ @linalg_structured_op -def pooling_poly( +def pooling_max_poly( I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), @@ -55,6 +55,19 @@ D.c])) +@linalg_structured_op +def pooling_min_poly( + I=TensorDef(T1, S.N, S.H, S.W, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) + + @linalg_structured_op def fill_rng_poly( min=ScalarDef(F64), @@ -216,7 +229,7 @@ return conv_poly( input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2]) - # CHECK-LABEL: @test_f32i32_pooling + # CHECK-LABEL: @test_f32i32_max_pooling # CHECK: linalg.generic # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]] # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] @@ -229,11 +242,11 @@ @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), i32)) - def test_f32i32_pooling(input, shape, init_result): - return pooling_poly( + def test_f32i32_max_pooling(input, shape, init_result): + return pooling_max_poly( input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) - # CHECK-LABEL: @test_f32f32_pooling + # CHECK-LABEL: @test_f32f32_max_pooling # CHECK: linalg.generic # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]] # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] @@ -245,8 +258,26 @@ @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), f32)) - def test_f32f32_pooling(input, shape, init_result): - return pooling_poly( + def test_f32f32_max_pooling(input, shape, init_result): + return pooling_max_poly( + input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + + # CHECK-LABEL: @test_f32i32_min_pooling + # CHECK: = cmpi slt, + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), + RankedTensorType.get((2, 4), i32)) + def test_f32i32_min_pooling(input, shape, init_result): + return pooling_min_poly( + input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + + # CHECK-LABEL: @test_f32f32_min_pooling + # CHECK: = cmpf olt, + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), + RankedTensorType.get((2, 4), f32)) + def test_f32f32_min_pooling(input, shape, init_result): + return pooling_min_poly( input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_i32_fill_rng diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -86,6 +86,8 @@ func @main() -> i32 attributes {llvm.emit_c_interface} { %v0 = constant 0 : i32 %v42 = constant 42.0 : f64 + %v77 = constant 77.0 : f64 + %v-13 = constant -13.0 : f64 %v1 = constant 1.0 : f64 %input = memref.alloc() : memref<1x4x16x1xf64> @@ -96,7 +98,11 @@ linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64> + memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64> + memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64> call @pooling_on_buffers(%input, %shape, %output) : (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> () @@ -301,7 +307,7 @@ test_conv_generic() -def test_pooling_builtin(): +def test_max_pooling_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() f64 = F64Type.get() @@ -325,13 +331,14 @@ execution_engine.invoke("main", res) log("RESULT: ", res[0]) + # 77 is not selected due to the dilation 2 in the second dimension. # CHECK: RESULT: 42 -test_pooling_builtin() +test_max_pooling_builtin() -def test_pooling_generic(): +def test_max_pooling_generic(): with Context() as ctx, Location.unknown(): module = Module.create() f64 = F64Type.get() @@ -360,7 +367,73 @@ execution_engine.invoke("main", res) log("RESULT: ", res[0]) + # 77 is not selected due to the dilation 2 in the second dimension. # CHECK: RESULT: 42 -test_pooling_generic() +test_max_pooling_generic() + + +def test_min_pooling_builtin(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func( + MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), + MemRefType.get((1, 2, 4, 1), i32)) + def pooling_on_buffers(input, shape, output): + linalg.pooling_nhwc_min_poly( + input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) + + execution_engine = ExecutionEngine(transform(module, pooling_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: -13 + + +test_min_pooling_builtin() + + +def test_min_pooling_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func( + MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), + MemRefType.get((1, 2, 4, 1), i32)) + def pooling_on_buffers(input, shape, output): + linalg.pooling_nhwc_min_poly( + input, + shape, + outs=[output], + strides=[2, 4], + dilations=[1, 2], + emit_generic=True) + + execution_engine = ExecutionEngine(transform(module, pooling_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: -13 + + +test_min_pooling_generic()