diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -313,8 +313,16 @@ /// Replaces the given payload op with another op. If the replacement op is /// null, removes the association of the payload op with its handle. Returns /// failure if the op is not associated with any handle. + /// + /// Note: This function does not update value handles. None of the original + /// op's results are allowed to be mapped to any value handle. LogicalResult replacePayloadOp(Operation *op, Operation *replacement); + /// Replaces the given payload value with another value. If the replacement + /// value is null, removes the association of the payload value with its + /// handle. Returns failure if the value is not associated with any handle. + LogicalResult replacePayloadValue(Value value, Value replacement); + private: /// Back-reference to the state that is being extended. TransformState &state; @@ -484,18 +492,18 @@ void forgetValueMapping(Value valueHandle, ArrayRef payloadOperations); - /// Updates the payload IR ops associated with the given transform IR value. - /// The callback function is called once per associated operation and is - /// expected to return the modified operation or nullptr. In the latter case, - /// the corresponding operation is no longer associated with the transform IR - /// value. Value handles associated with the results of the operation are - /// also updated to be associated with the results of the new operation. For - /// this reason, the new operation must have the same number of results. + /// Replaces the given payload op with another op. If the replacement op is + /// null, removes the association of the payload op with its handle. /// - /// Returns failure if the payload does not satisfy the conditions associated - /// with the type of the handle value. + /// Note: This function does not update value handles. None of the original + /// op's results are allowed to be mapped to any value handle. LogicalResult replacePayloadOp(Operation *op, Operation *replacement); + /// Replaces the given payload value with another value. If the replacement + /// value is null, removes the association of the payload value with its + /// handle. + LogicalResult replacePayloadValue(Value value, Value replacement); + /// If the operand is a handle consumed by the operation, i.e. has the "free" /// memory effect associated with it, identifies other handles that are /// pointing to payload IR operations nested in the operations pointed to by diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -336,18 +336,13 @@ dropMappingEntry(mappings.reverse, op, handle); } - // Drop the mapping between the op results and all value handles that point to - // them. Don't care if there are no such handles. - RaggedArray resultValueHandles; +#ifndef NDEBUG for (Value opResult : op->getResults()) { SmallVector valueHandles; (void)getHandlesForPayloadValue(opResult, valueHandles); - for (Value handle : valueHandles) { - Mappings &localMappings = getMapping(handle); - dropMappingEntry(localMappings.reverseValues, opResult, handle); - } - resultValueHandles.push_back(std::move(valueHandles)); + assert(valueHandles.empty() && "expected no mapping to old results"); } +#endif // NDEBUG // TODO: consider invalidating the handles to nested objects here. @@ -358,14 +353,6 @@ Mappings &mappings = getMapping(handle); dropMappingEntry(mappings.direct, handle, op); } - for (Value opResult : op->getResults()) { - SmallVector valueHandles; - (void)getHandlesForPayloadValue(opResult, valueHandles); - for (Value handle : valueHandles) { - Mappings &localMappings = getMapping(handle); - dropMappingEntry(localMappings.values, handle, opResult); - } - } return success(); } @@ -386,33 +373,33 @@ mappings.reverse[replacement].push_back(handle); } - // Second, replace the mapped results of the operation. - for (auto [origResult, handleList] : - llvm::zip(op->getResults(), resultValueHandles)) { - // No handles to the value, skip even if there is no replacement. - if (handleList.empty()) - continue; + return success(); +} - unsigned resultNumber = origResult.getResultNumber(); - if (resultNumber >= replacement->getNumResults()) { - return emitError(op->getLoc()) - << "cannot replace an op with another op producing less results " - "while tracking handles"; - } +LogicalResult +transform::TransformState::replacePayloadValue(Value value, Value replacement) { + SmallVector valueHandles; + (void)getHandlesForPayloadValue(value, valueHandles); + + for (Value handle : valueHandles) { + Mappings &mappings = getMapping(handle); + dropMappingEntry(mappings.reverseValues, value, handle); - Value replacementResult = replacement->getResult(resultNumber); - for (Value resultHandle : handleList) { - Mappings &mappings = getMapping(resultHandle); - auto it = mappings.values.find(resultHandle); + // If replacing with null, that is erasing the mapping, drop the mapping + // between the handles and the IR objects + if (!replacement) { + dropMappingEntry(mappings.values, handle, value); + } else { + auto it = mappings.values.find(handle); if (it == mappings.values.end()) continue; SmallVector &association = it->getSecond(); for (Value &mapped : association) { - if (mapped == origResult) - mapped = replacementResult; + if (mapped == value) + mapped = replacement; } - mappings.reverseValues[replacementResult].push_back(resultHandle); + mappings.reverseValues[replacement].push_back(handle); } } @@ -867,6 +854,16 @@ return state.replacePayloadOp(op, replacement); } +LogicalResult +transform::TransformState::Extension::replacePayloadValue(Value value, + Value replacement) { + SmallVector handles; + if (failed(state.getHandlesForPayloadValue(value, handles))) + return failure(); + + return state.replacePayloadValue(value, replacement); +} + //===----------------------------------------------------------------------===// // TransformResults //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir --- a/mlir/test/Dialect/Transform/transform-state-extension.mlir +++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir @@ -50,7 +50,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): test_add_test_extension "A" - // This is okay because we are replacing the top-level module opeation + // This is okay because we are replacing the top-level module operation // (0 results) with this operation that has _more_ (1) results. %dummy = test_remap_operand_to_self %arg0 : !pdl.operation } @@ -72,7 +72,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): test_add_test_extension "A" - // expected-error @below {{cannot replace an op with another op producing less results while tracking handles}} + // expected-error @below {{cannot replace an op with another op producing fewer results while tracking handles}} %dummy = test_remap_operand_to_self %arg0 : !pdl.operation %valuehandle = transform.get_result %dummy[0] : (!pdl.operation) -> !transform.any_value test_remap_operand_to_self %dummy diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(MLIRTestTransformDialect TestTransformDialectExtension.cpp TestTransformDialectInterpreter.cpp + TestTransformStateExtension.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h --- a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h +++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h @@ -29,9 +29,7 @@ StringRef getMessage() const { return message.getValue(); } - LogicalResult updateMapping(Operation *previous, Operation *updated) { - return replacePayloadOp(previous, updated); - } + LogicalResult updateMapping(Operation *previous, Operation *updated); private: StringAttr message; diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp @@ -0,0 +1,36 @@ +//===- TestTransformStateExtension.cpp ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestTransformStateExtension.h" + +using namespace mlir; + +LogicalResult +test::TestTransformStateExtension::updateMapping(Operation *previous, + Operation *updated) { + // Update value handles. The new ops should have at least as many results as + // the replacement op. Fewer results are acceptable, if those results are not + // mapped to any handle. + for (auto r = updated->getNumResults(); r < previous->getNumResults(); ++r) { + SmallVector handles; + (void)getTransformState().getHandlesForPayloadValue(previous->getResult(r), + handles); + if (!handles.empty()) + return emitError(previous->getLoc()) + << "cannot replace an op with another op producing fewer results " + "while tracking handles"; + } + + for (auto [oldValue, newValue] : + llvm::zip(previous->getResults(), updated->getResults())) + if (failed(replacePayloadValue(oldValue, newValue))) + return failure(); + + // Update op handle. + return replacePayloadOp(previous, updated); +}