diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py @@ -49,9 +49,6 @@ """All structured ops use the same mixin class.""" def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): - if outputs and results: - raise ValueError( - "Structured ops must have outputs or results, but not both.") super().__init__( self.build_generic(results=list(results), operands=[list(inputs), list(outputs)], diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py @@ -2,4 +2,50 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# These are the classes exposed by following the path: +# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py from .._linalg_ops_gen import * + +# These are the ground truth functions defined as: +# ``` +# @linalg_structured_op +# def matmul(A=TensorDef(T1, S.M, S.K), +# B=TensorDef(T2, S.K, S.N), +# C=TensorDef(U, S.M, S.N, output=True)): +# ``` +# using the linalg-py eDSL. +# The linalg-py eDSL builds a python representation (PyRepr) that is +# used in following ways: +# 1. PyRepr -> YAML to generate the C++ and Python .td files. These +# then turn into the core C++ dialect ops and _linalg_ops_gen +# respectively. _linalg_ops_gen uses the generic OpView mechanism +# to make the C++ classes available to python through the CAPI. +# All this occurs at compiler compile time. +# 2. PyRepr -> linalg.core_named_ops calls: piggybacks on the +# _linalg_ops_gen classes and the OpView mechanism to build IR at +# runtime in python: +# a. by default, the Named Op Form is emitted, e.g.: +# `linalg.matmul(lhs, rhs, outs=[out])` creates the following IR: +# ``` +# %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# outs(%0 : tensor<4x8xf32>) +# -> tensor<4x8xf32> +# ``` +# b. by setting emit_generic=True, the Generic Op Form is emitted, e.g.: +# `linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)` creates the following IR: +# ``` +# %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]} +# ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# outs(%0 : tensor<4x8xf32>) { +# ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): +# ... +# linalg.yield %3 : f32 +# } -> tensor<4x8xf32> +# ``` +# 3. PyRepr -> Runtime Custom Op definitions: directly generates a +# linalg.generic form like in 2.b. +# !!!WARNING!!!: if one creates a runtime custom op with the same name +# as an existing core named op, step 2. will likely take precedence. +# TODO: guard against surprises and fail create Runtime Custom Ops with +# the same name as existing Core Named Ops. +from .opdsl.ops.core_named_ops import * diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -40,11 +40,12 @@ class DefinedOpCallable: """Callable that wraps any defined op function.""" - def __init__(self, op_name: str, model: LinalgOpDef): + def __init__(self, op_name: str, op_class_name: str, model: LinalgOpDef): self.op_name = op_name + self.op_class_name = op_class_name self.model = model - def __call__(self, *args, emit_generic: bool = True, **kwargs): + def __call__(self, *args, emit_generic: bool = False, **kwargs): """Emits the corresponding op definition as IR. Most arguments are passed through to the underlying emitter. The following @@ -61,14 +62,23 @@ raise NotImplementedError( f"Emission of composite linalg ops not supported: {op_configs}") + # If the op_class_name is not registered, then we must emit_generic. + # TODO: this is chicken-and-eggy and the import is ugly but I didn't find a + # better way for now. + # TODO: this file should probably not be called dsl.py but rather is a client + # of the dsl.py. + import mlir + emit_generic = emit_generic or \ + (not self.op_class_name in mlir.dialects.linalg.__dict__.keys()) + op_config = op_configs[0] if op_config.structured_op: if emit_generic: return emit_generic_structured_op(op_config.structured_op, *args, **kwargs) else: - return emit_named_structured_op(op_config.structured_op, *args, - **kwargs) + return emit_named_structured_op(op_config.structured_op, self.op_name, + self.op_class_name, *args, **kwargs) raise NotImplementedError( f"Emission of linalg op type not supported: {op_config}") @@ -111,7 +121,7 @@ # TODO: The returned callable should be an IR emitter but that is not # upstreamed yet. - return DefinedOpCallable(op_name, tc_model) + return DefinedOpCallable(op_name, op_class_name, tc_model) def implements(*interfaces: OpInterfaceDef): diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -17,9 +17,9 @@ ] -def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, - outs: Value = ()): +def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value): all_arg_defs = op_config.ordered_tensor_args in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] @@ -49,6 +49,18 @@ [AffineMapAttr.get(am) for am in op_config.indexing_maps]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) + + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, + type_mapping, indexing_maps_attr, iterator_types_attr) + + +def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value = ()): + all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs) + generic_op = linalg.GenericOp( result_tensors=out_types, inputs=ins, @@ -77,10 +89,23 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, + op_name: str, + op_class_name: str, *ins: Value, outs: Value = ()): - raise NotImplementedError( - f"Emission of named structured ops is not supported: {op_config}") + all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs) + + if not op_class_name in linalg.__dict__.keys(): + raise NotImplementedError( + f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") + + named_op = getattr(linalg, op_class_name)(ins, outs, out_types) + if len(out_arg_defs) == 1: + return named_op.result + else: + return named_op.results class _BodyBuilder: diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -80,3 +80,26 @@ # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>) print(module) + +# CHECK-LABEL: TEST: testNamedStructuredOp +@run +def testNamedStructuredOp(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), + RankedTensorType.get((16, 8), f32)) + def named_form(lhs, rhs): + init_result = linalg.InitTensorOp([4, 8], f32) + # CHECK: linalg.matmul + return linalg.matmul(lhs, rhs, outs=[init_result.result]) + + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), + RankedTensorType.get((16, 8), f32)) + def generic_form(lhs, rhs): + init_result = linalg.InitTensorOp([4, 8], f32) + # CHECK: linalg.generic + return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True) + + print(module)