diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -10,6 +10,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e +from ._ods_common import get_op_result_or_value as _get_op_result_or_value def isa(cls: Type, ty: Type): try: @@ -26,11 +27,12 @@ results = [] if isa(RankedTensorType, output.type): results = [output.type] - op = self.build_generic(results=results, - operands=[value, output], - attributes=None, - loc=loc, - ip=ip) + op = self.build_generic( + results=results, + operands=[_get_op_result_or_value(o) for o in [value, output]], + attributes=None, + loc=loc, + ip=ip) OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, self.operation) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -5,11 +5,14 @@ # Provide a convenient name for sub-packages to resolve the main C-extension # with a relative import. from .._mlir_libs import _mlir as _cext +from typing import Sequence as _Sequence, Union as _Union __all__ = [ "equally_sized_accessor", "extend_opview_class", "get_default_loc_context", + "get_op_result_or_value", + "get_op_results_or_values", "segmented_accessor", ] @@ -118,3 +121,38 @@ # Location.current raises ValueError if there is no current location. return _cext.ir.Location.current.context return location.context + + +def get_op_result_or_value( + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value] +) -> _cext.ir.Value: + """Returns the given value or the single result of the given op. + + This is useful to implement op constructors so that they can take other ops as + arguments instead of requiring the caller to extract results for every op. + Raises ValueError if provided with an op that doesn't have a single result. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.result + elif isinstance(arg, _cext.ir.Operation): + return arg.result + else: + assert isinstance(arg, _cext.ir.Value) + return arg + + +def get_op_results_or_values( + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]] +) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: + """Returns the given sequence of values or the results of the given op. + + This is useful to implement op constructors so that they can take other ops as + lists of arguments instead of requiring the caller to extract results for + every op. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.results + elif isinstance(arg, _cext.ir.Operation): + return arg.results + else: + return arg diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -7,8 +7,8 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Any, Sequence - +from typing import Any, Optional, Sequence, Union +from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values class ForOp: """Specialization for the SCF for op class.""" @@ -17,7 +17,8 @@ lower_bound, upper_bound, step, - iter_args: Sequence[Any] = [], + iter_args: Optional[Union[Operation, OpView, + Sequence[Value]]] = None, *, loc=None, ip=None): @@ -26,14 +27,22 @@ - `lower_bound` is the value to use as lower bound of the loop. - `upper_bound` is the value to use as upper bound of the loop. - `step` is the value to use as loop step. - - `iter_args` is a list of additional loop-carried arguments. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. """ + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + results = [arg.type for arg in iter_args] super().__init__( self.build_generic( regions=1, results=results, - operands=[lower_bound, upper_bound, step] + list(iter_args), + operands=[ + _get_op_result_or_value(o) + for o in [lower_bound, upper_bound, step] + ] + list(iter_args), loc=loc, ip=ip)) self.regions[0].blocks.append(IndexType.get(), *results) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, List +from typing import Dict, List, Sequence, Union from contextlib import contextmanager import functools @@ -10,12 +10,15 @@ import threading from ..... import ir +from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values from .comprehension import * from .config import * from .emitter import * _CONTEXT = threading.local() +StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList, + Sequence[Union[ir.Value, ir.Operation, ir.OpView]]] @contextmanager def bind_op_def(model: LinalgOpDef): @@ -37,6 +40,15 @@ "but none is set. Did you mean to call this in an op definition?") +def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList: + if isinstance(outs, (ir.Operation, ir.OpView)): + return _get_op_results_or_values(outs) + elif isinstance(outs, ir.OpResultList): + return outs + + return [_get_op_result_or_value(o) for o in outs] + + class DefinedOpCallable: """Callable that wraps any defined op function.""" @@ -44,7 +56,8 @@ self.op_name = op_name self.model = model - def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs): + def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value], + outs: StructuredOpOuts, **kwargs): """Emits the corresponding op definition as IR. Most arguments are passed through to the underlying emitter. The following @@ -73,17 +86,19 @@ emit_generic or not ctx.is_registered_operation(fully_qualified_name)) op_config = op_configs[0] + out_values = _prepare_structured_op_outs(outs) + in_values = [_get_op_result_or_value(i) for i in ins] if op_config.structured_op: if emit_generic: return emit_generic_structured_op( - op_config.structured_op, *ins, outs=outs, **kwargs) + op_config.structured_op, *in_values, outs=out_values, **kwargs) else: return emit_named_structured_op( op_config.structured_op, self.op_name, self.model.metadata.cpp_class_name, - *ins, - outs=outs, + *in_values, + outs=out_values, **kwargs) raise NotImplementedError( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, Sequence +from typing import Dict, List, Sequence, Tuple, Union from .....ir import * from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region @@ -10,6 +10,7 @@ from .... import linalg from .... import std from .... import math +from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values from .scalar_expr import * from .config import * @@ -18,8 +19,10 @@ __all__ = [ "emit_generic_structured_op", "emit_named_structured_op", + "ValueList", ] +ValueList = Union[Sequence[Value], OpResultList] def isa(cls: Type, ty: Type): try: @@ -30,17 +33,18 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, outs: Sequence[Value], + *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs = op_config.ordered_operands in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"] attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"] - # Verify outs is a sequence. - if not isinstance(outs, Sequence): - raise ValueError(f"Expected named argument outs to have type Sequence " - f"but got {type(outs)}") + # Verify outs is a sequence or a list of results. + if not isinstance(outs, (Sequence, OpResultList)): + raise ValueError( + f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}" + ) # Arity validation. if len(ins) != len(in_arg_defs): @@ -122,7 +126,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: Sequence[Value], **attrs: Sequence[int]): + outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) @@ -153,8 +157,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, - op_class_name: str, *ins: Value, - outs: Sequence[Value], **attrs: Sequence[int]): + op_class_name: str, *ins: Value, outs: ValueList, + **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) @@ -355,11 +359,11 @@ return std.MinUIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") -def _infer_structured_outs(op_config: LinalgStructuredOpConfig, - in_arg_defs: Sequence[OperandDefConfig], - ins: Sequence[Value], - out_arg_defs: Sequence[OperandDefConfig], - outs: Sequence[Value]): +def _infer_structured_outs( + op_config: LinalgStructuredOpConfig, + in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value], + out_arg_defs: Sequence[OperandDefConfig], + outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]: """Infers implicit outs and output types. Respects existing contents of outs if not empty. diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -24,9 +24,9 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: operands.append(variadic1) - // CHECK: operands.append(non_variadic) - // CHECK: if variadic2 is not None: operands.append(variadic2) + // CHECK: operands.append(_get_op_results_or_values(variadic1)) + // CHECK: operands.append(_get_op_result_or_value(non_variadic)) + // CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2)) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -150,8 +150,8 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: operands.append(_gen_arg_0) - // CHECK: operands.append(_gen_arg_2) + // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) + // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) // CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: if is_ is not None: attributes["is"] = is_ @@ -197,9 +197,9 @@ // CHECK: results.append(i32) // CHECK: results.append(_gen_res_1) // CHECK: results.append(i64) - // CHECK: operands.append(_gen_arg_0) - // CHECK: operands.append(f32) - // CHECK: operands.append(_gen_arg_2) + // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) + // CHECK: operands.append(_get_op_result_or_value(f32)) + // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -230,8 +230,8 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: operands.append(non_variadic) - // CHECK: operands.extend(variadic) + // CHECK: operands.append(_get_op_result_or_value(non_variadic)) + // CHECK: operands.extend(_get_op_results_or_values(variadic)) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -285,7 +285,7 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: operands.append(in_) + // CHECK: operands.append(_get_op_result_or_value(in_)) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -353,8 +353,8 @@ // CHECK: attributes = {} // CHECK: results.append(i64) // CHECK: results.append(f64) - // CHECK: operands.append(i32) - // CHECK: operands.append(f32) + // CHECK: operands.append(_get_op_result_or_value(i32)) + // CHECK: operands.append(_get_op_result_or_value(f32)) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -185,3 +185,30 @@ return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True) print(module) + + +# CHECK-LABEL: TEST: testOpResultFromOtherOp +@run +def testOpResultFromOtherOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), + f32)) + def pass_an_op_directly(arg0, arg1): + one = std.ConstantOp(F32Type.get(), 1.0) + # CHECK: %[[LHS:.*]] = linalg.fill + lhs = linalg.FillOp(arg0, one) + # CHECK: %[[RHS:.*]] = linalg.fill + rhs = linalg.FillOp(arg1, one) + # CHECK: %[[INIT:.*]] = linalg.init_tensor + init = linalg.InitTensorOp([4, 8], f32) + # CHECK: linalg.matmul + # CHECK: ins(%[[LHS]], %[[RHS]] + # CHECK: outs(%[[INIT]] + return linalg.matmul(lhs, rhs, outs=init) + + print(module) diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -2,53 +2,82 @@ from mlir.ir import * from mlir.dialects import scf +from mlir.dialects import std from mlir.dialects import builtin -def run(f): +def constructAndPrintInModule(f): print("\nTEST:", f.__name__) - f() + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) return f # CHECK-LABEL: TEST: testSimpleLoop -@run +@constructAndPrintInModule def testSimpleLoop(): - with Context(), Location.unknown(): - module = Module.create() - index_type = IndexType.get() - with InsertionPoint(module.body): + index_type = IndexType.get() - @builtin.FuncOp.from_py_func(index_type, index_type, index_type) - def simple_loop(lb, ub, step): - loop = scf.ForOp(lb, ub, step, [lb, lb]) - with InsertionPoint(loop.body): - scf.YieldOp(loop.inner_iter_args) - return + @builtin.FuncOp.from_py_func(index_type, index_type, index_type) + def simple_loop(lb, ub, step): + loop = scf.ForOp(lb, ub, step, [lb, lb]) + with InsertionPoint(loop.body): + scf.YieldOp(loop.inner_iter_args) + return - # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) - # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] - # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]]) - # CHECK: scf.yield %[[I1]], %[[I2]] - print(module) + +# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] +# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]]) +# CHECK: scf.yield %[[I1]], %[[I2]] # CHECK-LABEL: TEST: testInductionVar -@run +@constructAndPrintInModule def testInductionVar(): - with Context(), Location.unknown(): - module = Module.create() - index_type = IndexType.get() - with InsertionPoint(module.body): + index_type = IndexType.get() + + @builtin.FuncOp.from_py_func(index_type, index_type, index_type) + def induction_var(lb, ub, step): + loop = scf.ForOp(lb, ub, step, [lb]) + with InsertionPoint(loop.body): + scf.YieldOp([loop.induction_variable]) + return + + +# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] +# CHECK: scf.yield %[[IV]] + + +@constructAndPrintInModule +def testOpsAsArguments(): + index_type = IndexType.get() + callee = builtin.FuncOp( + "callee", ([], [index_type, index_type]), visibility="private") + func = builtin.FuncOp("ops_as_arguments", ([], [])) + with InsertionPoint(func.add_entry_block()): + lb = std.ConstantOp.create_index(0) + ub = std.ConstantOp.create_index(42) + step = std.ConstantOp.create_index(2) + iter_args = std.CallOp(callee, []) + loop = scf.ForOp(lb, ub, step, iter_args) + with InsertionPoint(loop.body): + scf.YieldOp(loop.inner_iter_args) + std.ReturnOp([]) + - @builtin.FuncOp.from_py_func(index_type, index_type, index_type) - def induction_var(lb, ub, step): - loop = scf.ForOp(lb, ub, step, [lb]) - with InsertionPoint(loop.body): - scf.YieldOp([loop.induction_variable]) - return - - # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) - # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] - # CHECK: scf.yield %[[IV]] - print(module) +# CHECK-LABEL: TEST: testOpsAsArguments +# CHECK: func private @callee() -> (index, index) +# CHECK: func @ops_as_arguments() { +# CHECK: %[[LB:.*]] = constant 0 +# CHECK: %[[UB:.*]] = constant 42 +# CHECK: %[[STEP:.*]] = constant 2 +# CHECK: %[[ARGS:.*]]:2 = call @callee() +# CHECK: scf.for %arg0 = %c0 to %c42 step %c2 +# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1) +# CHECK: scf.yield %{{.*}}, %{{.*}} +# CHECK: return diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -28,7 +28,7 @@ # Autogenerated by mlir-tblgen; don't manually edit. from ._ods_common import _cext as _ods_cext -from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context +from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values _ods_ir = _ods_cext.ir try: @@ -489,20 +489,25 @@ )Py"; /// Template for appending a single element to the operand/result list. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; +/// {0} is the field name. +constexpr const char *singleOperandAppendTemplate = + "operands.append(_get_op_result_or_value({0}))"; +constexpr const char *singleResultAppendTemplate = "results.append({0})"; /// Template for appending an optional element to the operand/result list. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *optionalAppendTemplate = - "if {1} is not None: {0}s.append({1})"; - -/// Template for appending a a list of elements to the operand/result list. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})"; +/// {0} is the field name. +constexpr const char *optionalAppendOperandTemplate = + "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; +constexpr const char *optionalAppendResultTemplate = + "if {0} is not None: results.append({0})"; + +/// Template for appending a list of elements to the operand/result list. +/// {0} is the field name. +constexpr const char *multiOperandAppendTemplate = + "operands.extend(_get_op_results_or_values({0}))"; +constexpr const char *multiOperandAppendPackTemplate = + "operands.append(_get_op_results_or_values({0}))"; +constexpr const char *multiResultAppendTemplate = "results.extend({0})"; /// Template for setting an attribute in the operation builder. /// {0} is the attribute name; @@ -625,43 +630,70 @@ } /// Populates `builderLines` with additional lines that are required in the -/// builder. `kind` must be either "operand" or "result". `names` contains the -/// names of init arguments that correspond to the elements. -static void populateBuilderLines( - const Operator &op, const char *kind, llvm::ArrayRef names, - llvm::SmallVectorImpl &builderLines, - llvm::function_ref getNumElements, - llvm::function_ref - getElement) { - bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; +/// builder to set up op operands. +static void +populateBuilderLinesOperand(const Operator &op, + llvm::ArrayRef names, + llvm::SmallVectorImpl &builderLines) { + bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; // For each element, find or generate a name. - for (int i = 0, e = getNumElements(op); i < e; ++i) { - const NamedTypeConstraint &element = getElement(op, i); + for (int i = 0, e = op.getNumOperands(); i < e; ++i) { + const NamedTypeConstraint &element = op.getOperand(i); + std::string name = names[i]; + + // Choose the formatting string based on the element kind. + llvm::StringRef formatString; + if (!element.isVariableLength()) { + formatString = singleOperandAppendTemplate; + } else if (element.isOptional()) { + formatString = optionalAppendOperandTemplate; + } else { + assert(element.isVariadic() && "unhandled element group type"); + // If emitting with sizedSegments, then we add the actual list-typed + // element. Otherwise, we extend the actual operands. + if (sizedSegments) { + formatString = multiOperandAppendPackTemplate; + } else { + formatString = multiOperandAppendTemplate; + } + } + + builderLines.push_back(llvm::formatv(formatString.data(), name)); + } +} + +/// Populates `builderLines` with additional lines that are required in the +/// builder to set up op results. +static void +populateBuilderLinesResult(const Operator &op, + llvm::ArrayRef names, + llvm::SmallVectorImpl &builderLines) { + bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; + + // For each element, find or generate a name. + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + const NamedTypeConstraint &element = op.getResult(i); std::string name = names[i]; // Choose the formatting string based on the element kind. llvm::StringRef formatString; if (!element.isVariableLength()) { - formatString = singleElementAppendTemplate; + formatString = singleResultAppendTemplate; } else if (element.isOptional()) { - formatString = optionalAppendTemplate; + formatString = optionalAppendResultTemplate; } else { assert(element.isVariadic() && "unhandled element group type"); - // If emitting with sizedSegments, then we add the actual list typed - // element using the singleElementAppendTemplate. Otherwise, we extend - // the actual operands. + // If emitting with sizedSegments, then we add the actual list-typed + // element. Otherwise, we extend the actual operands. if (sizedSegments) { - // Append the list as is. - formatString = singleElementAppendTemplate; + formatString = singleResultAppendTemplate; } else { - // Append the list elements. - formatString = multiElementAppendTemplate; + formatString = multiResultAppendTemplate; } } - // Add the lines. - builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); + builderLines.push_back(llvm::formatv(formatString.data(), name)); } } @@ -680,12 +712,10 @@ op.getNumNativeAttributes() + op.getNumSuccessors()); populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); - populateBuilderLines( - op, "result", - llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), - builderLines, getNumResults, getResult); - populateBuilderLines(op, "operand", operandArgNames, builderLines, - getNumOperands, getOperand); + populateBuilderLinesResult( + op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), + builderLines); + populateBuilderLinesOperand(op, operandArgNames, builderLines); populateBuilderLinesAttr( op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), builderLines);