diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -153,6 +153,10 @@ /// values in the payload IR. Also works for reverse mappings. using ValueMapping = DenseMap>; + /// Mapping between a Value in the transform IR and an error message that + /// should be emitted when the value is used. + using InvalidatedHandleMap = DenseMap>; + /// The bidirectional mappings between transform IR values and payload IR /// operations, and the mapping between transform IR values and parameters. struct Mappings { @@ -567,26 +571,85 @@ /// handle. LogicalResult replacePayloadValue(Value value, Value replacement); - /// If the operand is a handle consumed by the operation, i.e. has the "free" - /// memory effect associated with it, identifies other handles that are - /// pointing to payload IR operations nested in the operations pointed to by - /// the consumed handle. Marks all such handles as invalidated to trigger - /// errors if they are used. If `throughValue` is passed, record the fact that - /// an op handle was invalidated because a value handle associated with - /// results of the payload op or its block arguments was invalidated. + /// Records handle invalidation reporters into `newlyInvalidated`. + /// Specifically, + /// - `handle` is the op operand that consumes the handle, + /// - `potentialAncestors` is a list of ancestors of the payload operation + /// that the consumed handle is associated with, including itself, + /// - `throughValue` is the payload value the handle to which is consumed, + /// when it is the case, null when the operation handle is consumed + /// directly. + /// Iterates over all known operation and value handles and records reporters + /// for any potential future use of `handle` or any other handle that is + /// invalidated by its consumption, i.e., any handle pointing to any payload + /// IR entity (operation or value) associated with the same payload IR entity + /// as the consumed handle, or any nested payload IR entity. If + /// `potentialAncestors` is empty, records the reporter anyway. Does not + /// override existing reporters. This must remain a const method so it doesn't + /// inadvertently mutate `invalidatedHandles` too early. void recordOpHandleInvalidation(OpOperand &consumingHandle, ArrayRef potentialAncestors, - Value throughValue = nullptr); - void recordOpHandleInvalidationOne(OpOperand &handle, - ArrayRef potentialAncestors, - Operation *payloadOp, Value otherHandle, - Value throughValue = nullptr); - + Value throughValue, + InvalidatedHandleMap &newlyInvalidated) const; + + /// Records handle invalidation reporters into `newlyInvalidated`. + /// Specifically, + /// - `consumingHandle` is the op operand that consumes the handle, + /// - `potentialAncestors` is a list of ancestors of the payload operation + /// that the consumed handle is associated with, including itself, + /// - `payloadOp` is the operation itself, + /// - `otherHandle` is another that may be associated with the affected + /// payload operations + /// - `throughValue` is the payload value the handle to which is consumed, + /// when it is the case, null when the operation handle is consumed + /// directly. + /// Looks at the payload opreations associated with `otherHandle` and if any + /// of these operations has an ancestor (or is itself) listed in + /// `potentialAncestors`, records the error message describing the use of the + /// invalidated handle. Does nothing if `otherHandle` already has a reporter + /// associated with it. This must remain a const method so it doesn't + /// inadvertently mutate `invalidatedHandles` too early. + void recordOpHandleInvalidationOne( + OpOperand &consumingHandle, ArrayRef potentialAncestors, + Operation *payloadOp, Value otherHandle, Value throughValue, + InvalidatedHandleMap &newlyInvalidated) const; + + /// Records handle invalidation reporters into `newlyInvalidated`. + /// Specifically, + /// - `opHandle` is the op operand that consumes the handle; + /// - `potentialAncestors` is a list of ancestors of the payload operation + /// that the consumed handle is associated with, including itself; + /// - `payloadValue` is the value defined by the operation associated with + /// the consuming handle as either op result or block argument; + /// - `valueHandle` is another that may be associated with the payload value. + /// Looks at the payload values associated with `valueHandle` and if any of + /// these values is defined, as op result or block argument, by an operation + /// whose ancestor (or the operation itself) is listed in + /// `potentialAncestors`, records the error message describing the use of the + /// invalidated handle. Does nothing if `valueHandle` already has a reporter + /// associated with it. This must remain a const method so it doesn't + /// inadvertently mutate `invalidatedHandles` too early. void recordValueHandleInvalidationByOpHandleOne( OpOperand &opHandle, ArrayRef potentialAncestors, - Value payloadValue, Value valueHandle); - - void recordValueHandleInvalidation(OpOperand &valueHandle); + Value payloadValue, Value valueHandle, + InvalidatedHandleMap &newlyInvalidated) const; + + /// Records handle invalidation reporters into `newlyInvalidated`. + /// Specifically, + /// - `valueHandle` is the op operand that consumes the handle, + /// - `throughValue` is the payload value the handle to which is consumed, + /// when it is the case, null when the operation handle is consumed + /// directly. + /// Iterates over all known operation and value handles and records reporters + /// for any potential future use of `handle` or any other handle that is + /// invalidated by its consumption, i.e., any handle pointing to any payload + /// IR entity (operation or value) associated with the same payload IR entity + /// as the consumed handle, or any nested payload IR entity. Does not override + /// existing reporters. This must remain a const method so it doesn't + /// inadvertently mutate `invalidatedHandles` too early. + void + recordValueHandleInvalidation(OpOperand &valueHandle, + InvalidatedHandleMap &newlyInvalidated) const; /// Checks that the operation does not use invalidated handles as operands. /// Reports errors and returns failure if it does. Otherwise, invalidates the @@ -596,6 +659,13 @@ LogicalResult checkAndRecordHandleInvalidation(TransformOpInterface transform); + /// Implementation of the checkAndRecordHandleInvalidation. This must remain a + /// const method so it doesn't inadvertently mutate `invalidatedHandles` too + /// early. + LogicalResult checkAndRecordHandleInvalidationImpl( + transform::TransformOpInterface transform, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const; + /// Remove all nullptrs from op handles that were added by `replacePayloadOp`. void compactOpHandles(); @@ -628,7 +698,7 @@ /// describe when the handles were invalidated. Calling such a function emits /// a user-visible diagnostic with an additional note pointing to the given /// location. - DenseMap> invalidatedHandles; + InvalidatedHandleMap invalidatedHandles; #if LLVM_ENABLE_ABI_BREAKING_CHECKS /// A stack of nested regions that are being processed in the transform IR. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -431,10 +431,13 @@ void transform::TransformState::recordOpHandleInvalidationOne( OpOperand &consumingHandle, ArrayRef potentialAncestors, - Operation *payloadOp, Value otherHandle, Value throughValue) { + Operation *payloadOp, Value otherHandle, Value throughValue, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { // If the op is associated with invalidated handle, skip the check as it - // may be reading invalid IR. - if (invalidatedHandles.count(otherHandle)) + // may be reading invalid IR. This also ensures we report the first + // invalidation and not the last one. + if (invalidatedHandles.count(otherHandle) || + newlyInvalidated.count(otherHandle)) return; FULL_LDBG("--recordOpHandleInvalidationOne\n"); @@ -467,9 +470,9 @@ Location opLoc = payloadOp->getLoc(); std::optional throughValueLoc = throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt; - invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, - otherHandle, - throughValueLoc](Location currentLoc) { + newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, + otherHandle, + throughValueLoc](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle invalidated by a " "previously executed transform op"; @@ -490,11 +493,14 @@ } void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( - OpOperand &consumingHandle, ArrayRef potentialAncestors, - Value payloadValue, Value valueHandle) { + OpOperand &opHandle, ArrayRef potentialAncestors, + Value payloadValue, Value valueHandle, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { // If the op is associated with invalidated handle, skip the check as it - // may be reading invalid IR. - if (invalidatedHandles.count(valueHandle)) + // may be reading invalid IR. This also ensures we report the first + // invalidation and not the last one. + if (invalidatedHandles.count(valueHandle) || + newlyInvalidated.count(valueHandle)) return; for (Operation *ancestor : potentialAncestors) { @@ -517,12 +523,12 @@ if (!ancestor->isAncestor(definingOp)) continue; - Operation *owner = consumingHandle.getOwner(); - unsigned operandNo = consumingHandle.getOperandNumber(); + Operation *owner = opHandle.getOwner(); + unsigned operandNo = opHandle.getOperandNumber(); Location ancestorLoc = ancestor->getLoc(); Location opLoc = definingOp->getLoc(); Location valueLoc = payloadValue.getLoc(); - invalidatedHandles[valueHandle] = + newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo, ancestorLoc, opLoc, valueLoc](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) @@ -551,7 +557,8 @@ void transform::TransformState::recordOpHandleInvalidation( OpOperand &handle, ArrayRef potentialAncestors, - Value throughValue) { + Value throughValue, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { if (potentialAncestors.empty()) { DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { @@ -561,7 +568,7 @@ Operation *owner = handle.getOwner(); unsigned operandNo = handle.getOperandNumber(); - invalidatedHandles[handle.get()] = [owner, operandNo](Location currentLoc) { + newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle associated with empty " "payload and invalidated by a " @@ -580,14 +587,16 @@ // number of IR objects (operations and values). Alternatively, we could walk // the IR nested in each payload op associated with the given handle and look // for handles associated with each operation and value. - for (const Mappings &mapping : llvm::make_second_range(mappings)) { + for (const transform::TransformState::Mappings &mapping : + llvm::make_second_range(mappings)) { // Go over all op handle mappings and mark as invalidated any handle // pointing to any of the payload ops associated with the given handle or // any op nested in them. for (const auto &[payloadOp, otherHandles] : mapping.reverse) { for (Value otherHandle : otherHandles) recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp, - otherHandle, throughValue); + otherHandle, throughValue, + newlyInvalidated); } // Go over all value handle mappings and mark as invalidated any handle // pointing to any result of the payload op associated with the given handle @@ -597,13 +606,15 @@ for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) { for (Value valueHandle : valueHandles) recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors, - payloadValue, valueHandle); + payloadValue, valueHandle, + newlyInvalidated); } } } void transform::TransformState::recordValueHandleInvalidation( - OpOperand &valueHandle) { + OpOperand &valueHandle, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { // Invalidate other handles to the same value. for (Value payloadValue : getPayloadValues(valueHandle.get())) { SmallVector otherValueHandles; @@ -612,8 +623,8 @@ Operation *owner = valueHandle.getOwner(); unsigned operandNo = valueHandle.getOperandNumber(); Location valueLoc = payloadValue.getLoc(); - invalidatedHandles[otherHandle] = [otherHandle, owner, operandNo, - valueLoc](Location currentLoc) { + newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo, + valueLoc](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle invalidated by a " "previously executed transform op"; @@ -629,17 +640,24 @@ if (auto opResult = llvm::dyn_cast(payloadValue)) { Operation *payloadOp = opResult.getOwner(); - recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue); + recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue, + newlyInvalidated); } else { auto arg = llvm::dyn_cast(payloadValue); for (Operation &payloadOp : *arg.getOwner()) - recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue); + recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue, + newlyInvalidated); } } } -LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( - TransformOpInterface transform) { +/// Checks that the operation does not use invalidated handles as operands. +/// Reports errors and returns failure if it does. Otherwise, invalidates the +/// handles consumed by the operation as well as any handles pointing to payload +/// IR operations nested in the operations associated with the consumed handles. +LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( + transform::TransformOpInterface transform, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { FULL_LDBG("--Start checkAndRecordHandleInvalidation\n"); auto memoryEffectsIface = cast(transform.getOperation()); @@ -651,13 +669,23 @@ DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { (DBGS() << "----iterate on handle: " << target.get() << "\n"); }); - // If the operand uses an invalidated handle, report it. + // If the operand uses an invalidated handle, report it. If the operation + // allows handles to point to repeated payload operations, only report + // pre-existing invalidation errors. Otherwise, also report invalidations + // caused by the current transform operation affecting its other operands. auto it = invalidatedHandles.find(target.get()); - if (!transform.allowsRepeatedHandleOperands() && - it != invalidatedHandles.end()) { - FULL_LDBG("--End checkAndRecordHandleInvalidation -> FAILURE\n"); + auto nit = newlyInvalidated.find(target.get()); + if (it != invalidatedHandles.end()) { + FULL_LDBG("--End checkAndRecordHandleInvalidation, found already " + "invalidated -> FAILURE\n"); return it->getSecond()(transform->getLoc()), failure(); } + if (!transform.allowsRepeatedHandleOperands() && + nit != newlyInvalidated.end()) { + FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly " + "invalidated (by this op) -> FAILURE\n"); + return nit->getSecond()(transform->getLoc()), failure(); + } // Invalidate handles pointing to the operations nested in the operation // associated with the handle consumed by this operation. @@ -666,15 +694,18 @@ effect.getValue() == target.get(); }; if (llvm::any_of(effects, consumesTarget)) { - FULL_LDBG("----found consume effect -> SKIP\n"); - if (llvm::isa(target.get().getType())) { + FULL_LDBG("----found consume effect\n"); + if (llvm::isa( + target.get().getType())) { FULL_LDBG("----recordOpHandleInvalidation\n"); - ArrayRef payloadOps = getPayloadOpsView(target.get()); - recordOpHandleInvalidation(target, payloadOps); - } else if (llvm::isa( + SmallVector payloadOps = + llvm::to_vector(getPayloadOps(target.get())); + recordOpHandleInvalidation(target, payloadOps, nullptr, + newlyInvalidated); + } else if (llvm::isa( target.get().getType())) { FULL_LDBG("----recordValueHandleInvalidation\n"); - recordValueHandleInvalidation(target); + recordValueHandleInvalidation(target, newlyInvalidated); } else { FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); } @@ -687,6 +718,16 @@ return success(); } +LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( + transform::TransformOpInterface transform) { + InvalidatedHandleMap newlyInvalidated; + LogicalResult checkResult = + checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated); + invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()), + std::make_move_iterator(newlyInvalidated.end())); + return checkResult; +} + template DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef payload, diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -342,3 +342,25 @@ // expected-error @below {{uses a handle associated with empty payload and invalidated by a previously executed transform op}} transform.test_print_remark_at_operand %0, "remark" : !transform.any_op } + +// ----- + +// Make sure we properly report a use-after-consume error when repeated handles +// are allowed in the consuming op. We still want to report handles consumed by +// _previous_ operations, just not by this one. To bypass the quick static check +// of repeated consumption, create a handle to the transform operation and +// invalidate the handle to the root module thus invalidating all other handles. + +// expected-note @below {{ancestor payload op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + // expected-note @below {{handle to invalidated ops}} + // expected-note @below {{nested payload op}} + %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + transform.test_consume_operand %arg0 : !transform.any_op + // expected-error @below {{uses a handle invalidated by a previously executed transform op}} + transform.test_consume_operand %0 { allow_repeated_handles } : !transform.any_op + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -178,6 +178,10 @@ transform::onlyReadsPayload(effects); } +bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() { + return getAllowRepeatedHandles(); +} + DiagnosedSilenceableFailure mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, transform::TransformState &state) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -97,11 +97,12 @@ } def TestConsumeOperand : Op, + [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let arguments = (ins Transform_AnyHandleOrParamType:$operand, - Optional:$second_operand); + Optional:$second_operand, + UnitAttr:$allow_repeated_handles); let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict `:` type($operand)" "(`,` type($second_operand)^)?";