diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -12,6 +12,7 @@ #include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Interfaces.h" +#include "llvm/ADT/STLExtras.h" namespace py = pybind11; @@ -183,9 +184,9 @@ } /// Given the arguments required to build an operation, attempts to infer its - /// return types. Throws value_error on faliure. + /// return types. Throws value_error on failure. std::vector - inferReturnTypes(std::optional> operands, + inferReturnTypes(std::optional operandList, std::optional attributes, std::optional> regions, DefaultingPyMlirContext context, @@ -193,10 +194,45 @@ llvm::SmallVector mlirOperands; llvm::SmallVector mlirRegions; - if (operands) { - mlirOperands.reserve(operands->size()); - for (PyValue &value : *operands) { - mlirOperands.push_back(value); + if (operandList && !operandList->empty()) { + // Note: as the list may contain other lists this may not be final size. + mlirOperands.reserve(operandList->size()); + for (const auto& it : llvm::enumerate(*operandList)) { + PyValue* val; + try { + val = py::cast(it.value()); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + continue; + } catch (py::cast_error &err) { + } + + try { + auto vals = py::cast(it.value()); + for (py::object v : vals) { + try { + val = py::cast(v); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + + ")") + .str()); + } + } + continue; + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + ")") + .str()); + } + + throw py::cast_error(); } } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -24,8 +24,8 @@ ValueRange values) { // Check static and dynamic offsets/sizes/strides does not overflow type. if (staticVals.size() != numElements) - return op->emitError("expected ") - << numElements << " " << name << " values"; + return op->emitError("expected ") << numElements << " " << name + << " values, got " << staticVals.size(); unsigned expectedNumDynamicEntries = llvm::count_if(staticVals, [&](int64_t staticVal) { return ShapedType::isDynamic(staticVal); diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -667,7 +667,7 @@ class InferTypeOpInterface: def __init__(self, object: object, context: Optional[Context] = None) -> None: ... - def inferReturnTypes(self, operands: Optional[List[Value]] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... + def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... @property def operation(self) -> Operation: ... @property diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py --- a/mlir/test/python/dialects/tensor.py +++ b/mlir/test/python/dialects/tensor.py @@ -74,3 +74,30 @@ return tensor.EmptyOp([], f32) print(module) + + +# CHECK-LABEL: TEST: testInferTypesInsertSlice +@run +def testInferTypesInsertSlice(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32Type = F32Type.get() + indexType = IndexType.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 1), f32Type), + RankedTensorType.get((1, 1), f32Type)) + # CHECK: func @f + # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] : + # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32> + def f(source, dest): + c0 = arith.ConstantOp(indexType, 0) + c1 = arith.ConstantOp(indexType, 1) + d0 = tensor.InsertSliceOp(source, dest, [], [], [], + DenseI64ArrayAttr.get([0, 0]), + DenseI64ArrayAttr.get([1, 1]), + DenseI64ArrayAttr.get([0, 0])) + return [d0.result] + + print(module)