diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -36,15 +36,6 @@ OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, self.operation) - # TODO: self.result is None. When len(results) == 1 we expect it to be - # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug - # in the generator of _linalg_ops_gen.py where we have: - # ``` - # def result(self): - # return self.operation.results[0] \ - # if len(self.operation.results) > 1 else None - # ``` - class InitTensorOp: """Extends the linalg.init_tensor op.""" diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -304,7 +304,7 @@ // CHECK: @builtins.property // CHECK: def optional(self): - // CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None + // CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1] } diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -68,10 +68,7 @@ @builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32)) def fill_tensor(out): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result - # TODO: FillOp.result is None. When len(results) == 1 we expect it to - # be results[0] as per _linalg_ops_gen.py. This seems like an - # orthogonal bug in the generator of _linalg_ops_gen.py. - return linalg.FillOp(output=out, value=zero).results[0] + return linalg.FillOp(output=out, value=zero).result # CHECK-LABEL: func @fill_buffer # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -207,3 +207,21 @@ print(implied.flt.type) # CHECK: index print(implied.index.type) + + +# CHECK-LABEL: TEST: testOptionalOperandOp +@run +def testOptionalOperandOp(): + with Context() as ctx, Location.unknown(): + test.register_python_test_dialect(ctx) + + module = Module.create() + with InsertionPoint(module.body): + + op1 = test.OptionalOperandOp(None) + # CHECK: op1.input is None: True + print(f"op1.input is None: {op1.input is None}") + + op2 = test.OptionalOperandOp(op1) + # CHECK: op2.input is None: False + print(f"op2.input is None: {op2.input is None}") diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -76,4 +76,9 @@ let results = (outs AnyType:$one, AnyType:$two, AnyType:$three); } +def OptionalOperandOp : TestOp<"optional_operand_op"> { + let arguments = (ins Optional:$input); + let results = (outs I32:$result); +} + #endif // PYTHON_TEST_OPS diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -109,10 +109,13 @@ /// {1} is either 'operand' or 'result'; /// {2} is the total number of element groups; /// {3} is the position of the current group in the group list. +/// This works if we have only one variable-length group (and it's the optional +/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is +/// smaller than the total number of groups. constexpr const char *opOneOptionalTemplate = R"Py( @builtins.property def {0}(self): - return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None + return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] )Py"; /// Template for the variadic group accessor in the single variadic group case: @@ -311,7 +314,7 @@ /// `operand` or `result` and is used verbatim in the emitted code. static void emitElementAccessors( const Operator &op, raw_ostream &os, const char *kind, - llvm::function_ref getNumVariadic, + llvm::function_ref getNumVariableLength, llvm::function_ref getNumElements, llvm::function_ref getElement) { @@ -326,12 +329,12 @@ llvm::StringRef(kind).drop_front()); std::string attrSizedTrait = attrSizedTraitForKind(kind); - unsigned numVariadic = getNumVariadic(op); + unsigned numVariableLength = getNumVariableLength(op); - // If there is only one variadic element group, its size can be inferred from - // the total number of elements. If there are none, the generation is - // straightforward. - if (numVariadic <= 1) { + // If there is only one variable-length element group, its size can be + // inferred from the total number of elements. If there are none, the + // generation is straightforward. + if (numVariableLength <= 1) { bool seenVariableLength = false; for (int i = 0, e = getNumElements(op); i < e; ++i) { const NamedTypeConstraint &element = getElement(op, i); @@ -364,7 +367,7 @@ const NamedTypeConstraint &element = getElement(op, i); if (!element.name.empty()) { os << llvm::formatv(opVariadicEqualPrefixTemplate, - sanitizeName(element.name), kind, numVariadic, + sanitizeName(element.name), kind, numVariableLength, numPrecedingSimple, numPrecedingVariadic); os << llvm::formatv(element.isVariableLength() ? opVariadicEqualVariadicTemplate @@ -414,20 +417,20 @@ /// Emits accessors to Op operands. static void emitOperandAccessors(const Operator &op, raw_ostream &os) { - auto getNumVariadic = [](const Operator &oper) { + auto getNumVariableLengthOperands = [](const Operator &oper) { return oper.getNumVariableLengthOperands(); }; - emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, - getOperand); + emitElementAccessors(op, os, "operand", getNumVariableLengthOperands, + getNumOperands, getOperand); } /// Emits accessors Op results. static void emitResultAccessors(const Operator &op, raw_ostream &os) { - auto getNumVariadic = [](const Operator &oper) { + auto getNumVariableLengthResults = [](const Operator &oper) { return oper.getNumVariableLengthResults(); }; - emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, - getResult); + emitElementAccessors(op, os, "result", getNumVariableLengthResults, + getNumResults, getResult); } /// Emits accessors to Op attributes.