diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -443,6 +443,9 @@ llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ", [](Operation *op) { llvm::dbgs() << *op; }); llvm::dbgs() << "\n"); + + Operation *owner = consumingHandle.getOwner(); + unsigned operandNo = consumingHandle.getOperandNumber(); for (Operation *ancestor : potentialAncestors) { // clang-format off DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, @@ -462,8 +465,6 @@ // deleted before the lambda gets called. Location ancestorLoc = ancestor->getLoc(); Location opLoc = payloadOp->getLoc(); - Operation *owner = consumingHandle.getOwner(); - unsigned operandNo = consumingHandle.getOperandNumber(); std::optional throughValueLoc = throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt; invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, @@ -551,6 +552,27 @@ void transform::TransformState::recordOpHandleInvalidation( OpOperand &handle, ArrayRef potentialAncestors, Value throughValue) { + + if (potentialAncestors.empty()) { + DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { + (DBGS() << "----recording invalidation for empty handle: " << handle.get() + << "\n"); + }); + + Operation *owner = handle.getOwner(); + unsigned operandNo = handle.getOperandNumber(); + invalidatedHandles[handle.get()] = [owner, operandNo](Location currentLoc) { + InFlightDiagnostic diag = emitError(currentLoc) + << "op uses a handle associated with empty " + "payload and invalidated by a " + "previously executed transform op"; + diag.attachNote(owner->getLoc()) + << "invalidated by this transform op that consumes its operand #" + << operandNo; + }; + return; + } + // Iterate over the mapping and invalidate aliasing handles. This is quite // expensive and only necessary for error reporting in case of transform // dialect misuse with dangling handles. Iteration over the handles is based diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -331,3 +331,14 @@ test_consume_operand %3 : !transform.any_value test_consume_operand %2 : !transform.any_op } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.test_produce_empty_payload : !transform.any_op + // expected-note @below {{invalidated by this transform op that consumes its operand #0}} + transform.test_consume_operand %0 : !transform.any_op + // expected-error @below {{uses a handle associated with empty payload and invalidated by a previously executed transform op}} + transform.test_print_remark_at_operand %0, "remark" : !transform.any_op +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -627,6 +627,12 @@ return DiagnosedSilenceableFailure::success(); } +DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + results.set(cast(getOut()), {}); + return DiagnosedSilenceableFailure::success(); +} + void mlir::test::TestProduceNullParamOp::getEffects( SmallVectorImpl &effects) { transform::producesHandle(getOut(), effects); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -427,6 +427,15 @@ let cppNamespace = "::mlir::test"; } +def TestProduceEmptyPayloadOp + : Op, + MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> { + let results = (outs TransformHandleTypeInterface:$out); + let assemblyFormat = "attr-dict `:` type($out)"; + let cppNamespace = "::mlir::test"; +} + def TestProduceNullParamOp : Op,