diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -53,6 +53,9 @@ // Note: as the list may contain other lists this may not be final size. mlirOperands.reserve(operandList->size()); for (const auto &&it : llvm::enumerate(*operandList)) { + if (it.value().is_none()) + continue; + PyValue *val; try { val = py::cast(it.value()); diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -4,6 +4,7 @@ import mlir.dialects.func as func import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor +import mlir.dialects.arith as arith def run(f): print("\nTEST:", f.__name__) @@ -418,3 +419,27 @@ print("rank:", shaped_type_components.rank) print("element type:", shaped_type_components.element_type) print("shape:", shaped_type_components.shape) + + + +# CHECK-LABEL: TEST: testInferTypeOpInterface +@run +def testInferTypeOpInterface(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + i64 = IntegerType.get_signless(64) + zero = arith.ConstantOp(i64, 0) + + one_operand = test.InferResultsVariadicInputsOp( + single=zero, doubled=None + ) + # CHECK: i32 + print(one_operand.result.type) + + two_operands = test.InferResultsVariadicInputsOp( + single=zero, doubled=zero + ) + # CHECK: f32 + print(two_operands.result.type) \ No newline at end of file diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -101,6 +101,31 @@ }]; } +def I32OrF32 : TypeConstraint, + "i32 or f32">; + +def InferResultsVariadicInputsOp : TestOp<"infer_results_variadic_inputs_op", + [InferTypeOpInterface, AttrSizedOperandSegments]> { + let arguments = (ins Optional:$single, Optional:$doubled); + let results = (outs I32OrF32:$res); + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + ::mlir::Builder b(context); + if (operands.size() == 1) + inferredReturnTypes.push_back(b.getI32Type()); + else if (operands.size() == 2) + inferredReturnTypes.push_back(b.getF32Type()); + return ::mlir::success(); + } + }]; +} + // If all result types are buildable, the InferTypeOpInterface is implied and is // autogenerated by C++ ODS. def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {