diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -53,6 +53,9 @@ // Note: as the list may contain other lists this may not be final size. mlirOperands.reserve(operandList->size()); for (const auto &&it : llvm::enumerate(*operandList)) { + if (it.value().is_none()) + continue; + PyValue *val; try { val = py::cast(it.value()); diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -4,12 +4,15 @@ import mlir.dialects.func as func import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor +import mlir.dialects.arith as arith + def run(f): print("\nTEST:", f.__name__) f() return f + # CHECK-LABEL: TEST: testAttributes @run def testAttributes(): @@ -131,6 +134,7 @@ del op.unit print(f"Unit: {op.unit}") + # CHECK-LABEL: TEST: attrBuilder @run def attrBuilder(): @@ -216,8 +220,8 @@ print(first_type_attr.one.type) print(first_type_attr.two.type) - first_attr = test.FirstAttrDeriveAttrOp( - FloatAttr.get(F32Type.get(), 3.14)) + first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get( + F32Type.get(), 3.14)) # CHECK-COUNT-3: f32 print(first_attr.one.type) print(first_attr.two.type) @@ -344,6 +348,7 @@ i8 = IntegerType.get_signless(8) class Tensor(test.TestTensorValue): + def __str__(self): return super().__str__().replace("Value", "Tensor") @@ -371,50 +376,65 @@ # CHECK-LABEL: TEST: inferReturnTypeComponents @run def inferReturnTypeComponents(): - with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) - module = Module.create() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - resultType = UnrankedTensorType.get(i32) - operandTypes = [ - RankedTensorType.get([1, 3, 10, 10], i32), - UnrankedTensorType.get(i32), - ] - f = func.FuncOp( - "test_inferReturnTypeComponents", (operandTypes, [resultType]) - ) - entry_block = Block.create_at_start(f.operation.regions[0], operandTypes) - with InsertionPoint(entry_block): - ranked_op = test.InferShapedTypeComponentsOp( - resultType, entry_block.arguments[0] - ) - unranked_op = test.InferShapedTypeComponentsOp( - resultType, entry_block.arguments[1] - ) - - # CHECK: has rank: True - # CHECK: rank: 4 - # CHECK: element type: i32 - # CHECK: shape: [1, 3, 10, 10] - iface = InferShapedTypeOpInterface(ranked_op) - shaped_type_components = iface.inferReturnTypeComponents( - operands=[ranked_op.operand] - )[0] - print("has rank:", shaped_type_components.has_rank) - print("rank:", shaped_type_components.rank) - print("element type:", shaped_type_components.element_type) - print("shape:", shaped_type_components.shape) - - # CHECK: has rank: False - # CHECK: rank: None - # CHECK: element type: i32 - # CHECK: shape: None - iface = InferShapedTypeOpInterface(unranked_op) - shaped_type_components = iface.inferReturnTypeComponents( - operands=[unranked_op.operand] - )[0] - print("has rank:", shaped_type_components.has_rank) - print("rank:", shaped_type_components.rank) - print("element type:", shaped_type_components.element_type) - print("shape:", shaped_type_components.shape) + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + resultType = UnrankedTensorType.get(i32) + operandTypes = [ + RankedTensorType.get([1, 3, 10, 10], i32), + UnrankedTensorType.get(i32), + ] + f = func.FuncOp("test_inferReturnTypeComponents", + (operandTypes, [resultType])) + entry_block = Block.create_at_start(f.operation.regions[0], operandTypes) + with InsertionPoint(entry_block): + ranked_op = test.InferShapedTypeComponentsOp(resultType, + entry_block.arguments[0]) + unranked_op = test.InferShapedTypeComponentsOp(resultType, + entry_block.arguments[1]) + + # CHECK: has rank: True + # CHECK: rank: 4 + # CHECK: element type: i32 + # CHECK: shape: [1, 3, 10, 10] + iface = InferShapedTypeOpInterface(ranked_op) + shaped_type_components = iface.inferReturnTypeComponents( + operands=[ranked_op.operand])[0] + print("has rank:", shaped_type_components.has_rank) + print("rank:", shaped_type_components.rank) + print("element type:", shaped_type_components.element_type) + print("shape:", shaped_type_components.shape) + + # CHECK: has rank: False + # CHECK: rank: None + # CHECK: element type: i32 + # CHECK: shape: None + iface = InferShapedTypeOpInterface(unranked_op) + shaped_type_components = iface.inferReturnTypeComponents( + operands=[unranked_op.operand])[0] + print("has rank:", shaped_type_components.has_rank) + print("rank:", shaped_type_components.rank) + print("element type:", shaped_type_components.element_type) + print("shape:", shaped_type_components.shape) + + +# CHECK-LABEL: TEST: testInferTypeOpInterface +@run +def testInferTypeOpInterface(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + i64 = IntegerType.get_signless(64) + zero = arith.ConstantOp(i64, 0) + + one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None) + # CHECK: i32 + print(one_operand.result.type) + + two_operands = test.InferResultsVariadicInputsOp(single=zero, + doubled=zero) + # CHECK: f32 + print(two_operands.result.type) diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -101,6 +101,31 @@ }]; } +def I32OrF32 : TypeConstraint, + "i32 or f32">; + +def InferResultsVariadicInputsOp : TestOp<"infer_results_variadic_inputs_op", + [InferTypeOpInterface, AttrSizedOperandSegments]> { + let arguments = (ins Optional:$single, Optional:$doubled); + let results = (outs I32OrF32:$res); + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + ::mlir::Builder b(context); + if (operands.size() == 1) + inferredReturnTypes.push_back(b.getI32Type()); + else if (operands.size() == 2) + inferredReturnTypes.push_back(b.getF32Type()); + return ::mlir::success(); + } + }]; +} + // If all result types are buildable, the InferTypeOpInterface is implied and is // autogenerated by C++ ODS. def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {