diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md --- a/mlir/docs/Dialects/Transform.md +++ b/mlir/docs/Dialects/Transform.md @@ -25,20 +25,19 @@ transformed as payload IR, and to the IR guiding the transformation as transform IR. -The main use case for this dialect is orchestrating fine-grain -transformations on individual operations or sets thereof. For example, it -may involve finding loop-like operations with specific properties (e.g., -large size) in the payload IR, applying loop tiling to those and only those -operations, and then applying loop unrolling to the inner loops produced -by the previous transformations. As such, it is not intended as a -replacement for the pass infrastructure, nor for the pattern rewriting -infrastructure. In the most common case, the transform IR will be processed -and applied to the payload IR by a pass. Transformations expressed by the -transform dialect may be implemented using the pattern infrastructure or any -other relevant MLIR component. +The main use case for this dialect is orchestrating fine-grain transformations +on individual IR objects (operations or values) or sets thereof. For example, it +may involve finding loop-like operations with specific properties (e.g., large +size) in the payload IR, applying loop tiling to those and only those +operations, and then applying loop unrolling to the inner loops produced by the +previous transformations. As such, it is not intended as a replacement for the +pass infrastructure, nor for the pattern rewriting infrastructure. In the most +common case, the transform IR will be processed and applied to the payload IR by +a pass. Transformations expressed by the transform dialect may be implemented +using the pattern infrastructure or any other relevant MLIR component. The following IR gives a rough idea of what the operations in this dialect -may look like: +may look like without using actually existing operations: ```mlir %0 = transform.loop.find { size > 42 } : !transform.interface @@ -46,57 +45,70 @@ %2:2 = transform.loop.tile %0 tile_sizes(1, 4, %1) : (!transform.interface) -> (!transform.op, !transform.op) +%3 = transform.get_op_result [0] %2#0 : !transform.any_value +transform.assign_to_fast_memory %3 transform.loop.unroll %1#1 : !transform.op ``` -The values used in the Transform dialect may correspond to either: +The values used in the Transform dialect may correspond to: * sets of operations in the payload IR; + * sets of values in the payload IR; + * sets of parameters (attributes) known at the execution time of the transform dialect. -The former kind of values is also referred to as *handles*. In the example -above, `%0` corresponds to the set of loops found in the payload IR that -satisfy the condition, and `%2` correspond to groups of outer and inner -loops, respectively, produced by the tiling transformation, whereas `%1` -corresponds to a list of tile sizes selected for each of the operations -that `%0` corresponds to. +The former two kinds of values are also referred to as operation and value +*handles*, respectively. In the example above, `%0` corresponds to the set of +loops found in the payload IR that satisfy the condition, and `%2` correspond to +groups of outer and inner loops, respectively, produced by the tiling +transformation. `%3` corresponds to a set of values that are produced by the +outer loops after tiling. `%1` corresponds to a list of tile sizes selected for +each of the operations that `%0` corresponds to. -A transform handle such as `%0` may be associated with multiple payload +An operation handle such as `%0` may be associated with multiple payload operations. This is conceptually a set of operations and no assumptions should be made about the order of ops unless specified otherwise by the operation. -Operations may take as operands and produce an arbitrary combination of values -representing handles and parameters. Most Transform IR ops support operand -values that are mapped to multiple operations. They usually apply the respective -transformation for every mapped op ("batched execution"). Deviations from this -convention are described in the documentation of Transform IR ops. - -The transform IR values have transform IR types, which implement either -[TransformHandleTypeInterface](Transform.md#transformhandletypeinterface-transformhandletypeinterface) -or -[TransformParamTypeInterface](Transform.md##transformparamtypeinterface-transformparamtypeinterface). -The former interface verifiers properties of payload IR operations associated -with the value that are known to the transform dialect, for example, all -associated payload operations implement a "TileableOp" interface, or have a -specific "loop" kind. Similarly, the latter interface verifies properties of -attributes associated with the parameter value. These properties are used to -statically indicate pre- and post-conditions of a transformation connected to a -Transform dialect operation. The conditions are verified when attributes or -payload IR operations are first associated with a transform handle. By -convention, Transform dialect operations are expected to indicate narrow -preconditions for their operands by enforcing operand type constraints in the -their definitions and verifiers. On the contrary, operations are expected to -have few constraints on their results. Specific instances of a transform -operation can then be created with a more restricted result type than the -constraint in the operation (e.g., the "find" operation only constrains the -result type to be a transform IR type while its concrete instance can have a -type with stricter constraints such as implementing the "tilable" interface). -The verification will then happen at transform execution time. This approach -allows one to capture payload IR operation properties in the transform IR -without resorting to excessive use of type casts or coupling dialect extensions -between themselves. It is a trade-off between verbosity/complexity and static -hardening, which can be revised in the future. +Similarly, a value handle such as `%3` may be associated with a set of payload +IR values. Transform dialect operations may take as operands and produce an +arbitrary combination of values representing handles and parameters. Most +Transform IR ops support operand values that are mapped to multiple payload +objects. They usually apply the respective transformation for every mapped +object ("batched execution"). Deviations from this convention are described in +the documentation of Transform IR ops. + +The transform IR values have transform IR types, which should implement exactly one of: + + * [TransformHandleTypeInterface](Transform.md#transformhandletypeinterface-transformhandletypeinterface), + + * [TransformValueHandleTypeInterface](Transform.md#transformvaluehandletypeinterface-transformvaluehandletypeinterface), + + * [TransformParamTypeInterface](Transform.md##transformparamtypeinterface-transformparamtypeinterface). + +The goal of these type interfaces, beyond providing a common base for accepted +types, is to verify the properties of the associated objects. For example, a +handle type interface implementation may check whether all associated payload IR +operations implement the "TileableOp" interface or have a specific "loop" kind. +Similarly, a value handle type interface implementation may check if the +associated payload IR values are block arguments or have a specific type, or a +parameter type interface may check whether the associated attributes contain +non-negative integer values. These properties are used to statically indicate + pre- and post-conditions of a transformation connected to a Transform dialect +operation. The conditions are verified when payload objects operations are first +associated with a transform handle. By convention, Transform dialect operations +are expected to indicate narrow preconditions for their operands by enforcing +operand type constraints in the their definitions and verifiers. On the +contrary, operations are expected to have few constraints on their results. +Specific instances of a transform operation can then be created with a more +restricted result type than the constraint in the operation (e.g., the "find" +operation only constrains the result type to be a transform IR type while its +concrete instance can have a type with stricter constraints such as implementing +the "tilable" interface). The verification will then happen at transform +execution time. This approach allows one to capture payload IR operation +properties in the transform IR without resorting to excessive use of type casts +or coupling dialect extensions between themselves. It is a trade-off between +verbosity/complexity and static hardening, which can be revised in the future. Overall, Transform IR ops are expected to be contained in a single top-level op. Such top-level ops specify how to apply the transformations described @@ -111,7 +123,7 @@ ```c++ LogicalResult transform::applyTransforms( Operation *payloadRoot, - ArrayRef> extraMappings, + const RaggedArray &extraMappings, TransformOpInterface transform, const TransformOptions &options); ``` @@ -163,7 +175,7 @@ the same extension mechanism. The types must: * Implement exactly one of `TransformHandleTypeInterface`, - `TransformParamTypeInterface`. + `TransformValueHandleTypeInterface`, `TransformParamTypeInterface`. ## Side Effects @@ -255,18 +267,57 @@ ## Handle Invalidation -The execution model of the transform dialect allows a payload IR operation -to be associated with _multiple_ handles as well as nested payload IR -operations to be associated with different handles. A transform IR operation -that consumes a handle automatically _invalidates_ all the other handles -associated with the same payload IR operations, or with any of their -descendants, as the consumed handle. Note that the _entire_ handle is -invalidated, even if some of the payload IR operations associated with it -or their ancestors were not associated with the consumed handle. Any use of -the invalidated handle results in undefined behavior since the payload IR -operations associated with it are likely to have been mutated or erased. The -mere fact of the handle being invalidated does _not_ trigger undefined -behavior, only its appearance as an operand does. +The execution model of the transform dialect allows a payload IR operation to be +associated with _multiple_ handles as well as nested payload IR operations to be +associated with different handles. Similarly, a payload IR value may be +associated with multiple transform IR value handles. When a transform IR +operation consumes a handle, it usually indicates that the corresponding payload +IR object was destroyed and should no longer be referenced. Transform IR handles +that _may_ be pointing to an erased payload IR object are _invalidated_. The +mere presence of an invalidated handle in the transform IR is not a problem, but +_using_ it results in undefined behavior. Invalidated handles can be thought of +as dangling pointers. Note that the _entire_ handle is invalidated, even if some +of the payload IR objects associated with it remain live. + +The following handle invalidation rules apply. + + * When an operation handle is consumed, are invalidated: + + - operation handles associated with one of the payload operations that the + consumed handle is associated with; + + - operation handles associated with one of the operations _nested_ in the + payload operations described above; + + - value handles associated with any result of any operation described above; + + - value handles associated with any argument of a block contained in a + region attached to any operation described above. + + * When a value handle is consumed, are invalidated: + + - operation handles associated with payload operations that produce as + result any value associated with the consumed handle (when the associated + is an operation result); + + - operation handles associated with payload operations _nested_ in the + payload operations described above; + + - operation handles associated with payload operations (recursively) + _contained_ in the block that defines as argument any value associated + with the consumed handle (when the associated value is a block argument); + note that the adjacent blocks are not affected; + + - value handles associated with any result of any operation described above, + including all results of the operation defining as result the value + associated with the consumed handle; + + - value handles associated with any argument of a block contained in a + region attached to any operation described above. + +More intuitively, consuming a handle invalidates any handle that may be pointing +to an object defined or contained in the payload IR subtree rooted at the +closest operation or block. The Transform dialect infrastructure has the capability of checking whether the transform IR op operand is invalidated before applying the diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -11,8 +11,8 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" +#include "mlir/Dialect/Transform/Utils/RaggedArray.h" #include "mlir/IR/OpDefinition.h" - #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LogicalResult.h" @@ -45,7 +45,7 @@ }; using Param = Attribute; -using MappedValue = llvm::PointerUnion; +using MappedValue = llvm::PointerUnion; /// Entry point to the Transform dialect infrastructure. Applies the /// transformation specified by `transform` to payload IR contained in @@ -55,7 +55,7 @@ /// This function internally keeps track of the transformation state. LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, - ArrayRef> extraMapping = {}, + const RaggedArray &extraMapping = {}, const TransformOptions &options = TransformOptions()); /// The state maintained across applications of various ops implementing the @@ -107,16 +107,22 @@ /// parameters. using ParamMapping = DenseMap>; + /// Mapping between a Value in the transform IR and the corrsponding list of + /// values in the payload IR. Also works for reverse mappings. + using ValueMapping = DenseMap>; + /// The bidirectional mappings between transform IR values and payload IR /// operations, and the mapping between transform IR values and parameters. struct Mappings { TransformOpMapping direct; TransformOpReverseMapping reverse; ParamMapping params; + ValueMapping values; + ValueMapping reverseValues; }; friend LogicalResult applyTransforms(Operation *, TransformOpInterface, - ArrayRef>, + const RaggedArray &, const TransformOptions &); public: @@ -140,11 +146,21 @@ /// corresponds to. ArrayRef getParams(Value value) const; + /// Returns the list of payload IR values that the given transform IR value + /// corresponds to. + ArrayRef getPayloadValues(Value handleValue) const; + /// Populates `handles` with all handles pointing to the given Payload IR op. /// Returns success if such handles exist, failure otherwise. LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl &handles) const; + /// Populates `handles` with all handles pointing to the given payload IR + /// value. Returns success if such handles exist, failure otherwise. + LogicalResult + getHandlesForPayloadValue(Value payloadValue, + SmallVectorImpl &handles) const; + /// Applies the transformation specified by the given transform op and updates /// the state accordingly. DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform); @@ -319,10 +335,10 @@ /// which may or may not contain the region with transform ops. Additional /// options can be provided through the trailing configuration object. TransformState(Region *region, Operation *payloadRoot, - ArrayRef> extraMappings = {}, + const RaggedArray &extraMappings = {}, const TransformOptions &options = TransformOptions()); - /// Returns the mappings frame for the reigon in which the value is defined. + /// Returns the mappings frame for the region in which the value is defined. const Mappings &getMapping(Value value) const { return const_cast(this)->getMapping(value); } @@ -344,10 +360,6 @@ return it->second; } - /// Removes the mapping between the given payload IR operation and the given - /// transform IR value. - void dropReverseMapping(Mappings &mappings, Operation *op, Value value); - /// Sets the payload IR ops associated with the given transform IR value /// (handle). A payload op may be associated multiple handles as long as /// at most one of them gets consumed by further transformations. @@ -367,40 +379,111 @@ /// by side effects. Practically, a transformation consuming a handle means /// that the associated payload operation may no longer exist. /// + /// Similarly, operation handles may be invalidate and should not be used + /// after a transform that consumed a value handle pointing to a payload value + /// defined by the operation as either block argument or op result. For + /// example, in the following sequence, the last transform operation rewrites + /// the callee to not return a specified result: + /// + /// %0 = transform.find_call "myfunc" + /// %1 = transform.find_results_of_calling "myfunc" + /// transform.drop_call_result_from_signature %1[0] + /// + /// which requires the call operations to be recreated. Therefore, the handle + /// %0 becomes associated with a dangling pointer and should not be used. + /// /// Returns failure if the payload does not satisfy the conditions associated /// with the type of the handle value. The value is expected to have a type /// implementing TransformHandleTypeInterface. LogicalResult setPayloadOps(Value value, ArrayRef targets); + /// Sets the payload IR values association with the given transform IR value + /// (handle). A payload value may be associated with multiple handles as long + /// as at most one of them is consumed by further transformations. For + /// example, a hypothetical "get results of calls to function with the given + /// name" transform may be performed twice in a row producing handles pointing + /// to the same values: + /// + /// %0 = transform.find_results_of_calling "myfunc" + /// %1 = transform.find_results_of_calling "myfunc" + /// + /// which is valid by itself. However, calling a hypothetical "erase value + /// producer" transform on both handles: + /// + /// transform.erase_value_produce %0 + /// transform.erase_value_produce %1 + /// + /// is invalid provided the transformation "consumes" the handle as expressed + /// by side effects (which themselves reflect the semantics of the transform + /// erasing the producer and making the handle dangling). Practically, a + /// transformation consuming a handle means the associated payload value may + /// no longer exist. + /// + /// Similarly, value handles are invalidated and should not be used after a + /// transform that consumed an operation handle pointing to the payload IR + /// operation defining the values associated the value handle, as either block + /// arguments or op results, or any ancestor operation. For example, + /// + /// %0 = transform.find_call "myfunc" + /// %1 = transform.find_results_of_calling "myfunc" + /// transform.rewrite_and_rename %0 { new_name = "func" } + /// + /// makes %1 unusable after the last transformation if it consumes %0. When an + /// operation handle is consumed, it usually indicates that the operation was + /// destroyed or heavily modified, meaning that the values it defines may no + /// longer exist. + /// + /// Returns failure if the payload values do not satisfy the conditions + /// associated with the type of the handle value. The value is expected to + /// have a type implementing TransformValueHandleTypeInterface. + LogicalResult setPayloadValues(Value handle, ValueRange payloadValues); + /// Sets the parameters associated with the given transform IR value. Returns /// failure if the parameters do not satisfy the conditions associated with /// the type of the value. The value is expected to have a type implementing /// TransformParamTypeInterface. LogicalResult setParams(Value value, ArrayRef params); - /// Forgets the payload IR ops associated with the given transform IR value. - void removePayloadOps(Value value); + /// Forgets the payload IR ops associated with the given transform IR value, + /// as well as any association between value handles and the results of said + /// payload IR op. + void forgetMapping(Value opHandle, ValueRange origOpFlatResults); + + void forgetValueMapping(Value valueHandle, + ArrayRef payloadOperations); /// Updates the payload IR ops associated with the given transform IR value. /// The callback function is called once per associated operation and is /// expected to return the modified operation or nullptr. In the latter case, /// the corresponding operation is no longer associated with the transform IR - /// value. + /// value. Value handles associated with the results of the operation are + /// also updated to be associated with the results of the new operation. For + /// this reason, the new operation must have the same number of results. /// /// Returns failure if the payload does not satisfy the conditions associated /// with the type of the handle value. - LogicalResult - updatePayloadOps(Value value, - function_ref callback); + LogicalResult replacePayloadOp(Operation *op, Operation *replacement); /// If the operand is a handle consumed by the operation, i.e. has the "free" /// memory effect associated with it, identifies other handles that are /// pointing to payload IR operations nested in the operations pointed to by /// the consumed handle. Marks all such handles as invalidated to trigger - /// errors if they are used. - void recordHandleInvalidation(OpOperand &handle); - void recordHandleInvalidationOne(OpOperand &handle, Operation *payloadOp, - Value otherHandle); + /// errors if they are used. If `throughValue` is passed, record the fact that + /// an op handle was invalidated because a value handle associated with + /// results of the payload op or its block arguments was invalidated. + void recordOpHandleInvalidation(OpOperand &consumingHandle, + ArrayRef potentialAncestors, + Value throughValue = nullptr); + void recordOpHandleInvalidationOne(OpOperand &handle, + ArrayRef potentialAncestors, + Operation *payloadOp, Value otherHandle, + Value throughValue = nullptr); + + void recordValueHandleInvalidationByOpHandleOne( + OpOperand &opHandle, ArrayRef potentialAncestors, + Value payloadValue, Value valueHandle); + + void recordValueHandleInvalidation(OpOperand &valueHandle); /// Checks that the operation does not use invalidated handles as operands. /// Reports errors and returns failure if it does. Otherwise, invalidates the @@ -421,14 +504,10 @@ /// The top-level operation that contains all payload IR, typically a module. Operation *topLevel; - /// Storage for extra mapped values (payload operations or parameters) to be + /// Extra mapped values (payload operations, values or parameters) to be /// associated with additional entry block arguments of the top-level - /// transform operation. Each entry in `topLevelMappedValues` is a reference - /// to a contiguous block in `topLevelMappedValueStorage`. - // TODO: turn this into a proper named data structure, there are several more - // below. - SmallVector> topLevelMappedValues; - SmallVector topLevelMappedValueStorage; + /// transform operation. + RaggedArray topLevelMappedValues; /// Additional options controlling the transformation state behavior. TransformOptions options; @@ -455,16 +534,23 @@ public: /// Indicates that the result of the transform IR op at the given position /// corresponds to the given list of payload IR ops. Each result must be set - /// by the transformation exactly once. The value must have a type - /// implementing TransformHandleTypeInterface. + /// by the transformation exactly once in case of transformation succeeding. + /// The value must have a type implementing TransformHandleTypeInterface. void set(OpResult value, ArrayRef ops); /// Indicates that the result of the transform IR op at the given position /// corresponds to the given list of parameters. Each result must be set by - /// the transformation exactly once. The value must have a type implementing - /// TransformParamTypeInterface. + /// the transformation exactly once in case of transformation succeeding. The + /// value must have a type implementing TransformParamTypeInterface. void setParams(OpResult value, ArrayRef params); + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given range of payload IR values. Each result must be + /// set by the transformation exactly once in case of transformation + /// succeeding. The value must have a type implementing + /// TransformValueHandleTypeInterface. + void setValues(OpResult handle, ValueRange values); + private: /// Creates an instance of TransformResults that expects mappings for /// `numSegments` values, which may be associated with payload operations or @@ -481,34 +567,34 @@ /// be associated with parameters. ArrayRef getParams(unsigned resultNumber) const; + /// Gets the list of payload IR values associated with the result identified + /// by its number in the list of operation results. The result must have been + /// set to be associated with payload IR values. + ArrayRef getValues(unsigned resultNumber) const; + /// Returns `true` if the result identified by its number in the list of - /// operation results is associated with a list of parameters, `false` if it - /// is associated with the list of payload IR operations. + /// operation results is associated with a list of parameters, `false` + /// otherwise. bool isParam(unsigned resultNumber) const; + /// Returns `true` if the result identified by its number in the list of + /// operation results is associated with a list of payload IR value, `false` + /// otherwise. + bool isValue(unsigned resultNumber) const; + /// Returns `true` if the result identified by its number in the list of /// operation results is associated with something. bool isSet(unsigned resultNumber) const; - /// Storage for pointers to payload IR ops that are associated with results of - /// a transform IR op. `segments` contains as many entries as the transform IR - /// op has results, even if some of them are not associated with payload IR - /// operations. Each entry is a reference to a contiguous segment in the - /// `operations` list that contains the pointers to operations. This allows - /// for operations to be stored contiguously without nested vectors and for - /// different segments to be set in any order. - SmallVector, 2> segments; - SmallVector operations; - - /// Storage for parameters that are associated with results of the transform - /// IR op. `paramSegments` contains as many entries as the transform IR op has - /// results, even if some of them are not associated with parameters. Each - /// entry is a reference to a contiguous segment in the `params` list that - /// contains the actual parameters. This allows for parameters to be stored - /// contiguously without nested vectors and for different segments to be set - /// in any order. - SmallVector, 2> paramSegments; - SmallVector params; + /// Pointers to payload IR ops that are associated with results of a transform + /// IR op. + RaggedArray operations; + + /// Parameters that are associated with results of the transform IR op. + RaggedArray params; + + /// Payload IR values that are associated with results of a transform IR op. + RaggedArray values; }; TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { @@ -625,14 +711,14 @@ /// Side effect resource corresponding to the mapping between Transform IR /// values and Payload IR operations. An Allocate effect from this resource /// means creating a new mapping entry, it is always accompanied by a Write -/// effet. A Read effect from this resource means accessing the mapping. A Free +/// effect. A Read effect from this resource means accessing the mapping. A Free /// effect on this resource indicates the removal of the mapping entry, /// typically after a transformation that modifies the Payload IR operations /// associated with one of the Transform IR operation's operands. It is always /// accompanied by a Read effect. Read-after-Free and double-Free are not /// allowed (they would be problematic with "regular" memory effects too) as /// they indicate an attempt to access Payload IR operations that have been -/// modified, potentially erased, by the previous tranfsormations. +/// modified, potentially erased, by the previous transformations. // TODO: consider custom effects if these are not enabling generic passes such // as CSE/DCE to work. struct TransformMappingResource @@ -769,7 +855,7 @@ /// A single result of applying a transform op with `ApplyEachOpTrait` to a /// single payload operation. -using ApplyToEachResult = llvm::PointerUnion; +using ApplyToEachResult = MappedValue; /// A list of results of applying a transform op with `ApplyEachOpTrait` to a /// single payload operation, co-indexed with the results of the transform op. @@ -793,6 +879,9 @@ if constexpr (std::is_convertible_v) { results.push_back(static_cast(element)); + } else if constexpr (std::is_convertible_v) { + results.push_back(element.template get()); } else { results.push_back(static_cast(element)); } @@ -800,8 +889,12 @@ } /// Appends an element to the list. + // Using ApplyToEachResult that can be implicitly constructed from a Value but + // not from a concrete Op that is implicitly convertible to a Value to avoid + // ambiguity. void push_back(Operation *op) { results.push_back(op); } void push_back(Attribute attr) { results.push_back(attr); } + void push_back(ApplyToEachResult r) { results.push_back(r); } /// Reserves space for `size` elements in the list. void reserve(unsigned size) { results.reserve(size); } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -137,10 +137,10 @@ : TransformTypeInterfaceBase<"TransformHandleTypeInterface", "::mlir::Operation *"> { let description = [{ - Types that can be used for the Transform dialect handle values. Such types - define the properties of Payload IR operations associated with the handle. - A user of such a handle can assume that these properties have been verified - for any Payload IR operation associated with it. + Types that can be used for the Transform dialect operation handle values. + Such types define the properties of Payload IR operations associated with + the handle. A user of such a handle can assume that these properties have + been verified for any Payload IR operation associated with it. }]; } @@ -155,9 +155,21 @@ }]; } +def TransformValueHandleTypeInterface + : TransformTypeInterfaceBase<"TransformValueHandleTypeInterface", + "::mlir::Value"> { + let description = [{ + Types that can be used for the Transform dialect handle values pointing to + Payload IR values. Such types define the properties of Payload IR values + associated with the handle. Users of such a handle can assume that these + properties have been verified for any Payload IR value associated with it. + }]; +} + def Transform_AnyHandleOrParamType : Type, + TransformHandleTypeInterface.predicate, + TransformValueHandleTypeInterface.predicate]>, "any transform handle or parameter">; def FunctionalStyleTransformOpTrait diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td @@ -52,6 +52,15 @@ let genVerifyDecl = 1; } +def Transform_AnyValue : TypeDef]> { + let description = [{ + Transform IR value that can be associated with a list of Payload IR values. + }]; + let mnemonic = "any_value"; + let assemblyFormat = ""; +} + class Transform_ConcreteOpType : Type()" diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h --- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -40,7 +40,7 @@ LogicalResult interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, - ArrayRef> extraMappings, + const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, const Pass::Option &debugPayloadRootTag, diff --git a/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h @@ -0,0 +1,92 @@ +//===- RaggedArray.h - 2D array with different inner lengths ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +/// A 2D array where each row may have different length. Elements of each row +/// are stored contiguously, but rows don't have a fixed order in the storage. +template +class RaggedArray { +public: + /// Returns the number of rows in the 2D array. + size_t size() const { return slices.size(); } + + /// Returns true if the are no rows in the 2D array. Note that an array with a + /// non-zero number of empty rows is *NOT* empty. + bool empty() const { return slices.empty(); } + + /// Accesses `pos`-th row. + ArrayRef operator[](size_t pos) const { return at(pos); } + ArrayRef at(size_t pos) const { return slices[pos]; } + MutableArrayRef operator[](size_t pos) { return at(pos); } + MutableArrayRef at(size_t pos) { return slices[pos]; } + + /// Iterator over rows. + auto begin() { return slices.begin(); } + auto begin() const { return slices.begin(); } + auto end() { return slices.end(); } + auto end() const { return slices.end(); } + + /// Reserve space to store `size` rows with `nestedSize` elements each. + void reserve(size_t size, size_t nestedSize = 0) { + slices.reserve(size); + storage.reserve(size * nestedSize); + } + + /// Appends the given range of elements as a new row to the 2D array. May + /// invalidate the end iterator. + template + void push_back(Range &&elements) { + slices.push_back(appendToStorage(std::forward(elements))); + } + + /// Replaces the `pos`-th row in the 2D array with the given range of + /// elements. Invalidates iterators and references to `pos`-th and all + /// succeeding rows. + template + void replace(size_t pos, Range &&elements) { + auto from = slices[pos].data(); + if (from != nullptr) { + auto to = std::next(from, slices[pos].size()); + auto newFrom = storage.erase(from, to); + // Update the array refs after the underlying storage was shifted. + for (size_t i = pos + 1, e = size(); i < e; ++i) { + slices[i] = MutableArrayRef(newFrom, slices[i].size()); + std::advance(newFrom, slices[i].size()); + } + } + slices[pos] = appendToStorage(std::forward(elements)); + } + + /// Appends `num` empty rows to the array. + void appendEmptyRows(size_t num) { slices.resize(slices.size() + num); } + +private: + /// Appends the given elements to the storage and returns an ArrayRef pointing + /// to them in the storage. + template + MutableArrayRef appendToStorage(Range &&elements) { + size_t start = storage.size(); + llvm::append_range(storage, std::forward(elements)); + return MutableArrayRef(storage).drop_front(start); + } + + /// Outer elements of the ragged array. Each entry is a reference to a + /// contiguous segment in the `storage` list that contains the actual + /// elements. This allows for elements to be stored contiguously without + /// nested vectors and for different segments to be set or replaced in any + /// order. + SmallVector> slices; + + /// Dense storage for ragged array elements. + SmallVector storage; +}; +} // namespace mlir diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -38,12 +38,14 @@ void transform::detail::checkImplementsTransformHandleTypeInterface( TypeID typeID, MLIRContext *context) { const auto &abstractType = AbstractType::lookup(typeID, context); - assert( - (abstractType.hasInterface( - TransformHandleTypeInterface::getInterfaceID()) || - abstractType.hasInterface( - TransformParamTypeInterface::getInterfaceID())) && - "expected Transform dialect type to implement one of the two interfaces"); + assert((abstractType.hasInterface( + TransformHandleTypeInterface::getInterfaceID()) || + abstractType.hasInterface( + TransformParamTypeInterface::getInterfaceID()) || + abstractType.hasInterface( + TransformValueHandleTypeInterface::getInterfaceID())) && + "expected Transform dialect type to implement one of the three " + "interfaces"); } #endif // NDEBUG diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" @@ -29,17 +30,12 @@ transform::TransformState::TransformState( Region *region, Operation *payloadRoot, - ArrayRef> extraMappings, + const RaggedArray &extraMappings, const TransformOptions &options) : topLevel(payloadRoot), options(options) { topLevelMappedValues.reserve(extraMappings.size()); - for (ArrayRef mapping : extraMappings) { - size_t start = topLevelMappedValueStorage.size(); - llvm::append_range(topLevelMappedValueStorage, mapping); - topLevelMappedValues.push_back( - ArrayRef(topLevelMappedValueStorage) - .slice(start, mapping.size())); - } + for (ArrayRef mapping : extraMappings) + topLevelMappedValues.push_back(mapping); auto result = mappings.try_emplace(region); assert(result.second && "the region scope is already present"); @@ -55,16 +51,26 @@ transform::TransformState::getPayloadOps(Value value) const { const TransformOpMapping &operationMapping = getMapping(value).direct; auto iter = operationMapping.find(value); - assert(iter != operationMapping.end() && - "cannot find mapping for payload handle (param handle provided?)"); + assert( + iter != operationMapping.end() && + "cannot find mapping for payload handle (param/value handle provided?)"); return iter->getSecond(); } ArrayRef transform::TransformState::getParams(Value value) const { const ParamMapping &mapping = getMapping(value).params; auto iter = mapping.find(value); - assert(iter != mapping.end() && - "cannot find mapping for param handle (payload handle provided?)"); + assert(iter != mapping.end() && "cannot find mapping for param handle " + "(operation/value handle provided?)"); + return iter->getSecond(); +} + +ArrayRef +transform::TransformState::getPayloadValues(Value handleValue) const { + const ValueMapping &mapping = getMapping(handleValue).values; + auto iter = mapping.find(handleValue); + assert(iter != mapping.end() && "cannot find mapping for value handle " + "(param/operation handle provided?)"); return iter->getSecond(); } @@ -82,6 +88,20 @@ return success(found); } +LogicalResult transform::TransformState::getHandlesForPayloadValue( + Value payloadValue, SmallVectorImpl &handles) const { + bool found = false; + for (const Mappings &mapping : llvm::make_second_range(mappings)) { + auto iterator = mapping.reverseValues.find(payloadValue); + if (iterator != mapping.reverseValues.end()) { + llvm::append_range(handles, iterator->getSecond()); + found = true; + } + } + + return success(found); +} + LogicalResult transform::TransformState::mapBlockArgument(BlockArgument argument, ArrayRef values) { @@ -99,6 +119,20 @@ return setPayloadOps(argument, operations); } + if (argument.getType().isa()) { + SmallVector payloadValues; + payloadValues.reserve(values.size()); + for (MappedValue value : values) { + if (auto v = value.dyn_cast()) { + payloadValues.push_back(v); + continue; + } + return emitError(argument.getLoc()) + << "wrong kind of value provided for the top-level value handle"; + } + return setPayloadValues(argument, payloadValues); + } + assert(argument.getType().isa() && "unsupported kind of block argument"); SmallVector parameters; @@ -119,8 +153,8 @@ ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); - assert(!value.getType().isa() && - "cannot associate payload ops with a value of parameter type"); + assert(value.getType().isa() && + "wrong handle type"); for (Operation *target : targets) { if (target) @@ -150,6 +184,41 @@ return success(); } +LogicalResult +transform::TransformState::setPayloadValues(Value handle, + ValueRange payloadValues) { + assert(handle != nullptr && "attempting to set params for a null value"); + assert(handle.getType().isa() && + "wrong handle type"); + + for (Value payload : payloadValues) { + if (payload) + continue; + return emitError(handle.getLoc()) << "attempting to assign a null payload " + "value to this transform handle"; + } + + auto iface = handle.getType().cast(); + SmallVector payloadValueVector = llvm::to_vector(payloadValues); + DiagnosedSilenceableFailure result = + iface.checkPayload(handle.getLoc(), payloadValueVector); + if (failed(result.checkAndReport())) + return failure(); + + Mappings &mappings = getMapping(handle); + bool inserted = + mappings.values.insert({handle, std::move(payloadValueVector)}).second; + assert( + inserted && + "value handle is already associated with another list of payload values"); + (void)inserted; + + for (Value payload : payloadValues) + mappings.reverseValues[payload].push_back(handle); + + return success(); +} + LogicalResult transform::TransformState::setParams(Value value, ArrayRef params) { assert(value != nullptr && "attempting to set params for a null value"); @@ -177,54 +246,146 @@ return success(); } -void transform::TransformState::dropReverseMapping(Mappings &mappings, - Operation *op, Value value) { - auto it = mappings.reverse.find(op); - if (it == mappings.reverse.end()) +template +void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { + auto it = mapping.find(key); + if (it == mapping.end()) return; - llvm::erase_value(it->getSecond(), value); + llvm::erase_value(it->getSecond(), mapped); if (it->getSecond().empty()) - mappings.reverse.erase(it); + mapping.erase(it); } -void transform::TransformState::removePayloadOps(Value value) { - Mappings &mappings = getMapping(value); - for (Operation *op : mappings.direct[value]) - dropReverseMapping(mappings, op, value); - mappings.direct.erase(value); +void transform::TransformState::forgetMapping(Value opHandle, + ValueRange origOpFlatResults) { + Mappings &mappings = getMapping(opHandle); + for (Operation *op : mappings.direct[opHandle]) + dropMappingEntry(mappings.reverse, op, opHandle); + mappings.direct.erase(opHandle); + + for (Value opResult : origOpFlatResults) { + SmallVector resultHandles; + (void)getHandlesForPayloadValue(opResult, resultHandles); + for (Value resultHandle : resultHandles) { + Mappings &localMappings = getMapping(resultHandle); + dropMappingEntry(localMappings.values, resultHandle, opResult); + dropMappingEntry(localMappings.reverseValues, opResult, resultHandle); + } + } } -LogicalResult transform::TransformState::updatePayloadOps( - Value value, function_ref callback) { - Mappings &mappings = getMapping(value); - auto it = mappings.direct.find(value); - assert(it != mappings.direct.end() && "unknown handle"); - SmallVector &association = it->getSecond(); - SmallVector updated; - updated.reserve(association.size()); +void transform::TransformState::forgetValueMapping( + Value valueHandle, ArrayRef payloadOperations) { + Mappings &mappings = getMapping(valueHandle); + for (Value payloadValue : mappings.reverseValues[valueHandle]) + dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle); + mappings.values.erase(valueHandle); + + for (Operation *payloadOp : payloadOperations) { + SmallVector opHandles; + (void)getHandlesForPayloadOp(payloadOp, opHandles); + for (Value opHandle : opHandles) { + Mappings &localMappings = getMapping(opHandle); + dropMappingEntry(localMappings.direct, opHandle, payloadOp); + dropMappingEntry(localMappings.reverse, payloadOp, opHandle); + } + } +} + +LogicalResult +transform::TransformState::replacePayloadOp(Operation *op, + Operation *replacement) { + // Drop the mapping between the op and all handles that point to it. Don't + // care if there are on such handles. + SmallVector opHandles; + (void)getHandlesForPayloadOp(op, opHandles); + for (Value handle : opHandles) { + Mappings &mappings = getMapping(handle); + dropMappingEntry(mappings.reverse, op, handle); + } - for (Operation *op : association) { - dropReverseMapping(mappings, op, value); - if (Operation *updatedOp = callback(op)) { - updated.push_back(updatedOp); - mappings.reverse[updatedOp].push_back(value); + // Drop the mapping between the op results and all value handles that point to + // them. Don't care if there are no such handles. + RaggedArray resultValueHandles; + for (Value opResult : op->getResults()) { + SmallVector valueHandles; + (void)getHandlesForPayloadValue(opResult, valueHandles); + for (Value handle : valueHandles) { + Mappings &localMappings = getMapping(handle); + dropMappingEntry(localMappings.reverseValues, opResult, handle); } + resultValueHandles.push_back(std::move(valueHandles)); } - auto iface = value.getType().cast(); - DiagnosedSilenceableFailure result = - iface.checkPayload(value.getLoc(), updated); - if (failed(result.checkAndReport())) - return failure(); + // TODO: consider invalidating the handles to nested objects here. + + // If replacing with null, that is erasing the mapping, drop the mapping + // between the handles and the IR objects and return. + if (!replacement) { + for (Value handle : opHandles) { + Mappings &mappings = getMapping(handle); + dropMappingEntry(mappings.direct, handle, op); + } + for (Value opResult : op->getResults()) { + SmallVector valueHandles; + (void)getHandlesForPayloadValue(opResult, valueHandles); + for (Value handle : valueHandles) { + Mappings &localMappings = getMapping(handle); + dropMappingEntry(localMappings.values, handle, opResult); + } + } + return success(); + } + + // Otherwise, replace the pointed-to object of all handles while preserving + // their relative order. + if (op->getNumResults() != replacement->getNumResults()) { + return emitError(op->getLoc()) + << "cannot replace an op with another op producing a different " + "number of results while tracking handles"; + } + + // Replace the mapped operation if present. + for (Value handle : opHandles) { + Mappings &mappings = getMapping(handle); + auto it = mappings.direct.find(handle); + if (it == mappings.direct.end()) + continue; + + SmallVector &association = it->getSecond(); + // Note that an operation may be associated with the handle more than once. + for (Operation *&mapped : association) { + if (mapped == op) + mapped = replacement; + } + mappings.reverse[replacement].push_back(handle); + } + + // Replace the mapped results of the operation. + for (auto [origResult, replacementResult, handleList] : llvm::zip( + op->getResults(), replacement->getResults(), resultValueHandles)) { + for (Value resultHandle : handleList) { + Mappings &mappings = getMapping(resultHandle); + auto it = mappings.values.find(resultHandle); + if (it == mappings.values.end()) + continue; + + SmallVector &association = it->getSecond(); + for (Value &mapped : association) { + if (mapped == origResult) + mapped = replacementResult; + } + mappings.reverseValues[replacementResult].push_back(resultHandle); + } + } - it->second = updated; return success(); } -void transform::TransformState::recordHandleInvalidationOne( - OpOperand &handle, Operation *payloadOp, Value otherHandle) { - ArrayRef potentialAncestors = getPayloadOps(handle.get()); +void transform::TransformState::recordOpHandleInvalidationOne( + OpOperand &consumingHandle, ArrayRef potentialAncestors, + Operation *payloadOp, Value otherHandle, Value throughValue) { // If the op is associated with invalidated handle, skip the check as it // may be reading invalid IR. if (invalidatedHandles.count(otherHandle)) @@ -240,10 +401,13 @@ // deleted before the lambda gets called. Location ancestorLoc = ancestor->getLoc(); Location opLoc = payloadOp->getLoc(); - Operation *owner = handle.getOwner(); - unsigned operandNo = handle.getOperandNumber(); + Operation *owner = consumingHandle.getOwner(); + unsigned operandNo = consumingHandle.getOperandNumber(); + std::optional throughValueLoc = + throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt; invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, - otherHandle](Location currentLoc) { + otherHandle, + throughValueLoc](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle invalidated by a " "previously executed transform op"; @@ -251,19 +415,231 @@ diag.attachNote(owner->getLoc()) << "invalidated by this transform op that consumes its operand #" << operandNo - << " and invalidates handles to payload ops nested in payload " - "ops associated with the consumed handle"; + << " and invalidates all handles to payload IR entities associated " + "with this operand and entities nested in them"; diag.attachNote(ancestorLoc) << "ancestor payload op"; diag.attachNote(opLoc) << "nested payload op"; + if (throughValueLoc) { + diag.attachNote(*throughValueLoc) + << "consumed handle points to this payload value"; + } }; } } -void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { - for (const Mappings &mapping : llvm::make_second_range(mappings)) - for (const auto &[payloadOp, otherHandles] : mapping.reverse) +// void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( +// OpOperand &opHandle, Operation *payloadOp, Operation *ancestor) { +// if (invalidatedHandles.count(opHandle.get())) +// return; + +// Operation *owner = opHandle.getOwner(); +// unsigned operandNo = opHandle.getOperandNumber(); +// Location ancestorLoc = ancestor->getLoc(); +// Location opLoc = payloadOp->getLoc(); +// for (OpResult result : payloadOp->getResults()) { +// // Find all handles to "result", mark them as invalidated. +// SmallVector valueHandles; +// if (failed(getHandlesForPayloadValue(result, valueHandles))) +// continue; +// for (Value valueHandle : valueHandles) { +// unsigned resultNo = result.getResultNumber(); +// invalidatedHandles[valueHandle] = [valueHandle, owner, operandNo, +// ancestorLoc, resultNo, +// opLoc](Location currentLoc) { +// InFlightDiagnostic diag = emitError(currentLoc) +// << "op uses a handle invalidated by a " +// "previously executed transform op"; +// diag.attachNote(valueHandle.getLoc()) << "invalidated handle"; +// diag.attachNote(owner->getLoc()) +// << "invalidated by this transform op that consumes its operand #" +// << operandNo +// << " and invalidates handles to values defined by the " +// "associated payload ops or ops nested in those"; +// diag.attachNote(ancestorLoc) +// << "ancestor op associated with the consumed handle"; +// diag.attachNote(opLoc) +// << "op defining the value as result #" << resultNo; +// }; +// } +// } + +// SmallVector blockArguments; +// for (Region ®ion : payloadOp->getRegions()) { +// for (Block &block : region) { +// for (BlockArgument arg : block.getArguments()) { +// blockArguments.push_back(arg); +// } +// } +// } +// for (BlockArgument arg : blockArguments) { +// SmallVector valueHandles; +// if (failed(getHandlesForPayloadValue(arg, valueHandles))) +// continue; +// for (Value valueHandle : valueHandles) { +// unsigned argumentNo = arg.getArgNumber(); +// unsigned blockNo = std::distance(arg.getOwner()->getParent()->begin(), +// arg.getOwner()->getIterator()); +// unsigned regionNo = arg.getOwner()->getParent()->getRegionNumber(); +// Location argLoc = arg.getLoc(); +// invalidatedHandles[valueHandle] = [valueHandle, owner, operandNo, +// ancestorLoc, argumentNo, blockNo, +// regionNo, opLoc, argLoc](Location +// currentLoc) { +// InFlightDiagnostic diag = emitError(currentLoc) +// << "op uses a handle invalidated by a " +// "previously executed transform op"; +// diag.attachNote(valueHandle.getLoc()) << "invalidated handle"; +// diag.attachNote(owner->getLoc()) +// << "invalidated by this transform op that consumes its operand #" +// << operandNo +// << "and invalidates handles to values defined as block arguments +// " +// "by the associated payload ops or ops nested in those"; +// diag.attachNote(ancestorLoc) +// << "ancestor op associated with the consumed handle"; +// diag.attachNote(opLoc) +// << "op defining the value as block argument #" << argumentNo +// << " of block #" << blockNo << " in region #" << regionNo; +// diag.attachNote(argLoc) << "payload value"; +// }; +// } +// } +// } + +void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( + OpOperand &consumingHandle, ArrayRef potentialAncestors, + Value payloadValue, Value valueHandle) { + // If the op is associated with invalidated handle, skip the check as it + // may be reading invalid IR. + if (invalidatedHandles.count(valueHandle)) + return; + + for (Operation *ancestor : potentialAncestors) { + Operation *definingOp; + std::optional resultNo = std::nullopt; + unsigned argumentNo, blockNo, regionNo; + if (auto opResult = payloadValue.dyn_cast()) { + definingOp = opResult.getOwner(); + resultNo = opResult.getResultNumber(); + } else { + auto arg = payloadValue.cast(); + definingOp = arg.getParentBlock()->getParentOp(); + argumentNo = arg.getArgNumber(); + blockNo = std::distance(arg.getOwner()->getParent()->begin(), + arg.getOwner()->getIterator()); + regionNo = arg.getOwner()->getParent()->getRegionNumber(); + } + assert(definingOp && "expected the value to be defined by an op as result " + "or block argument"); + if (!ancestor->isAncestor(definingOp)) + continue; + + Operation *owner = consumingHandle.getOwner(); + unsigned operandNo = consumingHandle.getOperandNumber(); + Location ancestorLoc = ancestor->getLoc(); + Location opLoc = definingOp->getLoc(); + Location valueLoc = payloadValue.getLoc(); + invalidatedHandles[valueHandle] = + [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo, + ancestorLoc, opLoc, valueLoc](Location currentLoc) { + InFlightDiagnostic diag = emitError(currentLoc) + << "op uses a handle invalidated by a " + "previously executed transform op"; + diag.attachNote(valueHandle.getLoc()) << "invalidated handle"; + diag.attachNote(owner->getLoc()) + << "invalidated by this transform op that consumes its operand #" + << operandNo + << " and invalidates all handles to payload IR entities " + "associated with this operand and entities nested in them"; + diag.attachNote(ancestorLoc) + << "ancestor op associated with the consumed handle"; + if (resultNo) { + diag.attachNote(opLoc) + << "op defining the value as result #" << *resultNo; + } else { + diag.attachNote(opLoc) + << "op defining the value as block argument #" << argumentNo + << " of block #" << blockNo << " in region #" << regionNo; + } + diag.attachNote(valueLoc) << "payload value"; + }; + } +} + +void transform::TransformState::recordOpHandleInvalidation( + OpOperand &handle, ArrayRef potentialAncestors, + Value throughValue) { + // Iterate over the mapping and invalidate aliasing handles. This is quite + // expensive and only necessary for error reporting in case of transform + // dialect misuse with dangling handles. Iteration over the handles is based + // on the assumption that the number of handles is significantly less than the + // number of IR objects (operations and values). Alternatively, we could walk + // the IR nested in each payload op associated with the given handle and look + // for handles associated with each operation and value. + for (const Mappings &mapping : llvm::make_second_range(mappings)) { + // Go over all op handle mappings and mark as invalidated any handle + // pointing to any of the payload ops associated with the given handle or + // any op nested in them. + for (const auto &[payloadOp, otherHandles] : mapping.reverse) { for (Value otherHandle : otherHandles) - recordHandleInvalidationOne(handle, payloadOp, otherHandle); + recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp, + otherHandle, throughValue); + } + // Go over all value handle mappings and mark as invalidated any handle + // pointing to any result of the payload op associated with the given handle + // or any op nested in them. Similarly invalidate handles to argument of + // blocks belonging to any region of any payload op associated with the + // given handle or any op nested in them. + for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) { + for (Value valueHandle : valueHandles) + recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors, + payloadValue, valueHandle); + } + } + + // const Mappings &mapping = getMapping(handle.get()); + // for (Operation *payloadOp : mapping.direct.lookup(handle.get())) { + // payloadOp->walk([payloadOp,this,&handle](Operation *nestedOp) { + // recordValueHandleInvalidationByOpHandleOne(handle, nestedOp, + // payloadOp); + // }); + // } +} + +void transform::TransformState::recordValueHandleInvalidation( + OpOperand &valueHandle) { + // Invalidate other handles to the same value. + for (Value payloadValue : getPayloadValues(valueHandle.get())) { + SmallVector otherValueHandles; + (void)getHandlesForPayloadValue(payloadValue, otherValueHandles); + for (Value otherHandle : otherValueHandles) { + Operation *owner = valueHandle.getOwner(); + unsigned operandNo = valueHandle.getOperandNumber(); + Location valueLoc = payloadValue.getLoc(); + invalidatedHandles[otherHandle] = [otherHandle, owner, operandNo, + valueLoc](Location currentLoc) { + InFlightDiagnostic diag = emitError(currentLoc) + << "op uses a handle invalidated by a " + "previously executed transform op"; + diag.attachNote(otherHandle.getLoc()) << "invalidated handle"; + diag.attachNote(owner->getLoc()) + << "invalidated by this transform op that consumes its operand #" + << operandNo + << " and invalidates handles to the same values as associated with " + "it"; + diag.attachNote(valueLoc) << "payload value"; + }; + } + + if (auto opResult = payloadValue.dyn_cast()) { + Operation *payloadOp = opResult.getOwner(); + recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue); + } else { + auto arg = payloadValue.dyn_cast(); + for (Operation &payloadOp : *arg.getOwner()) + recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue); + } + } } LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( @@ -287,13 +663,44 @@ return isa(effect.getEffect()) && effect.getValue() == target.get(); }; - if (llvm::any_of(effects, consumesTarget)) - recordHandleInvalidation(target); + if (llvm::any_of(effects, consumesTarget)) { + if (target.get().getType().isa()) { + ArrayRef payloadOps = getPayloadOps(target.get()); + recordOpHandleInvalidation(target, payloadOps); + } else if (target.get() + .getType() + .isa()) { + recordValueHandleInvalidation(target); + } + } } return success(); } +template +DiagnosedSilenceableFailure +checkRepeatedConsumptionInOperand(ArrayRef payload, + transform::TransformOpInterface transform, + unsigned operandNumber) { + DenseSet seen; + for (T p : payload) { + if (!seen.insert(p).second) { + DiagnosedSilenceableFailure diag = + transform.emitSilenceableError() + << "a handle passed as operand #" << operandNumber + << " and consumed by this operation points to a payload " + "entity more than once"; + if constexpr (std::is_pointer_v) + diag.attachNote(p->getLoc()) << "repeated target op"; + else + diag.attachNote(p.getLoc()) << "repeated target value"; + return diag; + } + } + return DiagnosedSilenceableFailure::success(); +} + DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { LLVM_DEBUG(DBGS() << "applying: " << transform << "\n"); @@ -313,25 +720,82 @@ if (!isHandleConsumed(operand.get(), transform)) continue; - DenseSet seen; - for (Operation *op : getPayloadOps(operand.get())) { - if (!seen.insert(op).second) { - DiagnosedSilenceableFailure diag = - transform.emitSilenceableError() - << "a handle passed as operand #" << operand.getOperandNumber() - << " and consumed by this operation points to a payload " - "operation more than once"; - diag.attachNote(op->getLoc()) << "repeated target op"; - return diag; + Type operandType = operand.get().getType(); + if (operandType.isa()) { + DiagnosedSilenceableFailure check = + checkRepeatedConsumptionInOperand( + getPayloadOps(operand.get()), transform, + operand.getOperandNumber()); + if (!check.succeeded()) + return check; + } else if (operandType.isa()) { + DiagnosedSilenceableFailure check = + checkRepeatedConsumptionInOperand( + getPayloadValues(operand.get()), transform, + operand.getOperandNumber()); + if (!check.succeeded()) + return check; + } + } + } + + // Find which operands are consumed. + DenseSet consumedOperands; + auto memEffectInterface = + cast(transform.getOperation()); + SmallVector effects; + for (OpOperand &target : transform->getOpOperands()) { + effects.clear(); + memEffectInterface.getEffectsOnValue(target.get(), effects); + if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { + return isa( + effect.getResource()) && + isa(effect.getEffect()); + })) { + consumedOperands.insert(target.getOperandNumber()); + } + } + + // Remember the results of the payload ops associated with the consumed + // op handles or the ops defining the value handles so we can drop the + // association with them later. This must happen here because the + // transformation may destroy or mutate them so we cannot traverse the payload + // IR after that. + SmallVector origOpFlatResults; + SmallVector origAssociatedOps; + for (unsigned index : consumedOperands) { + Value operand = transform->getOperand(index); + if (operand.getType().isa()) { + for (Operation *payloadOp : getPayloadOps(operand)) + llvm::append_range(origOpFlatResults, payloadOp->getResults()); + continue; + } + if (operand.getType().isa()) { + for (Value payloadValue : getPayloadValues(operand)) { + if (payloadValue.isa()) { + origAssociatedOps.push_back(payloadValue.getDefiningOp()); + continue; } + llvm::append_range( + origAssociatedOps, + llvm::map_range(*payloadValue.cast().getOwner(), + [](Operation &op) { return &op; })); } + continue; } + DiagnosedDefiniteFailure diag = + emitDefiniteFailure(transform->getLoc()) + << "unexpectedly consumed a value that is not a handle as operand #" + << index; + diag.attachNote(operand.getLoc()) + << "value defined here with type " << operand.getType(); + return diag; } - transform::TransformResults results(transform->getNumResults()); // Compute the result but do not short-circuit the silenceable failure case as // we still want the handles to propagate properly so the "suppress" mode can // proceed on a best effort basis. + transform::TransformResults results(transform->getNumResults()); DiagnosedSilenceableFailure result(transform.apply(results, *this)); if (result.isDefiniteFailure()) return result; @@ -352,18 +816,12 @@ // Remove the mapping for the operand if it is consumed by the operation. This // allows us to catch use-after-free with assertions later on. - auto memEffectInterface = - cast(transform.getOperation()); - SmallVector effects; - for (OpOperand &target : transform->getOpOperands()) { - effects.clear(); - memEffectInterface.getEffectsOnValue(target.get(), effects); - if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { - return isa( - effect.getResource()) && - isa(effect.getEffect()); - })) { - removePayloadOps(target.get()); + for (unsigned index : consumedOperands) { + Value operand = transform->getOperand(index); + if (operand.getType().isa()) { + forgetMapping(operand, origOpFlatResults); + } else if (operand.getType().isa()) { + forgetValueMapping(operand, origAssociatedOps); } } @@ -378,6 +836,13 @@ setParams(result, results.getParams(result.getResultNumber())))) { return DiagnosedSilenceableFailure::definiteFailure(); } + } else if (result.getType().isa()) { + assert(results.isValue(result.getResultNumber()) && + "expected values for value-type-result"); + if (failed(setPayloadValues( + result, results.getValues(result.getResultNumber())))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } } else { assert(!results.isParam(result.getResultNumber()) && "expected payload ops for the non-parameter typed result"); @@ -409,15 +874,9 @@ if (failed(state.getHandlesForPayloadOp(op, handles))) return failure(); - for (Value handle : handles) { - LogicalResult result = - state.updatePayloadOps(handle, [&](Operation *current) { - return current == op ? replacement : current; - }); - if (failed(result)) - return failure(); - } - return success(); + // TODO: we may need to invalidate handles to operations and values nested in + // the operation being replaced. + return state.replacePayloadOp(op, replacement); } //===----------------------------------------------------------------------===// @@ -425,63 +884,95 @@ //===----------------------------------------------------------------------===// transform::TransformResults::TransformResults(unsigned numSegments) { - segments.resize(numSegments, - ArrayRef(nullptr, static_cast(0))); - paramSegments.resize(numSegments, ArrayRef( - nullptr, static_cast(0))); + operations.appendEmptyRows(numSegments); + params.appendEmptyRows(numSegments); + values.appendEmptyRows(numSegments); } void transform::TransformResults::set(OpResult value, ArrayRef ops) { int64_t position = value.getResultNumber(); - assert(position < static_cast(segments.size()) && + assert(position < static_cast(operations.size()) && "setting results for a non-existent handle"); - assert(segments[position].data() == nullptr && "results already set"); - int64_t start = operations.size(); - llvm::append_range(operations, ops); - segments[position] = ArrayRef(operations).drop_front(start); + assert(operations[position].data() == nullptr && "results already set"); + assert(params[position].data() == nullptr && + "another kind of results already set"); + assert(values[position].data() == nullptr && + "another kind of results already set"); + operations.replace(position, ops); } void transform::TransformResults::setParams( OpResult value, ArrayRef params) { int64_t position = value.getResultNumber(); - assert(position < static_cast(paramSegments.size()) && + assert(position < static_cast(this->params.size()) && "setting params for a non-existent handle"); - assert(paramSegments[position].data() == nullptr && "params already set"); - size_t start = this->params.size(); - llvm::append_range(this->params, params); - paramSegments[position] = ArrayRef(this->params).drop_front(start); + assert(this->params[position].data() == nullptr && "params already set"); + assert(operations[position].data() == nullptr && + "another kind of results already set"); + assert(values[position].data() == nullptr && + "another kind of results already set"); + this->params.replace(position, params); +} + +void transform::TransformResults::setValues(OpResult handle, + ValueRange values) { + int64_t position = handle.getResultNumber(); + assert(position < static_cast(values.size()) && + "setting values for a non-existent handle"); + assert(this->values[position].data() == nullptr && "values already set"); + assert(operations[position].data() == nullptr && + "another kind of results already set"); + assert(params[position].data() == nullptr && + "another kind of results already set"); + this->values.replace(position, values); } ArrayRef transform::TransformResults::get(unsigned resultNumber) const { - assert(resultNumber < segments.size() && + assert(resultNumber < operations.size() && "querying results for a non-existent handle"); - assert(segments[resultNumber].data() != nullptr && - "querying unset results (param expected?)"); - return segments[resultNumber]; + assert(operations[resultNumber].data() != nullptr && + "querying unset results (values or params expected?)"); + return operations[resultNumber]; } ArrayRef transform::TransformResults::getParams(unsigned resultNumber) const { - assert(resultNumber < paramSegments.size() && + assert(resultNumber < params.size() && "querying params for a non-existent handle"); - assert(paramSegments[resultNumber].data() != nullptr && - "querying unset params (payload ops expected?)"); - return paramSegments[resultNumber]; + assert(params[resultNumber].data() != nullptr && + "querying unset params (ops or values expected?)"); + return params[resultNumber]; +} + +ArrayRef +transform::TransformResults::getValues(unsigned resultNumber) const { + assert(resultNumber < params.size() && + "querying params for a non-existent handle"); + assert(values[resultNumber].data() != nullptr && + "querying unset values (ops or params expected?)"); + return values[resultNumber]; } bool transform::TransformResults::isParam(unsigned resultNumber) const { - assert(resultNumber < paramSegments.size() && + assert(resultNumber < params.size() && "querying association for a non-existent handle"); - return paramSegments[resultNumber].data() != nullptr; + return params[resultNumber].data() != nullptr; +} + +bool transform::TransformResults::isValue(unsigned resultNumber) const { + assert(resultNumber < values.size() && + "querying association for a non-existent handle"); + return values[resultNumber].data() != nullptr; } bool transform::TransformResults::isSet(unsigned resultNumber) const { - assert(resultNumber < paramSegments.size() && + assert(resultNumber < params.size() && "querying association for a non-existent handle"); - return paramSegments[resultNumber].data() != nullptr || - segments[resultNumber].data() != nullptr; + return params[resultNumber].data() != nullptr || + operations[resultNumber].data() != nullptr || + values[resultNumber].data() != nullptr; } //===----------------------------------------------------------------------===// @@ -547,6 +1038,12 @@ return oneResult[r.getResultNumber()].get(); })); transformResults.setParams(r, params); + } else if (r.getType().isa()) { + auto values = llvm::to_vector( + llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { + return oneResult[r.getResultNumber()].get(); + })); + transformResults.setValues(r, values); } else { auto payloads = llvm::to_vector( llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { @@ -571,6 +1068,8 @@ SmallVector &mapped = extraMappings.emplace_back(); if (operand.getType().isa()) { llvm::append_range(mapped, state.getPayloadOps(operand)); + } else if (operand.getType().isa()) { + llvm::append_range(mapped, state.getPayloadValues(operand)); } else { assert(operand.getType().isa() && "unsupported kind of transform dialect value"); @@ -639,13 +1138,15 @@ } for (BlockArgument arg : body->getArguments().drop_front()) { if (arg.getType() - .isa()) + .isa()) continue; InFlightDiagnostic diag = op->emitOpError() << "expects trailing entry block arguments to be of type implementing " - "TransformHandleTypeInterface or TransformParamTypeInterface"; + "TransformHandleTypeInterface, TransformValueHandleTypeInterface or " + "TransformParamTypeInterface"; diag.attachNote() << "argument #" << arg.getArgNumber() << " does not"; return diag; } @@ -675,7 +1176,9 @@ bool hasPayloadOperands = false; for (Value operand : op->getOperands()) { onlyReadsHandle(operand, effects); - if (operand.getType().isa()) + if (operand.getType() + .isa()) hasPayloadOperands = true; } if (hasPayloadOperands) @@ -816,7 +1319,7 @@ LogicalResult transform::applyTransforms(Operation *payloadRoot, TransformOpInterface transform, - ArrayRef> extraMapping, + const RaggedArray &extraMapping, const TransformOptions &options) { #ifndef NDEBUG if (!transform->hasTrait() || diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -626,7 +626,7 @@ void transform::ReplicateOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getPattern(), effects); - consumesHandle(getHandles(), effects); + onlyReadsHandle(getHandles(), effects); producesHandle(getReplicated(), effects); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -99,3 +99,13 @@ } return DiagnosedSilenceableFailure::success(); } + +//===----------------------------------------------------------------------===// +// transform::AnyValueType +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::AnyValueType::checkPayload(Location loc, + ArrayRef payload) const { + return DiagnosedSilenceableFailure::success(); +} diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -278,7 +278,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, - ArrayRef> extraMappings, + const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, const Pass::Option &debugPayloadRootTag, diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -13,11 +13,11 @@ ^bb1(%arg1: !pdl.operation): %match_name = transform.structured.match ops{["arith.constant"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation - transform.test_consume_operand %match_name + transform.test_consume_operand %match_name : !pdl.operation %match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %match_attr, "matched attr name" : !pdl.operation - transform.test_consume_operand %match_attr + transform.test_consume_operand %match_attr : !pdl.operation } // ----- @@ -34,7 +34,7 @@ %match_name = transform.structured.match ops{["arith.constant"]} filter_result_type = f32 in %arg1 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation - transform.test_consume_operand %match_name + transform.test_consume_operand %match_name : !pdl.operation } // ----- @@ -65,7 +65,7 @@ #linalg.iterator_type]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation - transform.test_consume_operand %match_attr + transform.test_consume_operand %match_attr : !pdl.operation %no_match = transform.structured.match attributes{iterator_types = [ diff --git a/mlir/test/Dialect/Transform/check-use-after-free.mlir b/mlir/test/Dialect/Transform/check-use-after-free.mlir --- a/mlir/test/Dialect/Transform/check-use-after-free.mlir +++ b/mlir/test/Dialect/Transform/check-use-after-free.mlir @@ -2,7 +2,7 @@ func.func @use_after_free_branching_control_flow() { // expected-note @below {{allocated here}} - %0 = transform.test_produce_param_or_forward_operand 42 + %0 = transform.test_produce_self_handle_or_forward_operand transform.test_transform_op_with_regions { "transform.test_branching_transform_op_terminator"() : () -> () }, @@ -11,7 +11,7 @@ "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () ^bb1: // expected-note @below {{freed here}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () ^bb2: "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () @@ -29,7 +29,7 @@ func.func @use_after_free_in_nested_op() { // expected-note @below {{allocated here}} - %0 = transform.test_produce_param_or_forward_operand 42 + %0 = transform.test_produce_self_handle_or_forward_operand // expected-note @below {{freed here}} transform.test_transform_op_with_regions { "transform.test_branching_transform_op_terminator"() : () -> () @@ -38,7 +38,7 @@ ^bb0: "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () ^bb1: - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () ^bb2: "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () @@ -74,7 +74,7 @@ // expected-note @below {{freed here}} transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 4 } { ^bb4(%arg4: !pdl.operation): - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.sequence" } // expected-warning @below {{operand #0 may be used after free}} transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 5 } { @@ -102,7 +102,7 @@ } // expected-note @below {{freed here}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.sequence" // expected-warning @below {{operand #0 may be used after free}} transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 5 } { ^bb3(%arg3: !pdl.operation): @@ -118,7 +118,7 @@ // be reported as use-after-free. func.func @use_after_free_self_cycle() { // expected-note @below {{allocated here}} - %0 = transform.test_produce_param_or_forward_operand 42 + %0 = transform.test_produce_self_handle_or_forward_operand transform.test_transform_op_with_regions { "transform.test_branching_transform_op_terminator"() : () -> () }, @@ -132,7 +132,7 @@ } // expected-warning @below {{operand #0 may be used after free}} // expected-note @below {{freed here}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () ^bb2: "transform.test_branching_transform_op_terminator"() : () -> () @@ -147,7 +147,7 @@ // use-after-free. func.func @use_after_free_cycle() { // expected-note @below {{allocated here}} - %0 = transform.test_produce_param_or_forward_operand 42 + %0 = transform.test_produce_self_handle_or_forward_operand transform.test_transform_op_with_regions { "transform.test_branching_transform_op_terminator"() : () -> () }, @@ -157,7 +157,7 @@ ^bb1: // expected-warning @below {{operand #0 may be used after free}} // expected-note @below {{freed here}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" "transform.test_branching_transform_op_terminator"()[^bb2, ^bb3] : () -> () ^bb2: "transform.test_branching_transform_op_terminator"()[^bb1] : () -> () diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -21,7 +21,7 @@ %0 = pdl_match @return in %arg1 : (!pdl.operation) -> !pdl.operation %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation // expected-note @below {{invalidated by this transform op that consumes its operand #0}} - test_consume_operand %1 + test_consume_operand %1 : !pdl.operation // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} test_print_remark_at_operand %0, "remark" : !pdl.operation } @@ -55,8 +55,8 @@ %0 = pdl_match @func in %arg1 : (!pdl.operation) -> !pdl.operation %1 = pdl_match @return in %arg1 : (!pdl.operation) -> !pdl.operation %2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation - // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}} - test_consume_operand %2 + // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload entity more than once}} + test_consume_operand %2 : !pdl.operation test_print_remark_at_operand %0, "remark" : !pdl.operation } } @@ -74,9 +74,9 @@ // expected-note @below {{handle to invalidated ops}} %2 = transform.test_copy_payload %0 // expected-note @below {{invalidated by this transform op that consumes its operand #0}} - transform.test_consume_operand %1 + transform.test_consume_operand %1 : !pdl.operation // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} - transform.test_consume_operand %2 + transform.test_consume_operand %2 : !pdl.operation } } @@ -95,8 +95,8 @@ // to overlapping sets of payload IR ops. // // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} - // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates handles}} - transform.test_consume_operand %1, %2 + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities}} + transform.test_consume_operand %1, %2 : !pdl.operation } } @@ -113,3 +113,221 @@ transform.merge_handles %1, %2 { deduplicate } : !pdl.operation } } +// ----- + +// expected-note @below {{payload value}} +%0 = "test.match_anchor"() : () -> (i32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated handle}} + %4 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates handles to the same values as associated with it}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %4 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor op associated with the consumed handle}} +// expected-note @below {{payload value}} +// expected-note @below {{op defining the value as result #0}} +%0 = "test.match_anchor"() : () -> (i32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{invalidated handle}} + %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %2 : !transform.any_op + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %3 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor op associated with the consumed handle}} +"test.match_anchor_1"() ({ +^bb0: + // expected-note @below {{op defining the value as result #0}} + // expected-note @below {{payload value}} + %0 = "test.match_anchor_2"() : () -> (i32) + "test.region_terminator"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{invalidated handle}} + %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %1 : !transform.any_op + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %3 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor op associated with the consumed handle}} +// expected-note @below {{op defining the value as block argument #0 of block #0 in region #0}} +"test.match_anchor_1"() ({ +// expected-note @below {{payload value}} +^bb0(%arg0: i32): + %0 = "test.match_anchor_2"() : () -> (i32) + "test.region_terminator"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{invalidated handle}} + %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %1 : !transform.any_op + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %3 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor op associated with the consumed handle}} +"test.match_anchor_1"() ({ +^bb: + // expected-note @below {{op defining the value as block argument #0 of block #0 in region #0}} + "test.op_with_regions"() ({ + // expected-note @below {{payload value}} + ^bb0(%arg0: i32): + %0 = "test.match_anchor_2"() : () -> (i32) + "test.region_terminator"() : () -> () + }): () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{invalidated handle}} + %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %1 : !transform.any_op + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %3 : !transform.any_value +} + +// ----- + +// expected-note @below {{ancestor payload op}} +// expected-note @below {{nested payload op}} +// expected-note @below {{consumed handle points to this payload value}} +%0 = "test.match_anchor"() : () -> (i32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + // expected-note @below {{handle to invalidated ops}} + %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %2 : !transform.any_op +} + +// ----- + +// expected-note @below {{ancestor payload op}} +// expected-note @below {{consumed handle points to this payload value}} +%0 = "test.match_anchor_1"() ({ +^bb0: + // expected-note @below {{nested payload op}} + "test.match_anchor_2"() : () -> () + "test.region_terminator"() : () -> () +}) : () -> (i32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{handle to invalidated ops}} + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %2 : !transform.any_op +} + + +// ----- + +"test.match_anchor_1"() ({ +// expected-note @below {{consumed handle points to this payload value}} +^bb0(%arg0: f32): + // expected-note @below {{ancestor payload op}} + // expected-note @below {{nested payload op}} + "test.match_anchor_2"() : () -> () + "test.region_terminator"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + // expected-note @below {{handle to invalidated ops}} + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %2 : !transform.any_op +} + +// ----- + +"test.op_with_regions"() ({ +// expected-note @below {{consumed handle points to this payload value}} +^bb(%arg0: i32): + // expected-note @below {{ancestor payload op}} + "test.op_with_regions"() ({ + ^bb0: + // expected-note @below {{nested payload op}} + "test.match_anchor_2"() : () -> () + "test.region_terminator"() : () -> () + }): () -> () + "test.match_anchor_1"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{handle to invalidated ops}} + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_argument_of_parent_block %1, 0 : (!transform.any_op) -> !transform.any_value + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + test_consume_operand %3 : !transform.any_value + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + test_consume_operand %2 : !transform.any_op +} + +// ----- + +// Removing a block argument does not invalidate handles to operations in another block. +// Not expecting an error here. + +"test.op_with_regions"() ({ +^bb1(%arg0: i32): + "test.match_anchor_1"() : () -> () +^bb2: + "test.match_anchor_2"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_argument_of_parent_block %1, 0 : (!transform.any_op) -> !transform.any_value + test_consume_operand %3 : !transform.any_value + test_consume_operand %2 : !transform.any_op +} diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir --- a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir +++ b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir @@ -37,6 +37,17 @@ // ----- +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value): + // expected-error @above {{wrong kind of value provided for the top-level value handle}} +} + +func.func @foo() { + return +} + +// ----- + // expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}} transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op): diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-results-of-ops=test.some_returning_op bind-second-extra-to-results-of-ops=test.some_other_returning_op})' \ +// RUN: --split-input-file --verify-diagnostics + +// Note that diagnostic checker will merge two diagnostics with the same message +// at the same location, so only check the remark once. +// +// expected-remark @below {{first extra}} +// expected-note @below {{value handle points to an op result #0}} +// expected-note @below {{value handle points to an op result #1}} +%0:2 = "test.some_returning_op"() : () -> (i32, i64) + +// expected-remark @below {{first extra}} +// expected-note @below {{value handle points to an op result #0}} +%1 = "test.some_returning_op"() : () -> index + +// Note that diagnostic checker will merge two diagnostics with the same message +// at the same location, so only check the remark once. +// +// expected-remark @below {{second extra}} +// expected-note @below {{value handle points to an op result #0}} +// expected-note @below {{value handle points to an op result #1}} +%2:2 = "test.some_other_returning_op"() : () -> (f32, f64) + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value, %arg2: !transform.any_value): + test_print_remark_at_operand_value %arg1, "first extra" : !transform.any_value + test_print_remark_at_operand_value %arg2, "second extra" : !transform.any_value +} + +// ----- + +%0:2 = "test.some_returning_op"() : () -> (i32, i64) +%1 = "test.some_returning_op"() : () -> index + +transform.sequence failures(propagate) { +// expected-error @below {{wrong kind of value provided for top-level operation handle}} +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value): +} + +// ----- + +// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}} +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value): +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -24,7 +24,7 @@ // ----- -// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface or TransformParamTypeInterface}} +// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface, TransformValueHandleTypeInterface or TransformParamTypeInterface}} // expected-note @below {{argument #1 does not}} transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op, %arg1: i64): @@ -144,11 +144,11 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } // ----- @@ -156,13 +156,13 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" // expected-note @below {{used here as operand #0}} transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - test_consume_operand_if_matches_param_or_fail %arg1[42] + test_consume_operand_of_op_kind_or_fail %arg1, "transform.test_produce_self_handle_or_forward_operand" } } @@ -171,13 +171,13 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } } @@ -186,15 +186,15 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" // expected-note @below {{used here as operand #0}} transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): transform.sequence %arg1 : !pdl.operation failures(propagate) { ^bb2(%arg2: !pdl.operation): - test_consume_operand_if_matches_param_or_fail %arg2[42] + test_consume_operand_of_op_kind_or_fail %arg2, "transform.test_produce_self_handle_or_forward_operand" } } } @@ -235,14 +235,14 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-error @below {{result #0 has more than one potential consumer}} - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand // expected-note @below {{used here as operand #0}} transform.foreach %0 : !pdl.operation { ^bb1(%arg1: !pdl.operation): - transform.test_consume_operand %arg1 + transform.test_consume_operand %arg1 : !pdl.operation } // expected-note @below {{used here as operand #0}} - transform.test_consume_operand %0 + transform.test_consume_operand %0 : !pdl.operation } // ----- diff --git a/mlir/test/Dialect/Transform/test-dialect-injection.mlir b/mlir/test/Dialect/Transform/test-dialect-injection.mlir --- a/mlir/test/Dialect/Transform/test-dialect-injection.mlir +++ b/mlir/test/Dialect/Transform/test-dialect-injection.mlir @@ -6,11 +6,11 @@ // CHECK: transform.test_transform_op transform.test_transform_op -// CHECK: = transform.test_produce_param_or_forward_operand 42 {foo = "bar"} -%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } +// CHECK: = transform.test_produce_self_handle_or_forward_operand {foo = "bar"} +%0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" } -// CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42] -transform.test_consume_operand_if_matches_param_or_fail %0[42] +// CHECK: transform.test_consume_operand_of_op_kind_or_fail %{{.*}}, +transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" // Ensure that the extension type is roundtripped correctly. // CHECK: transform.cast %{{.*}} : !pdl.operation to !transform.test_dialect_op diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -10,18 +10,18 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } + %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" } // expected-remark @below {{succeeded}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } // ----- transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } - // expected-error @below {{expected the operand to be associated with 21 got 42}} - transform.test_consume_operand_if_matches_param_or_fail %0[21] + %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" } + // expected-error @below {{expected the operand to be associated a payload op of kind transform.sequence got transform.test_produce_self_handle_or_forward_operand}} + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.sequence" } // ----- @@ -31,10 +31,10 @@ // to detect double-consumption. transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } + %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" } %1 = transform.test_copy_payload %0 // expected-remark @below {{succeeded}} - transform.test_consume_operand_if_matches_param_or_fail %0[42] + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } // ----- @@ -60,11 +60,11 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - %0 = test_produce_param_or_forward_operand 42 + %0 = test_produce_self_handle_or_forward_operand sequence %0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): // expected-remark @below {{succeeded}} - test_consume_operand_if_matches_param_or_fail %arg1[42] + test_consume_operand_of_op_kind_or_fail %arg1, "transform.test_produce_self_handle_or_forward_operand" } } @@ -74,11 +74,11 @@ ^bb0(%arg0: !pdl.operation): %0 = sequence %arg0 : !pdl.operation -> !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %1 = test_produce_param_or_forward_operand 42 + %1 = test_produce_self_handle_or_forward_operand yield %1 : !pdl.operation } // expected-remark @below {{succeeded}} - test_consume_operand_if_matches_param_or_fail %0[42] + test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } // ----- @@ -163,15 +163,15 @@ %0 = pdl_match @match_func in %arg1 : (!pdl.operation) -> !pdl.operation transform.alternatives %0 : !pdl.operation { ^bb2(%arg2: !pdl.operation): - %1 = transform.test_produce_param_or_forward_operand 42 + %1 = transform.test_produce_self_handle_or_forward_operand // This operation fails, which triggers the next alternative without // reporting the error. - transform.test_consume_operand_if_matches_param_or_fail %1[43] + transform.test_consume_operand_of_op_kind_or_fail %1, "transform.sequence" }, { ^bb2(%arg2: !pdl.operation): - %1 = transform.test_produce_param_or_forward_operand 42 + %1 = transform.test_produce_self_handle_or_forward_operand // expected-remark @below {{succeeded}} - transform.test_consume_operand_if_matches_param_or_fail %1[42] + transform.test_consume_operand_of_op_kind_or_fail %1, "transform.test_produce_self_handle_or_forward_operand" } } } @@ -315,17 +315,18 @@ %3 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation // expected-remark @below {{applying}} transform.test_emit_remark_and_erase_operand %3, "applying" {fail_after_erase} - %4 = transform.test_produce_param_or_forward_operand 43 + %4 = transform.test_produce_self_handle_or_forward_operand %3 transform.yield %4 : !pdl.operation }, { ^bb2(%arg2: !pdl.operation): - %4 = transform.test_produce_param_or_forward_operand 42 + %4 = transform.test_produce_self_handle_or_forward_operand transform.yield %4 : !pdl.operation } // The first alternative failed, so the returned value is taken from the - // second alternative. + // second alternative, associated test_produce_self_handle_or_forward_operand rather + // than pdl_match. // expected-remark @below {{succeeded}} - transform.test_consume_operand_if_matches_param_or_fail %2[42] + transform.test_consume_operand_of_op_kind_or_fail %2, "transform.test_produce_self_handle_or_forward_operand" } } @@ -349,12 +350,12 @@ // expected-error @below {{scope must not contain the transforms being applied}} transform.alternatives %arg1 : !pdl.operation { ^bb2(%arg2: !pdl.operation): - %0 = transform.test_produce_param_or_forward_operand 42 - transform.test_consume_operand_if_matches_param_or_fail %0[43] + %0 = transform.test_produce_self_handle_or_forward_operand + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.sequence" }, { ^bb2(%arg2: !pdl.operation): - %0 = transform.test_produce_param_or_forward_operand 42 - transform.test_consume_operand_if_matches_param_or_fail %0[42] + %0 = transform.test_produce_self_handle_or_forward_operand + transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand" } } } @@ -1094,6 +1095,14 @@ // ----- +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{attempting to assign a null payload value to this transform handle}} + %0 = transform.test_produce_null_value : !transform.any_value +} + +// ----- + // expected-error @below {{could not find a nested top-level transform op}} // expected-note @below {{use the 'transform-file-name' option to provide transform as external file}} module { @@ -1106,7 +1115,65 @@ ^bb0(%arg0: !transform.any_op): } -// expected-error @below {{ore than one top-level transform op}} +// expected-error @below {{more than one top-level transform op}} +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): +} + +// ----- + +transform.sequence failures(propagate) { +// expected-remark @below {{value handle}} +// expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}} +^bb1(%arg0: !transform.any_op): + %0 = test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value + test_print_remark_at_operand_value %0, "value handle" : !transform.any_value +} + +// ----- + +// expected-remark @below {{result handle}} +// expected-note @below {{value handle points to an op result #1}} +%0:2 = "test.get_two_results"() : () -> (i32, i32) +// expected-remark @below {{result handle}} +// expected-note @below {{value handle points to an op result #1}} +%1:3 = "test.get_three_results"() : () -> (i32, i32, f32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %2 = transform.structured.match ops{["test.get_two_results", "test.get_three_results"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_result %2, 1 : (!transform.any_op) -> !transform.any_value + test_print_remark_at_operand_value %3, "result handle" : !transform.any_value +} + +// ----- + +"test.op_with_regions"() ({ +^bb0: + "test.regon_terminator"() : () -> () +}, { +^bb1: + "test.regon_terminator"() : () -> () +// expected-remark @below {{block argument handle}} +// expected-note @below {{value handle points to a block argument #2 in block #1 in region #1}} +^bb2(%arg0: i32, %arg1: f64, %arg3: index): + "test.match_anchor"() : () -> () + "test.regon_terminator"() : () -> () +}) : () -> () + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = test_produce_value_handle_to_argument_of_parent_block %2, 2 : (!transform.any_op) -> !transform.any_value + test_print_remark_at_operand_value %3, "block argument handle" : !transform.any_value +} + +// ----- + transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): + // expected-note @below {{value defined here with type '!transform.test_dialect_param'}} + %0 = test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op + // expected-error @below {{unexpectedly consumed a value that is not a handle as operand #0}} + test_consume_operand %0 : !transform.test_dialect_param } diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir --- a/mlir/test/Dialect/Transform/transform-state-extension.mlir +++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir @@ -45,6 +45,18 @@ } } +// ----- + +// expected-error @below {{cannot replace an op with another op producing a different number of results while tracking handles}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !pdl.operation): + test_add_test_extension "A" + %dummy = test_remap_operand_to_self %arg0 : !transform.any_op + } +} + + // ----- module { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -106,22 +106,66 @@ } // namespace DiagnosedSilenceableFailure -mlir::test::TestProduceParamOrForwardOperandOp::apply( +mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(getResult().cast(), getOperation()->getOperand(0).getDefiningOp()); } else { - results.set(getResult().cast(), - reinterpret_cast(*getParameter())); + results.set(getResult().cast(), getOperation()); } return DiagnosedSilenceableFailure::success(); } -LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { - if (getParameter().has_value() ^ (getNumOperands() != 1)) - return emitOpError() << "expects either a parameter or an operand"; - return success(); +DiagnosedSilenceableFailure +mlir::test::TestProduceValueHandleToSelfOperand::apply( + transform::TransformResults &results, transform::TransformState &state) { + results.setValues(getOut().cast(), getIn()); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceValueHandleToSelfOperand::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::producesHandle(getOut(), effects); + transform::onlyReadsPayload(effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceValueHandleToResult::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (target->getNumResults() <= getNumber()) + return emitSilenceableError() << "payload has no result #" << getNumber(); + results.push_back(target->getResult(getNumber())); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceValueHandleToResult::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::producesHandle(getOut(), effects); + transform::onlyReadsPayload(effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (!target->getBlock()) + return emitSilenceableError() << "payload has no parent block"; + if (target->getBlock()->getNumArguments() <= getNumber()) + return emitSilenceableError() + << "parent of the payload has no argument #" << getNumber(); + results.push_back(target->getBlock()->getArgument(getNumber())); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::producesHandle(getOut(), effects); + transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure @@ -130,16 +174,14 @@ return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( +DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); assert(payload.size() == 1 && "expected a single target op"); - auto value = reinterpret_cast(payload[0]); - if (static_cast(value) != getParameter()) { + if (payload[0]->getName().getStringRef() != getOpKind()) { return emitSilenceableError() - << "op expected the operand to be associated with " << getParameter() - << " got " << value; + << "op expected the operand to be associated a payload op of kind " + << getOpKind() << " got " << payload[0]->getName().getStringRef(); } emitRemark() << "succeeded"; @@ -155,6 +197,32 @@ return DiagnosedSilenceableFailure::success(); } +DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply( + transform::TransformResults &results, transform::TransformState &state) { + ArrayRef values = state.getPayloadValues(getIn()); + for (Value value : values) { + std::string note; + llvm::raw_string_ostream os(note); + if (auto arg = value.dyn_cast()) { + os << "a block argument #" << arg.getArgNumber() << " in block #" + << std::distance(arg.getOwner()->getParent()->begin(), + arg.getOwner()->getIterator()) + << " in region #" << arg.getOwner()->getParent()->getRegionNumber(); + } else { + os << "an op result #" << value.cast().getResultNumber(); + } + InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage(); + diag.attachNote() << "value handle points to " << os.str(); + } + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestPrintRemarkAtOperandValue::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::onlyReadsPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -199,6 +267,13 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getOperand(), effects); + transform::producesHandle(getOut(), effects); + transform::onlyReadsPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); @@ -482,6 +557,18 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestProduceNullValueOp::getEffects( + SmallVectorImpl &effects) { + transform::producesHandle(getOut(), effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + results.setValues(getOut().cast(), Value()); + return DiagnosedSilenceableFailure::success(); +} + void mlir::test::TestRequiredMemoryEffectsOp::getEffects( SmallVectorImpl &effects) { if (getHasOperandEffect()) diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -39,39 +39,82 @@ let assemblyFormat = ""; } -def TestProduceParamOrForwardOperandOp - : Op]> { let arguments = (ins - Arg, "", [TransformMappingRead]>:$operand, - OptionalAttr:$parameter); + Arg, "", [TransformMappingRead]>:$operand); let results = (outs Res:$res); - let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict"; + let assemblyFormat = "($operand^)? attr-dict"; let cppNamespace = "::mlir::test"; - let hasVerifier = 1; +} + +def TestProduceValueHandleToSelfOperand + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformHandleTypeInterface:$in); + let results = (outs TransformValueHandleTypeInterface:$out); + let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + +} + +def TestProduceValueHandleToResult + : Op]> { + let arguments = (ins TransformHandleTypeInterface:$in, + I64Attr:$number); + let results = (outs TransformValueHandleTypeInterface:$out); + let assemblyFormat = "$in `,` $number attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def TestProduceValueHandleToArgumentOfParentBlock + : Op]> { + let arguments = (ins TransformHandleTypeInterface:$in, + I64Attr:$number); + let results = (outs TransformValueHandleTypeInterface:$out); + let assemblyFormat = "$in `,` $number attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; } def TestConsumeOperand : Op]> { let arguments = (ins - Arg:$operand, Arg, "", [TransformMappingRead, TransformMappingFree]>:$second_operand); - let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict"; + let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict `:` type($operand)"; let cppNamespace = "::mlir::test"; } -def TestConsumeOperandIfMatchesParamOrFail - : Op]> { let arguments = (ins Arg:$operand, - I64Attr:$parameter); - let assemblyFormat = "$operand `[` $parameter `]` attr-dict"; + StrAttr:$op_kind); + let assemblyFormat = "$operand `,` $op_kind attr-dict"; let cppNamespace = "::mlir::test"; } @@ -87,6 +130,16 @@ let cppNamespace = "::mlir::test"; } +def TestPrintRemarkAtOperandValue + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformValueHandleTypeInterface:$in, + StrAttr:$message); + let assemblyFormat = "$in `,` $message attr-dict `:` type($in)"; + let cppNamespace = "::mlir::test"; +} + def TestAddTestExtensionOp : Op, @@ -107,11 +160,11 @@ def TestRemapOperandPayloadToSelfOp : Op]> { - let arguments = (ins - Arg:$operand); - let assemblyFormat = "$operand attr-dict"; + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let arguments = (ins PDL_Operation:$operand); + let results = (outs Optional:$out); + let assemblyFormat = "$operand attr-dict (`:` type($out)^)?"; let cppNamespace = "::mlir::test"; } @@ -352,6 +405,15 @@ let cppNamespace = "::mlir::test"; } +def TestProduceNullValueOp + : Op, + DeclareOpInterfaceMethods]> { + let results = (outs TransformValueHandleTypeInterface:$out); + let assemblyFormat = "attr-dict `:` type($out)"; + let cppNamespace = "::mlir::test"; +} + def TestRequiredMemoryEffectsOp : Op, diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -46,65 +46,93 @@ return "apply transform dialect operations one by one"; } - ArrayRef - findOperationsByName(Operation *root, StringRef name, - SmallVectorImpl &storage) { - size_t start = storage.size(); + void findOperationsByName(Operation *root, StringRef name, + SmallVectorImpl &operations) { root->walk([&](Operation *op) { if (op->getName().getStringRef() == name) { - storage.push_back(op); + operations.push_back(op); } }); - return ArrayRef(storage).drop_front(start); } - ArrayRef - createParameterMapping(MLIRContext &context, ArrayRef values, - SmallVectorImpl &storage) { - size_t start = storage.size(); - llvm::append_range(storage, llvm::map_range(values, [&](int v) { - Builder b(&context); - return transform::MappedValue(b.getI64IntegerAttr(v)); - })); - return ArrayRef(storage).drop_front(start); + void createParameterMapping(MLIRContext &context, ArrayRef values, + RaggedArray &result) { + SmallVector storage = + llvm::to_vector(llvm::map_range(values, [&](int v) { + Builder b(&context); + return transform::MappedValue(b.getI64IntegerAttr(v)); + })); + result.push_back(std::move(storage)); + } + + void + createOpResultMapping(Operation *root, StringRef name, + RaggedArray &extraMapping) { + SmallVector operations; + findOperationsByName(root, name, operations); + SmallVector results; + for (Operation *op : operations) + llvm::append_range(results, op->getResults()); + extraMapping.push_back(results); + } + + unsigned numberOfSetOptions(const Option &ops, + const ListOption ¶ms, + const Option &values) { + unsigned numSetValues = 0; + numSetValues += !ops.empty(); + numSetValues += !params.empty(); + numSetValues += !values.empty(); + return numSetValues; } void runOnOperation() override { - if (!bindFirstExtraToOps.empty() && !bindFirstExtraToParams.empty()) { - emitError(UnknownLoc::get(&getContext())) - << "cannot bind the first extra top-level argument to both " - "operations and parameters"; + unsigned firstSetOptions = + numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, + bindFirstExtraToResultsOfOps); + unsigned secondSetOptions = + numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams, + bindSecondExtraToResultsOfOps); + auto loc = UnknownLoc::get(&getContext()); + if (firstSetOptions > 1) { + emitError(loc) << "cannot bind the first extra top-level argument to " + "multiple entities"; return signalPassFailure(); } - if (!bindSecondExtraToOps.empty() && !bindSecondExtraToParams.empty()) { - emitError(UnknownLoc::get(&getContext())) - << "cannot bind the second extra top-level argument to both " - "operations and parameters"; + if (secondSetOptions > 1) { + emitError(loc) << "cannot bind the second extra top-level argument to " + "multiple entities"; return signalPassFailure(); } - if ((!bindSecondExtraToOps.empty() || !bindSecondExtraToParams.empty()) && - bindFirstExtraToOps.empty() && bindFirstExtraToParams.empty()) { - emitError(UnknownLoc::get(&getContext())) - << "cannot bind the second extra top-level argument without binding " - "the first"; - return signalPassFailure(); + if (firstSetOptions == 0 && secondSetOptions != 0) { + emitError(loc) << "cannot bind the second extra top-level argument " + "without bindings the first"; } - SmallVector extraMappingStorage; - SmallVector> extraMapping; + RaggedArray extraMapping; if (!bindFirstExtraToOps.empty()) { - extraMapping.push_back(findOperationsByName( - getOperation(), bindFirstExtraToOps.getValue(), extraMappingStorage)); + SmallVector operations; + findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), + operations); + extraMapping.push_back(operations); } else if (!bindFirstExtraToParams.empty()) { - extraMapping.push_back(createParameterMapping( - getContext(), bindFirstExtraToParams, extraMappingStorage)); + createParameterMapping(getContext(), bindFirstExtraToParams, + extraMapping); + } else if (!bindFirstExtraToResultsOfOps.empty()) { + createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, + extraMapping); } + if (!bindSecondExtraToOps.empty()) { - extraMapping.push_back(findOperationsByName( - getOperation(), bindSecondExtraToOps, extraMappingStorage)); + SmallVector operations; + findOperationsByName(getOperation(), bindSecondExtraToOps, operations); + extraMapping.push_back(operations); } else if (!bindSecondExtraToParams.empty()) { - extraMapping.push_back(createParameterMapping( - getContext(), bindSecondExtraToParams, extraMappingStorage)); + createParameterMapping(getContext(), bindSecondExtraToParams, + extraMapping); + } else if (!bindSecondExtraToResultsOfOps.empty()) { + createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, + extraMapping); } options = options.enableExpensiveChecks(enableExpensiveChecks); @@ -128,6 +156,10 @@ *this, "bind-first-extra-to-params", llvm::cl::desc("bind the first extra argument of the top-level op to " "the given integer parameters")}; + Option bindFirstExtraToResultsOfOps{ + *this, "bind-first-extra-to-results-of-ops", + llvm::cl::desc("bind the first extra argument of the top-level op to " + "results of payload operations of the given kind")}; Option bindSecondExtraToOps{ *this, "bind-second-extra-to-ops", @@ -137,6 +169,11 @@ *this, "bind-second-extra-to-params", llvm::cl::desc("bind the second extra argument of the top-level op to " "the given integer parameters")}; + Option bindSecondExtraToResultsOfOps{ + *this, "bind-second-extra-to-results-of-ops", + llvm::cl::desc("bind the second extra argument of the top-level op to " + "results of payload operations of the given kind")}; + Option transformFileName{ *this, "transform-file-name", llvm::cl::init(""), llvm::cl::desc(