diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -459,6 +459,10 @@ /// is associated with the list of payload IR operations. bool isParam(unsigned resultNumber) const; + /// Returns `true` if the result identified by its number in the list of + /// operation results is associated with something. + bool isSet(unsigned resultNumber) const; + /// Storage for pointers to payload IR ops that are associated with results of /// a transform IR op. `segments` contains as many entries as the transform IR /// op has results, even if some of them are not associated with payload IR @@ -849,16 +853,19 @@ Location specificOpLoc = specificOp->getLoc(); DiagnosedSilenceableFailure res = transformOp.applyToOne(specificOp, partialResults, state); - if (res.isDefiniteFailure() || - failed(detail::checkApplyToOne(transformOp, specificOpLoc, - partialResults))) { + if (res.isDefiniteFailure()) return DiagnosedSilenceableFailure::definiteFailure(); - } - if (res.isSilenceableFailure()) + if (res.isSilenceableFailure()) { res.takeDiagnostics(silenceableStack); - else - results.push_back(std::move(partialResults)); + continue; + } + + if (failed(detail::checkApplyToOne(transformOp, specificOpLoc, + partialResults))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + results.push_back(std::move(partialResults)); } if (!silenceableStack.empty()) { return DiagnosedSilenceableFailure::silenceableFailure( diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -298,7 +298,6 @@ auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { - results.push_back(target); DiagnosedSilenceableFailure diag = emitSilenceableError() << "Given target is not gpu.launch, set `generate_gpu_launch` " @@ -312,7 +311,6 @@ mlir::transform::gpu::findTopLevelForeachThreadOp( target, topLevelForeachThreadOp, transformOp); if (!diag.succeeded()) { - results.push_back(target); diag.attachNote(target->getLoc()) << "when applied to this payload op"; return diag; } @@ -325,7 +323,6 @@ DiagnosedSilenceableFailure diag = createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch); if (!diag.succeeded()) { - results.push_back(target); return diag; } rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); @@ -363,7 +360,7 @@ /// Searches `scf.foreach_thread` ops nested under `target` and maps each such /// op to GPU threads. Mapping is one-to-one and the induction variables of /// `scf.foreach_thread` are rewritten to gpu.thread_id according to the -/// thread_dim_apping attribute. Sibling `scf.foreach_thread` are supported in +/// thread_dim_mapping attribute. Sibling `scf.foreach_thread` are supported in /// which case, the union of the number of threads is computed and may result /// in predication. Dynamic, `scf.foreach_thread` trip counts are currently /// not supported. Dynamic block dim sizes are currently not supported. @@ -531,7 +528,6 @@ auto transformOp = cast(getOperation()); if (!gpuLaunch) { - results.push_back(target); return emitSilenceableError() << "Given target is not gpu.launch"; } @@ -542,7 +538,6 @@ checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, blockDim[0], blockDim[1], blockDim[2]); if (diag.isSilenceableFailure()) { - results.push_back(target); diag.attachNote(getLoc()) << getBlockDimAttrName() << " is very large"; return diag; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -158,7 +158,6 @@ #undef DOWNSCALE_NORMAL #undef DOWNSCALE_CALL #undef DOWNSCALE - results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// @@ -583,7 +582,6 @@ while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { - results.set(getFusedOp().cast(), ArrayRef()); Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark); diag << "could not find next producer to fuse into container"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); @@ -628,7 +626,6 @@ fusedOps.push_back(cloned); continue; } - results.set(getFusedOp().cast(), ArrayRef()); return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } @@ -654,7 +651,6 @@ results.push_back(generic->getOperation()); return DiagnosedSilenceableFailure::success(); } - results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } @@ -717,7 +713,6 @@ ArrayRef payloadOps = state.getPayloadOps(getTarget()); if (payloadOps.size() != 1) { - results.set(getResult().cast(), {}); return emitDefiniteFailure("requires exactly one target handle"); } @@ -802,8 +797,6 @@ TransformState &state) { if (getLowSize().getType().isa()) { if (target.hasDynamicShape()) { - results.assign( - ArrayRef({Attribute(), Attribute(), Attribute()})); auto diag = emitSilenceableError() << "cannot compute parametric tile sizes for dynamically " "shaped payload op"; @@ -814,8 +807,6 @@ FailureOr spec = computeStaticMultiTileSizes( target, getDimension(), getTargetSize(), getDivisor()); if (failed(spec)) { - results.assign( - ArrayRef({Attribute(), Attribute(), Attribute()})); return emitSilenceableError() << "failed to compute multi-size tiling sizes"; } @@ -1271,7 +1262,6 @@ return DiagnosedSilenceableFailure::success(); } - results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } @@ -1491,11 +1481,8 @@ llvm::map_range(state.getParams(getDynamicSplitPoint()), [](Attribute attr) { return OpFoldResult(attr); })); } - if (diag.isSilenceableFailure()) { - results.set(getFirst().cast(), {}); - results.set(getSecond().cast(), {}); + if (diag.isSilenceableFailure()) return diag; - } if (splitPoints.size() != payload.size()) { return emitDefiniteFailure() @@ -1518,8 +1505,6 @@ if (!linalgOp) { auto diag = emitSilenceableError() << "only applies to structured ops"; diag.attachNote(target->getLoc()) << "target op"; - results.set(getFirst().cast(), {}); - results.set(getSecond().cast(), {}); return diag; } @@ -1527,8 +1512,6 @@ auto diag = emitSilenceableError() << "dimension " << getDimension() << " does not exist in target op"; diag.attachNote(target->getLoc()) << "target op"; - results.set(getFirst().cast(), {}); - results.set(getSecond().cast(), {}); return diag; } @@ -1552,8 +1535,6 @@ } if (second.size() != first.size() && !second.empty()) { - results.set(getFirst().cast(), {}); - results.set(getSecond().cast(), {}); auto diag = emitSilenceableError() << "splitting does not produce the second part for a subset of targets"; @@ -1781,7 +1762,6 @@ numThreads, tileSizes, getMapping()); if (failed(result)) { - results.assign(4, nullptr); auto diag = emitSilenceableError() << "could not tile reduction"; diag.attachNote(target.getLoc()) << "target operation"; return diag; @@ -1874,8 +1854,6 @@ }))); if (paramSizes.back().size() != targets.size()) { - for (OpResult r : getResults()) - transformResults.set(r, {}); DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many parameter values (" @@ -1891,8 +1869,6 @@ dynamicSizeProducers.push_back(state.getPayloadOps(transformValue)); if (dynamicSizeProducers.back().size() != targets.size()) { - for (OpResult r : getResults()) - transformResults.set(r, {}); DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many dynamic size-producing operations (" @@ -1907,8 +1883,6 @@ op->getResult(0).getType().isa()) continue; - for (OpResult r : getResults()) - transformResults.set(r, {}); DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " "with a single index-type result"; @@ -2237,11 +2211,8 @@ rewriter, state, transformOp, targets, mixedNumThreads, mixedTileSizes, getMapping(), tileOps, tiledOps); - if (!diag.succeeded()) { - transformResults.set(getForeachThreadOp().cast(), {}); - transformResults.set(getTiledOp().cast(), {}); + if (!diag.succeeded()) return diag; - } transformResults.set(getForeachThreadOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -42,7 +42,6 @@ : scf::ForOp::getOperationName()) << "' parent"; diag.attachNote(target->getLoc()) << "target op"; - results.set(getResult().cast(), {}); return diag; } current = loop; @@ -96,7 +95,6 @@ DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to outline"; diag.attachNote(target->getLoc()) << "target op"; - results.set(getTransformed().cast(), {}); return diag; } func::CallOp call; @@ -200,7 +198,6 @@ results.push_back(*patternResult); return DiagnosedSilenceableFailure::success(); } - results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -294,6 +294,20 @@ if (result.isDefiniteFailure()) return result; + // If a silenceable failure was produced, some results may be unset, set them + // to empty lists. + if (result.isSilenceableFailure()) { + for (OpResult opResult : transform->getResults()) { + if (results.isSet(opResult.getResultNumber())) + continue; + + if (opResult.getType().isa()) + results.setParams(opResult, {}); + else + results.set(opResult, {}); + } + } + // Remove the mapping for the operand if it is consumed by the operation. This // allows us to catch use-after-free with assertions later on. auto memEffectInterface = @@ -421,6 +435,13 @@ return paramSegments[resultNumber].data() != nullptr; } +bool transform::TransformResults::isSet(unsigned resultNumber) const { + assert(resultNumber < paramSegments.size() && + "querying association for a non-existent handle"); + return paramSegments[resultNumber].data() != nullptr || + segments[resultNumber].data() != nullptr; +} + //===----------------------------------------------------------------------===// // Utilities for TransformEachOpTrait. //===----------------------------------------------------------------------===// @@ -432,57 +453,43 @@ Location transformOpLoc = transformOp->getLoc(); StringRef transformOpName = transformOp->getName().getStringRef(); unsigned expectedNumResults = transformOp->getNumResults(); - // TODO: encode this implicit must always produce `expectedNumResults` - // and nullptr is fine with a proper trait. + + // Reuse the emission of the diagnostic note. + auto emitDiag = [&]() { + auto diag = mlir::emitError(transformOpLoc); + diag.attachNote(payloadOpLoc) << "when applied to this op"; + return diag; + }; + if (partialResult.size() != expectedNumResults) { - auto diag = mlir::emitError(transformOpLoc, "applications of ") - << transformOpName << " expected to produce " - << expectedNumResults << " results (actually produced " - << partialResult.size() << ")."; + auto diag = emitDiag() << "application of " << transformOpName + << " expected to produce " << expectedNumResults + << " results (actually produced " + << partialResult.size() << ")."; diag.attachNote(transformOpLoc) - << "If you need variadic results, consider a generic `apply` " + << "if you need variadic results, consider a generic `apply` " << "instead of the specialized `applyToOne`."; - diag.attachNote(transformOpLoc) - << "Producing " << expectedNumResults << " null results is " - << "allowed if the use case warrants it."; - diag.attachNote(payloadOpLoc) << "when applied to this op"; - return failure(); - } - - // Check that all is null or none is null - // TODO: relax this behavior and encode with a proper trait. - if (llvm::any_of( - partialResult, - [](llvm::PointerUnion ptr) { return ptr; }) && - llvm::any_of(partialResult, - [](llvm::PointerUnion ptr) { - return !ptr; - })) { - auto diag = mlir::emitError(transformOpLoc, "unexpected application of ") - << transformOpName - << " produces both null and non null results."; - diag.attachNote(payloadOpLoc) << "when applied to this op"; return failure(); } // Check that the right kind of value was produced. for (const auto &[ptr, res] : llvm::zip(partialResult, transformOp->getResults())) { + if (ptr.isNull()) { + return emitDiag() << "null result #" << res.getResultNumber() + << " produced"; + } if (ptr.is() && !res.getType().template isa()) { - mlir::emitError(transformOpLoc) - << "applications of " << transformOpName - << " expected to produce an Attribute for result #" - << res.getResultNumber(); - return failure(); + return emitDiag() << "application of " << transformOpName + << " expected to produce an Attribute for result #" + << res.getResultNumber(); } if (ptr.is() && !res.getType().template isa()) { - mlir::emitError(transformOpLoc) - << "applications of " << transformOpName - << " expected to produce an Operation * for result #" - << res.getResultNumber(); - return failure(); + return emitDiag() << "application of " << transformOpName + << " expected to produce an Operation * for result #" + << res.getResultNumber(); } } return success(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -444,8 +444,6 @@ << "could not find a producer for operand number: " << operandNumber << " of " << *target; diag.attachNote(target->getLoc()) << "target op"; - results.set(getResult().cast(), - SmallVector{}); return diag; } producers.push_back(producer); @@ -518,10 +516,6 @@ getHandle() ? state.getPayloadOps(getHandle()).size() : 0; int64_t expectedNumResultHandles = getNumResultHandles(); if (numResultHandles != expectedNumResultHandles) { - // Failing case needs to propagate gracefully for both suppress and - // propagate modes. - for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx) - results.set(getResults()[idx].cast(), {}); // Empty input handle corner case: always propagates empty handles in both // suppress and propagate modes. if (numResultHandles == 0) diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -409,9 +409,8 @@ transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation - // expected-error @below {{applications of transform.test_wrong_number_of_results expected to produce 3 results (actually produced 1).}} - // expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}} - // expected-note @below {{Producing 3 null results is allowed if the use case warrants it.}} + // expected-error @below {{application of transform.test_wrong_number_of_results expected to produce 3 results (actually produced 1).}} + // expected-note @below {{if you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}} transform.test_wrong_number_of_results %0 } } @@ -437,9 +436,8 @@ transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation - // expected-error @below {{applications of transform.test_wrong_number_of_multi_results expected to produce 1 results (actually produced 0)}} - // expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}} - // expected-note @below {{Producing 1 null results is allowed if the use case warrants it.}} + // expected-error @below {{application of transform.test_wrong_number_of_multi_results expected to produce 1 results (actually produced 0)}} + // expected-note @below {{if you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}} transform.test_wrong_number_of_multi_results %0 } } @@ -514,7 +512,7 @@ transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation - // expected-error @below {{unexpected application of transform.test_mixed_null_and_non_null_results produces both null and non null results.}} + // expected-error @below {{null result #0 produced}} transform.test_mixed_null_and_non_null_results %0 } } @@ -1041,12 +1039,15 @@ // ----- -transform.sequence failures(propagate) { -^bb0(%arg0: !transform.any_op): - // expected-error @below {{expected to produce an Operation * for result #0}} - transform.test_produce_transform_param_or_forward_operand %arg0 - { first_result_is_param } - : (!transform.any_op) -> (!transform.any_op, !transform.param) +// expected-note @below {{when applied to this op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected to produce an Operation * for result #0}} + transform.test_produce_transform_param_or_forward_operand %arg0 + { first_result_is_param } + : (!transform.any_op) -> (!transform.any_op, !transform.param) + } } // ----- @@ -1055,7 +1056,7 @@ module { transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - // expected-error @below {{produces both null and non null results}} + // expected-error @below {{null result #0 produced}} transform.test_produce_transform_param_or_forward_operand %arg0 { first_result_is_null } : (!transform.any_op) -> (!transform.any_op, !transform.param) @@ -1064,12 +1065,15 @@ // ----- -transform.sequence failures(propagate) { -^bb0(%arg0: !transform.any_op): - // expected-error @below {{expected to produce an Attribute for result #1}} - transform.test_produce_transform_param_or_forward_operand %arg0 - { second_result_is_handle } - : (!transform.any_op) -> (!transform.any_op, !transform.param) +// expected-note @below {{when applied to this op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected to produce an Attribute for result #1}} + transform.test_produce_transform_param_or_forward_operand %arg0 + { second_result_is_handle } + : (!transform.any_op) -> (!transform.any_op, !transform.param) + } } // -----