diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -106,7 +106,7 @@ SameOperandsAndResultType, ElementwiseMappable])> { - let results = (outs AnyType); + let results = (outs AnyType:$result); let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py --- a/mlir/test/Bindings/Python/dialects.py +++ b/mlir/test/Bindings/Python/dialects.py @@ -63,7 +63,7 @@ run(testUserDialectClass) -# XHECK-LABEL: TEST: testCustomOpView +# CHECK-LABEL: TEST: testCustomOpView # This test uses the standard dialect AddFOp as an example of a user op. # TODO: Op creation and access is still quite verbose: simplify this test as # additional capabilities come online. @@ -82,17 +82,17 @@ # Create via dialects context collection. input1 = createInput() input2 = createInput() - op1 = ctx.dialects.std.AddFOp(input1, input2) + op1 = ctx.dialects.std.AddFOp(input1.type, input1, input2) # Create via an import from mlir.dialects.std import AddFOp - AddFOp(input1, op1.result) + AddFOp(input1.type, input1, op1.result) - # XHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" - # XHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" - # XHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32 - # XHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32 + # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" + # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" + # CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32 + # CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32 m.operation.print() -# TODO: re-enable when constructs are generated again -# run(testCustomOpView) + +run(testCustomOpView) diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -18,6 +18,23 @@ // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands" def AttrSizedOperandsOp : TestOp<"attr_sized_operands", [AttrSizedOperandSegments]> { + // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: operand_segment_sizes = array.array('L') + // CHECK: operands += [*variadic1] + // CHECK: operand_segment_sizes.append(len(variadic1)) + // CHECK: operands.append(non_variadic) + // CHECK: operand_segment_sizes.append(1) + // CHECK: if variadic2 is not None: operands.append(variadic2) + // CHECK: operand_segment_sizes.append(0 if variadic2 is None else 1) + // CHECK: attributes["operand_segment_sizes"] = _ir.DenseElementsAttr.get(operand_segment_sizes, + // CHECK: context=Location.current.context if loc is None else loc.context) + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.attr_sized_operands", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + // CHECK: @property // CHECK: def variadic1(self): // CHECK: operand_range = _segmented_accessor( @@ -47,6 +64,23 @@ // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results" def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { + // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: result_segment_sizes = array.array('L') + // CHECK: if variadic1 is not None: results.append(variadic1) + // CHECK: result_segment_sizes.append(0 if variadic1 is None else 1) + // CHECK: results.append(non_variadic) + // CHECK: result_segment_sizes.append(1) # non_variadic + // CHECK: if variadic2 is not None: results.append(variadic2) + // CHECK: result_segment_sizes.append(0 if variadic2 is None else 1) + // CHECK: attributes["result_segment_sizes"] = _ir.DenseElementsAttr.get(result_segment_sizes, + // CHECK: context=Location.current.context if loc is None else loc.context) + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.attr_sized_results", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + // CHECK: @property // CHECK: def variadic1(self): // CHECK: result_range = _segmented_accessor( @@ -75,11 +109,32 @@ // CHECK: class EmptyOp(_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.empty" def EmptyOp : TestOp<"empty">; + // CHECK: def __init__(self, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.empty", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) // CHECK: @_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.missing_names" def MissingNamesOp : TestOp<"missing_names"> { + // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: results.append(i32) + // CHECK: results.append(_gen_res_1) + // CHECK: results.append(i64) + // CHECK: operands.append(_gen_arg_0) + // CHECK: operands.append(f32) + // CHECK: operands.append(_gen_arg_2) + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.missing_names", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + // CHECK: @property // CHECK: def f32(self): // CHECK: return self.operation.operands[1] @@ -99,6 +154,16 @@ // CHECK: class OneVariadicOperandOp(_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand" def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { + // CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: operands.append(non_variadic) + // CHECK: operands += [*variadic] + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.one_variadic_operand", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + // CHECK: @property // CHECK: def non_variadic(self): // CHECK: return self.operation.operands[0] @@ -114,6 +179,16 @@ // CHECK: class OneVariadicResultOp(_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result" def OneVariadicResultOp : TestOp<"one_variadic_result"> { + // CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: results += [*variadic] + // CHECK: results.append(non_variadic) + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.one_variadic_result", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + // CHECK: @property // CHECK: def variadic(self): // CHECK: variadic_group_length = len(self.operation.results) - 2 + 1 @@ -130,6 +205,15 @@ // CHECK: class PythonKeywordOp(_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.python_keyword" def PythonKeywordOp : TestOp<"python_keyword"> { + // CHECK: def __init__(self, in_, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: operands.append(in_) + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.python_keyword", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + // CHECK: @property // CHECK: def in_(self): // CHECK: return self.operation.operands[0] @@ -186,6 +270,18 @@ // CHECK: class SimpleOp(_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.simple" def SimpleOp : TestOp<"simple"> { + // CHECK: def __init__(self, i64, f64, i32, f32, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: results.append(i64) + // CHECK: results.append(f64) + // CHECK: operands.append(i32) + // CHECK: operands.append(f32) + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.simple", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + // CHECK: @property // CHECK: def i32(self): // CHECK: return self.operation.operands[0] diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -26,6 +26,7 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. +import array from . import _cext from . import _segmented_accessor, _equally_sized_accessor _ir = _cext.ir @@ -172,6 +173,12 @@ return name.str(); } +static std::string attrSizedTraitForKind(const char *kind) { + return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", + llvm::StringRef(kind).take_front().upper(), + llvm::StringRef(kind).drop_front()); +} + /// Emits accessors to "elements" of an Op definition. Currently, the supported /// elements are operands and results, indicated by `kind`, which must be either /// `operand` or `result` and is used verbatim in the emitted code. @@ -190,10 +197,7 @@ llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", llvm::StringRef(kind).take_front().upper(), llvm::StringRef(kind).drop_front()); - std::string attrSizedTrait = - llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", - llvm::StringRef(kind).take_front().upper(), - llvm::StringRef(kind).drop_front()); + std::string attrSizedTrait = attrSizedTraitForKind(kind); unsigned numVariadic = getNumVariadic(op); @@ -271,20 +275,23 @@ llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); } +/// Free function helpers accessing Operator components. +static int getNumOperands(const Operator &op) { return op.getNumOperands(); } +static const NamedTypeConstraint &getOperand(const Operator &op, int i) { + return op.getOperand(i); +} +static int getNumResults(const Operator &op) { return op.getNumResults(); } +static const NamedTypeConstraint &getResult(const Operator &op, int i) { + return op.getResult(i); +} + /// Emits accessor to Op operands. static void emitOperandAccessors(const Operator &op, raw_ostream &os) { auto getNumVariadic = [](const Operator &oper) { return oper.getNumVariableLengthOperands(); }; - auto getNumElements = [](const Operator &oper) { - return oper.getNumOperands(); - }; - auto getElement = [](const Operator &oper, - int i) -> const NamedTypeConstraint & { - return oper.getOperand(i); - }; - emitElementAccessors(op, os, "operand", getNumVariadic, getNumElements, - getElement); + emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, + getOperand); } /// Emits access or Op results. @@ -292,21 +299,152 @@ auto getNumVariadic = [](const Operator &oper) { return oper.getNumVariableLengthResults(); }; - auto getNumElements = [](const Operator &oper) { - return oper.getNumResults(); - }; - auto getElement = [](const Operator &oper, - int i) -> const NamedTypeConstraint & { - return oper.getResult(i); - }; - emitElementAccessors(op, os, "result", getNumVariadic, getNumElements, - getElement); + emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, + getResult); +} + +/// Template for the default auto-generated builder. +/// {0} is the operation name; +/// {1} is a comma-separated list of builder arguments, including the trailing +/// `loc` and `ip`; +/// {2} is the code populating `operands`, `results` and `attributes` fields. +constexpr const char *initTemplate = R"Py( + def __init__(self, {1}): + operands = [] + results = [] + attributes = {{} + {2} + super().__init__(_ir.Operation.create( + "{0}", attributes=attributes, operands=operands, results=results, + loc=loc, ip=ip)) +)Py"; + +/// Template for appending a single element to the operand/result list. +/// {0} is either 'operand' or 'result'; +/// {1} is the field name. +constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; + +/// Template for appending an optional element to the operand/result list. +/// {0} is either 'operand' or 'result'; +/// {1} is the field name. +constexpr const char *optionalAppendTemplate = + "if {1} is not None: {0}s.append({1})"; + +/// Template for appending a variadic element to the operand/result list. +/// {0} is either 'operand' or 'result'; +/// {1} is the field name. +constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]"; + +/// Template for setting up the segment sizes buffer. +constexpr const char *segmentDeclarationTemplate = + "{0}_segment_sizes = array.array('L')"; + +/// Template for attaching segment sizes to the attribute list. +constexpr const char *segmentAttributeTemplate = + R"Py(attributes["{0}_segment_sizes"] = _ir.DenseElementsAttr.get({0}_segment_sizes, + context=Location.current.context if loc is None else loc.context))Py"; + +/// Template for appending the unit size to the segment sizes. +/// {0} is either 'operand' or 'result'; +/// {1} is the field name. +constexpr const char *singleElementSegmentTemplate = + "{0}_segment_sizes.append(1) # {1}"; + +/// Template for appending 0/1 for an optional element to the segment sizes. +/// {0} is either 'operand' or 'result'; +/// {1} is the field name. +constexpr const char *optionalSegmentTemplate = + "{0}_segment_sizes.append(0 if {1} is None else 1)"; + +/// Template for appending the length of a variadic group to the segment sizes. +/// {0} is either 'operand' or 'result'; +/// {1} is the field name. +constexpr const char *variadicSegmentTemplate = + "{0}_segment_sizes.append(len({1}))"; + +/// Populates `builderArgs` with the list of `__init__` arguments that +/// correspond to either operands or results of `op`, and `builderLines` with +/// additional lines that are required in the builder. `kind` must be either +/// "operand" or "result". `unnamedTemplate` is used to generate names for +/// operands or results that don't have the name in ODS. +static void populateBuilderLines( + const Operator &op, const char *kind, const char *unnamedTemplate, + llvm::SmallVectorImpl &builderArgs, + llvm::SmallVectorImpl &builderLines, + llvm::function_ref getNumElements, + llvm::function_ref + getElement) { + // The segment sizes buffer only has to be populated if there attr-sized + // segments trait is present. + bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; + if (includeSegments) + builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind)); + + // For each element, find or generate a name. + for (int i = 0, e = getNumElements(op); i < e; ++i) { + const NamedTypeConstraint &element = getElement(op, i); + std::string name = element.name.str(); + if (name.empty()) + name = llvm::formatv(unnamedTemplate, i).str(); + name = sanitizeName(name); + builderArgs.push_back(name); + + // Choose the formatting string based on the element kind. + llvm::StringRef formatString, segmentFormatString; + if (!element.isVariableLength()) { + formatString = singleElementAppendTemplate; + segmentFormatString = singleElementSegmentTemplate; + } else if (element.isOptional()) { + formatString = optionalAppendTemplate; + segmentFormatString = optionalSegmentTemplate; + } else { + assert(element.isVariadic() && "unhandled element group type"); + formatString = variadicAppendTemplate; + segmentFormatString = variadicSegmentTemplate; + } + + // Add the lines. + builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); + if (includeSegments) + builderLines.push_back( + llvm::formatv(segmentFormatString.data(), kind, name)); + } + + if (includeSegments) + builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind)); +} + +/// Emits a default builder constructing an operation from the list of its +/// result types, followed by a list of its operands. +static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { + // TODO: support attribute types. + if (op.getNumNativeAttributes() != 0) + return; + + // If we are asked to skip default builders, comply. + if (op.skipDefaultBuilders()) + return; + + llvm::SmallVector builderArgs; + llvm::SmallVector builderLines; + builderArgs.reserve(op.getNumOperands() + op.getNumResults()); + populateBuilderLines(op, "result", "_gen_res_{0}", builderArgs, builderLines, + getNumResults, getResult); + populateBuilderLines(op, "operand", "_gen_arg_{0}", builderArgs, builderLines, + getNumOperands, getOperand); + + builderArgs.push_back("loc=None"); + builderArgs.push_back("ip=None"); + os << llvm::formatv(initTemplate, op.getOperationName(), + llvm::join(builderArgs, ", "), + llvm::join(builderLines, "\n ")); } /// Emits bindings for a specific Op to the given output stream. static void emitOpBindings(const Operator &op, raw_ostream &os) { os << llvm::formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); + emitDefaultOpBuilder(op, os); emitOperandAccessors(op, os); emitResultAccessors(op, os); }