diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt @@ -28,6 +28,10 @@ add_mlir_interface(TransformInterfaces) add_mlir_doc(TransformInterfaces TransformOpInterfaces Dialects/ -gen-op-interface-docs) +add_mlir_interface(MatchInterfaces) +add_dependencies(MLIRMatchInterfacesIncGen MLIRTransformInterfacesIncGen) +add_mlir_doc(TransformInterfaces MatchOpInterfaces Dialects/ -gen-op-interface-docs) + set(LLVM_TARGET_DEFINITIONS TransformInterfaces.td) mlir_tablegen(TransformTypeInterfaces.h.inc -gen-type-interface-decls) mlir_tablegen(TransformTypeInterfaces.cpp.inc -gen-type-interface-defs) diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -0,0 +1,73 @@ +//===- MatchInterfaces.h - Transform Dialect Interfaces ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H +#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace transform { +class MatchOpInterface; + +template +class SingleOpMatcherOpTrait + : public OpTrait::TraitBase { + template + using has_get_operand_handle = + decltype(std::declval().getOperandHandle()); + template + using has_match_operation = decltype(std::declval().matchOperation( + std::declval(), std::declval(), + std::declval())); + +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert(llvm::is_detected::value, + "SingleOpMatcherOpTrait expects operation type to have the " + "getOperandHandle() method"); + static_assert(llvm::is_detected::value, + "SingleOpMatcherOpTrait expected operation type to have the " + "matchOperation(Operation *, TransformResults &, " + "TransformState &) method"); + + // This must be a dynamic assert because interface registration is dynamic. + assert(isa(op) && + "SingleOpMatchOpTrait is only available on operations with " + "MatchOpInterface"); + Value operandHandle = cast(op).getOperandHandle(); + if (!operandHandle.getType().isa()) { + return op->emitError() << "SingleOpMatchOpTrait requires the op handle " + "to be of TransformHandleTypeInterface"; + } + + return success(); + } + + DiagnosedSilenceableFailure apply(TransformResults &results, + TransformState &state) { + Value operandHandle = cast(this->getOperation()).getOperandHandle(); + ArrayRef payload = state.getPayloadOps(operandHandle); + if (payload.size() != 1) { + return emitDefiniteFailure(this->getOperation()->getLoc()) + << "SingleOpMatchOpTrait requires the operand handle to point to " + "a single payload op"; + } + + return cast(this->getOperation()) + .matchOperation(payload[0], results, state); + } +}; + +} // namespace transform +} // namespace mlir + +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h.inc" + +#endif // MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td @@ -0,0 +1,26 @@ +//===- MatchInterfaces.td - Transform dialect interfaces ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" + +def MatchOpInterface + : OpInterface<"MatchOpInterface", [TransformOpInterface]> { + let cppNamespace = "::mlir::transform"; +} + +def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> { + let cppNamespace = "::mlir::transform"; + + string extraDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure matchOperation( + ::mlir::Operation *current, + ::mlir::transform::TransformResults &results, + ::mlir::transform::TransformState &state); + }]; +} diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -17,9 +17,11 @@ #include "mlir/Support/LogicalResult.h" namespace mlir { + namespace transform { class TransformOpInterface; +class TransformResults; /// Options controlling the application of transform operations by the /// TransformState. @@ -400,6 +402,11 @@ return it->second; } + /// Updates the state to include the associations between op results and the + /// provided result of applying a transform op. + LogicalResult updateStateFromResults(const TransformResults &results, + ResultRange opResults); + /// Sets the payload IR ops associated with the given transform IR value /// (handle). A payload op may be associated multiple handles as long as /// at most one of them gets consumed by further transformations. @@ -690,6 +697,11 @@ void prepareValueMappings( SmallVectorImpl> &mappings, ValueRange values, const transform::TransformState &state); + +/// Populates `results` with payload associations that match exactly those of +/// the operands to `block`'s terminator. +void forwardTerminatorOperands(Block *block, transform::TransformState &state, + transform::TransformResults &results); } // namespace detail /// This trait is supposed to be attached to Transform dialect operations that diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/FunctionInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -17,6 +17,7 @@ include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" +include "mlir/Dialect/Transform/IR/MatchInterfaces.td" include "mlir/Dialect/Transform/IR/TransformAttrs.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" @@ -116,6 +117,69 @@ }]; } +def ForeachMatchOp : TransformDialectOp<"foreach_match", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Applies named sequences when a named matcher succeeds"; + let description = [{ + Given a pair of co-indexed lists of transform dialect symbols (such as + `transform.named_sequence`), walks the payload IR associated with the root + handle and interprets the symbols as matcher/action pairs by applying the + body of the corresponding symbol definition. The symbol from the first list + is the matcher part: if it results in a silenceable error, the error is + silenced and the next matcher is attempted. Definite failures from any + matcher stop the application immediately and are propagated unconditionally. + If none of the matchers succeeds, the next payload operation in walk order + (post-order at the moment of writing, double check `Operation::walk`) is + matched. If a matcher succeeds, the co-indexed action symbol is applied and + the following matchers are not applied to the same payload operation. If the + action succeeds, the next payload operation in walk order is matched. If it + fails, both silenceable and definite errors are propagated as the result of + this op. + + The matcher symbol must take one operand of a type that implements the same + transform dialect interface as the `root` operand (a check is performed at + application time to see if the associated payload satisfies the constraints + of the actual type). It must not consume the operand as multiple matchers + may be applied. The matcher may produce any number of results. The action + symbol paired with the matcher must take the same number of arguments as the + matcher has results, and these arguments must implement the same transform + dialect interfaces, but not necessarily have the exact same type (again, a + check is performed at application time to see if the associated payload + satisfies the constraints of actual types on both sides). The action symbol + may not have results. The actions are expected to only modify payload + operations nested in the `root` payload operations associated with the + operand of this transform operation. + + This operation consumes the operand and produces a new handle associated + with the same payload. This is necessary to trigger invalidation of handles + to any of the payload operations nested in the payload operations associated + with the operand, as those are likely to be modified by actions. Note that + the root payload operation associated with the operand are not matched. + + The operation succeeds if none of the matchers produced a definite failure + during application and if all of the applied actions produced success. Note + that it also succeeds if all the matchers failed on all payload operations, + i.e. failure to apply is not an error. The operation produces a silenceable + failure if any applied action produced a silenceable failure. In this case, + the resulting handle is associated with an empty payload. The operation + produces a definite failure if any of the applied matchers or actions + produced a definite failure. + }]; + + let arguments = (ins TransformHandleTypeInterface:$root, + SymbolRefArrayAttr:$matchers, + SymbolRefArrayAttr:$actions); + let results = (outs TransformHandleTypeInterface:$updated); + + let assemblyFormat = + "`in` $root custom($matchers, $actions) " + "attr-dict `:` functional-type($root, $updated)"; + + let hasVerifier = 1; +} + def ForeachOp : TransformDialectOp<"foreach", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -270,6 +334,7 @@ def IncludeOp : TransformDialectOp<"include", [CallOpInterface, + MatchOpInterface, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -1,10 +1,12 @@ add_mlir_dialect_library(MLIRTransformDialect + MatchInterfaces.cpp TransformDialect.cpp TransformInterfaces.cpp TransformOps.cpp TransformTypes.cpp DEPENDS + MLIRMatchInterfacesIncGen MLIRTransformDialectIncGen MLIRTransformInterfacesIncGen diff --git a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp @@ -0,0 +1,17 @@ +//===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Generated interface implementation. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -920,40 +920,44 @@ } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - for (OpResult result : transform->getResults()) { - assert(result.getDefiningOp() == transform.getOperation() && - "payload IR association for a value other than the result of the " - "current transform op"); + if (failed(updateStateFromResults(results, transform->getResults()))) + return DiagnosedSilenceableFailure::definiteFailure(); + + printOnFailureRAII.release(); + DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { + DBGS() << "Top-level payload:\n"; + getTopLevel()->print(llvm::dbgs()); + }); + return result; +} + +LogicalResult transform::TransformState::updateStateFromResults( + const TransformResults &results, ResultRange opResults) { + for (OpResult result : opResults) { if (result.getType().isa()) { assert(results.isParam(result.getResultNumber()) && "expected parameters for the parameter-typed result"); if (failed( setParams(result, results.getParams(result.getResultNumber())))) { - return DiagnosedSilenceableFailure::definiteFailure(); + return failure(); } } else if (result.getType().isa()) { assert(results.isValue(result.getResultNumber()) && "expected values for value-type-result"); if (failed(setPayloadValues( result, results.getValues(result.getResultNumber())))) { - return DiagnosedSilenceableFailure::definiteFailure(); + return failure(); } } else { assert(!results.isParam(result.getResultNumber()) && "expected payload ops for the non-parameter typed result"); if (failed( setPayloadOps(result, results.get(result.getResultNumber())))) { - return DiagnosedSilenceableFailure::definiteFailure(); + return failure(); } } } - - printOnFailureRAII.release(); - DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { - DBGS() << "Top-level payload:\n"; - getTopLevel()->print(llvm::dbgs()); - }); - return result; + return success(); } //===----------------------------------------------------------------------===// @@ -1193,7 +1197,7 @@ } //===----------------------------------------------------------------------===// -// Utilities for PossibleTopLevelTransformOpTrait. +// Utilities for implementing transform ops with regions. //===----------------------------------------------------------------------===// void transform::detail::prepareValueMappings( @@ -1213,6 +1217,29 @@ } } +void transform::detail::forwardTerminatorOperands( + Block *block, transform::TransformState &state, + transform::TransformResults &results) { + for (auto &&[terminatorOperand, result] : + llvm::zip(block->getTerminator()->getOperands(), + block->getParentOp()->getOpResults())) { + if (result.getType().isa()) { + results.set(result, state.getPayloadOps(terminatorOperand)); + } else if (result.getType() + .isa()) { + results.setValues(result, state.getPayloadValues(terminatorOperand)); + } else { + assert(result.getType().isa() && + "unhandled transform type interface"); + results.setParams(result, state.getParams(terminatorOperand)); + } + } +} + +//===----------------------------------------------------------------------===// +// Utilities for PossibleTopLevelTransformOpTrait. +//===----------------------------------------------------------------------===// + LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -8,9 +8,11 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -18,12 +20,17 @@ #include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "transform-dialect" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") +#define DEBUG_TYPE_MATCHER "transform-matcher" +#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") +#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) + using namespace mlir; static ParseResult parseSequenceOpOperands( @@ -35,6 +42,11 @@ Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes); +static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, + ArrayAttr matchers, ArrayAttr actions); +static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, + ArrayAttr &matchers, + ArrayAttr &actions); #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" @@ -300,25 +312,6 @@ results.set(res, {}); } -static void forwardTerminatorOperands(Block *block, - transform::TransformState &state, - transform::TransformResults &results) { - for (auto &&[terminatorOperand, result] : - llvm::zip(block->getTerminator()->getOperands(), - block->getParentOp()->getOpResults())) { - if (result.getType().isa()) { - results.set(result, state.getPayloadOps(terminatorOperand)); - } else if (result.getType() - .isa()) { - results.setValues(result, state.getPayloadValues(terminatorOperand)); - } else { - assert(result.getType().isa() && - "unhandled transform type interface"); - results.setParams(result, state.getParams(terminatorOperand)); - } - } -} - DiagnosedSilenceableFailure transform::AlternativesOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -388,7 +381,7 @@ clone); rewriter.replaceOp(original, clone->getResults()); } - forwardTerminatorOperands(®.front(), state, results); + detail::forwardTerminatorOperands(®.front(), state, results); return DiagnosedSilenceableFailure::success(); } } @@ -451,6 +444,339 @@ }); } +//===----------------------------------------------------------------------===// +// ForeachMatchOp +//===----------------------------------------------------------------------===// + +/// Applies matcher operations from the given `block` assigning `op` as the +/// payload of the block's first argument. Updates `state` accordingly. If any +/// of the matcher produces a silenceable failure, discards it (printing the +/// content to the debug output stream) and returns failure. If any of the +/// matchers produces a definite failure, reports it and returns failure. If all +/// matchers in the block succeed, populates `mappings` with the payload +/// entities associated with the block terminator operands. +static DiagnosedSilenceableFailure +matchBlock(Block &block, Operation *op, transform::TransformState &state, + SmallVectorImpl> &mappings) { + assert(block.getParent() && "cannot match using a detached block"); + auto matchScope = state.make_isolated_region_scope(*block.getParent()); + if (failed(state.mapBlockArgument(block.getArgument(0), {op}))) + return DiagnosedSilenceableFailure::definiteFailure(); + + for (Operation &match : block.without_terminator()) { + if (!isa(match)) { + return emitDefiniteFailure(match.getLoc()) + << "expected operations in the match part to " + "implement MatchOpInterface"; + } + DiagnosedSilenceableFailure diag = + state.applyTransform(cast(match)); + if (diag.succeeded()) + continue; + + return diag; + } + + // Remember the values mapped to the terminator operands so we can + // forward them to the action. + ValueRange yieldedValues = block.getTerminator()->getOperands(); + transform::detail::prepareValueMappings(mappings, yieldedValues, state); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure +transform::ForeachMatchOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector> + matchActionPairs; + matchActionPairs.reserve(getMatchers().size()); + SymbolTableCollection symbolTable; + for (auto &&[matcher, action] : + llvm::zip_equal(getMatchers(), getActions())) { + auto matcherSymbol = + symbolTable.lookupNearestSymbolFrom( + getOperation(), cast(matcher)); + auto actionSymbol = + symbolTable.lookupNearestSymbolFrom( + getOperation(), cast(action)); + assert(matcherSymbol && actionSymbol && + "unresolved symbols not caught by the verifier"); + + if (matcherSymbol.isExternal()) + return emitDefiniteFailure() << "unresolved external symbol " << matcher; + if (actionSymbol.isExternal()) + return emitDefiniteFailure() << "unresolved external symbol " << action; + + matchActionPairs.emplace_back(matcherSymbol, actionSymbol); + } + + for (Operation *root : state.getPayloadOps(getRoot())) { + WalkResult walkResult = root->walk([&](Operation *op) { + // Skip over the root op itself so we don't invalidate it. + if (op == root) + return WalkResult::advance(); + + DEBUG_MATCHER({ + DBGS_MATCHER() << "matching "; + op->print(llvm::dbgs(), + OpPrintingFlags().assumeVerified().skipRegions()); + llvm::dbgs() << " @" << op << "\n"; + }); + + // Try all the match/action pairs until the first successful match. + for (auto [matcher, action] : matchActionPairs) { + SmallVector> mappings; + DiagnosedSilenceableFailure diag = + matchBlock(matcher.getFunctionBody().front(), op, state, mappings); + if (diag.isDefiniteFailure()) + return WalkResult::interrupt(); + if (diag.isSilenceableFailure()) { + DEBUG_MATCHER(DBGS_MATCHER() + << "matcher " << matcher.getName() << " failed\n"); + continue; + } + + auto scope = state.make_isolated_region_scope(action.getFunctionBody()); + for (auto &&[arg, map] : llvm::zip_equal( + action.getFunctionBody().front().getArguments(), mappings)) { + if (failed(state.mapBlockArgument(arg, map))) + return WalkResult::interrupt(); + } + + for (Operation &transform : + action.getFunctionBody().front().without_terminator()) { + DiagnosedSilenceableFailure result = + state.applyTransform(cast(transform)); + if (failed(result.checkAndReport())) + return WalkResult::interrupt(); + } + break; + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return DiagnosedSilenceableFailure::definiteFailure(); + } + + // The root operation should not have been affected, so we can just reassign + // the payload to the result. Note that we need to consume the root handle to + // make sure any handles to operations inside, that could have been affected + // by actions, are invalidated. + results.set(getUpdated().cast(), state.getPayloadOps(getRoot())); + return DiagnosedSilenceableFailure::success(); +} + +void transform::ForeachMatchOp::getEffects( + SmallVectorImpl &effects) { + // Bail if invalid. + if (getOperation()->getNumOperands() < 1 || + getOperation()->getNumResults() < 1) { + return modifiesPayload(effects); + } + + consumesHandle(getRoot(), effects); + producesHandle(getUpdated(), effects); + modifiesPayload(effects); +} + +/// Parses the comma-separated list of symbol reference pairs of the format +/// `@matcher -> @action`. +static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, + ArrayAttr &matchers, + ArrayAttr &actions) { + StringAttr matcher; + StringAttr action; + SmallVector matcherList; + SmallVector actionList; + do { + if (parser.parseSymbolName(matcher) || parser.parseArrow() || + parser.parseSymbolName(action)) { + return failure(); + } + matcherList.push_back(SymbolRefAttr::get(matcher)); + actionList.push_back(SymbolRefAttr::get(action)); + } while (parser.parseOptionalComma().succeeded()); + + matchers = parser.getBuilder().getArrayAttr(matcherList); + actions = parser.getBuilder().getArrayAttr(actionList); + return success(); +} + +/// Prints the comma-separated list of symbol reference pairs of the format +/// `@matcher -> @action`. +static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, + ArrayAttr matchers, ArrayAttr actions) { + printer.increaseIndent(); + printer.increaseIndent(); + for (auto &&[matcher, action, idx] : llvm::zip_equal( + matchers, actions, llvm::seq(0, matchers.size()))) { + printer.printNewline(); + printer << cast(matcher) << " -> " + << cast(action); + if (idx != matchers.size() - 1) + printer << ", "; + } + printer.decreaseIndent(); + printer.decreaseIndent(); +} + +LogicalResult transform::ForeachMatchOp::verify() { + if (getMatchers().size() != getActions().size()) + return emitOpError() << "expected the same number of matchers and actions"; + if (getMatchers().empty()) + return emitOpError() << "expected at least one match/action pair"; + + llvm::SmallPtrSet matcherNames; + for (Attribute name : getMatchers()) { + if (matcherNames.insert(name).second) + continue; + emitWarning() << "matcher " << name + << " is used more than once, only the first match will apply"; + } + + return success(); +} + +/// Returns `true` if both types implement one of the interfaces provided as +/// template parameters. +template +static bool implementSameInterface(Type t1, Type t2) { + return ((isa(t1) && isa(t2)) || ... || false); +} + +/// Returns `true` if both types implement one of the transform dialect +/// interfaces. +static bool implementSameTransformInterface(Type t1, Type t2) { + return implementSameInterface( + t1, t2); +} + +/// Checks that the attributes of the function-like operation have correct +/// consumption effect annotations. If `alsoVerifyInternal`, checks for +/// annotations being present even if they can be inferred from the body. +static DiagnosedSilenceableFailure +verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, + bool alsoVerifyInternal = false) { + auto transformOp = cast(op.getOperation()); + llvm::SmallDenseSet consumedArguments; + if (!op.isExternal()) { + transform::getConsumedBlockArguments(op.getFunctionBody().front(), + consumedArguments); + } + for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { + bool isConsumed = + op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) != + nullptr; + bool isReadOnly = + op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) != + nullptr; + if (isConsumed && isReadOnly) { + return transformOp.emitSilenceableError() + << "argument #" << i << " cannot be both readonly and consumed"; + } + if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) { + return transformOp.emitSilenceableError() + << "must provide consumed/readonly status for arguments of " + "external or called ops"; + } + if (op.isExternal()) + continue; + + if (consumedArguments.contains(i) && !isConsumed && isReadOnly) { + return transformOp.emitSilenceableError() + << "argument #" << i + << " is consumed in the body but is not marked as such"; + } + if (!consumedArguments.contains(i) && isConsumed) { + Diagnostic warning(op->getLoc(), DiagnosticSeverity::Warning); + warning << "argument #" << i + << " is not consumed in the body but is marked as consumed"; + return DiagnosedSilenceableFailure::silenceableFailure( + std::move(warning)); + } + } + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::ForeachMatchOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { + assert(getMatchers().size() == getActions().size()); + auto consumedAttr = + StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName); + for (auto &&[matcher, action] : + llvm::zip_equal(getMatchers(), getActions())) { + auto matcherSymbol = dyn_cast_or_null( + symbolTable.lookupNearestSymbolFrom(getOperation(), + cast(matcher))); + auto actionSymbol = dyn_cast_or_null( + symbolTable.lookupNearestSymbolFrom(getOperation(), + cast(action))); + if (!matcherSymbol || + !isa(matcherSymbol.getOperation())) + return emitError() << "unresolved matcher symbol " << matcher; + if (!actionSymbol || + !isa(actionSymbol.getOperation())) + return emitError() << "unresolved action symbol " << action; + + if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol, + /*alsoVerifyInternal=*/true) + .checkAndReport())) { + return failure(); + } + if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol, + /*alsoVerifyInternal=*/true) + .checkAndReport())) { + return failure(); + } + + ArrayRef matcherResults = matcherSymbol.getResultTypes(); + ArrayRef actionArguments = actionSymbol.getArgumentTypes(); + if (matcherResults.size() != actionArguments.size()) { + return emitError() << "mismatching number of matcher results and " + "action arguments between " + << matcher << " (" << matcherResults.size() << ") and " + << action << " (" << actionArguments.size() << ")"; + } + for (auto &&[i, matcherType, actionType] : + llvm::enumerate(matcherResults, actionArguments)) { + if (implementSameTransformInterface(matcherType, actionType)) + continue; + + return emitError() << "mismatching type interfaces for matcher result " + "and action argument #" + << i; + } + + if (!actionSymbol.getResultTypes().empty()) { + InFlightDiagnostic diag = + emitError() << "action symbol is not expected to have results"; + diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; + return diag; + } + + if (matcherSymbol.getArgumentTypes().size() != 1 || + !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0], + getRoot().getType())) { + InFlightDiagnostic diag = + emitOpError() << "expects matcher symbol to have one argument with " + "the same transform interface as the first operand"; + diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; + return diag; + } + + if (matcherSymbol.getArgAttr(0, consumedAttr)) { + InFlightDiagnostic diag = + emitOpError() + << "does not expect matcher symbol to consume its operand"; + diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; + return diag; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // ForeachOp //===----------------------------------------------------------------------===// @@ -690,7 +1016,7 @@ // Forward the operation mapping for values yielded from the sequence to the // values produced by the sequence op. - forwardTerminatorOperands(&block, state, results); + transform::detail::forwardTerminatorOperands(&block, state, results); return DiagnosedSilenceableFailure::success(); } @@ -801,57 +1127,6 @@ } } -template -static bool implementSameInterface(Type t1, Type t2) { - return ((isa(t1) && isa(t2)) || ... || false); -} - -/// Checks that the attributes of the named sequence operation have correct -/// consumption effect annotations. If `alsoVerifyInternal`, checks for -/// annotations being present even if they can be inferred from the body. -static DiagnosedSilenceableFailure -verifyNamedSequenceConsumeAnnotations(transform::NamedSequenceOp op, - bool alsoVerifyInternal = false) { - llvm::SmallDenseSet consumedArguments; - if (!op.isExternal()) { - transform::getConsumedBlockArguments(op.getBody().front(), - consumedArguments); - } - for (unsigned i = 0, e = op.getFunctionType().getNumInputs(); i < e; ++i) { - bool isConsumed = - op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) != - nullptr; - bool isReadOnly = - op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) != - nullptr; - if (isConsumed && isReadOnly) { - return op.emitSilenceableError() - << "argument #" << i << " cannot be both readonly and consumed"; - } - if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) { - return op.emitSilenceableError() - << "must provide consumed/readonly status for arguments of " - "external or called ops"; - } - if (op.isExternal()) - continue; - - if (consumedArguments.contains(i) && !isConsumed && isReadOnly) { - return op.emitSilenceableError() - << "argument #" << i - << " is consumed in the body but is not marked as such"; - } - if (!consumedArguments.contains(i) && isConsumed) { - Diagnostic warning(op->getLoc(), DiagnosticSeverity::Warning); - warning << "argument #" << i - << " is not consumed in the body but is marked as consumed"; - return DiagnosedSilenceableFailure::silenceableFailure( - std::move(warning)); - } - } - return DiagnosedSilenceableFailure::success(); -} - LogicalResult transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Access through indirection and do additional checking because this may be @@ -883,18 +1158,16 @@ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { Type resultType = getResult(i).getType(); Type funcType = fnType.getResult(i); - if (!implementSameInterface(resultType, - funcType)) { + if (!implementSameTransformInterface(resultType, funcType)) { return emitOpError() << "type of result #" << i << " must implement the same transform dialect " "interface as the corresponding callee result"; } } - return verifyNamedSequenceConsumeAnnotations(target, - /*alsoVerifyInternal=*/true) + return verifyFunctionLikeConsumeAnnotations( + cast(*target), + /*alsoVerifyInternal=*/true) .checkAndReport(); } @@ -973,6 +1246,60 @@ getResAttrsAttrName()); } +/// Verifies that a symbol function-like transform dialect operation has the +/// signature and the terminator that have conforming types, i.e., types +/// implementing the same transform dialect type interface. If `allowExternal` +/// is set, allow external symbols (declarations) and don't check the terminator +/// as it may not exist. +static DiagnosedSilenceableFailure +verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) { + if (auto parent = op->getParentOfType()) { + DiagnosedSilenceableFailure diag = + emitSilenceableFailure(op) + << "cannot be defined inside another transform op"; + diag.attachNote(parent.getLoc()) << "ancestor transform op"; + return diag; + } + + if (op.isExternal() || op.getFunctionBody().empty()) { + if (allowExternal) + return DiagnosedSilenceableFailure::success(); + + return emitSilenceableFailure(op) << "cannot be external"; + } + + if (op.getFunctionBody().front().empty()) + return emitSilenceableFailure(op) << "expected a non-empty body block"; + + Operation *terminator = &op.getFunctionBody().front().back(); + if (!isa(terminator)) { + DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) + << "expected '" + << transform::YieldOp::getOperationName() + << "' as terminator"; + diag.attachNote(terminator->getLoc()) << "terminator"; + return diag; + } + + if (terminator->getNumOperands() != op.getResultTypes().size()) { + return emitSilenceableFailure(terminator) + << "expected terminator to have as many operands as the parent op " + "has results"; + } + for (auto [i, operandType, resultType] : llvm::zip_equal( + llvm::seq(0, terminator->getNumOperands()), + terminator->getOperands().getType(), op.getResultTypes())) { + if (operandType == resultType) + continue; + return emitSilenceableFailure(terminator) + << "the type of the terminator operand #" << i + << " must match the type of the corresponding parent op result (" + << operandType << " vs " << resultType << ")"; + } + + return DiagnosedSilenceableFailure::success(); +} + /// Verification of a NamedSequenceOp. This does not report the error /// immediately, so it can be used to check for op's well-formedness before the /// verifier runs, e.g., during trait verification. @@ -1000,7 +1327,7 @@ } if (op.isExternal() || op.getBody().empty()) - return verifyNamedSequenceConsumeAnnotations(op); + return verifyFunctionLikeConsumeAnnotations(cast(*op)); if (op.getBody().front().empty()) return emitSilenceableFailure(op) << "expected a non-empty body block"; @@ -1032,7 +1359,14 @@ << operandType << " vs " << resultType << ")"; } - return verifyNamedSequenceConsumeAnnotations(op); + auto funcOp = cast(*op); + DiagnosedSilenceableFailure diag = + verifyFunctionLikeConsumeAnnotations(funcOp); + if (!diag.succeeded()) + return diag; + + return verifyYieldingSingleBlockOp(funcOp, + /*allowExternal=*/true); } LogicalResult transform::NamedSequenceOp::verify() { diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -562,3 +562,113 @@ transform.yield } } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{unresolved matcher symbol @foo}} + transform.foreach_match in %root + @foo -> @bar : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + func.func private @foo() + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{unresolved matcher symbol @foo}} + transform.foreach_match in %root + @foo -> @bar : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match() + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{unresolved action symbol @bar}} + transform.foreach_match in %root + @match -> @bar : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + func.func private @bar() + transform.named_sequence @match() + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{unresolved action symbol @bar}} + transform.foreach_match in %root + @match -> @bar : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match() -> !transform.any_op + transform.named_sequence @action() + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{mismatching number of matcher results and action arguments between @match (1) and @action (0)}} + transform.foreach_match in %root + @match -> @action : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match(!transform.any_op {transform.readonly}) + // expected-note @below {{symbol declaration}} + transform.named_sequence @action() -> !transform.any_op + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{action symbol is not expected to have results}} + transform.foreach_match in %root + @match -> @action : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + // expected-note @below {{symbol declaration}} + transform.named_sequence @match() + transform.named_sequence @action() + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{expects matcher symbol to have one argument with the same transform interface as the first operand}} + transform.foreach_match in %root + @match -> @action : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + // expected-note @below {{symbol declaration}} + transform.named_sequence @match(!transform.any_op {transform.consumed}) + transform.named_sequence @action() + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{'transform.foreach_match' op does not expect matcher symbol to consume its operand}} + transform.foreach_match in %root + @match -> @action : (!transform.any_op) -> !transform.any_op + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1335,3 +1335,92 @@ transform.test_print_remark_at_operand_value %0#1, "value" : !transform.any_value } } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match1(%current: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.test_succeed_if_operand_of_op_kind %current, "test.some_op" : !transform.any_op + transform.yield %current : !transform.any_op + } + + transform.named_sequence @match2(%current: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.test_succeed_if_operand_of_op_kind %current, "func.func" : !transform.any_op + transform.yield %current : !transform.any_op + } + + transform.named_sequence @action1(%current: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %current, "matched1" : !transform.any_op + transform.yield + } + transform.named_sequence @action2(%current: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %current, "matched2" : !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + transform.foreach_match in %root + @match1 -> @action1, + @match2 -> @action2 + : (!transform.any_op) -> (!transform.any_op) + transform.yield + } + + // expected-remark @below {{matched2}} + func.func private @foo() + // expected-remark @below {{matched2}} + func.func private @bar() + "test.testtest"() : () -> () + // expected-remark @below {{matched1}} + "test.some_op"() : () -> () +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match(!transform.any_op {transform.readonly}) + transform.named_sequence @action() + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{unresolved external symbol @match}} + transform.foreach_match in %root + @match -> @action : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match(%arg: !transform.any_op {transform.readonly}) { + transform.yield + } + transform.named_sequence @action() + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + // expected-error @below {{unresolved external symbol @action}} + transform.foreach_match in %root + @match -> @action : (!transform.any_op) -> !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match(%arg: !transform.any_op {transform.readonly}) { + // expected-error @below {{expected operations in the match part to implement MatchOpInterface}} + transform.test_print_remark_at_operand %arg, "remark" : !transform.any_op + transform.yield + } + transform.named_sequence @action() { + transform.yield + } + + transform.sequence failures(propagate) { + ^bb0(%root: !transform.any_op): + transform.foreach_match in %root + @match -> @action : (!transform.any_op) -> !transform.any_op + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h @@ -15,6 +15,7 @@ #define MLIR_TESTTRANSFORMDIALECTEXTENSION_H #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -209,6 +209,25 @@ transform::modifiesPayload(effects); } +DiagnosedSilenceableFailure +mlir::test::TestSucceedIfOperandOfOpKind::matchOperation( + Operation *op, transform::TransformResults &results, + transform::TransformState &state) { + if (op->getName().getStringRef() != getOpKind()) { + return emitSilenceableError() + << "op expected the operand to be associated with a payload op of " + "kind " + << getOpKind() << " got " << op->getName().getStringRef(); + } + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestSucceedIfOperandOfOpKind::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getOperand(), effects); + transform::onlyReadsPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/MatchInterfaces.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" @@ -115,6 +116,20 @@ let cppNamespace = "::mlir::test"; } +def TestSucceedIfOperandOfOpKind + : Op]> { + let arguments = (ins + TransformHandleTypeInterface:$operand_handle, + StrAttr:$op_kind); + let assemblyFormat = + "$operand_handle `,` $op_kind attr-dict `:` type($operand_handle)"; + let extraClassDeclaration = SingleOpMatcher.extraDeclaration; + let cppNamespace = "::mlir::test"; +} + def TestPrintRemarkAtOperandOp : Op, diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9519,6 +9519,31 @@ deps = [":TransformDialectTdFiles"], ) +gentbl_cc_library( + name = "TransformDialectMatchInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-op-interface-decls", + ], + "include/mlir/Dialect/Transform/IR/MatchInterfaces.h.inc", + ), + ( + [ + "-gen-op-interface-defs", + ], + "include/mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/IR/MatchInterfaces.td", + deps = [ + ":TransformDialectTdFiles", + ":TransformDialectInterfacesIncGen", + ], +) + gentbl_cc_library( name = "TransformDialectIncGen", strip_include_prefix = "include", @@ -9598,6 +9623,7 @@ ":TransformDialectEnumsIncGen", ":TransformDialectIncGen", ":TransformDialectInterfacesIncGen", + ":TransformDialectMatchInterfacesIncGen", ":TransformDialectUtils", ":TransformOpsIncGen", ":TransformTypesIncGen",