diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -151,6 +151,37 @@ transform operations can return _new_ handles that can be read or consumed by subsequent operations. + ## Handle Invalidation + + The execution model of the transform dialect expects that a payload IR + operation is associated with _at most one_ transform IR handle. This avoids + the situation when a handle to an operation outlives the operation itself + that can be erased during a tranfsormation triggered through another handle. + + Handles pointing to operations nested in each other are allowed to co-exist + in the transform IR. However, a transform IR operation that consumes such a + handle automatically _invalidates_ all the other handles that are associated + with operations nested in the operations associated with the consumed + handle. Any use of the invalidated handle results in undefined behavior + since the payload IR operations associated with it are likely to have been + mutated or erased. The mere fact of the handle being invalidated does _not_ + trigger undefined behavior, only its appearance as an operand does. + Invalidation applies to the entire handle, even if some of the payload IR + operations associated with it are not nested in payload IR operations + associated with another, consumed handle. + + Note: the restriction on two handles not pointing to the same operation may + be relaxed in the future to follow the invalidation model for nested + operation. + + The Transform dialect infrastructure has the capability of checking whether + the transform IR op operand is invalidated before applying the + transfomration. However, such a check is computationally expensive and + must be enabled explicitly through `TransformOptions`. Additionally, the + `transform-dialect-check-uses` pass emits warnings when a handle may be used + after it has been consumed, but does so abstractly, without processing the + payload IR. + ## Intended Use and Integrations The transformation control infrastructure provided by this dialect is diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -18,6 +18,27 @@ class TransformOpInterface; +/// Options controlling the application of transform operations by the +/// TransformState. +class TransformOptions { +public: + TransformOptions() {} + + /// Requests computationally expensive checks of the transform and payload IR + /// well-formedness to be performed before each transformation. In particular, + /// these ensure that the handles still point to valid operations when used. + TransformOptions &enableExpensiveChecks(bool enable = true) { + expensiveChecksEnabled = enable; + return *this; + } + + /// Returns true if the expensive checks are requested. + bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; } + +private: + bool expensiveChecksEnabled = true; +}; + /// The state maintained across applications of various ops implementing the /// TransformOpInterface. The operations implementing this interface and the /// surrounding structure are referred to as transform IR. The operations to @@ -63,8 +84,10 @@ /// Creates a state for transform ops living in the given region. The parent /// operation of the region. The second argument points to the root operation /// in the payload IR beind transformed, which may or may not contain the - /// region with transform ops. - TransformState(Region ®ion, Operation *root); + /// region with transform ops. Additional options can be provided through the + /// trailing configuration object. + TransformState(Region ®ion, Operation *root, + const TransformOptions &options = TransformOptions()); /// Returns the op at which the transformation state is rooted. This is /// typically helpful for transformations that apply globally. @@ -296,6 +319,21 @@ static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op, Value handle); + /// If the operand is a handle consumed by the operation, i.e. has the "free" + /// memory effect associated with it, identifies other handles that are + /// pointing to payload IR operations nested in the operations pointed to by + /// the consumed handle. Marks all such handles as invalidated so trigger + /// errors if they are used. + void recordHandleInvalidation(OpOperand &handle); + + /// Checks that the operation does not use invalidated handles as operands. + /// Reports errors and returns failure if it does. Otherwise, invalidates the + /// handles consumed by the operation as well as any handles pointing to + /// payload IR operations nested in the operations associated with the + /// consumed handles. + LogicalResult + checkAndRecordHandleInvalidation(TransformOpInterface transform); + /// The mappings between transform IR values and payload IR ops, aggregated by /// the region in which the transform IR values are defined. llvm::SmallDenseMap mappings; @@ -307,6 +345,14 @@ /// The top-level operation that contains all payload IR, typically a module. Operation *topLevel; + /// Additional options controlling the transformation state behavior. + TransformOptions options; + + /// The mapping from invalidated handles to the error-reporting functions that + /// describe when the handles were invalidated. Calling such a function emits + /// a user-visible diagnostic. + DenseMap> invalidatedHandles; + #if LLVM_ENABLE_ABI_BREAKING_CHECKS /// A stack of nested regions that are being processed in the transform IR. /// Each region must be an ancestor of the following regions in this list. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -21,8 +21,9 @@ constexpr const Value transform::TransformState::kTopLevelValue; -transform::TransformState::TransformState(Region ®ion, Operation *root) - : topLevel(root) { +transform::TransformState::TransformState(Region ®ion, Operation *root, + const TransformOptions &options) + : topLevel(root), options(options) { auto result = mappings.try_emplace(®ion); assert(result.second && "the region scope is already present"); (void)result; @@ -120,8 +121,78 @@ return success(); } +void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { + ArrayRef potentialAncestors = getPayloadOps(handle.get()); + for (const Mappings &mapping : llvm::make_second_range(mappings)) { + for (const auto &kvp : mapping.reverse) { + // If the op is associated with invalidated handle, skip the check as it + // may be reading invalid IR. + Operation *op = kvp.first; + Value otherHandle = kvp.second; + if (invalidatedHandles.count(otherHandle)) + continue; + + for (Operation *ancestor : potentialAncestors) { + if (!ancestor->isProperAncestor(op)) + continue; + + // Make sure the error-reporting lambda doesn't capture anything + // by-reference because it will go out of scope. Additionally, extract + // location from Payload IR ops because the ops themselves may be + // deleted before the lambda gets called. + Location ancestorLoc = ancestor->getLoc(); + Location opLoc = op->getLoc(); + Operation *owner = handle.getOwner(); + unsigned operandNo = handle.getOperandNumber(); + invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, + otherHandle]() { + InFlightDiagnostic diag = + owner->emitOpError() + << "invalidated the handle to payload operations nested in the " + "payload operation associated with its operand #" + << operandNo; + diag.attachNote(ancestorLoc) << "ancestor op"; + diag.attachNote(opLoc) << "nested op"; + diag.attachNote(otherHandle.getLoc()) << "other handle"; + }; + } + } + } +} + +LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( + TransformOpInterface transform) { + auto memoryEffectsIface = + cast(transform.getOperation()); + SmallVector effects; + memoryEffectsIface.getEffectsOnResource( + transform::TransformMappingResource::get(), effects); + + for (OpOperand &target : transform->getOpOperands()) { + // If the operand uses an invalidated handle, report it. + auto it = invalidatedHandles.find(target.get()); + if (it != invalidatedHandles.end()) + return it->getSecond()(), failure(); + + // Invalidate handles pointing to the operations nested in the operation + // associated with the handle consumed by this operation. + auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) { + return isa(effect.getEffect()) && + effect.getValue() == target.get(); + }; + if (llvm::find_if(effects, consumesTarget) != effects.end()) + recordHandleInvalidation(target); + } + return success(); +} + LogicalResult transform::TransformState::applyTransform(TransformOpInterface transform) { + if (options.getExpensiveChecksEnabled() && + failed(checkAndRecordHandleInvalidation(transform))) { + return failure(); + } + transform::TransformResults results(transform->getNumResults()); if (failed(transform.apply(results, *this))) return failure(); @@ -131,23 +202,23 @@ auto memEffectInterface = cast(transform.getOperation()); SmallVector effects; - for (Value target : transform->getOperands()) { + for (OpOperand &target : transform->getOpOperands()) { effects.clear(); - memEffectInterface.getEffectsOnValue(target, effects); + memEffectInterface.getEffectsOnValue(target.get(), effects); if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { return isa( effect.getResource()) && isa(effect.getEffect()); })) { - removePayloadOps(target); + removePayloadOps(target.get()); } } - for (auto &en : llvm::enumerate(transform->getResults())) { - assert(en.value().getDefiningOp() == transform.getOperation() && + for (OpResult result : transform->getResults()) { + assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " "current transform op"); - if (failed(setPayloadOps(en.value(), results.get(en.index())))) + if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) return failure(); } diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter='enable-expensive-checks=1' --split-input-file --verify-diagnostics %s + +// expected-note @below {{ancestor op}} +func.func @func() { + // expected-note @below {{nested op}} + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @return : benefit(1) { + %0 = operands + %1 = types + %2 = operation "func.return"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + // expected-note @below {{other handle}} + %0 = pdl_match @return in %arg1 + %1 = get_closest_isolated_parent %0 + // expected-error @below {{invalidated the handle to payload operations nested in the payload operation associated with its operand #0}} + test_consume_operand %1 + test_print_remark_at_operand %0, "remark" + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -119,6 +119,12 @@ return success(); } +LogicalResult +mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, + transform::TransformState &state) { + return success(); +} + LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -34,6 +34,15 @@ let hasVerifier = 1; } +def TestConsumeOperand : Op]> { + let arguments = (ins + Arg:$operand); + let assemblyFormat = "$operand attr-dict"; + let cppNamespace = "::mlir::test"; +} + def TestConsumeOperandIfMatchesParamOrFail : Op]> { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -27,6 +27,10 @@ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTransformDialectInterpreterPass) + TestTransformDialectInterpreterPass() = default; + TestTransformDialectInterpreterPass( + const TestTransformDialectInterpreterPass &) {} + StringRef getArgument() const override { return "test-transform-dialect-interpreter"; } @@ -37,13 +41,21 @@ void runOnOperation() override { ModuleOp module = getOperation(); - transform::TransformState state(module.getBodyRegion(), module); + transform::TransformState state( + module.getBodyRegion(), module, + transform::TransformOptions().enableExpensiveChecks( + enableExpensiveChecks)); for (auto op : module.getBody()->getOps()) { if (failed(state.applyTransform(op))) return signalPassFailure(); } } + + Option enableExpensiveChecks{ + *this, "enable-expensive-checks", llvm::cl::init(false), + llvm::cl::desc("perform expensive checks to better report errors in the " + "transform IR")}; }; } // namespace