diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -66,11 +66,11 @@ A Transform IR value such as `%0` may be associated with multiple payload operations. This is conceptually a set of operations and no assumptions - should be made about the order of ops. Most Transform IR ops support - operand values that are mapped to multiple operations. They usually apply - the respective transformation for every mapped op ("batched execution"). - Deviations from this convention are described in the documentation of - Transform IR ops. + should be made about the order of ops unless specified otherwise by the + operation. Most Transform IR ops support operand values that are mapped to + multiple operations. They usually apply the respective transformation for + every mapped op ("batched execution"). Deviations from this convention are + described in the documentation of Transform IR ops. Overall, Transform IR ops are expected to be contained in a single top-level op. Such top-level ops specify how to apply the transformations described @@ -161,17 +161,18 @@ ## Execution Model - The transformation starts at the specifed top-level transform IR operation - and applies to some payload IR scope, identified by the payload IR op that - contains the IR to transform. It is the responsibility of the user to - properly select the scope and/or to avoid the transformations to modify the - IR outside of the given scope. The top-level transform IR operation may - contain further transform operations and execute them in the desired order. + The transformation starts at the user-specified top-level transform IR + operation and applies to some user-specified payload IR scope, identified by + the payload IR op that contains the IR to transform. It is the + responsibility of the user to properly select the scope and/or to avoid the + transformations to modify the IR outside of the given scope. The top-level + transform IR operation may contain further transform operations and execute + them in the desired order. Transformation application functions produce a tri-state status: - success; - - recoverable (silencable) failure; + - recoverable (silenceable) failure; - irrecoverable failure. Transformation container operations may intercept recoverable failures and @@ -180,9 +181,9 @@ failures, the diagnostics are emitted immediately whereas their emission is postponed for recoverable failures. Transformation container operations may also fail to recover from a theoretically recoverable failure, in which case - they are expected to emit the diagnostic and turn the failure into an - irrecoverable one. A recoverable failure produced by applying the top-level - transform IR operation is considered irrecoverable. + they can either propagate it to their parent or emit the diagnostic and turn + the failure into an irrecoverable one. A recoverable failure produced by + applying the top-level transform IR operation is considered irrecoverable. Transformation container operations are allowed to "step over" some nested operations if the application of some previous operation produced a failure. @@ -193,26 +194,18 @@ ## Handle Invalidation - The execution model of the transform dialect expects that a payload IR - operation is associated with _at most one_ transform IR handle. This avoids - the situation when a handle to an operation outlives the operation itself - that can be erased during a transformation triggered through another handle. - - Handles pointing to operations nested in each other are allowed to co-exist - in the transform IR. However, a transform IR operation that consumes such a - handle automatically _invalidates_ all the other handles that are associated - with operations nested in the operations associated with the consumed - handle. Any use of the invalidated handle results in undefined behavior - since the payload IR operations associated with it are likely to have been - mutated or erased. The mere fact of the handle being invalidated does _not_ - trigger undefined behavior, only its appearance as an operand does. - Invalidation applies to the entire handle, even if some of the payload IR - operations associated with it are not nested in payload IR operations - associated with another, consumed handle. - - Note: the restriction on two handles not pointing to the same operation may - be relaxed in the future to follow the invalidation model for nested - operation. + The execution model of the transform dialect allows a payload IR operation + to be associated with _multiple_ handles as well as nested payload IR + operations to be associated with different handles. A transform IR operation + that consumes a handle automatically _invalidates_ all the other handles + associated with the same payload IR operations, or with any of their + descendants, as the consumed handle. Note that the _entire_ handle is + invalidated, even if some of the payload IR operations associated with it + or their ancestors were not associated with the consumed handle. Any use of + the invalidated handle results in undefined behavior since the payload IR + operations associated with it are likely to have been mutated or erased. The + mere fact of the handle being invalidated does _not_ trigger undefined + behavior, only its appearance as an operand does. The Transform dialect infrastructure has the capability of checking whether the transform IR op operand is invalidated before applying the diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -209,11 +209,13 @@ /// TransformOpInterface. The operations implementing this interface and the /// surrounding structure are referred to as transform IR. The operations to /// which transformations apply are referred to as payload IR. The state thus -/// contains the mapping between values defined in the transform IR ops and -/// payload IR ops. It assumes that each value in the transform IR can be used -/// at most once (since transformations are likely to change the payload IR ops -/// the value corresponds to). Checks that transform IR values correspond to -/// disjoint sets of payload IR ops throughout the transformation. +/// contains the many-to-many mapping between values defined in the transform IR +/// ops and payload IR ops. The "expensive-checks" option can be passed to +/// the constructor at transformation execution time that transform IR values +/// used as operands by a transform IR operation are not associated with +/// dangling pointers to payload IR operations that are known to have been +/// erased by previous transformation through the same or a different transform +/// IR value. /// /// A reference to this class is passed as an argument to "apply" methods of the /// transform op interface. Thus the "apply" method can call @@ -235,9 +237,10 @@ /// operations in the payload IR. using TransformOpMapping = DenseMap>; - /// Mapping between a payload IR operation and the transform IR value it is - /// currently associated with. - using TransformOpReverseMapping = DenseMap; + /// Mapping between a payload IR operation and the transform IR values it is + /// associated with. + using TransformOpReverseMapping = + DenseMap>; /// Bidirectional mappings between transform IR values and payload IR /// operations. @@ -249,7 +252,7 @@ public: /// Creates a state for transform ops living in the given region. The parent /// operation of the region. The second argument points to the root operation - /// in the payload IR beind transformed, which may or may not contain the + /// in the payload IR being transformed, which may or may not contain the /// region with transform ops. Additional options can be provided through the /// trailing configuration object. TransformState(Region ®ion, Operation *root, @@ -263,9 +266,10 @@ /// This is helpful for transformations that apply to a particular handle. ArrayRef getPayloadOps(Value value) const; - /// Returns the Transform IR handle for the given Payload IR op if it exists - /// in the state, null otherwise. - Value getHandleForPayloadOp(Operation *op) const; + /// Populates `handles` with all handles pointing to the given Payload IR op. + /// Returns success if such handles exist, failure otherwise. + LogicalResult getHandlesForPayloadOp(Operation *op, + SmallVectorImpl &handles) const; /// Applies the transformation specified by the given transform op and updates /// the state accordingly. @@ -275,13 +279,13 @@ /// list of operations in the payload IR. The arguments must be defined in /// blocks of the currently processed transform IR region, typically after a /// region scope is defined. - LogicalResult mapBlockArguments(BlockArgument argument, - ArrayRef operations) { + void mapBlockArguments(BlockArgument argument, + ArrayRef operations) { #if LLVM_ENABLE_ABI_BREAKING_CHECKS assert(argument.getParentRegion() == regionStack.back() && "mapping block arguments from a region other than the active one"); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - return setPayloadOps(argument, operations); + setPayloadOps(argument, operations); } // Forward declarations to support limited visibility. @@ -379,7 +383,8 @@ const TransformState &getTransformState() const { return state; } /// Replaces the given payload op with another op. If the replacement op is - /// null, removes the association of the payload op with its handle. + /// null, removes the association of the payload op with its handle. Returns + /// failure if the op is not associated with any handle. LogicalResult replacePayloadOp(Operation *op, Operation *replacement); private: @@ -451,20 +456,29 @@ return it->second; } - /// Sets the payload IR ops associated with the given transform IR value. - /// Fails if this would result in multiple transform IR values with uses - /// corresponding to the same payload IR ops. For example, a hypothetical - /// "find function by name" transform op would (indirectly) call this - /// function for its result. Having two such calls in a row with for different - /// values, e.g. coming from different ops: + /// Removes the mapping between the given payload IR operation and the given + /// transform IR value. + void dropReverseMapping(Mappings &mappings, Operation *op, Value value); + + /// Sets the payload IR ops associated with the given transform IR value + /// (handle). A payload op may be associated multiple handles as long as + /// at most one of them gets consumed by further transformations. + /// For example, a hypothetical "find function by name" may be called twice in + /// a row to produce two handles pointing to the same function: /// /// %0 = transform.find_func_by_name { name = "myfunc" } /// %1 = transform.find_func_by_name { name = "myfunc" } /// - /// would lead to both values pointing to the same operation. The second call - /// to setPayloadOps will fail, unless the association with the %0 value is - /// removed first by calling update/removePayloadOps. - LogicalResult setPayloadOps(Value value, ArrayRef targets); + /// which is valid by itself. However, calling a hypothetical "rewrite and + /// rename function" transform on both handles: + /// + /// transform.rewrite_and_rename %0 { new_name = "func" } + /// transform.rewrite_and_rename %1 { new_name = "func" } + /// + /// is invalid given the transformation "consumes" the handle as expressed + /// by side effects. Practically, a transformation consuming a handle means + /// that the associated payload operation may no longer exist. + void setPayloadOps(Value value, ArrayRef targets); /// Forgets the payload IR ops associated with the given transform IR value. void removePayloadOps(Value value); @@ -473,24 +487,18 @@ /// The callback function is called once per associated operation and is /// expected to return the modified operation or nullptr. In the latter case, /// the corresponding operation is no longer associated with the transform IR - /// value. May fail if the operation produced by the update callback is - /// already associated with a different Transform IR handle value. - LogicalResult - updatePayloadOps(Value value, - function_ref callback); - - /// Attempts to record the mapping between the given Payload IR operation and - /// the given Transform IR handle. Fails and reports an error if the operation - /// is already tracked by another handle. - static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op, - Value handle); + /// value. + void updatePayloadOps(Value value, + function_ref callback); /// If the operand is a handle consumed by the operation, i.e. has the "free" /// memory effect associated with it, identifies other handles that are /// pointing to payload IR operations nested in the operations pointed to by - /// the consumed handle. Marks all such handles as invalidated so trigger + /// the consumed handle. Marks all such handles as invalidated to trigger /// errors if they are used. void recordHandleInvalidation(OpOperand &handle); + void recordHandleInvalidationOne(OpOperand &handle, Operation *payloadOp, + Value otherHandle); /// Checks that the operation does not use invalidated handles as operands. /// Reports errors and returns failure if it does. Otherwise, invalidates the @@ -566,9 +574,9 @@ /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait /// to either the list of operations associated with its operand or the root of /// the payload IR, depending on what is available in the context. -LogicalResult -mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, - Operation *op, Region ®ion); +void mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, + Operation *op, + Region ®ion); /// Verification hook for PossibleTopLevelTransformOpTrait. LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); @@ -605,18 +613,17 @@ /// Sets up the mapping between the entry block of the given region of this op /// and the relevant list of Payload IR operations in the given state. The /// state is expected to be already scoped at the region of this operation. - /// Returns failure if the mapping failed, e.g., the value is already mapped. - LogicalResult mapBlockArguments(TransformState &state, Region ®ion) { + void mapBlockArguments(TransformState &state, Region ®ion) { assert(region.getParentOp() == this->getOperation() && "op comes from the wrong region"); - return detail::mapPossibleTopLevelTransformOpBlockArguments( + detail::mapPossibleTopLevelTransformOpBlockArguments( state, this->getOperation(), region); } - LogicalResult mapBlockArguments(TransformState &state) { + void mapBlockArguments(TransformState &state) { assert( this->getOperation()->getNumRegions() == 1 && "must indicate the region to map if the operation has more than one"); - return mapBlockArguments(state, this->getOperation()->getRegion(0)); + mapBlockArguments(state, this->getOperation()->getRegion(0)); } }; diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "transform-dialect" @@ -45,35 +46,28 @@ return iter->getSecond(); } -Value transform::TransformState::getHandleForPayloadOp(Operation *op) const { +LogicalResult transform::TransformState::getHandlesForPayloadOp( + Operation *op, SmallVectorImpl &handles) const { + bool found = false; for (const Mappings &mapping : llvm::make_second_range(mappings)) { - if (Value handle = mapping.reverse.lookup(op)) - return handle; + auto iterator = mapping.reverse.find(op); + if (iterator != mapping.reverse.end()) { + llvm::append_range(handles, iterator->getSecond()); + found = true; + } } - return Value(); -} -LogicalResult transform::TransformState::tryEmplaceReverseMapping( - Mappings &map, Operation *operation, Value handle) { - auto insertionResult = map.reverse.insert({operation, handle}); - if (!insertionResult.second && insertionResult.first->second != handle) { - InFlightDiagnostic diag = operation->emitError() - << "operation tracked by two handles"; - diag.attachNote(handle.getLoc()) << "handle"; - diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; - return diag; - } - return success(); + return success(found); } -LogicalResult -transform::TransformState::setPayloadOps(Value value, - ArrayRef targets) { +void transform::TransformState::setPayloadOps(Value value, + ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); + // TODO: this may go now if (value.use_empty()) - return success(); + return; // Setting new payload for the value without cleaning it first is a misuse of // the API, assert here. @@ -84,25 +78,29 @@ assert(inserted && "value is already associated with another list"); (void)inserted; - // Having multiple handles to the same operation is an error in the transform - // expressed using the dialect and may be constructed by valid API calls from - // valid IR. Emit an error here. - for (Operation *op : targets) { - if (failed(tryEmplaceReverseMapping(mappings, op, value))) - return failure(); - } + for (Operation *op : targets) + mappings.reverse[op].push_back(value); +} - return success(); +void transform::TransformState::dropReverseMapping(Mappings &mappings, + Operation *op, Value value) { + auto it = mappings.reverse.find(op); + if (it != mappings.reverse.end()) + return; + + llvm::erase_value(it->getSecond(), value); + if (it->getSecond().empty()) + mappings.reverse.erase(it); } void transform::TransformState::removePayloadOps(Value value) { Mappings &mappings = getMapping(value); for (Operation *op : mappings.direct[value]) - mappings.reverse.erase(op); + dropReverseMapping(mappings, op, value); mappings.direct.erase(value); } -LogicalResult transform::TransformState::updatePayloadOps( +void transform::TransformState::updatePayloadOps( Value value, function_ref callback) { Mappings &mappings = getMapping(value); auto it = mappings.direct.find(value); @@ -112,60 +110,60 @@ updated.reserve(association.size()); for (Operation *op : association) { - mappings.reverse.erase(op); + dropReverseMapping(mappings, op, value); if (Operation *updatedOp = callback(op)) { updated.push_back(updatedOp); - if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value))) - return failure(); + mappings.reverse[updatedOp].push_back(value); } } std::swap(association, updated); - return success(); } -void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { +void transform::TransformState::recordHandleInvalidationOne( + OpOperand &handle, Operation *payloadOp, Value otherHandle) { ArrayRef potentialAncestors = getPayloadOps(handle.get()); - for (const Mappings &mapping : llvm::make_second_range(mappings)) { - for (const auto &kvp : mapping.reverse) { - // If the op is associated with invalidated handle, skip the check as it - // may be reading invalid IR. - Operation *op = kvp.first; - Value otherHandle = kvp.second; - if (invalidatedHandles.count(otherHandle)) - continue; - - for (Operation *ancestor : potentialAncestors) { - if (!ancestor->isProperAncestor(op)) - continue; - - // Make sure the error-reporting lambda doesn't capture anything - // by-reference because it will go out of scope. Additionally, extract - // location from Payload IR ops because the ops themselves may be - // deleted before the lambda gets called. - Location ancestorLoc = ancestor->getLoc(); - Location opLoc = op->getLoc(); - Operation *owner = handle.getOwner(); - unsigned operandNo = handle.getOperandNumber(); - invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, - otherHandle](Location currentLoc) { - InFlightDiagnostic diag = emitError(currentLoc) - << "op uses a handle invalidated by a " - "previously executed transform op"; - diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops"; - diag.attachNote(owner->getLoc()) - << "invalidated by this transform op that consumes its operand #" - << operandNo - << " and invalidates handles to payload ops nested in payload " - "ops associated with the consumed handle"; - diag.attachNote(ancestorLoc) << "ancestor payload op"; - diag.attachNote(opLoc) << "nested payload op"; - }; - } - } + // If the op is associated with invalidated handle, skip the check as it + // may be reading invalid IR. + if (invalidatedHandles.count(otherHandle)) + return; + + for (Operation *ancestor : potentialAncestors) { + if (!ancestor->isAncestor(payloadOp)) + continue; + + // Make sure the error-reporting lambda doesn't capture anything + // by-reference because it will go out of scope. Additionally, extract + // location from Payload IR ops because the ops themselves may be + // deleted before the lambda gets called. + Location ancestorLoc = ancestor->getLoc(); + Location opLoc = payloadOp->getLoc(); + Operation *owner = handle.getOwner(); + unsigned operandNo = handle.getOperandNumber(); + invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, + otherHandle](Location currentLoc) { + InFlightDiagnostic diag = emitError(currentLoc) + << "op uses a handle invalidated by a " + "previously executed transform op"; + diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops"; + diag.attachNote(owner->getLoc()) + << "invalidated by this transform op that consumes its operand #" + << operandNo + << " and invalidates handles to payload ops nested in payload " + "ops associated with the consumed handle"; + diag.attachNote(ancestorLoc) << "ancestor payload op"; + diag.attachNote(opLoc) << "nested payload op"; + }; } } +void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { + for (const Mappings &mapping : llvm::make_second_range(mappings)) + for (const auto &[payloadOp, otherHandles] : mapping.reverse) + for (Value otherHandle : otherHandles) + recordHandleInvalidationOne(handle, payloadOp, otherHandle); +} + LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( TransformOpInterface transform) { auto memoryEffectsIface = @@ -252,8 +250,7 @@ assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " "current transform op"); - if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) - return DiagnosedSilenceableFailure::definiteFailure(); + setPayloadOps(result, results.get(result.getResultNumber())); } printOnFailureRAII.release(); @@ -273,10 +270,16 @@ LogicalResult transform::TransformState::Extension::replacePayloadOp(Operation *op, Operation *replacement) { - return state.updatePayloadOps(state.getHandleForPayloadOp(op), - [&](Operation *current) { - return current == op ? replacement : current; - }); + SmallVector handles; + if (failed(state.getHandlesForPayloadOp(op, handles))) + return failure(); + + for (Value handle : handles) { + state.updatePayloadOps(handle, [&](Operation *current) { + return current == op ? replacement : current; + }); + } + return success(); } //===----------------------------------------------------------------------===// @@ -311,7 +314,7 @@ // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// -LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( +void transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; if (op->getNumOperands() != 0) @@ -319,7 +322,7 @@ else targets.push_back(state.getTopLevel()); - return state.mapBlockArguments(region.front().getArgument(0), targets); + state.mapBlockArguments(region.front().getArgument(0), targets); } LogicalResult diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -211,8 +211,7 @@ for (Operation *clone : clones) clone->erase(); }); - if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) - return DiagnosedSilenceableFailure::definiteFailure(); + state.mapBlockArguments(reg.front().getArgument(0), clones); bool failed = false; for (Operation &transform : reg.front().without_terminator()) { @@ -285,8 +284,7 @@ for (Operation *op : payloadOps) { auto scope = state.make_region_scope(getBody()); - if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) - return DiagnosedSilenceableFailure::definiteFailure(); + state.mapBlockArguments(getIterationVariable(), {op}); // Execute loop body. for (Operation &transform : getBody().front().without_terminator()) { @@ -512,8 +510,7 @@ transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); - if (failed(mapBlockArguments(state))) - return DiagnosedSilenceableFailure::definiteFailure(); + mapBlockArguments(state); // Apply the sequenced ops one by one. for (Operation &transform : getBodyBlock()->without_terminator()) { @@ -707,8 +704,7 @@ [&]() { state.removeExtension(); }); auto scope = state.make_region_scope(getBody()); - if (failed(mapBlockArguments(state))) - return DiagnosedSilenceableFailure::definiteFailure(); + mapBlockArguments(state); return state.applyTransform(transformOp); } diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -60,3 +60,42 @@ test_print_remark_at_operand %0, "remark" } } + + +// ----- + +// expected-note @below {{ancestor payload op}} +// expected-note @below {{nested payload op}} +module { + + transform.sequence failures(propagate) { + ^bb0(%0: !pdl.operation): + %1 = transform.test_copy_payload %0 + // expected-note @below {{handle to invalidated ops}} + %2 = transform.test_copy_payload %0 + // expected-note @below {{invalidated by this transform op that consumes its operand #0}} + transform.test_consume_operand %1 + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + transform.test_consume_operand %2 + } +} + +// ----- + +// expected-note @below {{ancestor payload op}} +// expected-note @below {{nested payload op}} +module { + + transform.sequence failures(propagate) { + ^bb0(%0: !pdl.operation): + %1 = transform.test_copy_payload %0 + // expected-note @below {{handle to invalidated ops}} + %2 = transform.test_copy_payload %0 + // Consuming two handles in the same operation is invalid if they point + // to overlapping sets of payload IR ops. + // + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates handles}} + transform.test_consume_operand %1, %2 + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -17,14 +17,13 @@ // ----- -// expected-error @below {{operation tracked by two handles}} -%0 = transform.test_produce_param_or_forward_operand 42 -// expected-note @below {{handle}} -%1 = transform.test_produce_param_or_forward_operand from %0 -// expected-note @below {{handle}} -%2 = transform.test_produce_param_or_forward_operand from %0 -transform.test_consume_operand_if_matches_param_or_fail %1[42] -transform.test_consume_operand_if_matches_param_or_fail %2[42] +// It is okay to have multiple handles to the same payload op as long +// as only one of them is consumed. The expensive checks mode is necessary +// to detect double-consumption. +%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } +%1 = transform.test_copy_payload %0 +// expected-remark @below {{succeeded}} +transform.test_consume_operand_if_matches_param_or_fail %0[42] // ----- diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -171,9 +171,13 @@ << "extension present, " << extension->getMessage(); for (Operation *payload : state.getPayloadOps(getOperand())) { diag.attachNote(payload->getLoc()) << "associated payload op"; - assert(state.getHandleForPayloadOp(payload) == getOperand() && +#ifndef NDEBUG + SmallVector handles; + assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); + assert(llvm::is_contained(handles, getOperand()) && "inconsistent mapping between transform IR handles and payload IR " "operations"); +#endif // NDEBUG } return DiagnosedSilenceableFailure::success(); @@ -297,6 +301,13 @@ transform::onlyReadsHandle(getHandle(), effects); } +DiagnosedSilenceableFailure +mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + results.set(getCopy().cast(), state.getPayloadOps(getHandle())); + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -38,8 +38,10 @@ [DeclareOpInterfaceMethods]> { let arguments = (ins Arg:$operand); - let assemblyFormat = "$operand attr-dict"; + [TransformMappingRead, TransformMappingFree]>:$operand, + Arg, "", + [TransformMappingRead, TransformMappingFree]>:$second_operand); + let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict"; let cppNamespace = "::mlir::test"; } @@ -231,4 +233,14 @@ let cppNamespace = "::mlir::test"; } +def TestCopyPayloadOp + : Op]> { + let arguments = (ins Arg:$handle); + let results = (outs Res:$copy); + let cppNamespace = "::mlir::test"; + let assemblyFormat = "$handle attr-dict"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD