diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -74,6 +74,10 @@ /// This is helpful for transformations that apply to a particular handle. ArrayRef getPayloadOps(Value value) const; + /// Returns the Transform IR handle for the given Payload IR op if it exists + /// in the state, null otherwise. + Value getHandleForPayloadOp(Operation *op) const; + /// Applies the transformation specified by the given transform op and updates /// the state accordingly. LogicalResult applyTransform(TransformOpInterface transform); @@ -185,6 +189,10 @@ /// Provides read-only access to the parent TransformState object. const TransformState &getTransformState() const { return state; } + /// Replaces the given payload op with another op. If the replacement op is + /// null, removes the association of the payload op with its handle. + LogicalResult replacePayloadOp(Operation *op, Operation *replacement); + private: /// Back-reference to the state that is being extended. TransformState &state; @@ -276,9 +284,17 @@ /// The callback function is called once per associated operation and is /// expected to return the modified operation or nullptr. In the latter case, /// the corresponding operation is no longer associated with the transform IR - /// value. - void updatePayloadOps(Value value, - function_ref callback); + /// value. May fail if the operation produced by the update callback is + /// already associated with a different Transform IR handle value. + LogicalResult + updatePayloadOps(Value value, + function_ref callback); + + /// Attempts to record the mapping between the given Payload IR operation and + /// the given Transform IR handle. Fails and reports an error if the operation + /// is already tracked by another handle. + static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op, + Value handle); /// The mappings between transform IR values and payload IR ops, aggregated by /// the region in which the transform IR values are defined. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -41,6 +41,27 @@ return iter->getSecond(); } +Value transform::TransformState::getHandleForPayloadOp(Operation *op) const { + for (const Mappings &mapping : llvm::make_second_range(mappings)) { + if (Value handle = mapping.reverse.lookup(op)) + return handle; + } + return Value(); +} + +LogicalResult transform::TransformState::tryEmplaceReverseMapping( + Mappings &map, Operation *operation, Value handle) { + auto insertionResult = map.reverse.insert({operation, handle}); + if (!insertionResult.second) { + InFlightDiagnostic diag = operation->emitError() + << "operation tracked by two handles"; + diag.attachNote(handle.getLoc()) << "handle"; + diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; + return diag; + } + return success(); +} + LogicalResult transform::TransformState::setPayloadOps(Value value, ArrayRef targets) { @@ -63,14 +84,8 @@ // expressed using the dialect and may be constructed by valid API calls from // valid IR. Emit an error here. for (Operation *op : targets) { - auto insertionResult = mappings.reverse.insert({op, value}); - if (!insertionResult.second) { - InFlightDiagnostic diag = op->emitError() - << "operation tracked by two handles"; - diag.attachNote(value.getLoc()) << "handle"; - diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; - return diag; - } + if (failed(tryEmplaceReverseMapping(mappings, op, value))) + return failure(); } return success(); @@ -83,19 +98,26 @@ mappings.direct.erase(value); } -void transform::TransformState::updatePayloadOps( +LogicalResult transform::TransformState::updatePayloadOps( Value value, function_ref callback) { - auto it = getMapping(value).direct.find(value); - assert(it != getMapping(value).direct.end() && "unknown handle"); + Mappings &mappings = getMapping(value); + auto it = mappings.direct.find(value); + assert(it != mappings.direct.end() && "unknown handle"); SmallVector &association = it->getSecond(); SmallVector updated; updated.reserve(association.size()); - for (Operation *op : association) - if (Operation *updatedOp = callback(op)) + for (Operation *op : association) { + mappings.reverse.erase(op); + if (Operation *updatedOp = callback(op)) { updated.push_back(updatedOp); + if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value))) + return failure(); + } + } std::swap(association, updated); + return success(); } LogicalResult @@ -132,8 +154,21 @@ return success(); } +//===----------------------------------------------------------------------===// +// TransformState::Extension +//===----------------------------------------------------------------------===// + transform::TransformState::Extension::~Extension() = default; +LogicalResult +transform::TransformState::Extension::replacePayloadOp(Operation *op, + Operation *replacement) { + return state.updatePayloadOps(state.getHandleForPayloadOp(op), + [&](Operation *current) { + return current == op ? replacement : current; + }); +} + //===----------------------------------------------------------------------===// // TransformResults //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -split-input-file + +// expected-note @below {{associated payload op}} +module { + transform.sequence { + ^bb0(%arg0: !pdl.operation): + // expected-remark @below {{extension absent}} + test_check_if_test_extension_present %arg0 + test_add_test_extension "A" + // expected-remark @below {{extension present, A}} + test_check_if_test_extension_present %arg0 + test_remove_test_extension + // expected-remark @below {{extension absent}} + test_check_if_test_extension_present %arg0 + } +} + +// ----- + +// expected-note @below {{associated payload op}} +module { + transform.sequence { + ^bb0(%arg0: !pdl.operation): + test_add_test_extension "A" + test_remove_test_extension + test_add_test_extension "B" + // expected-remark @below {{extension present, B}} + test_check_if_test_extension_present %arg0 + } +} + +// ----- + +// expected-note @below {{associated payload op}} +module { + transform.sequence { + ^bb0(%arg0: !pdl.operation): + test_add_test_extension "A" + // expected-remark @below {{extension present, A}} + test_check_if_test_extension_present %arg0 + // expected-note @below {{associated payload op}} + test_remap_operand_to_self %arg0 + // expected-remark @below {{extension present, A}} + test_check_if_test_extension_present %arg0 + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -12,10 +12,10 @@ //===----------------------------------------------------------------------===// #include "TestTransformDialectExtension.h" +#include "TestTransformStateExtension.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" using namespace mlir; @@ -142,6 +142,49 @@ return success(); } +LogicalResult +mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + state.addExtension(getMessageAttr()); + return success(); +} + +LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + auto *extension = state.getExtension(); + if (!extension) { + emitRemark() << "extension absent"; + return success(); + } + + InFlightDiagnostic diag = emitRemark() + << "extension present, " << extension->getMessage(); + for (Operation *payload : state.getPayloadOps(getOperand())) { + diag.attachNote(payload->getLoc()) << "associated payload op"; + assert(state.getHandleForPayloadOp(payload) == getOperand() && + "inconsistent mapping between transform IR handles and payload IR " + "operations"); + } + + return success(); +} + +LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + auto *extension = state.getExtension(); + if (!extension) + return emitError() << "TestTransformStateExtension missing"; + + return extension->updateMapping(state.getPayloadOps(getOperand()).front(), + getOperation()); +} + +LogicalResult mlir::test::TestRemoveTestExtensionOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + state.removeExtension(); + return success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -56,4 +56,41 @@ let cppNamespace = "::mlir::test"; } +def TestAddTestExtensionOp + : Op, + NoSideEffect]> { + let arguments = (ins StrAttr:$message); + let assemblyFormat = "$message attr-dict"; + let cppNamespace = "::mlir::test"; +} + +def TestCheckIfTestExtensionPresentOp + : Op]> { + let arguments = (ins + Arg:$operand); + let assemblyFormat = "$operand attr-dict"; + let cppNamespace = "::mlir::test"; +} + +def TestRemapOperandPayloadToSelfOp + : Op]> { + let arguments = (ins + Arg:$operand); + let assemblyFormat = "$operand attr-dict"; + let cppNamespace = "::mlir::test"; +} + +def TestRemoveTestExtensionOp + : Op, + NoSideEffect]> { + let assemblyFormat = "attr-dict"; + let cppNamespace = "::mlir::test"; +} + + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h @@ -0,0 +1,42 @@ +//===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an TransformState extension for the purpose of testing the +// relevant APIs. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H +#define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +using namespace mlir; + +namespace mlir { +namespace test { +class TestTransformStateExtension + : public transform::TransformState::Extension { +public: + TestTransformStateExtension(transform::TransformState &state, + StringAttr message) + : Extension(state), message(message) {} + + StringRef getMessage() const { return message.getValue(); } + + LogicalResult updateMapping(Operation *previous, Operation *updated) { + return replacePayloadOp(previous, updated); + } + +private: + StringAttr message; +}; +} // namespace test +} // namespace mlir + +#endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H