diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1153,7 +1153,7 @@ throw py::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(resultSegmentSpec.size()) + - "result segments but was provided " + + " result segments but was provided " + llvm::Twine(resultTypeList.size())) .str()); } @@ -1164,7 +1164,7 @@ if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto resultType = py::cast(std::get<0>(it.value())); + auto *resultType = py::cast(std::get<0>(it.value())); if (resultType) { resultTypes.push_back(resultType); resultSegmentLengths.push_back(1); diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -18,7 +18,7 @@ // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedOperandsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands" -// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,] +// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,] def AttrSizedOperandsOp : TestOp<"attr_sized_operands", [AttrSizedOperandSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): @@ -28,7 +28,7 @@ // CHECK: regions = None // CHECK: operands.append(_get_op_results_or_values(variadic1)) // CHECK: operands.append(_get_op_result_or_value(non_variadic)) - // CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2)) + // CHECK: operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -40,6 +40,7 @@ // CHECK: self.operation.operands, // CHECK: self.operation.attributes["operand_segment_sizes"], 0) // CHECK: return operand_range + // CHECK-NOT: if len(operand_range) // // CHECK: @builtins.property // CHECK: def non_variadic(self): @@ -61,7 +62,7 @@ // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedResultsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results" -// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,] +// CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,] def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): @@ -71,7 +72,7 @@ // CHECK: regions = None // CHECK: if variadic1 is not None: results.append(variadic1) // CHECK: results.append(non_variadic) - // CHECK: if variadic2 is not None: results.append(variadic2) + // CHECK: results.append(variadic2) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -97,8 +98,9 @@ // CHECK: self.operation.results, // CHECK: self.operation.attributes["result_segment_sizes"], 2) // CHECK: return result_range + // CHECK-NOT: if len(result_range) let results = (outs Optional:$variadic1, AnyType:$non_variadic, - Optional:$variadic2); + Variadic:$variadic2); } @@ -277,6 +279,35 @@ let results = (outs I32:$i32, AnyFloat, I64:$i64); } +// CHECK: @_ods_cext.register_operation(_Dialect) +// CHECK: class OneOptionalOperandOp(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS +def OneOptionalOperandOp : TestOp<"one_optional_operand"> { + let arguments = (ins AnyType:$non_optional, Optional:$optional); + // CHECK: def __init__(self, non_optional, optional, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: regions = None + // CHECK: operands.append(_get_op_result_or_value(non_optional)) + // CHECK: if optional is not None: operands.append(_get_op_result_or_value(optional)) + // CHECK: _ods_successors = None + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + + // CHECK: @builtins.property + // CHECK: def non_optional(self): + // CHECK: return self.operation.operands[0] + + // CHECK: @builtins.property + // CHECK: def optional(self): + // CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None + +} + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicOperandOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand" diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -2,25 +2,58 @@ from mlir.ir import * import mlir.dialects.builtin as builtin +import mlir.dialects.std as std import mlir.dialects.vector as vector def run(f): print("\nTEST:", f.__name__) - f() + with Context(), Location.unknown(): + f() + return f # CHECK-LABEL: TEST: testPrintOp @run def testPrintOp(): - with Context() as ctx, Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - @builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get())) - def print_vector(arg): - return vector.PrintOp(arg) - - # CHECK-LABEL: func @print_vector( - # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) { - # CHECK: vector.print %[[ARG]] : vector<12x5xf32> - # CHECK: return - # CHECK: } - print(module) + module = Module.create() + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get())) + def print_vector(arg): + return vector.PrintOp(arg) + + # CHECK-LABEL: func @print_vector( + # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) { + # CHECK: vector.print %[[ARG]] : vector<12x5xf32> + # CHECK: return + # CHECK: } + print(module) + + +# CHECK-LABEL: TEST: testTransferReadOp +@run +def testTransferReadOp(): + module = Module.create() + with InsertionPoint(module.body): + vector_type = VectorType.get([2, 3], F32Type.get()) + memref_type = MemRefType.get([-1, -1], F32Type.get()) + index_type = IndexType.get() + mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1)) + identity_map = AffineMap.get_identity(vector_type.rank) + identity_map_attr = AffineMapAttr.get(identity_map) + func = builtin.FuncOp("transfer_read", + ([memref_type, index_type, + F32Type.get(), mask_type], [])) + with InsertionPoint(func.add_entry_block()): + A, zero, padding, mask = func.arguments + vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, + padding, mask, None) + vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, + padding, None, None) + std.ReturnOp([]) + + # CHECK: @transfer_read(%[[MEM:.*]]: memref, %[[IDX:.*]]: index, + # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>) + # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] + # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]] + # CHECK-NOT: %[[MASK]] + print(module) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -67,6 +67,7 @@ /// Each segment spec is either None (default) or an array of integers /// where: /// 1 = single element (expect non sequence operand/result) +/// 0 = optional element (expect a value or None) /// -1 = operand/result is a sequence corresponding to a variadic constexpr const char *opClassSizedSegmentsTemplate = R"Py( _ODS_{0}_SEGMENTS = {1} @@ -505,6 +506,9 @@ /// {0} is the field name. constexpr const char *optionalAppendOperandTemplate = "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; +constexpr const char *optionalAppendAttrSizedOperandsTemplate = + "operands.append(_get_op_result_or_value({0}) if {0} is not None else " + "None)"; constexpr const char *optionalAppendResultTemplate = "if {0} is not None: results.append({0})"; @@ -693,7 +697,11 @@ if (!element.isVariableLength()) { formatString = singleOperandAppendTemplate; } else if (element.isOptional()) { - formatString = optionalAppendOperandTemplate; + if (sizedSegments) { + formatString = optionalAppendAttrSizedOperandsTemplate; + } else { + formatString = optionalAppendOperandTemplate; + } } else { assert(element.isVariadic() && "unhandled element group type"); // If emitting with sizedSegments, then we add the actual list-typed @@ -882,10 +890,10 @@ std::string segmentSpec("["); for (int i = 0, e = getNumElements(op); i < e; ++i) { const NamedTypeConstraint &element = getElement(op, i); - if (element.isVariableLength()) { - segmentSpec.append("-1,"); - } else if (element.isOptional()) { + if (element.isOptional()) { segmentSpec.append("0,"); + } else if (element.isVariadic()) { + segmentSpec.append("-1,"); } else { segmentSpec.append("1,"); }