diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md --- a/mlir/docs/Dialects/Transform.md +++ b/mlir/docs/Dialects/Transform.md @@ -42,47 +42,61 @@ ```mlir %0 = transform.loop.find { size > 42 } : !transform.interface -%1:2 = transform.loop.tile %0 { tile_sizes = [2,3,4] } +%1 = transform.compute_trailing_tile_size %0 : !transform.param +%2:2 = transform.loop.tile %0 tile_sizes(1, 4, %1) : (!transform.interface) - -> (!transform.op, !transform.op) + -> (!transform.op, !transform.op) transform.loop.unroll %1#1 : !transform.op ``` -The values used in the Transform dialect, also referred to as *handles*, -correspond to (groups of) operations in the payload IR. In the example +The values used in the Transform dialect may correspond to either: + + * sets of operations in the payload IR; + + * sets of parameters (attributes) known at the execution time of the + transform dialect. + +The former kind of values is also referred to as *handles*. In the example above, `%0` corresponds to the set of loops found in the payload IR that -satisfy the condition, and `%1` correspond to groups of outer and inner -loops, respectively, produced by the tiling transformation. +satisfy the condition, and `%2` correspond to groups of outer and inner +loops, respectively, produced by the tiling transformation, whereas `%1` +corresponds to a list of tile sizes selected for each of the operations +that `%0` corresponds to. A transform handle such as `%0` may be associated with multiple payload -operations. This is conceptually a set of operations and no assumptions -should be made about the order of ops unless specified otherwise by the -operation. Most Transform IR ops support operand values that are mapped to -multiple operations. They usually apply the respective transformation for -every mapped op ("batched execution"). Deviations from this convention are -described in the documentation of Transform IR ops. - -The handle values have transform IR types. These types describe properties -of payload IR operations associated with the value that are known to the -transform dialect, for example, all associated payload operations implement -a "TileableOp" interface, or have a specific "loop" kind. These properties -are used to statically indicate pre- and post-conditions of a -transformation connected to a Transform dialect operation. The conditions -are verified when payload IR operations are first associated with a -transform handle. By convention, Transform dialect operations are expected -to indicate narrow preconditions for their operands by enforcing operand -type constraints in the their definitions and verifiers. On the contrary, -operations are expected to have few constraints on their results. Specific -instances of a transform operation can then be created with a more -restricted result type than the constraint in the operation (e.g., the -"find" operation only constrains the result type to be a transform IR type -while its concrete instance can have a type with stricter constraints such -as implementing the "tilable" interface). The verification will then happen -at transform execution time. This approach allows one to capture payload IR -operation properties in the transform IR without resorting to excessive -use of type casts or coupling dialect extensions between themselves. It is -a trade-off between verbosity/complexity and static hardening, which can -be revised in the future. +operations. This is conceptually a set of operations and no assumptions should +be made about the order of ops unless specified otherwise by the operation. +Operations may take as operands and produce an arbitrary combination of values +representing handles and parameters. Most Transform IR ops support operand +values that are mapped to multiple operations. They usually apply the respective +transformation for every mapped op ("batched execution"). Deviations from this +convention are described in the documentation of Transform IR ops. + +The transform IR values have transform IR types, which implement either +[TransformTypeInterface](Transform.md#transformtypeinterface-transformtypeinterface) +or +[TransformParamTypeInterface](Transform.md##transformparamtypeinterface-transformparamtypeinterface). +The former interface verifiers properties of payload IR operations associated +with the value that are known to the transform dialect, for example, all +associated payload operations implement a "TileableOp" interface, or have a +specific "loop" kind. Similarly, the latter interface verifies properties of +attributes associated with the parameter value. These properties are used to +statically indicate pre- and post-conditions of a transformation connected to a +Transform dialect operation. The conditions are verified when attributes or +payload IR operations are first associated with a transform handle. By +convention, Transform dialect operations are expected to indicate narrow +preconditions for their operands by enforcing operand type constraints in the +their definitions and verifiers. On the contrary, operations are expected to +have few constraints on their results. Specific instances of a transform +operation can then be created with a more restricted result type than the +constraint in the operation (e.g., the "find" operation only constrains the +result type to be a transform IR type while its concrete instance can have a +type with stricter constraints such as implementing the "tilable" interface). +The verification will then happen at transform execution time. This approach +allows one to capture payload IR operation properties in the transform IR +without resorting to excessive use of type casts or coupling dialect extensions +between themselves. It is a trade-off between verbosity/complexity and static +hardening, which can be revised in the future. Overall, Transform IR ops are expected to be contained in a single top-level op. Such top-level ops specify how to apply the transformations described @@ -96,8 +110,8 @@ ```c++ LogicalResult transform::applyTransforms(Operation *payloadRoot, - TransformOpInterface transform, - const TransformOptions &options); + TransformOpInterface transform, + const TransformOptions &options); ``` that applies the transformations specified by the top-level `transform` to @@ -139,6 +153,12 @@ dialect is loaded to allow for those implementations to be supplied by separate dialect extensions if desired. +Similarly to operations, additional types can be injected into the dialect using +the same extension mechanism. The types must: + + * Implement exactly one of `TransformTypeInterface`, + `TransformParamTypeInterface`. + ## Side Effects The Transform dialect relies on MLIR side effect modelling to enable @@ -250,6 +270,8 @@ after it has been consumed, but does so abstractly, without processing the payload IR. +Values associated with parameters (non-handles) cannot be invalidated. + ## Intended Use and Integrations The transformation control infrastructure provided by this dialect is diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -312,45 +312,58 @@ /// The state maintained across applications of various ops implementing the /// TransformOpInterface. The operations implementing this interface and the /// surrounding structure are referred to as transform IR. The operations to -/// which transformations apply are referred to as payload IR. The state thus -/// contains the many-to-many mapping between values defined in the transform IR -/// ops and payload IR ops. The "expensive-checks" option can be passed to -/// the constructor at transformation execution time that transform IR values -/// used as operands by a transform IR operation are not associated with -/// dangling pointers to payload IR operations that are known to have been -/// erased by previous transformation through the same or a different transform -/// IR value. +/// which transformations apply are referred to as payload IR. Transform IR +/// operates on values that can be associated either with a list of payload IR +/// operations (such values are referred to as handles) or with a list of +/// parameters represented as attributes. The state thus contains the mapping +/// between values defined in the transform IR ops and either payload IR ops or +/// parameters. For payload ops, the mapping is many-to-many and the reverse +/// mapping is also stored. The "expensive-checks" option can be passed to the +/// constructor at transformation execution time that transform IR values used +/// as operands by a transform IR operation are not associated with dangling +/// pointers to payload IR operations that are known to have been erased by +/// previous transformation through the same or a different transform IR value. /// /// A reference to this class is passed as an argument to "apply" methods of the -/// transform op interface. Thus the "apply" method can call +/// transform op interface. Thus the "apply" method can call either /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations -/// associated with its operand and subject to transformation. The method is -/// expected to populate the `TransformResults` class instance in order to -/// update the mapping. The `applyTransform` method takes care of propagating -/// the state of `TransformResults` into the instance of this class. +/// or `state.getParams( getSomeOperand() )` to obtain the list of parameters +/// associated with its operand. The method is expected to populate the +/// `TransformResults` class instance in order to update the mapping. The +/// `applyTransform` method takes care of propagating the state of +/// `TransformResults` into the instance of this class. /// /// When applying transform IR operations with regions, the client is expected -/// to create a RegionScope RAII object to create a new "stack frame" for +/// to create a `RegionScope` RAII object to create a new "stack frame" for /// values defined inside the region. The mappings from and to these values will /// be automatically dropped when the object goes out of scope, typically at the -/// end of the "apply" function of the parent operation. If a region contains +/// end of the `apply` function of the parent operation. If a region contains /// blocks with arguments, the client can map those arguments to payload IR ops -/// using "mapBlockArguments". +/// using `mapBlockArguments`. class TransformState { +public: + using Param = Attribute; + +private: /// Mapping between a Value in the transform IR and the corresponding set of /// operations in the payload IR. - using TransformOpMapping = DenseMap>; + using TransformOpMapping = DenseMap>; /// Mapping between a payload IR operation and the transform IR values it is /// associated with. using TransformOpReverseMapping = DenseMap>; - /// Bidirectional mappings between transform IR values and payload IR - /// operations. + /// Mapping between a Value in the transform IR and the corresponding list of + /// parameters. + using ParamMapping = DenseMap>; + + /// The bidirectional mappings between transform IR values and payload IR + /// operations, and the mapping between transform IR values and parameters. struct Mappings { TransformOpMapping direct; TransformOpReverseMapping reverse; + ParamMapping params; }; friend LogicalResult applyTransforms(Operation *payloadRoot, @@ -366,6 +379,10 @@ /// This is helpful for transformations that apply to a particular handle. ArrayRef getPayloadOps(Value value) const; + /// Returns the list of parameters that the given transform IR value + /// corresponds to. + ArrayRef getParams(Value value) const; + /// Populates `handles` with all handles pointing to the given Payload IR op. /// Returns success if such handles exist, failure otherwise. LogicalResult getHandlesForPayloadOp(Operation *op, @@ -590,9 +607,16 @@ /// that the associated payload operation may no longer exist. /// /// Returns failure if the payload does not satisfy the conditions associated - /// with the type of the handle value. + /// with the type of the handle value. The value is expected to have a type + /// implementing TransformTypeInterface. LogicalResult setPayloadOps(Value value, ArrayRef targets); + /// Sets the parameters associated with the given transform IR value. Returns + /// failure if the parameters do not satisfy the conditions associated with + /// the type of the value. The value is expected to have a type implementing + /// TransformParamTypeInterface. + LogicalResult setParams(Value value, ArrayRef params); + /// Forgets the payload IR ops associated with the given transform IR value. void removePayloadOps(Value value); @@ -661,26 +685,56 @@ public: /// Indicates that the result of the transform IR op at the given position /// corresponds to the given list of payload IR ops. Each result must be set - /// by the transformation exactly once. + /// by the transformation exactly once. The value must have a type + /// implementing TransformTypeInterface. void set(OpResult value, ArrayRef ops); + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given list of parameters. Each result must be set by + /// the transformation exactly once. The value must have a type implementing + /// TransformParamTypeInterface. + void setParams(OpResult value, ArrayRef params); + private: /// Creates an instance of TransformResults that expects mappings for - /// `numSegments` values. + /// `numSegments` values, which may be associated with payload operations or + /// parameters. explicit TransformResults(unsigned numSegments); /// Gets the list of operations associated with the result identified by its - /// number in the list of operation results. + /// number in the list of operation results. The result must have been set to + /// be associated with payload IR operations. ArrayRef get(unsigned resultNumber) const; + /// Gets the list of parameters associated with the result identified by its + /// number in the list of operation results. The result must have been set to + /// be associated with parameters. + ArrayRef getParams(unsigned resultNumber) const; + + /// Returns `true` if the result identified by its number in the list of + /// operation results is associated with a list of parameters, `false` if it + /// is associated with the list of payload IR operations. + bool isParam(unsigned resultNumber) const; + /// Storage for pointers to payload IR ops that are associated with results of /// a transform IR op. `segments` contains as many entries as the transform IR - /// op has results. Each entry is a reference to a contiguous segment in - /// the `operations` list that contains the pointers to operations. This - /// allows for operations to be stored contiguously without nested vectors and - /// for different segments to be set in any order. + /// op has results, even if some of them are not associated with payload IR + /// operations. Each entry is a reference to a contiguous segment in the + /// `operations` list that contains the pointers to operations. This allows + /// for operations to be stored contiguously without nested vectors and for + /// different segments to be set in any order. SmallVector, 2> segments; SmallVector operations; + + /// Storage for parameters that are associated with results of the transform + /// IR op. `paramSegments` contains as many entries as the transform IR op has + /// results, even if some of them are not associated with parameters. Each + /// entry is a reference to a contiguous segment in the `params` list that + /// contains the actual parameters. This allows for parameters to be stored + /// contiguously without nested vectors and for different segments to be set + /// in any order. + SmallVector, 2> paramSegments; + SmallVector params; }; TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { @@ -895,6 +949,39 @@ } }; +namespace detail { +/// Non-template implementation of ParamProducerTransformOpTrait::getEffects(). +void getParamProducerTransformOpTraitEffects( + Operation *op, SmallVectorImpl &effects); +/// Non-template implementation of ParamProducerTransformOpTrait::verify(). +LogicalResult verifyParamProducerTransformOpTrait(Operation *op); +} // namespace detail + +/// Trait implementing the MemoryEffectsOpInterface for operations that produce +/// transform dialect parameters. It marks all op results of +/// TransformHandleTypeInterface as produced by the op, all operands as only +/// read by the op and, if at least one of the operand is a handle to payload +/// ops, the entire payload as potentially read. The op must only produce +/// parameter-typed results. +template +class ParamProducerTransformOpTrait + : public OpTrait::TraitBase { +public: + /// Populates `effects` with effect instances described in the trait + /// documentation. + void getEffects(SmallVectorImpl &effects) { + detail::getParamProducerTransformOpTraitEffects(this->getOperation(), + effects); + } + + /// Checks that the op matches the expectation of this trait, i.e., that it + /// implements the MemoryEffectsOpInterface and only produces parameter-typed + /// results. + static LogicalResult verifyTrait(Operation *op) { + return detail::verifyParamProducerTransformOpTrait(op); + } +}; + } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -103,27 +103,21 @@ }]; } -def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> { - let description = [{ - Types that can be used for Transform dialect handle values. Such types - define the properties of Payload IR operations associated with the handle. - A user of such a handle can assume that these properties have been verified - for any Payload IR operation associated with it. - }]; - +class TransformTypeInterfaceBase + : TypeInterface { let cppNamespace = "::mlir::transform"; let methods = [ InterfaceMethod< /*desc=*/[{ - Checks if the given list of associated Payload IR operations satisfy - the conditions defined by this type. If not, produces a silenceable + Checks if the given associated objects (Payload IR operations or attributes) + satisfy the conditions defined by this type. If not, produces a silenceable error at the specified location. }], /*returnType=*/"::mlir::DiagnosedSilenceableFailure", /*name=*/"checkPayload", /*arguments=*/(ins "::mlir::Location":$loc, - "::mlir::ArrayRef<::mlir::Operation *>":$payload) + "::mlir::ArrayRef<" # cppObjectType # ">":$payload) > ]; @@ -135,6 +129,29 @@ }]; } +def TransformTypeInterface + : TransformTypeInterfaceBase<"TransformTypeInterface", + "::mlir::Operation *"> { + let description = [{ + Types that can be used for the Transform dialect handle values. Such types + define the properties of Payload IR operations associated with the handle. + A user of such a handle can assume that these properties have been verified + for any Payload IR operation associated with it. + }]; +} + +def TransformParamTypeInterface + : TransformTypeInterfaceBase<"TransformParamTypeInterface", + "::mlir::Attribute"> { + let description = [{ + Types that can be used for the Transform dialect parameter values. Such types + define the structure of the parameters associated with the value, e.g., their + underlying type. A user of the value can assume that the parameter has been + verified. + }]; + +} + def FunctionalStyleTransformOpTrait : NativeOpTrait<"FunctionalStyleTransformOpTrait"> { let cppNamespace = "::mlir::transform"; @@ -148,4 +165,8 @@ let cppNamespace = "::mlir::transform"; } +def ParamProducerTransformOpTrait : NativeOpTrait<"ParamProducerTransformOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td @@ -36,6 +36,22 @@ let assemblyFormat = "`<` $operation_name `>`"; } +def Transform_ParamType : TypeDef]> { + let description = [{ + Transform IR value that can be associated with the list of parameters + of the given type. Types are currently limited to integers, but may be + extended in the future to other types values of which can be contained + in attributes. + }]; + let mnemonic = "param"; + let parameters = (ins + TypeParameter<"::mlir::Type", "Underlying type of the parameter">:$type + ); + let assemblyFormat = "`<` $type `>`"; + let genVerifyDecl = 1; +} + class Transform_ConcreteOpType : Type()" diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -38,7 +38,11 @@ void transform::detail::checkImplementsTransformTypeInterface( TypeID typeID, MLIRContext *context) { const auto &abstractType = AbstractType::lookup(typeID, context); - assert(abstractType.hasInterface(TransformTypeInterface::getInterfaceID())); + assert( + (abstractType.hasInterface(TransformTypeInterface::getInterfaceID()) || + abstractType.hasInterface( + TransformParamTypeInterface::getInterfaceID())) && + "expected Transform dialect type to implement one of the two interfaces"); } #endif // NDEBUG diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -44,7 +44,16 @@ transform::TransformState::getPayloadOps(Value value) const { const TransformOpMapping &operationMapping = getMapping(value).direct; auto iter = operationMapping.find(value); - assert(iter != operationMapping.end() && "unknown handle"); + assert(iter != operationMapping.end() && + "cannot find mapping for payload handle (param handle provided?)"); + return iter->getSecond(); +} + +ArrayRef transform::TransformState::getParams(Value value) const { + const ParamMapping &mapping = getMapping(value).params; + auto iter = mapping.find(value); + assert(iter != mapping.end() && + "cannot find mapping for param handle (payload handle provided?)"); return iter->getSecond(); } @@ -67,6 +76,8 @@ ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); + assert(!value.getType().isa() && + "cannot associate payload ops with a value of parameter type"); auto iface = value.getType().cast(); DiagnosedSilenceableFailure result = @@ -89,6 +100,26 @@ return success(); } +LogicalResult transform::TransformState::setParams(Value value, + ArrayRef params) { + assert(value != nullptr && "attempting to set params for a null value"); + + auto valueType = value.getType().dyn_cast(); + assert(value && + "cannot associate parameter with a value of non-parameter type"); + DiagnosedSilenceableFailure result = + valueType.checkPayload(value.getLoc(), params); + if (failed(result.checkAndReport())) + return failure(); + + Mappings &mappings = getMapping(value); + bool inserted = + mappings.params.insert({value, llvm::to_vector(params)}).second; + assert(inserted && "value is already associated with another list of params"); + (void)inserted; + return success(); +} + void transform::TransformState::dropReverseMapping(Mappings &mappings, Operation *op, Value value) { auto it = mappings.reverse.find(op); @@ -112,8 +143,8 @@ Mappings &mappings = getMapping(value); auto it = mappings.direct.find(value); assert(it != mappings.direct.end() && "unknown handle"); - SmallVector &association = it->getSecond(); - SmallVector updated; + SmallVector &association = it->getSecond(); + SmallVector updated; updated.reserve(association.size()); for (Operation *op : association) { @@ -269,8 +300,21 @@ assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " "current transform op"); - if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) - return DiagnosedSilenceableFailure::definiteFailure(); + if (result.getType().isa()) { + assert(results.isParam(result.getResultNumber()) && + "expected parameters for the parameter-typed result"); + if (failed( + setParams(result, results.getParams(result.getResultNumber())))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + } else { + assert(!results.isParam(result.getResultNumber()) && + "expected payload ops for the non-parameter typed result"); + if (failed( + setPayloadOps(result, results.get(result.getResultNumber())))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + } } printOnFailureRAII.release(); @@ -312,6 +356,8 @@ transform::TransformResults::TransformResults(unsigned numSegments) { segments.resize(numSegments, ArrayRef(nullptr, static_cast(0))); + paramSegments.resize(numSegments, ArrayRef( + nullptr, static_cast(0))); } void transform::TransformResults::set(OpResult value, @@ -325,14 +371,41 @@ segments[position] = makeArrayRef(operations).drop_front(start); } +void transform::TransformResults::setParams( + OpResult value, ArrayRef params) { + int64_t position = value.getResultNumber(); + assert(position < static_cast(paramSegments.size()) && + "setting params for a non-existent handle"); + assert(paramSegments[position].data() == nullptr && "params already set"); + size_t start = this->params.size(); + llvm::append_range(this->params, params); + paramSegments[position] = makeArrayRef(this->params).drop_front(start); +} + ArrayRef transform::TransformResults::get(unsigned resultNumber) const { assert(resultNumber < segments.size() && "querying results for a non-existent handle"); - assert(segments[resultNumber].data() != nullptr && "querying unset results"); + assert(segments[resultNumber].data() != nullptr && + "querying unset results (param expected?)"); return segments[resultNumber]; } +ArrayRef +transform::TransformResults::getParams(unsigned resultNumber) const { + assert(resultNumber < paramSegments.size() && + "querying params for a non-existent handle"); + assert(paramSegments[resultNumber].data() != nullptr && + "querying unset params (payload ops expected?)"); + return paramSegments[resultNumber]; +} + +bool transform::TransformResults::isParam(unsigned resultNumber) const { + assert(resultNumber < paramSegments.size() && + "querying association for a non-existent handle"); + return paramSegments[resultNumber].data() != nullptr; +} + //===----------------------------------------------------------------------===// // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// @@ -386,6 +459,43 @@ return success(); } +//===----------------------------------------------------------------------===// +// Utilities for ParamProducedTransformOpTrait. +//===----------------------------------------------------------------------===// + +void transform::detail::getParamProducerTransformOpTraitEffects( + Operation *op, SmallVectorImpl &effects) { + producesHandle(op->getResults(), effects); + bool hasPayloadOperands = false; + for (Value operand : op->getOperands()) { + onlyReadsHandle(operand, effects); + if (operand.getType().isa()) + hasPayloadOperands = true; + } + if (hasPayloadOperands) + onlyReadsPayload(effects); +} + +LogicalResult +transform::detail::verifyParamProducerTransformOpTrait(Operation *op) { + // Interfaces can be attached dynamically, so this cannot be a static + // assert. + if (!op->getName().getInterface()) { + llvm::report_fatal_error( + Twine("ParamProducerTransformOpTrait must be attached to an op that " + "implements MemoryEffectsOpInterface, found on ") + + op->getName().getStringRef()); + } + for (Value result : op->getResults()) { + if (result.getType().isa()) + continue; + return op->emitOpError() + << "ParamProducerTransformOpTrait attached to this op expects " + "result types to implement TransformParamTypeInterface"; + } + return success(); +} + //===----------------------------------------------------------------------===// // Memory effects. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" @@ -37,12 +38,20 @@ >(); } +//===----------------------------------------------------------------------===// +// transform::AnyOpType +//===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::AnyOpType::checkPayload(Location loc, ArrayRef payload) const { return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// transform::OperationType +//===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::OperationType::checkPayload(Location loc, ArrayRef payload) const { @@ -58,3 +67,35 @@ return DiagnosedSilenceableFailure::success(); } + +//===----------------------------------------------------------------------===// +// transform::ParamType +//===----------------------------------------------------------------------===// + +LogicalResult +transform::ParamType::verify(function_ref emitError, + Type type) { + IntegerType intType = type.dyn_cast(); + if (!intType || intType.getWidth() > 64) + return emitError() << "only supports integer types with width <=64"; + return success(); +} + +DiagnosedSilenceableFailure +transform::ParamType::checkPayload(Location loc, + ArrayRef payload) const { + for (Attribute attr : payload) { + auto integerAttr = attr.dyn_cast(); + if (!integerAttr) { + return emitSilenceableError(loc) + << "expected parameter to be an integer attribute, got " << attr; + } + if (integerAttr.getType() != getType()) { + return emitSilenceableError(loc) + << "expected the type of the parameter attribute (" + << integerAttr.getType() << ") to match the parameter type (" + << getType() << ")"; + } + } + return DiagnosedSilenceableFailure::success(); +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -920,6 +920,7 @@ } "test.some_op"() : () -> () + // ----- func.func @split_handles(%a: index, %b: index, %c: index) { @@ -937,3 +938,56 @@ /// propagate mode. yield %fun : !pdl.operation } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.test_produce_integer_param_with_type i32 : !transform.test_dialect_param + // expected-remark @below {{0 : i32}} + transform.test_print_param %0 : !transform.test_dialect_param +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected the type of the parameter attribute ('i32') to match the parameter type ('i64')}} + transform.test_produce_integer_param_with_type i32 : !transform.param +} + +// ----- + + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.test_add_to_param 40 + %1 = transform.test_add_to_param %0, 2 + // expected-remark @below {{42 : i32}} + transform.test_print_param %1 : !transform.test_dialect_param +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg0 + %1 = transform.test_produce_param_with_number_of_test_ops %0 : !pdl.operation + // expected-remark @below {{1 : i32, 3 : i32}} + transform.test_print_param %1 : !transform.test_dialect_param + %2 = transform.test_add_to_param %1, 100 + // expected-remark @below {{101 : i32, 103 : i32}} + transform.test_print_param %2 : !transform.test_dialect_param +} + +func.func private @one_test_op(%arg0: i32) { + "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32 + return +} + +func.func private @three_test_ops(%arg0: i32) { + "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32 + "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32 + "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32 + return +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h @@ -23,12 +23,12 @@ class DialectRegistry; } // namespace mlir -#define GET_OP_CLASSES -#include "TestTransformDialectExtension.h.inc" - #define GET_TYPEDEF_CLASSES #include "TestTransformDialectExtensionTypes.h.inc" +#define GET_OP_CLASSES +#include "TestTransformDialectExtension.h.inc" + namespace test { /// Registers the test extension to the Transform dialect. void registerTestTransformDialectExtension(::mlir::DialectRegistry ®istry); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -17,8 +17,10 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -317,15 +319,27 @@ for (Operation *op : payload) { if (op->getName().getDialectNamespace() != "test") { - Diagnostic diag(loc, DiagnosticSeverity::Error); - diag << "expected the payload operation to belong to the 'test' dialect"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + return emitSilenceableError(loc) << "expected the payload operation to " + "belong to the 'test' dialect"; } } return DiagnosedSilenceableFailure::success(); } +DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( + Location loc, ArrayRef payload) const { + for (Attribute attr : payload) { + auto integerAttr = attr.dyn_cast(); + if (integerAttr && integerAttr.getType().isSignlessInteger(32)) + continue; + return emitSilenceableError(loc) + << "expected the parameter to be a i32 integer attribute"; + } + + return DiagnosedSilenceableFailure::success(); +} + void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTarget(), effects); @@ -346,6 +360,75 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestPrintParamOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getParam(), effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestPrintParamOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + std::string str; + llvm::raw_string_ostream os(str); + llvm::interleaveComma(state.getParams(getParam()), os); + auto diag = emitRemark() << os.str(); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure +mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector values(/*Size=*/1, /*Value=*/0); + if (Value param = getParam()) { + values = llvm::to_vector( + llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { + return attr.cast().getValue().getLimitedValue( + UINT32_MAX); + })); + } + + Builder builder(getContext()); + SmallVector result = llvm::to_vector( + llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { + return builder.getI32IntegerAttr(value + getAddendum()); + })); + results.setParams(getResult().cast(), result); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceParamWithNumberOfTestOps::apply( + transform::TransformResults &results, transform::TransformState &state) { + Builder builder(getContext()); + SmallVector result = llvm::to_vector( + llvm::map_range(state.getPayloadOps(getHandle()), + [&builder](Operation *payload) -> Attribute { + int32_t count = 0; + payload->walk([&count](Operation *op) { + if (op->getName().getDialectNamespace() == "test") + ++count; + }); + return builder.getI32IntegerAttr(count); + })); + results.setParams(getResult().cast(), result); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceIntegerParamWithTypeOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + Attribute zero = IntegerAttr::get(getType(), 0); + results.setParams(getResult().cast(), zero); + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() { + if (!getType().isa()) { + return emitOpError() << "expects an integer type"; + } + return success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL @@ -371,9 +454,6 @@ }; } // namespace -#define GET_OP_CLASSES -#include "TestTransformDialectExtension.cpp.inc" - // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. LLVM_ATTRIBUTE_UNUSED static OptionalParseResult @@ -384,6 +464,9 @@ #define GET_TYPEDEF_CLASSES #include "TestTransformDialectExtensionTypes.cpp.inc" +#define GET_OP_CLASSES +#include "TestTransformDialectExtension.cpp.inc" + void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { registry.addExtensions(); } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -29,6 +29,16 @@ let assemblyFormat = ""; } +def TestTransformTestDialectParamType + : TypeDef]> { + let description = [{ + Parameter associated with an i32 attribute for testing purposes. + }]; + let mnemonic = "test_dialect_param"; + let assemblyFormat = ""; +} + def TestProduceParamOrForwardOperandOp : Op]> { @@ -262,4 +272,45 @@ let cppNamespace = "::mlir::test"; } +def TestPrintParamOp + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformParamTypeInterface:$param); + let assemblyFormat = "$param attr-dict `:` type($param)"; + let cppNamespace = "::mlir::test"; +} + +def TestAddToParamOp + : Op]> { + let arguments = (ins Optional:$param, + I32Attr:$addendum); + let results = (outs TestTransformTestDialectParamType:$result); + let assemblyFormat = "($param^ `,`)? $addendum attr-dict"; + let cppNamespace = "::mlir::test"; +} + +def TestProduceParamWithNumberOfTestOps + : Op]> { + let arguments = (ins TransformTypeInterface:$handle); + let results = (outs TestTransformTestDialectParamType:$result); + let assemblyFormat = "$handle attr-dict `:` type($handle)"; + let cppNamespace = "::mlir::test"; +} + +def TestProduceIntegerParamWithTypeOp + : Op]> { + let arguments = (ins TypeAttr:$type); + let results = (outs TransformParamTypeInterface:$result); + let assemblyFormat = "$type attr-dict `:` type($result)"; + let cppNamespace = "::mlir::test"; + let hasVerifier = 1; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD