diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -832,36 +832,30 @@ SmallVector silenceableStack; unsigned expectedNumResults = transformOp->getNumResults(); for (Operation *target : targets) { - // Emplace back a placeholder for the returned new ops and params. - // This is filled with `expectedNumResults` if the op fails to apply. - ApplyToEachResultList placeholder; - placeholder.reserve(expectedNumResults); - results.push_back(std::move(placeholder)); - auto specificOp = dyn_cast(target); if (!specificOp) { Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error); diag << "transform applied to the wrong op kind"; diag.attachNote(target->getLoc()) << "when applied to this op"; - // Producing `expectedNumResults` nullptr is a silenceableFailure mode. - // TODO: encode this implicit `expectedNumResults` nullptr == - // silenceableFailure with a proper trait. - results.back().assign(expectedNumResults, nullptr); silenceableStack.push_back(std::move(diag)); continue; } + ApplyToEachResultList partialResults; + partialResults.reserve(expectedNumResults); Location specificOpLoc = specificOp->getLoc(); DiagnosedSilenceableFailure res = - transformOp.applyToOne(specificOp, results.back(), state); + transformOp.applyToOne(specificOp, partialResults, state); if (res.isDefiniteFailure() || failed(detail::checkApplyToOne(transformOp, specificOpLoc, - results.back()))) { + partialResults))) { return DiagnosedSilenceableFailure::definiteFailure(); } if (res.isSilenceableFailure()) res.takeDiagnostics(silenceableStack); + else + results.push_back(std::move(partialResults)); } if (!silenceableStack.empty()) { return DiagnosedSilenceableFailure::silenceableFailure( diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -80,6 +80,13 @@ assert(!value.getType().isa() && "cannot associate payload ops with a value of parameter type"); + for (Operation *target : targets) { + if (target) + continue; + return emitError(value.getLoc()) + << "attempting to assign a null payload op to this transform value"; + } + auto iface = value.getType().cast(); DiagnosedSilenceableFailure result = iface.checkPayload(value.getLoc(), targets); @@ -105,6 +112,13 @@ ArrayRef params) { assert(value != nullptr && "attempting to set params for a null value"); + for (Attribute attr : params) { + if (attr) + continue; + return emitError(value.getLoc()) + << "attempting to assign a null parameter to this transform value"; + } + auto valueType = value.getType().dyn_cast(); assert(value && "cannot associate parameter with a value of non-parameter type"); diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1024,3 +1024,19 @@ { second_result_is_handle } : (!transform.any_op) -> (!transform.any_op, !transform.param) } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{attempting to assign a null payload op to this transform value}} + %0 = transform.test_produce_null_payload : !transform.any_op +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{attempting to assign a null parameter to this transform value}} + %0 = transform.test_produce_null_param : !transform.param +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -458,6 +458,28 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestProduceNullPayloadOp::getEffects( + SmallVectorImpl &effects) { + transform::producesHandle(getOut(), effects); +} + +DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + SmallVector null({nullptr}); + results.set(getOut().cast(), null); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceNullParamOp::getEffects( + SmallVectorImpl &effects) {} + +DiagnosedSilenceableFailure +mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + results.setParams(getOut().cast(), Attribute()); + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -334,4 +334,22 @@ }]; } +def TestProduceNullPayloadOp + : Op, + DeclareOpInterfaceMethods]> { + let results = (outs TransformHandleTypeInterface:$out); + let assemblyFormat = "attr-dict `:` type($out)"; + let cppNamespace = "::mlir::test"; +} + +def TestProduceNullParamOp + : Op, + DeclareOpInterfaceMethods]> { + let results = (outs TransformParamTypeInterface:$out); + let assemblyFormat = "attr-dict `:` type($out)"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD