diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md --- a/mlir/docs/Dialects/Transform.md +++ b/mlir/docs/Dialects/Transform.md @@ -109,13 +109,19 @@ programmatically triggered by calling: ```c++ -LogicalResult transform::applyTransforms(Operation *payloadRoot, - TransformOpInterface transform, - const TransformOptions &options); +LogicalResult transform::applyTransforms( + Operation *payloadRoot, + ArrayRef> extraMappings, + TransformOpInterface transform, + const TransformOptions &options); ``` that applies the transformations specified by the top-level `transform` to -payload IR contained in `payloadRoot`. +payload IR contained in `payloadRoot`. The payload root operation will be +associated with the first argument of the entry block of the top-level transform +op. This block may have additional arguments, handles or parameters. They will +be associated with values provided as `extraMappings`. The call will report an +error and return if the wrong number of mappings is provided. ## Dialect Extension Mechanism diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -42,6 +42,9 @@ bool expensiveChecksEnabled = true; }; +using Param = Attribute; +using MappedValue = llvm::PointerUnion; + /// Entry point to the Transform dialect infrastructure. Applies the /// transformation specified by `transform` to payload IR contained in /// `payloadRoot`. The `transform` operation may contain other operations that @@ -50,6 +53,7 @@ /// This function internally keeps track of the transformation state. LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, + ArrayRef> extraMapping = {}, const TransformOptions &options = TransformOptions()); /// The state maintained across applications of various ops implementing the @@ -85,7 +89,7 @@ /// using `mapBlockArguments`. class TransformState { public: - using Param = Attribute; + using Param = transform::Param; private: /// Mapping between a Value in the transform IR and the corresponding set of @@ -109,15 +113,23 @@ ParamMapping params; }; - friend LogicalResult applyTransforms(Operation *payloadRoot, - TransformOpInterface transform, - const TransformOptions &options); + friend LogicalResult applyTransforms(Operation *, TransformOpInterface, + ArrayRef>, + const TransformOptions &); public: /// Returns the op at which the transformation state is rooted. This is /// typically helpful for transformations that apply globally. Operation *getTopLevel() const; + /// Returns the number of extra mappings for the top-level operation. + size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); } + + /// Returns the position-th extra mapping for the top-level operation. + ArrayRef getTopLevelMapping(size_t position) const { + return topLevelMappedValues[position]; + } + /// Returns the list of ops that the given transform IR value corresponds to. /// This is helpful for transformations that apply to a particular handle. ArrayRef getPayloadOps(Value value) const; @@ -150,6 +162,8 @@ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS return setPayloadOps(argument, operations); } + LogicalResult mapBlockArgument(BlockArgument argument, + ArrayRef values); // Forward declarations to support limited visibility. class RegionScope; @@ -302,6 +316,7 @@ /// which may or may not contain the region with transform ops. Additional /// options can be provided through the trailing configuration object. TransformState(Region *region, Operation *payloadRoot, + ArrayRef> extraMappings = {}, const TransformOptions &options = TransformOptions()); /// Returns the mappings frame for the reigon in which the value is defined. @@ -403,6 +418,15 @@ /// The top-level operation that contains all payload IR, typically a module. Operation *topLevel; + /// Storage for extra mapped values (payload operations or parameters) to be + /// associated with additional entry block arguments of the top-level + /// transform operation. Each entry in `topLevelMappedValues` is a reference + /// to a contiguous block in `topLevelMappedValueStorage`. + // TODO: turn this into a proper named data structure, there are several more + // below. + SmallVector> topLevelMappedValues; + SmallVector topLevelMappedValueStorage; + /// Additional options controlling the transformation state behavior. TransformOptions options; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -26,6 +26,9 @@ /// A builder function that populates the body of a SequenceOp. using SequenceBodyBuilderFn = ::llvm::function_ref; +using SequenceBodyBuilderArgsFn = + ::llvm::function_ref; } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -384,7 +384,8 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpAsmOpInterface, PossibleTopLevelTransformOpTrait, - SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">, + AttrSizedOperandSegments]> { let summary = "Contains a sequence of other transform ops to apply"; let description = [{ The transformations indicated by the sequence are applied in order of their @@ -417,12 +418,14 @@ }]; let arguments = (ins FailurePropagationMode:$failure_propagation_mode, - Optional:$root); + Optional:$root, + Variadic:$extra_bindings); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$body); let assemblyFormat = - "($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` " + "custom($root, type($root), $extra_bindings, type($extra_bindings))" + " (`->` type($results)^)? `failures` `(` " "$failure_propagation_mode `)` attr-dict-with-keyword regions"; let builders = [ @@ -432,11 +435,25 @@ "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, "::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>, - // Build a sequence without a root but a certain bbArg type. + // Build a sequence with a root and additional arguments. + OpBuilder<(ins + "::mlir::TypeRange":$resultTypes, + "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, + "::mlir::Value":$root, "::mlir::ValueRange":$extraBindings, + "SequenceBodyBuilderArgsFn":$bodyBuilder)>, + + // Build a top-level sequence (no root). + OpBuilder<(ins + "::mlir::TypeRange":$resultTypes, + "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, + "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>, + + // Build a top-level sequence (no root) with extra arguments. OpBuilder<(ins "::mlir::TypeRange":$resultTypes, "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, - "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)> + "::mlir::Type":$bbArgType, "::mlir::TypeRange":$extraBindingTypes, + "SequenceBodyBuilderArgsFn":$bodyBuilder)> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -27,10 +27,20 @@ constexpr const Value transform::TransformState::kTopLevelValue; -transform::TransformState::TransformState(Region *region, - Operation *payloadRoot, - const TransformOptions &options) +transform::TransformState::TransformState( + Region *region, Operation *payloadRoot, + ArrayRef> extraMappings, + const TransformOptions &options) : topLevel(payloadRoot), options(options) { + topLevelMappedValues.reserve(extraMappings.size()); + for (ArrayRef mapping : extraMappings) { + size_t start = topLevelMappedValueStorage.size(); + llvm::append_range(topLevelMappedValueStorage, mapping); + topLevelMappedValues.push_back( + ArrayRef(topLevelMappedValueStorage) + .slice(start, mapping.size())); + } + auto result = mappings.try_emplace(region); assert(result.second && "the region scope is already present"); (void)result; @@ -72,6 +82,38 @@ return success(found); } +LogicalResult +transform::TransformState::mapBlockArgument(BlockArgument argument, + ArrayRef values) { + if (argument.getType().isa()) { + SmallVector operations; + operations.reserve(values.size()); + for (MappedValue value : values) { + if (auto *op = value.dyn_cast()) { + operations.push_back(op); + continue; + } + return emitError(argument.getLoc()) + << "wrong kind of value provided for top-level operation handle"; + } + return setPayloadOps(argument, operations); + } + + assert(argument.getType().isa() && + "unsupported kind of block argument"); + SmallVector parameters; + parameters.reserve(values.size()); + for (MappedValue value : values) { + if (auto attr = value.dyn_cast()) { + parameters.push_back(attr); + continue; + } + return emitError(argument.getLoc()) + << "wrong kind of value provided for top-level parameter"; + } + return setParams(argument, parameters); +} + LogicalResult transform::TransformState::setPayloadOps(Value value, ArrayRef targets) { @@ -522,12 +564,43 @@ LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; - if (op->getNumOperands() != 0) + SmallVector> extraMappings; + if (op->getNumOperands() != 0) { llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); - else + for (Value operand : op->getOperands().drop_front()) { + SmallVector &mapped = extraMappings.emplace_back(); + if (operand.getType().isa()) { + llvm::append_range(mapped, state.getPayloadOps(operand)); + } else { + assert(operand.getType().isa() && + "unsupported kind of transform dialect value"); + llvm::append_range(mapped, state.getParams(operand)); + } + } + } else { + if (state.getNumTopLevelMappings() != + region.front().getNumArguments() - 1) { + return emitError(op->getLoc()) + << "operation expects " << region.front().getNumArguments() - 1 + << " extra value bindings, but " << state.getNumTopLevelMappings() + << " were provided to the interpreter"; + } + targets.push_back(state.getTopLevel()); + for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i) + extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i))); + } + + if (failed(state.mapBlockArguments(region.front().getArgument(0), targets))) + return failure(); + + for (BlockArgument argument : region.front().getArguments().drop_front()) { + if (failed(state.mapBlockArgument( + argument, extraMappings[argument.getArgNumber() - 1]))) + return failure(); + } - return state.mapBlockArguments(region.front().getArgument(0), targets); + return success(); } LogicalResult @@ -547,19 +620,42 @@ return op->emitOpError() << "expects a single-block region"; Block *body = &bodyRegion->front(); - if (body->getNumArguments() != 1 || - !body->getArgumentTypes()[0].isa()) { + if (body->getNumArguments() == 0) { + return op->emitOpError() + << "expects the entry block to have at least one argument"; + } + if (!body->getArgument(0).getType().isa()) { return op->emitOpError() - << "expects the entry block to have one argument " - "of type implementing TransformHandleTypeInterface"; + << "expects the first entry block argument to be of type " + "implementing TransformHandleTypeInterface"; + } + BlockArgument arg = body->getArgument(0); + if (op->getNumOperands() != 0) { + if (arg.getType() != op->getOperand(0).getType()) { + return op->emitOpError() + << "expects the type of the block argument to match " + "the type of the operand"; + } + } + for (BlockArgument arg : body->getArguments().drop_front()) { + if (arg.getType() + .isa()) + continue; + + InFlightDiagnostic diag = + op->emitOpError() + << "expects trailing entry block arguments to be of type implementing " + "TransformHandleTypeInterface or TransformParamTypeInterface"; + diag.attachNote() << "argument #" << arg.getArgNumber() << " does not"; + return diag; } if (auto *parent = op->getParentWithTrait()) { - if (op->getNumOperands() == 0) { + if (op->getNumOperands() != body->getNumArguments()) { InFlightDiagnostic diag = op->emitOpError() - << "expects the root operation to be provided for a nested op"; + << "expects operands to be provided for a nested op"; diag.attachNote(parent->getLoc()) << "nested in another possible top-level op"; return diag; @@ -717,9 +813,11 @@ // Entry point. //===----------------------------------------------------------------------===// -LogicalResult transform::applyTransforms(Operation *payloadRoot, - TransformOpInterface transform, - const TransformOptions &options) { +LogicalResult +transform::applyTransforms(Operation *payloadRoot, + TransformOpInterface transform, + ArrayRef> extraMapping, + const TransformOptions &options) { #ifndef NDEBUG if (!transform->hasTrait() || transform->getNumOperands() != 0) { @@ -730,7 +828,8 @@ } #endif // NDEBUG - TransformState state(transform->getParentRegion(), payloadRoot, options); + TransformState state(transform->getParentRegion(), payloadRoot, extraMapping, + options); return state.applyTransform(transform).checkAndReport(); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -26,6 +26,16 @@ using namespace mlir; +static ParseResult parseSequenceOpOperands( + OpAsmParser &parser, Optional &root, + Type &rootType, + SmallVectorImpl &extraBindings, + SmallVectorImpl &extraBindingTypes); +static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, + Value root, Type rootType, + ValueRange extraBindings, + TypeRange extraBindingTypes); + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" @@ -654,6 +664,76 @@ return DiagnosedSilenceableFailure::success(); } +static ParseResult parseSequenceOpOperands( + OpAsmParser &parser, Optional &root, + Type &rootType, + SmallVectorImpl &extraBindings, + SmallVectorImpl &extraBindingTypes) { + OpAsmParser::UnresolvedOperand rootOperand; + OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand); + if (!hasRoot.has_value()) { + root = std::nullopt; + return success(); + } + if (failed(hasRoot.value())) + return failure(); + root = rootOperand; + + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseOperandList(extraBindings))) + return failure(); + } + if (failed(parser.parseColon())) + return failure(); + + // The paren is truly optional. + (void)parser.parseOptionalLParen(); + + if (failed(parser.parseType(rootType))) { + return failure(); + } + + if (!extraBindings.empty()) { + if (parser.parseComma() || parser.parseTypeList(extraBindingTypes)) + return failure(); + } + + if (extraBindingTypes.size() != extraBindings.size()) { + return parser.emitError(parser.getNameLoc(), + "expected types to be provided for all operands"); + } + + // The paren is truly optional. + (void)parser.parseOptionalRParen(); + return success(); +} + +static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, + Value root, Type rootType, + ValueRange extraBindings, + TypeRange extraBindingTypes) { + if (!root) + return; + + printer << root; + bool hasExtras = !extraBindings.empty(); + if (hasExtras) { + printer << ", "; + printer.printOperands(extraBindings); + } + + printer << " : "; + if (hasExtras) + printer << "("; + + printer << rootType; + if (hasExtras) { + printer << ", "; + llvm::interleaveComma(extraBindingTypes, printer.getStream()); + printer << ")"; + } +} + /// Returns `true` if the given op operand may be consuming the handle value in /// the Transform IR. That is, if it may have a Free effect on it. static bool isValueUsePotentialConsumer(OpOperand &use) { @@ -691,22 +771,22 @@ } LogicalResult transform::SequenceOp::verify() { - assert(getBodyBlock()->getNumArguments() == 1 && - "the number of arguments must have been verified to be 1 by " + assert(getBodyBlock()->getNumArguments() >= 1 && + "the number of arguments must have been verified to be more than 1 by " "PossibleTopLevelTransformOpTrait"); - BlockArgument arg = getBodyBlock()->getArgument(0); - if (getRoot()) { - if (arg.getType() != getRoot().getType()) { - return emitOpError() << "expects the type of the block argument to match " - "the type of the operand"; - } + if (!getRoot() && !getExtraBindings().empty()) { + return emitOpError() + << "does not expect extra operands when used as top-level"; } - // Check if the block argument has more than one consuming use. - if (failed(checkDoubleConsume( - arg, [this]() { return (emitOpError() << "block argument #0"); }))) { - return failure(); + // Check if a block argument has more than one consuming use. + for (BlockArgument arg : getBodyBlock()->getArguments()) { + if (failed(checkDoubleConsume(arg, [this, arg]() { + return (emitOpError() << "block argument #" << arg.getArgNumber()); + }))) { + return failure(); + } } // Check properties of the nested operations they cannot check themselves. @@ -740,26 +820,26 @@ return success(); } +/// Appends to `effects` the memory effect instances on `target` with the same +/// resource and effect as the ones the operation `iface` having on `source`. +static void +remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, + SmallVectorImpl &effects) { + SmallVector nestedEffects; + iface.getEffectsOnValue(source, nestedEffects); + for (const auto &effect : nestedEffects) + effects.emplace_back(effect.getEffect(), target, effect.getResource()); +} + void transform::SequenceOp::getEffects( SmallVectorImpl &effects) { - auto *mappingResource = TransformMappingResource::get(); - effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); - - for (Value result : getResults()) { - effects.emplace_back(MemoryEffects::Allocate::get(), result, - mappingResource); - effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); - } + onlyReadsHandle(getRoot(), effects); + onlyReadsHandle(getExtraBindings(), effects); + producesHandle(getResults(), effects); if (!getRoot()) { for (Operation &op : *getBodyBlock()) { - auto iface = dyn_cast(&op); - if (!iface) { - // TODO: fill all possible effects; or require ops to actually implement - // the memory effect interface always - assert(false); - } - + auto iface = cast(&op); SmallVector nestedEffects; iface.getEffects(effects); } @@ -769,24 +849,20 @@ // Carry over all effects on the argument of the entry block as those on the // operand, this is the same value just remapped. for (Operation &op : *getBodyBlock()) { - auto iface = dyn_cast(&op); - if (!iface) { - // TODO: fill all possible effects; or require ops to actually implement - // the memory effect interface always - assert(false); - } + auto iface = cast(&op); - SmallVector nestedEffects; - iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); - for (const auto &effect : nestedEffects) - effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); + remapEffects(iface, getBodyBlock()->getArgument(0), getRoot(), effects); + for (auto [source, target] : llvm::zip( + getBodyBlock()->getArguments().drop_front(), getExtraBindings())) { + remapEffects(iface, source, target, effects); + } } } OperandRange transform::SequenceOp::getSuccessorEntryOperands( std::optional index) { assert(index && *index == 0 && "unexpected region index"); - if (getOperation()->getNumOperands() == 1) + if (getOperation()->getNumOperands() > 0) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); @@ -813,21 +889,51 @@ bounds.emplace_back(1, 1); } +template +static void buildSequenceBody(OpBuilder &builder, OperationState &state, + Type bbArgType, TypeRange extraBindingTypes, + FnTy bodyBuilder) { + SmallVector types; + types.reserve(1 + extraBindingTypes.size()); + types.push_back(bbArgType); + llvm::append_range(types, extraBindingTypes); + + OpBuilder::InsertionGuard guard(builder); + Region *region = state.regions.back().get(); + Block *bodyBlock = builder.createBlock(region, region->begin(), + extraBindingTypes, {state.location}); + + // Populate body. + builder.setInsertionPointToStart(bodyBlock); + if constexpr (llvm::function_traits::num_args == 3) { + bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); + } else { + bodyBuilder(builder, state.location, bodyBlock->getArgument(0), + bodyBlock->getArguments().drop_front()); + } +} + void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, TypeRange resultTypes, FailurePropagationMode failurePropagationMode, Value root, SequenceBodyBuilderFn bodyBuilder) { - build(builder, state, resultTypes, failurePropagationMode, root); - Region *region = state.regions.back().get(); + build(builder, state, resultTypes, failurePropagationMode, root, + /*extraBindings=*/ValueRange()); Type bbArgType = root.getType(); - OpBuilder::InsertionGuard guard(builder); - Block *bodyBlock = builder.createBlock( - region, region->begin(), TypeRange{bbArgType}, {state.location}); + buildSequenceBody(builder, state, bbArgType, + /*extraBindingTypes=*/TypeRange(), bodyBuilder); +} - // Populate body. - builder.setInsertionPointToStart(bodyBlock); - bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); +void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, + TypeRange resultTypes, + FailurePropagationMode failurePropagationMode, + Value root, ValueRange extraBindings, + SequenceBodyBuilderArgsFn bodyBuilder) { + build(builder, state, resultTypes, failurePropagationMode, root, + extraBindings); + buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(), + bodyBuilder); } void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, @@ -835,15 +941,20 @@ FailurePropagationMode failurePropagationMode, Type bbArgType, SequenceBodyBuilderFn bodyBuilder) { - build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value()); - Region *region = state.regions.back().get(); - OpBuilder::InsertionGuard guard(builder); - Block *bodyBlock = builder.createBlock( - region, region->begin(), TypeRange{bbArgType}, {state.location}); + build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), + /*extraBindings=*/ValueRange()); + buildSequenceBody(builder, state, bbArgType, + /*extraBindingTypes=*/TypeRange(), bodyBuilder); +} - // Populate body. - builder.setInsertionPointToStart(bodyBlock); - bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); +void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, + TypeRange resultTypes, + FailurePropagationMode failurePropagationMode, + Type bbArgType, TypeRange extraBindingTypes, + SequenceBodyBuilderArgsFn bodyBuilder) { + build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), + /*extraBindings=*/ValueRange()); + buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder); } //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -89,7 +89,9 @@ class SequenceOp: def __init__(self, failure_propagation_mode, results: Sequence[Type], - target: Union[Operation, Value, Type]): + target: Union[Operation, Value, Type], + extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], + Operation, OpView]] = None): root = _get_op_result_or_value(target) if isinstance( target, (Operation, Value)) else None root_type = root.type if not isinstance(target, Type) else target @@ -98,10 +100,25 @@ IntegerType.get_signless(32), failure_propagation_mode._as_int()) else: failure_propagation_mode = failure_propagation_mode + + if extra_bindings is None: + extra_bindings = [] + if isinstance(extra_bindings, (Operation, OpView)): + extra_bindings = _get_op_results_or_values(extra_bindings) + + extra_binding_types = [] + if len(extra_bindings) != 0: + if isinstance(extra_bindings[0], Type): + extra_binding_types = extra_bindings + extra_bindings = [] + else: + extra_binding_types = [v.type for v in extra_bindings] + super().__init__(results_=results, failure_propagation_mode=failure_propagation_mode_attr, - root=root) - self.regions[0].blocks.append(root_type) + root=root, + extra_bindings=extra_bindings) + self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) @property def body(self) -> Block: @@ -111,6 +128,10 @@ def bodyTarget(self) -> Value: return self.body.arguments[0] + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] + class WithPDLPatternsOp: diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \ +// RUN: --split-input-file --verify-diagnostics + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op): + transform.test_print_remark_at_operand %arg1, "first extra" : !transform.any_op + transform.test_print_remark_at_operand %arg2, "second extra" : !transform.any_op +} + +// expected-remark @below {{first extra}} +func.func @foo() { + // expected-remark @below {{second extra}} + return +} + +// expected-remark @below {{first extra}} +func.func @bar(%arg0: i1) { + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + // expected-remark @below {{second extra}} + return +^bb2: + // expected-remark @below {{second extra}} + return +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param): + // expected-error @above {{wrong kind of value provided for top-level parameter}} +} + +func.func @foo() { + return +} + +// ----- + +// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}} +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op): +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op): + transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) { + ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op): + transform.test_print_remark_at_operand %arg4, "first extra" : !transform.any_op + transform.test_print_remark_at_operand %arg5, "second extra" : !transform.any_op + } +} + +// expected-remark @below {{first extra}} +func.func @foo() { + // expected-remark @below {{second extra}} + return +} + +// expected-remark @below {{first extra}} +func.func @bar(%arg0: i1) { + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + // expected-remark @below {{second extra}} + return +^bb2: + // expected-remark @below {{second extra}} + return +} diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \ +// RUN: --split-input-file --verify-diagnostics + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation, %arg1: !transform.param, %arg2: !transform.param): + // expected-remark @below {{1 : i64, 2 : i64, 3 : i64}} + transform.test_print_param %arg1 : !transform.param + // expected-remark @below {{42 : i64, 45 : i64}} + transform.test_print_param %arg2 : !transform.param +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation, %arg1: !transform.any_op, %arg2: !transform.param): + // expected-error @above {{wrong kind of value provided for top-level operation handle}} +} + +// ----- + +// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}} +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation, %arg1: !transform.param, %arg2: !transform.param, %arg3: !transform.param): +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -1,15 +1,22 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics -// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}} +// expected-error @below {{expects the entry block to have at least one argument}} transform.sequence failures(propagate) { } // ----- +// expected-error @below {{expects the first entry block argument to be of type implementing TransformHandleTypeInterface}} +transform.sequence failures(propagate) { +^bb0(%rag0: i64): +} + +// ----- + // expected-note @below {{nested in another possible top-level op}} transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - // expected-error @below {{expects the root operation to be provided for a nested op}} + // expected-error @below {{expects operands to be provided for a nested op}} transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): } @@ -17,6 +24,14 @@ // ----- +// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface or TransformParamTypeInterface}} +// expected-note @below {{argument #1 does not}} +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: i64): +} + +// ----- + // expected-error @below {{expected children ops to implement TransformOpInterface}} transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): @@ -46,10 +61,29 @@ // ----- +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op): + // expected-error @below {{expected types to be provided for all operands}} + transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op) failures(propagate) { + ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op): + } +} + +// ----- + +%0 = "test.generate_something"() : () -> !transform.any_op +// expected-error @below {{does not expect extra operands when used as top-level}} +"transform.sequence"(%0) ({ +^bb0(%arg0: !transform.any_op): + "transform.yield"() : () -> () +}) {failure_propagation_mode = 1 : i32, operand_segment_sizes = array} : (!transform.any_op) -> () + +// ----- + // expected-note @below {{nested in another possible top-level op}} transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - // expected-error @below {{expects the root operation to be provided for a nested op}} + // expected-error @below {{expects operands to be provided for a nested op}} transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): } @@ -190,7 +224,7 @@ // ----- -// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}} +// expected-error @below {{expects the entry block to have at least one argument}} transform.alternatives { ^bb0: transform.yield diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -50,6 +50,33 @@ } } +// CHECK: transform.sequence failures(propagate) +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op): + // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) + transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) { + ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op): + } +} + +// CHECK: transform.sequence failures(propagate) +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op): + // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) + transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) { + ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op): + } +} + +// CHECK: transform.sequence failures(propagate) +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op): + // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) + transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) { + ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op): + } +} + // CHECK: transform.sequence // CHECK: foreach transform.sequence failures(propagate) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -39,12 +40,72 @@ return "apply transform dialect operations one by one"; } + ArrayRef + findOperationsByName(Operation *root, StringRef name, + SmallVectorImpl &storage) { + size_t start = storage.size(); + root->walk([&](Operation *op) { + if (op->getName().getStringRef() == name) { + storage.push_back(op); + } + }); + return ArrayRef(storage).drop_front(start); + } + + ArrayRef + createParameterMapping(MLIRContext &context, ArrayRef values, + SmallVectorImpl &storage) { + size_t start = storage.size(); + llvm::append_range(storage, llvm::map_range(values, [&](int v) { + Builder b(&context); + return transform::MappedValue(b.getI64IntegerAttr(v)); + })); + return ArrayRef(storage).drop_front(start); + } + void runOnOperation() override { + if (!bindFirstExtraToOps.empty() && !bindFirstExtraToParams.empty()) { + emitError(UnknownLoc::get(&getContext())) + << "cannot bind the first extra top-level argument to both " + "operations and parameters"; + return signalPassFailure(); + } + if (!bindSecondExtraToOps.empty() && !bindSecondExtraToParams.empty()) { + emitError(UnknownLoc::get(&getContext())) + << "cannot bind the second extra top-level argument to both " + "operations and parameters"; + return signalPassFailure(); + } + if ((!bindSecondExtraToOps.empty() || !bindSecondExtraToParams.empty()) && + bindFirstExtraToOps.empty() && bindFirstExtraToParams.empty()) { + emitError(UnknownLoc::get(&getContext())) + << "cannot bind the second extra top-level argument without binding " + "the first"; + return signalPassFailure(); + } + + SmallVector extraMappingStorage; + SmallVector> extraMapping; + if (!bindFirstExtraToOps.empty()) { + extraMapping.push_back(findOperationsByName( + getOperation(), bindFirstExtraToOps.getValue(), extraMappingStorage)); + } else if (!bindFirstExtraToParams.empty()) { + extraMapping.push_back(createParameterMapping( + getContext(), bindFirstExtraToParams, extraMappingStorage)); + } + if (!bindSecondExtraToOps.empty()) { + extraMapping.push_back(findOperationsByName( + getOperation(), bindSecondExtraToOps, extraMappingStorage)); + } else if (!bindSecondExtraToParams.empty()) { + extraMapping.push_back(createParameterMapping( + getContext(), bindSecondExtraToParams, extraMappingStorage)); + } + ModuleOp module = getOperation(); for (auto op : module.getBody()->getOps()) { if (failed(transform::applyTransforms( - module, op, + module, op, extraMapping, transform::TransformOptions().enableExpensiveChecks( enableExpensiveChecks)))) return signalPassFailure(); @@ -55,6 +116,24 @@ *this, "enable-expensive-checks", llvm::cl::init(false), llvm::cl::desc("perform expensive checks to better report errors in the " "transform IR")}; + + Option bindFirstExtraToOps{ + *this, "bind-first-extra-to-ops", + llvm::cl::desc("bind the first extra argument of the top-level op to " + "payload operations of the given kind")}; + ListOption bindFirstExtraToParams{ + *this, "bind-first-extra-to-params", + llvm::cl::desc("bind the first extra argument of the top-level op to " + "the given integer parameters")}; + + Option bindSecondExtraToOps{ + *this, "bind-second-extra-to-ops", + llvm::cl::desc("bind the second extra argument of the top-level op to " + "payload operations of the given kind")}; + ListOption bindSecondExtraToParams{ + *this, "bind-second-extra-to-params", + llvm::cl::desc("bind the second extra argument of the top-level op to " + "the given integer parameters")}; }; struct TestTransformDialectEraseSchedulePass diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -69,6 +69,38 @@ # CHECK: } +@run +def testSequenceOpWithExtras(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(), + [transform.AnyOpType.get(), + transform.OperationType.get("foo.bar")]) + with InsertionPoint(sequence.body): + transform.YieldOp() + # CHECK-LABEL: TEST: testSequenceOpWithExtras + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): + + +@run +def testNestedSequenceOpWithExtras(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(), + [transform.AnyOpType.get(), + transform.OperationType.get("foo.bar")]) + with InsertionPoint(sequence.body): + nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, + [], sequence.bodyTarget, + sequence.bodyExtraArgs) + with InsertionPoint(nested.body): + transform.YieldOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): + # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) + + @run def testTransformPDLOps(): withPdl = transform.WithPDLPatternsOp(pdl.OperationType.get())