diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py @@ -49,9 +49,6 @@ """All structured ops use the same mixin class.""" def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): - if outputs and results: - raise ValueError( - "Structured ops must have outputs or results, but not both.") super().__init__( self.build_generic(results=list(results), operands=[list(inputs), list(outputs)], diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py @@ -2,4 +2,52 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# These are the backing OpView classes generated from the linalg tablegen +# definitions following these steps: +# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * + +# These are the ground truth functions defined as: +# ``` +# @linalg_structured_op +# def matmul(A=TensorDef(T1, S.M, S.K), +# B=TensorDef(T2, S.K, S.N), +# C=TensorDef(U, S.M, S.N, output=True)): +# ``` +# using the linalg-py eDSL. +# The linalg-py eDSL builds a python representation (PyRepr) that is +# used in following ways: +# 1. PyRepr -> YAML to generate the C++ and Python .td files. These +# then turn into the core C++ Op classes and Python OpView classes +# respectively (made available in _linalg_ops_gen). The generic OpView class +# mechanism makes the C++ classes available to python through the CAPI. +# PyRepr -> YAML currently occurs before compiler compile time. +# The other steps in this category occur at compiler compile time. +# 2. PyRepr -> linalg.core_named_ops calls: piggybacks on the +# _linalg_ops_gen classes and the OpView mechanism to build IR at +# runtime in python: +# a. by default, the Named Op Form is emitted, e.g.: +# `linalg.matmul(lhs, rhs, outs=[out])` creates the following IR: +# ``` +# %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# outs(%0 : tensor<4x8xf32>) +# -> tensor<4x8xf32> +# ``` +# b. by setting emit_generic=True, the Generic Op Form is emitted, e.g.: +# `linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)` creates the following IR: +# ``` +# %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]} +# ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# outs(%0 : tensor<4x8xf32>) { +# ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): +# ... +# linalg.yield %3 : f32 +# } -> tensor<4x8xf32> +# ``` +# 3. PyRepr -> Runtime Custom Op definitions: directly generates a +# linalg.generic form like in 2.b. +# !!!WARNING!!!: if one creates a runtime custom op with the same name +# as an existing core named op, step 2. will likely take precedence. +# TODO: guard against surprises and fail create Runtime Custom Ops with +# the same name as existing Core Named Ops. +from .opdsl.ops.core_named_ops import * diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -359,16 +359,16 @@ """Metadata about the op (generally not behavior impacting).""" yaml_tag = "!LinalgOpMetadata" - def __init__(self, name: str, cpp_op_name: Optional[str], doc: Optional[str]): + def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): self.name = name - self.cpp_op_name = cpp_op_name if cpp_op_name is not None else name + self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name self.doc = doc self.implements = [] # type: List[OpInterfaceDef] def to_yaml_custom_dict(self): d = dict( name=self.name, - cpp_op_name=self.cpp_op_name, + cpp_class_name=self.cpp_class_name, doc=self.doc, ) if self.implements: @@ -381,9 +381,9 @@ def __init__(self, name: str, - cpp_op_name: Optional[str] = None, + cpp_class_name: Optional[str] = None, doc: Optional[str] = None): - self.metadata = OpMetadataDef(name=name, cpp_op_name=cpp_op_name, doc=doc) + self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) self.registered_tensors = dict() # type: Dict[str, TensorDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() @@ -413,7 +413,7 @@ def __repr__(self): lines = [ - f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_op_name}," + f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name}," ] for name, tensor in self.registered_tensors.items(): lines.append(f" {tensor}") diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -44,7 +44,7 @@ self.op_name = op_name self.model = model - def __call__(self, *args, emit_generic: bool = True, **kwargs): + def __call__(self, *args, emit_generic: bool = False, **kwargs): """Emits the corresponding op definition as IR. Most arguments are passed through to the underlying emitter. The following @@ -61,14 +61,21 @@ raise NotImplementedError( f"Emission of composite linalg ops not supported: {op_configs}") + # TODO: this file should probably not be called dsl.py but rather is a client + # of the dsl.py. + from .... import linalg as linalg_ops + emit_generic = (emit_generic or + (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys())) + op_config = op_configs[0] if op_config.structured_op: if emit_generic: return emit_generic_structured_op(op_config.structured_op, *args, **kwargs) else: - return emit_named_structured_op(op_config.structured_op, *args, - **kwargs) + return emit_named_structured_op( + op_config.structured_op, self.op_name, + self.model.metadata.cpp_class_name, *args, **kwargs) raise NotImplementedError( f"Emission of linalg op type not supported: {op_config}") @@ -91,7 +98,7 @@ op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" tc_model = LinalgOpDef(name=op_name, - cpp_op_name=op_class_name, + cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) # Extract arguments and TensorDefs from the signature. diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -17,9 +17,9 @@ ] -def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, - outs: Value = ()): +def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value): all_arg_defs = op_config.ordered_tensor_args in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] @@ -49,6 +49,18 @@ [AffineMapAttr.get(am) for am in op_config.indexing_maps]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) + + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, + type_mapping, indexing_maps_attr, iterator_types_attr) + + +def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value = ()): + all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs) + generic_op = linalg.GenericOp( result_tensors=out_types, inputs=ins, @@ -77,10 +89,23 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, + op_name: str, + op_class_name: str, *ins: Value, outs: Value = ()): - raise NotImplementedError( - f"Emission of named structured ops is not supported: {op_config}") + all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs) + + if not op_class_name in linalg.__dict__.keys(): + raise NotImplementedError( + f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") + + named_op = getattr(linalg, op_class_name)(ins, outs, out_types) + if len(out_arg_defs) == 1: + return named_op.result + else: + return named_op.results class _BodyBuilder: diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -75,8 +75,33 @@ results=[])) with InsertionPoint(func.add_entry_block()): lhs, rhs, result = func.entry_block.arguments + # TODO: prperly hook up the region. linalg.MatmulOp([lhs, rhs], outputs=[result]) std.ReturnOp([]) # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>) print(module) + +# CHECK-LABEL: TEST: testNamedStructuredOp +@run +def testNamedStructuredOp(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), + RankedTensorType.get((16, 8), f32)) + def named_form(lhs, rhs): + init_result = linalg.InitTensorOp([4, 8], f32) + # CHECK: linalg.matmul + # TODO: prperly hook up the region. + return linalg.matmul(lhs, rhs, outs=[init_result.result]) + + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), + RankedTensorType.get((16, 8), f32)) + def generic_form(lhs, rhs): + init_result = linalg.InitTensorOp([4, 8], f32) + # CHECK: linalg.generic + return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True) + + print(module)