diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -23,8 +23,19 @@ "::mlir::pdl_interp::PDLInterpDialect", ]; + let hasOperationAttrVerify = 1; let extraClassDeclaration = [{ + /// Name of the attribute attachable to the symbol table operation + /// containing named sequences. This is used to trigger verification. + constexpr const static llvm::StringLiteral + kWithNamedSequenceAttrName = "transform.with_named_sequence"; + + /// Names of the attribute attachable to an operation so it can be + /// identified as root by the default interpreter pass. + constexpr const static llvm::StringLiteral + kTargetTagAttrName = "transform.target_tag"; + /// Returns the named PDL constraint functions available in the dialect /// as a map from their name to the function. const ::llvm::StringMap<::mlir::PDLConstraintFunction> & diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -192,6 +192,12 @@ // class body to comply with visibility and full-declaration requirements. inline RegionScope make_region_scope(Region ®ion); + /// Creates a new region scope for the given isolated-from-above region. + /// Unlike the non-isolated counterpart, there is no nesting expectation. + // Implementation note: this method is inline but implemented outside of the + // class body to comply with visibility and full-declaration requirements + inline RegionScope make_isolated_region_scope(Region ®ion); + /// A RAII object maintaining a "stack frame" for a transform IR region. When /// applying a transform IR operation that contains a region, the caller is /// expected to create a RegionScope before applying the ops contained in the @@ -201,17 +207,23 @@ class RegionScope { public: /// Forgets the mapping from or to values defined in the associated - /// transform IR region. + /// transform IR region, and restores the mapping that existed before + /// entering this scope. ~RegionScope() { state.mappings.erase(region); + if (storedMappings.has_value()) + state.mappings.swap(*storedMappings); #if LLVM_ENABLE_ABI_BREAKING_CHECKS state.regionStack.pop_back(); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } private: + /// Tag structure for differentiating the constructor for isolated regions. + struct Isolated {}; + /// Creates a new scope for mappings between values defined in the given - /// transform IR region and payload IR operations. + /// transform IR region and payload IR objects. RegionScope(TransformState &state, Region ®ion) : state(state), region(®ion) { auto res = state.mappings.try_emplace(this->region); @@ -225,13 +237,33 @@ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } + /// Creates a new scope for mappings between values defined in the given + /// isolated-from-above transform IR region and payload IR objects. + RegionScope(TransformState &state, Region ®ion, Isolated) + : state(state), region(®ion) { + // Store the previous mapping stack locally. + storedMappings = llvm::SmallDenseMap(); + storedMappings->swap(state.mappings); + state.mappings.try_emplace(this->region); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + state.regionStack.push_back(this->region); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } + /// Back-reference to the transform state. TransformState &state; /// The region this scope is associated with. Region *region; + /// Local copy of the mappings that existed before entering the current + /// region. Used only when the current region is isolated so we don't + /// accidentally look up the values defined outside the isolated region. + std::optional> storedMappings = + std::nullopt; + friend RegionScope TransformState::make_region_scope(Region &); + friend RegionScope TransformState::make_isolated_region_scope(Region &); }; friend class RegionScope; @@ -551,6 +583,13 @@ /// TransformValueHandleTypeInterface. void setValues(OpResult handle, ValueRange values); + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given range of mapped values. All mapped values are + /// expected to be compatible with the type of the result, e.g., if the result + /// is an operation handle, all mapped values are expected to be payload + /// operations. + void setMappedValues(OpResult handle, ArrayRef values); + private: /// Creates an instance of TransformResults that expects mappings for /// `numSegments` values, which may be associated with payload operations or @@ -597,10 +636,21 @@ RaggedArray values; }; +/// Creates a RAII object the lifetime of which corresponds to the new mapping +/// for transform IR values defined in the given region. Values defined in +/// surrounding regions remain accessible. TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { return RegionScope(*this, region); } +/// Creates a RAII object the lifetime of which corresponds to the new mapping +/// for transform IR values defined in the given isolated-from-above region. +/// Values defined in surrounding regions cannot be accessed. +TransformState::RegionScope +TransformState::make_isolated_region_scope(Region ®ion) { + return RegionScope(*this, region, RegionScope::Isolated()); +} + namespace detail { /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait /// to either the list of operations associated with its operand or the root of @@ -614,6 +664,12 @@ /// Verification hook for TransformOpInterface. LogicalResult verifyTransformOpInterface(Operation *op); + +/// Populates `mappings` with mapped values associated with the given transform +/// IR values in the given `state`. +void prepareValueMappings( + SmallVectorImpl> &mappings, + ValueRange values, const transform::TransformState &state); } // namespace detail /// This trait is supposed to be attached to Transform dialect operations that diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -12,9 +12,11 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -9,10 +9,12 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Dialect/Transform/IR/TransformAttrs.td" @@ -266,6 +268,51 @@ "functional-type(operands, results)"; } +def IncludeOp : TransformDialectOp<"include", + [CallOpInterface, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Includes a named transform sequence"; + let description = [{ + The application of this transform operation is equivalent to applying the + operations contained in the named transform sequence with operands being + remapped to block arguments. The behavior of the operation when a + transformation in the included named sequence produces a silenceable error + is controlled by the `failure_propagation_mode` attribute. When set to + `propagate`, the failure of any nested transformation in the sequence + implies immediate failure of the entire sequence with a silenceable error, + and no further transformation is attempted. When set to `suppress`, + silenceable errors in nested operations are ignored and further + transformations are applied. Beware that even silenceable errors may leave + the payload IR in a state unsuitable for further transformations. It is the + responsibility of the user to ensure the following transformations are + robust enough when errors are suppressed. Definite errors are propagated + immediately regardless of the mode. The objects associated with the results + of this operation are the same as those associated with the operands of the + `transform.yield` in the referenced named sequence. + }]; + + let arguments = (ins SymbolRefAttr:$target, + FailurePropagationMode:$failure_propagation_mode, + Variadic:$operands); + let results = (outs Variadic:$results); + + let assemblyFormat = + "$target `failures` `(` $failure_propagation_mode `)`" + "`(` $operands `)` attr-dict `:` functional-type($operands, $results)"; + + let extraClassDeclaration = [{ + ::mlir::CallInterfaceCallable getCallableForCallee() { + return getTarget(); + } + + ::mlir::Operation::operand_range getArgOperands() { + return getOperands(); + } + }]; +} + def MergeHandlesOp : TransformDialectOp<"merge_handles", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -289,6 +336,67 @@ let hasFolder = 1; } +def NamedSequenceOp : TransformDialectOp<"named_sequence", + [CallableOpInterface, + FunctionOpInterface, + IsolatedFromAbove, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Named transform sequence that can be included elsewhere"; + let description = [{ + Defines a named (callable, function-like) sequence of other Transform + dialect operations that can be included using `transform.include` as part of + another Transform dialect construct. This sequence is not processed + immediately but rather dispatched to when the inclusion is processed. The + arguments and results can be used to communicate a subset of mapping into + the named sequence. The sequence must consist of a single block and end with + a `transform.yield` terminator. The operands of the terminator become the + results of the `transform.include`. + + When dispatched to, the operations in the named sequence are executed one by + one, similarly to the regular unnamed sequence. The failure propagation mode + is specified on the `transform.include`. Different inclusions may use + different failure propagation modes. This transform operation always + succeeds by itself, but the inclusion may fail if any of the operations + fail. + + Named sequences can only appear at the top-level of the Transform dialect + nesting structure. That is, they cannot be nested in other Transform dialect + operations. Furthermore, one of the ancestors must have the `SymbolTable` + trait and have the `transform.with_named_sequence` attribute attached. + + Named sequences may include other named sequences via `transform.include`, + but recursion is *not* allowed. + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrBase<"::mlir::FunctionType", + "function type attribute">:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region SizedRegion<1>:$body); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::llvm::ArrayRef<::mlir::Type> getArgumentTypes() { + return getFunctionType().getInputs(); + } + ::llvm::ArrayRef<::mlir::Type> getResultTypes() { + return getFunctionType().getResults(); + } + + ::mlir::Region *getCallableRegion() { + return &getBody(); + } + ::llvm::ArrayRef<::mlir::Type> getCallableResults() { + return getFunctionType().getResults(); + } + }]; +} + def SplitHandlesOp : TransformDialectOp<"split_handles", [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods, @@ -376,7 +484,6 @@ let assemblyFormat = "$target attr-dict (`:` type($target)^)?"; } - def ReplicateOp : TransformDialectOp<"replicate", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -426,21 +533,21 @@ let description = [{ The transformations indicated by the sequence are applied in order of their appearance. Each value produced by a transformation within the sequence - corresponds to an operation or a group of operations in the payload IR. - The behavior of the operation when a nested transformation produces a - silenceable error is controlled by the `failure_propagation_mode` attribute. - When set to `propagate`, the failure of any nested transformation in the - sequence implies immediate failure of the entire sequence with a silenceable - error, and no further transformation is attempted. When set to `suppress`, + corresponds to a group of operations or values in the payload IR, or to a + group of parameters, depending on the type of the value. The behavior of the + operation when a nested transformation produces a silenceable error is + controlled by the `failure_propagation_mode` attribute. When set to + `propagate`, the failure of any nested transformation in the sequence + implies immediate failure of the entire sequence with a silenceable error, + and no further transformation is attempted. When set to `suppress`, silenceable errors in nested operations are ignored and further transformations are applied. Beware that even silenceable errors may leave - the payload IR in a state unsuitable for further transformations. It is - the responsibility of the caller to ensure the following transformations - are robust enough when errors are suppressed. Definite errors reported by - nested transformations abort the sequence regardless of the propagation - mode. The set of modes may be extended in the future, e.g., to collect - silenceable errors and report them after attempting all transformations in - the sequence. + the payload IR in a state unsuitable for further transformations. It is the + responsibility of the caller to ensure the following transformations are + robust enough when errors are suppressed. Definite errors reported by nested + transformations abort the sequence regardless of the propagation mode. The + set of modes may be extended in the future, e.g., to collect silenceable + errors and report them after attempting all transformations in the sequence. The entry block of this operation has a single argument that maps to either the operand if provided or the top-level container operation of the payload @@ -565,7 +672,8 @@ }]; let arguments = (ins - Arg, "Operation handles yielded back to the parent" + Arg, + "Transform values yielded back to the parent" >:$operands); let assemblyFormat = "operands attr-dict (`:` type($operands)^)?"; diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -7,12 +7,14 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Analysis/CallGraph.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/SCCIterator.h" using namespace mlir; @@ -128,4 +130,53 @@ llvm::report_fatal_error(StringRef(buffer)); } +LogicalResult transform::TransformDialect::verifyOperationAttribute( + Operation *op, NamedAttribute attribute) { + if (attribute.getName().getValue() == kWithNamedSequenceAttrName) { + if (!op->hasTrait()) { + return emitError(op->getLoc()) << attribute.getName() + << " attribute can only be attached to " + "operations with symbol tables"; + } + + const mlir::CallGraph callgraph(op); + for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) { + if (!scc.hasCycle()) + continue; + + // Need to check this here additionally because this verification may run + // before we check the nested operations. + if ((*scc->begin())->isExternal()) + return op->emitOpError() << "contains a call to an external operation, " + "which is not allowed"; + + Operation *first = (*scc->begin())->getCallableRegion()->getParentOp(); + InFlightDiagnostic diag = emitError(first->getLoc()) + << "recursion not allowed in named sequences"; + for (auto it = std::next(scc->begin()); it != scc->end(); ++it) { + // Need to check this here additionally because this verification may + // run before we check the nested operations. + if ((*it)->isExternal()) { + return op->emitOpError() << "contains a call to an external " + "operation, which is not allowed"; + } + + Operation *current = (*it)->getCallableRegion()->getParentOp(); + diag.attachNote(current->getLoc()) << "operation on recursion stack"; + } + return diag; + } + return success(); + } + if (attribute.getName().getValue() == kTargetTagAttrName) { + if (!attribute.getValue().isa()) { + return op->emitError() + << attribute.getName() << " attribute must be a string"; + } + return success(); + } + return emitError(op->getLoc()) + << "unknown attribute: " << attribute.getName(); +} + #include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc" diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -104,50 +104,77 @@ return success(found); } -LogicalResult -transform::TransformState::mapBlockArgument(BlockArgument argument, - ArrayRef values) { - if (argument.getType().isa()) { +/// Given a list of MappedValues, cast them to the value kind implied by the +/// interface of the handle type, and dispatch to one of the callbacks. +static DiagnosedSilenceableFailure dispatchMappedValues( + Value handle, ArrayRef values, + function_ref)> operationsFn, + function_ref)> paramsFn, + function_ref valuesFn) { + if (handle.getType().isa()) { SmallVector operations; operations.reserve(values.size()); - for (MappedValue value : values) { + for (transform::MappedValue value : values) { if (auto *op = value.dyn_cast()) { operations.push_back(op); continue; } - return emitError(argument.getLoc()) + return emitSilenceableFailure(handle.getLoc()) << "wrong kind of value provided for top-level operation handle"; } - return setPayloadOps(argument, operations); + if (failed(operationsFn(operations))) + return DiagnosedSilenceableFailure::definiteFailure(); + return DiagnosedSilenceableFailure::success(); } - if (argument.getType().isa()) { + if (handle.getType().isa()) { SmallVector payloadValues; payloadValues.reserve(values.size()); - for (MappedValue value : values) { + for (transform::MappedValue value : values) { if (auto v = value.dyn_cast()) { payloadValues.push_back(v); continue; } - return emitError(argument.getLoc()) + return emitSilenceableFailure(handle.getLoc()) << "wrong kind of value provided for the top-level value handle"; } - return setPayloadValues(argument, payloadValues); + if (failed(valuesFn(payloadValues))) + return DiagnosedSilenceableFailure::definiteFailure(); + return DiagnosedSilenceableFailure::success(); } - assert(argument.getType().isa() && + assert(handle.getType().isa() && "unsupported kind of block argument"); - SmallVector parameters; + SmallVector parameters; parameters.reserve(values.size()); - for (MappedValue value : values) { + for (transform::MappedValue value : values) { if (auto attr = value.dyn_cast()) { parameters.push_back(attr); continue; } - return emitError(argument.getLoc()) + return emitSilenceableFailure(handle.getLoc()) << "wrong kind of value provided for top-level parameter"; } - return setParams(argument, parameters); + if (failed(paramsFn(parameters))) + return DiagnosedSilenceableFailure::definiteFailure(); + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult +transform::TransformState::mapBlockArgument(BlockArgument argument, + ArrayRef values) { + return dispatchMappedValues( + argument, values, + [&](ArrayRef operations) { + return setPayloadOps(argument, operations); + }, + [&](ArrayRef params) { + return setParams(argument, params); + }, + [&](ValueRange payloadValues) { + return setPayloadValues(argument, payloadValues); + }) + .checkAndReport(); } LogicalResult @@ -887,6 +914,27 @@ this->values.replace(position, values); } +void transform::TransformResults::setMappedValues( + OpResult handle, ArrayRef values) { + DiagnosedSilenceableFailure diag = dispatchMappedValues( + handle, values, + [&](ArrayRef operations) { + return set(handle, operations), success(); + }, + [&](ArrayRef params) { + return setParams(handle, params), success(); + }, + [&](ValueRange payloadValues) { + return setValues(handle, payloadValues), success(); + }); +#ifndef NDEBUG + if (!diag.succeeded()) + llvm::dbgs() << diag.getStatusString() << "\n"; + assert(diag.succeeded() && "incorrect mapping"); +#endif // NDEBUG + (void)diag.silence(); +} + ArrayRef transform::TransformResults::get(unsigned resultNumber) const { assert(resultNumber < operations.size() && @@ -1029,24 +1077,30 @@ // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// +void transform::detail::prepareValueMappings( + SmallVectorImpl> &mappings, + ValueRange values, const transform::TransformState &state) { + for (Value operand : values) { + SmallVector &mapped = mappings.emplace_back(); + if (operand.getType().isa()) { + llvm::append_range(mapped, state.getPayloadOps(operand)); + } else if (operand.getType().isa()) { + llvm::append_range(mapped, state.getPayloadValues(operand)); + } else { + assert(operand.getType().isa() && + "unsupported kind of transform dialect value"); + llvm::append_range(mapped, state.getParams(operand)); + } + } +} + LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; SmallVector> extraMappings; if (op->getNumOperands() != 0) { llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); - for (Value operand : op->getOperands().drop_front()) { - SmallVector &mapped = extraMappings.emplace_back(); - if (operand.getType().isa()) { - llvm::append_range(mapped, state.getPayloadOps(operand)); - } else if (operand.getType().isa()) { - llvm::append_range(mapped, state.getPayloadValues(operand)); - } else { - assert(operand.getType().isa() && - "unsupported kind of transform dialect value"); - llvm::append_range(mapped, state.getParams(operand)); - } - } + prepareValueMappings(extraMappings, op->getOperands().drop_front(), state); } else { if (state.getNumTopLevelMappings() != region.front().getNumArguments() - 1) { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" @@ -175,11 +176,19 @@ static void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results) { - for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(), - block->getParentOp()->getOpResults())) { - Value terminatorOperand = std::get<0>(pair); - OpResult result = std::get<1>(pair); - results.set(result, state.getPayloadOps(terminatorOperand)); + for (auto &&[terminatorOperand, result] : + llvm::zip(block->getTerminator()->getOperands(), + block->getParentOp()->getOpResults())) { + if (result.getType().isa()) { + results.set(result, state.getPayloadOps(terminatorOperand)); + } else if (result.getType() + .isa()) { + results.setValues(result, state.getPayloadValues(terminatorOperand)); + } else { + assert(result.getType().isa() && + "unhandled transform type interface"); + results.setParams(result, state.getParams(terminatorOperand)); + } } } @@ -524,6 +533,177 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// IncludeOp +//===----------------------------------------------------------------------===// + +/// Applies the transform ops contained in `block`. Maps `results` to the same +/// values as the operands of the block terminator. +static DiagnosedSilenceableFailure +applySequenceBlock(Block &block, transform::FailurePropagationMode mode, + transform::TransformState &state, + transform::TransformResults &results) { + // Apply the sequenced ops one by one. + for (Operation &transform : block.without_terminator()) { + DiagnosedSilenceableFailure result = + state.applyTransform(cast(transform)); + if (result.isDefiniteFailure()) + return result; + + if (result.isSilenceableFailure()) { + if (mode == transform::FailurePropagationMode::Propagate) { + // Propagate empty results in case of early exit. + forwardEmptyOperands(&block, state, results); + return result; + } + (void)result.silence(); + } + } + + // Forward the operation mapping for values yielded from the sequence to the + // values produced by the sequence op. + forwardTerminatorOperands(&block, state, results); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure +transform::IncludeOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + auto callee = SymbolTable::lookupNearestSymbolFrom( + getOperation(), getTarget()); + assert(callee && "unverified reference to unknown symbol"); + + // Map operands to block arguments. + SmallVector> mappings; + detail::prepareValueMappings(mappings, getOperands(), state); + auto scope = state.make_isolated_region_scope(callee.getBody()); + for (auto &&[arg, map] : + llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) { + if (failed(state.mapBlockArgument(arg, map))) + return DiagnosedSilenceableFailure::definiteFailure(); + } + + DiagnosedSilenceableFailure result = applySequenceBlock( + callee.getBody().front(), getFailurePropagationMode(), state, results); + mappings.clear(); + detail::prepareValueMappings( + mappings, callee.getBody().front().getTerminator()->getOperands(), state); + for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings)) + results.setMappedValues(result, mapping); + return result; +} + +/// Appends to `effects` the memory effect instances on `target` with the same +/// resource and effect as the ones the operation `iface` having on `source`. +static void +remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, + SmallVectorImpl &effects) { + SmallVector nestedEffects; + iface.getEffectsOnValue(source, nestedEffects); + for (const auto &effect : nestedEffects) + effects.emplace_back(effect.getEffect(), target, effect.getResource()); +} + +/// Appends to `effects` the same effects as the operations of `block` have on +/// block arguments but associated with `operands.` +static void +remapArgumentEffects(Block &block, ValueRange operands, + SmallVectorImpl &effects) { + for (Operation &op : block) { + auto iface = dyn_cast(&op); + if (!iface) + continue; + + for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) { + remapEffects(iface, source, target, effects); + } + + SmallVector nestedEffects; + iface.getEffectsOnResource(transform::PayloadIRResource::get(), + nestedEffects); + llvm::append_range(effects, nestedEffects); + } +} + +static DiagnosedSilenceableFailure +verifyNamedSequenceOp(transform::NamedSequenceOp op); + +void transform::IncludeOp::getEffects( + SmallVectorImpl &effects) { + // Bail if the callee is unknown. This may run as part of the verification + // process before we verified the validity of the callee or of this op. + auto target = + getOperation()->getAttrOfType(getTargetAttrName()); + if (!target) + return; + auto callee = SymbolTable::lookupNearestSymbolFrom( + getOperation(), getTarget()); + if (!callee) + return; + DiagnosedSilenceableFailure earlyVerifierResult = + verifyNamedSequenceOp(callee); + if (!earlyVerifierResult.succeeded()) { + (void)earlyVerifierResult.silence(); + return; + } + + // Carry over effects from the callee. + remapArgumentEffects(callee.getBody().front(), getOperands(), effects); + + // Proper effects. + onlyReadsHandle(getOperands(), effects); + producesHandle(getResults(), effects); +} + +template +static bool implementSameInterface(Type t1, Type t2) { + return ((isa(t1) && isa(t2)) || ... || false); +} + +LogicalResult +transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Access through indirection and do additional checking because this may be + // running before the main op verifier. + auto targetAttr = getOperation()->getAttrOfType("target"); + if (!targetAttr) + return emitOpError() << "expects a 'target' symbol reference attribute"; + + auto target = symbolTable.lookupNearestSymbolFrom( + *this, targetAttr); + if (!target) + return emitOpError() << "does not reference a named transform sequence"; + + FunctionType fnType = target.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i).getType() != fnType.getInput(i)) { + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + } + } + + if (fnType.getNumResults() != getNumResults()) + return emitError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + Type resultType = getResult(i).getType(); + Type funcType = fnType.getResult(i); + if (!implementSameInterface(resultType, + funcType)) { + return emitOpError() << "type of result #" << i + << " must implement the same transform dialect " + "interface as the corresponding callee result"; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // MergeHandlesOp //===----------------------------------------------------------------------===// @@ -567,6 +747,105 @@ return getHandles().front(); } +//===----------------------------------------------------------------------===// +// NamedSequenceOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::NamedSequenceOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + // Nothing to do here. + return DiagnosedSilenceableFailure::success(); +} + +void transform::NamedSequenceOp::getEffects( + SmallVectorImpl &effects) {} + +ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser, + OperationState &result) { + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), + [](Builder &builder, ArrayRef inputs, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(inputs, results); }, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void transform::NamedSequenceOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, cast(getOperation()), /*isVariadic=*/false, + getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(), + getResAttrsAttrName()); +} + +/// Verification of a NamedSequenceOp. This does not report the error +/// immediately, so it can be used to check for op's well-formedness before the +/// verifier runs, e.g., during trait verification. +static DiagnosedSilenceableFailure +verifyNamedSequenceOp(transform::NamedSequenceOp op) { + if (op.isExternal()) + return emitSilenceableFailure(op) << "cannot be empty"; + + if (Operation *parent = op->getParentWithTrait()) { + if (!parent->getAttr( + transform::TransformDialect::kWithNamedSequenceAttrName)) { + DiagnosedSilenceableFailure diag = + emitSilenceableFailure(op) + << "expects the parent symbol table to have the '" + << transform::TransformDialect::kWithNamedSequenceAttrName + << "' attribute"; + diag.attachNote(parent->getLoc()) << "symbol table operation"; + return diag; + } + } + + if (auto parent = op->getParentOfType()) { + DiagnosedSilenceableFailure diag = + emitSilenceableFailure(op) + << "cannot be defined inside another transform op"; + diag.attachNote(parent.getLoc()) << "ancestor transform op"; + return diag; + } + + if (op.getBody().front().empty()) + return emitSilenceableFailure(op) << "expected a non-empty body block"; + + Operation *terminator = &op.getBody().front().back(); + if (!isa(terminator)) { + DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) + << "expected '" + << transform::YieldOp::getOperationName() + << "' as terminator"; + diag.attachNote(terminator->getLoc()) << "terminator"; + return diag; + } + + if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) { + return emitSilenceableFailure(terminator) + << "expected terminator to have as many operands as the parent op " + "has results"; + } + for (auto [i, operandType, resultType] : + llvm::zip_equal(llvm::seq(0, terminator->getNumOperands()), + terminator->getOperands().getType(), + op.getFunctionType().getResults())) { + if (operandType == resultType) + continue; + return emitSilenceableFailure(terminator) + << "the type of the terminator operand #" << i + << " must match the type of the corresponding parent op result (" + << operandType << " vs " << resultType << ")"; + } + + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::NamedSequenceOp::verify() { + // Actual verification happens in a separate function for reusability. + return verifyNamedSequenceOp(*this).checkAndReport(); +} + //===----------------------------------------------------------------------===// // SplitHandlesOp //===----------------------------------------------------------------------===// @@ -692,27 +971,8 @@ if (failed(mapBlockArguments(state))) return DiagnosedSilenceableFailure::definiteFailure(); - // Apply the sequenced ops one by one. - for (Operation &transform : getBodyBlock()->without_terminator()) { - DiagnosedSilenceableFailure result = - state.applyTransform(cast(transform)); - if (result.isDefiniteFailure()) - return result; - - if (result.isSilenceableFailure()) { - if (getFailurePropagationMode() == FailurePropagationMode::Propagate) { - // Propagate empty results in case of early exit. - forwardEmptyOperands(getBodyBlock(), state, results); - return result; - } - (void)result.silence(); - } - } - - // Forward the operation mapping for values yielded from the sequence to the - // values produced by the sequence op. - forwardTerminatorOperands(getBodyBlock(), state, results); - return DiagnosedSilenceableFailure::success(); + return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state, + results); } static ParseResult parseSequenceOpOperands( @@ -871,22 +1131,6 @@ return success(); } -/// Appends to `effects` the memory effect instances on `target` with the same -/// resource and effect as the ones the operation `iface` having on `source`. -static void -remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, - SmallVectorImpl &effects) { - SmallVector nestedEffects; - iface.getEffectsOnValue(source, nestedEffects); - for (const auto &effect : nestedEffects) - effects.emplace_back(effect.getEffect(), target, effect.getResource()); -} - -namespace { -template -using has_get_extra_bindings = decltype(std::declval().getExtraBindings()); -} // namespace - /// Populate `effects` with transform dialect memory effects for the potential /// top-level operation. Such operations have recursive effects from nested /// operations. When they have an operand, we can additionally remap effects on @@ -911,26 +1155,8 @@ // Carry over all effects on arguments of the entry block as those on the // operands, this is the same value just remapped. - for (Operation &op : *operation.getBodyBlock()) { - auto iface = dyn_cast(&op); - if (!iface) - continue; - - remapEffects(iface, operation.getBodyBlock()->getArgument(0), - operation.getRoot(), effects); - if constexpr (llvm::is_detected::value) { - for (auto [source, target] : - llvm::zip(operation.getBodyBlock()->getArguments().drop_front(), - operation.getExtraBindings())) { - remapEffects(iface, source, target, effects); - } - } - - SmallVector nestedEffects; - iface.getEffectsOnResource(transform::PayloadIRResource::get(), - nestedEffects); - llvm::append_range(effects, nestedEffects); - } + remapArgumentEffects(*operation.getBodyBlock(), operation->getOperands(), + effects); } void transform::SequenceOp::getEffects( diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -83,6 +83,9 @@ ::mlir::transform::TransformOpInterface topLevelTransform = nullptr; WalkResult walkResult = root->walk( [&](::mlir::transform::TransformOpInterface transformOp) { + if (!transformOp + ->hasTrait()) + return WalkResult::skip(); if (!topLevelTransform) { topLevelTransform = transformOp; return WalkResult::skip(); diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -284,3 +284,184 @@ // expected-note @below {{no 'allocate' effect specified for result #0}} transform.test_required_memory_effects %arg0 {has_operand_effect, modifies_payload} : (!transform.any_op) -> !transform.any_op } + +// ----- + +// expected-error @below {{attribute can only be attached to operations with symbol tables}} +"test.unknown_container"() { transform.with_named_sequence } : () -> () + +// ----- + +module attributes { transform.with_named_sequence } { + // expected-error @below {{failed to verify constraint: region with 1 blocks}} + "transform.named_sequence"() ({}) { sym_name = "external_named_sequence", function_type = () -> () } : () -> () + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + transform.include @external_named_sequence failures(propagate) () : () -> () + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + // expected-error @below {{recursion not allowed in named sequences}} + transform.named_sequence @self_recursion() -> () { + transform.include @self_recursion failures(suppress) () : () -> () + } +} + +// ----- + +module @mutual_recursion attributes { transform.with_named_sequence } { + // expected-note @below {{operation on recursion stack}} + transform.named_sequence @foo(%arg0: !transform.any_op) -> () { + transform.include @bar failures(suppress) (%arg0) : (!transform.any_op) -> () + transform.yield + } + + // expected-error @below {{recursion not allowed in named sequences}} + transform.named_sequence @bar(%arg0: !transform.any_op) -> () { + transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + +// expected-error @below {{unknown attribute: "transform.unknown_container"}} +module @unknown_attribute attributes { transform.unknown_container } {} + +// ----- + +module { + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{op does not reference a named transform sequence}} + transform.include @non_existent failures(propagate) () : () -> () + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{requires attribute 'target'}} + "transform.include"() {failure_propagation_mode = 0} : () -> () + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @foo(%arg0: !transform.any_op) -> () { + transform.yield + } + + transform.sequence failures(suppress) { + ^bb0(%arg1: !transform.any_op): + // expected-error @below {{incorrect number of operands for callee}} + transform.include @foo failures(suppress) () : () -> () + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @foo(%arg0: !transform.any_op) -> () { + transform.yield + } + + transform.sequence failures(suppress) { + ^bb0(%arg1: !transform.op<"builtin.module">): + // expected-error @below {{operand type mismatch: expected operand type '!transform.any_op', but provided '!transform.op<"builtin.module">' for operand number 0}} + transform.include @foo failures(suppress) (%arg1) : (!transform.op<"builtin.module">) -> () + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) { + transform.yield %arg0 : !transform.any_op + } + + transform.sequence failures(suppress) { + ^bb0(%arg1: !transform.any_op): + // expected-error @below {{incorrect number of results for callee}} + transform.include @foo failures(suppress) (%arg1) : (!transform.any_op) -> () + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) { + transform.yield %arg0 : !transform.any_op + } + + transform.sequence failures(suppress) { + ^bb0(%arg1: !transform.any_op): + // expected-error @below {{type of result #0 must implement the same transform dialect interface as the corresponding callee result}} + transform.include @foo failures(suppress) (%arg1) : (!transform.any_op) -> (!transform.any_value) + } +} + +// ----- + +// expected-note @below {{symbol table operation}} +module { + // expected-error @below {{expects the parent symbol table to have the 'transform.with_named_sequence' attribute}} + transform.named_sequence @parent_has_no_attributes() { + transform.yield + } +} + +// ----- + +module attributes { transform.with_named_sequence} { + // expected-note @below {{ancestor transform op}} + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{cannot be defined inside another transform op}} + transform.named_sequence @nested() { + transform.yield + } + } +} + +// ----- + +module attributes { transform.with_named_sequence} { + func.func private @foo() + + // expected-error @below {{expected 'transform.yield' as terminator}} + transform.named_sequence @nested() { + // expected-note @below {{terminator}} + func.call @foo() : () -> () + } +} + + +// ----- + +module attributes { transform.with_named_sequence} { + func.func private @foo() + + transform.named_sequence @nested(%arg0: !transform.any_op) { + // expected-error @below {{expected terminator to have as many operands as the parent op has results}} + transform.yield %arg0 : !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence} { + func.func private @foo() + + transform.named_sequence @nested(%arg0: !transform.any_op) -> !transform.op<"builtin.module"> { + // expected-error @below {{the type of the terminator operand #0 must match the type of the corresponding parent op result}} + transform.yield %arg0 : !transform.any_op + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1255,3 +1255,82 @@ %op = transform.get_defining_op %bbarg : (!transform.any_value) -> !transform.any_op transform.test_print_remark_at_operand %op, "matched" : !transform.any_op } + +// ----- + +module @named_inclusion attributes { transform.with_named_sequence } { + + transform.named_sequence @foo(%arg0: !transform.any_op) -> () { + // expected-remark @below {{applying transformation "a"}} + transform.test_transform_op "a" + transform.yield + } + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () + } +} + +// ----- + +module @named_inclusion_in_named attributes { transform.with_named_sequence } { + + transform.named_sequence @foo(%arg0: !transform.any_op) -> () { + // expected-remark @below {{applying transformation "a"}} + transform.test_transform_op "a" + transform.yield + } + + transform.named_sequence @bar(%arg0: !transform.any_op) -> () { + // expected-remark @below {{applying transformation "b"}} + transform.test_transform_op "b" + transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.yield + } + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + transform.include @bar failures(suppress) (%arg0) : (!transform.any_op) -> () + } +} + +// ----- + +// expected-remark @below {{operation}} +module @named_operands attributes { transform.with_named_sequence } { + + transform.named_sequence @foo(%arg0: !transform.any_op, %arg1: !transform.any_value) -> () { + transform.test_print_remark_at_operand %arg0, "operation" : !transform.any_op + transform.test_print_remark_at_operand_value %arg1, "value" : !transform.any_value + transform.yield + } + + transform.sequence failures(propagate) { + // expected-remark @below {{value}} + // expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}} + ^bb0(%arg0: !transform.any_op): + %0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value + include @foo failures(propagate) (%arg0, %0) : (!transform.any_op, !transform.any_value) -> () + } +} + +// ----- + +// expected-remark @below {{operation}} +module @named_return attributes { transform.with_named_sequence } { + + // expected-remark @below {{value}} + // expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}} + transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op, !transform.any_value) { + %0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value + transform.yield %arg0, %0 : !transform.any_op, !transform.any_value + } + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + %0:2 = include @foo failures(propagate) (%arg0) : (!transform.any_op) -> (!transform.any_op, !transform.any_value) + transform.test_print_remark_at_operand %0#0, "operation" : !transform.any_op + transform.test_print_remark_at_operand_value %0#1, "value" : !transform.any_value + } +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9316,7 +9316,10 @@ ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/Transform/IR/TransformOps.td", - deps = [":TransformDialectTdFiles"], + deps = [ + ":CallInterfacesTdFiles", + ":TransformDialectTdFiles" + ], ) gentbl_cc_library( @@ -9342,6 +9345,7 @@ srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]), hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]), deps = [ + ":CallInterfaces", ":ControlFlowInterfaces", ":IR", ":PDLDialect",