diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -1015,6 +1015,8 @@ for (OpResult r : this->getOperation()->getResults()) { if (r.getType().isa()) transformResults.setParams(r, emptyParams); + else if (r.getType().isa()) + transformResults.setValues(r, ValueRange()); else transformResults.set(r, emptyPayload); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -921,48 +921,60 @@ // Check that the right kind of value was produced. for (const auto &[ptr, res] : llvm::zip(partialResult, transformOp->getResults())) { - if (ptr.isNull()) { - return emitDiag() << "null result #" << res.getResultNumber() - << " produced"; + if (ptr.isNull()) + continue; + if (res.getType().template isa() && + !ptr.is()) { + return emitDiag() << "application of " << transformOpName + << " expected to produce an Operation * for result #" + << res.getResultNumber(); } - if (ptr.is() && - !res.getType().template isa()) { + if (res.getType().template isa() && + !ptr.is()) { return emitDiag() << "application of " << transformOpName << " expected to produce an Attribute for result #" << res.getResultNumber(); } - if (ptr.is() && - !res.getType().template isa()) { + if (res.getType().template isa() && + !ptr.is()) { return emitDiag() << "application of " << transformOpName - << " expected to produce an Operation * for result #" + << " expected to produce a Value for result #" << res.getResultNumber(); } } return success(); } +template +static SmallVector castVector(ArrayRef range) { + return llvm::to_vector(llvm::map_range( + range, [](transform::MappedValue value) { return value.get(); })); +} + void transform::detail::setApplyToOneResults( Operation *transformOp, TransformResults &transformResults, ArrayRef results) { + SmallVector> transposed; + transposed.resize(transformOp->getNumResults()); + for (const ApplyToEachResultList &partialResults : results) { + if (llvm::any_of(partialResults, + [](MappedValue value) { return value.isNull(); })) + continue; + assert(transformOp->getNumResults() == partialResults.size() && + "expected as many partial results as op as results"); + for (auto &[i, value] : llvm::enumerate(partialResults)) + transposed[i].push_back(value); + } + for (OpResult r : transformOp->getResults()) { + unsigned position = r.getResultNumber(); if (r.getType().isa()) { - auto params = llvm::to_vector( - llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { - return oneResult[r.getResultNumber()].get(); - })); - transformResults.setParams(r, params); + transformResults.setParams(r, + castVector(transposed[position])); } else if (r.getType().isa()) { - auto values = llvm::to_vector( - llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { - return oneResult[r.getResultNumber()].get(); - })); - transformResults.setValues(r, values); + transformResults.setValues(r, castVector(transposed[position])); } else { - auto payloads = llvm::to_vector( - llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { - return oneResult[r.getResultNumber()].get(); - })); - transformResults.set(r, payloads); + transformResults.set(r, castVector(transposed[position])); } } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -495,8 +495,9 @@ // ----- +// This should not fail. + func.func @foo() { - // expected-note @below {{when applied to this op}} "op" () : () -> () return } @@ -513,7 +514,6 @@ transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation - // expected-error @below {{null result #0 produced}} transform.test_mixed_null_and_non_null_results %0 } } @@ -1053,11 +1053,11 @@ // ----- -// expected-note @below {{when applied to this op}} +// Should not fail. + module { transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - // expected-error @below {{null result #0 produced}} transform.test_produce_transform_param_or_forward_operand %arg0 { first_result_is_null } : (!transform.any_op) -> (!transform.any_op, !transform.param) @@ -1079,6 +1079,19 @@ // ----- +// expected-note @below {{when applied to this op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected to produce a Value for result #0}} + transform.test_produce_transform_param_or_forward_operand %arg0 + { second_result_is_handle } + : (!transform.any_op) -> (!transform.any_value, !transform.param) + } +} + +// ----- + transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): // expected-error @below {{attempting to assign a null payload op to this transform value}} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -371,7 +371,7 @@ UnitAttr:$first_result_is_param, UnitAttr:$first_result_is_null, UnitAttr:$second_result_is_handle); - let results = (outs TransformHandleTypeInterface:$out, + let results = (outs AnyType:$out, TransformParamTypeInterface:$param); let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; let cppNamespace = "::mlir::test";