diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md --- a/mlir/docs/Dialects/Transform.md +++ b/mlir/docs/Dialects/Transform.md @@ -431,6 +431,10 @@ [include "Dialects/MemRefTransformOps.md"] +## Structured (Linalg) Match Operations + +[include "Dialects/LinalgStructuredMatchOps.md"] + ## Structured (Linalg) Transform Operations [include "Dialects/LinalgStructuredTransformOps.md"] diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -1,8 +1,17 @@ +set(LLVM_TARGET_DEFINITIONS LinalgMatchOps.td) +mlir_tablegen(LinalgMatchOps.h.inc -gen-op-decls) +mlir_tablegen(LinalgMatchOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRLinalgMatchOpsIncGen) + set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td) mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls) mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRLinalgTransformOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS LinalgTransformEnums.td) mlir_tablegen(LinalgTransformOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(LinalgTransformOpsEnums.cpp.inc -gen-enum-defs) -add_public_tablegen_target(MLIRLinalgTransformOpsIncGen) +add_public_tablegen_target(MLIRLinalgTransformEnumsIncGen) +add_mlir_doc(LinalgMatchOps LinalgStructuredMatchOps Dialects/ -gen-op-doc) add_mlir_doc(LinalgTransformOps LinalgStructuredTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td @@ -0,0 +1,465 @@ +//===- LinalgMatchOps.td - Linalg transform matcher ops ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_MATCH_OPS +#define LINALG_MATCH_OPS + +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" +include "mlir/Dialect/Transform/IR/MatchInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// Structured match op and predicates usable inside it. +//===----------------------------------------------------------------------===// + +def MatchStructuredOp : Op, + SingleOpMatcher, + SingleBlockImplicitTerminator<"::mlir::transform::MatchStructuredYieldOp">]> { + let summary = + "Matches a structured (linalg) operation with additional conditions"; + let description = [{ + Checks if the payload operation associated with the operand handle is a + structured operation, that is, an operation that implements + `LinalgOpInterface`, and that all conditions listed in the body of this + operation are satisfied. Produces a silenceable failure if the payload + operation is not structured. + + The transform operations nested in the body region are applied one by one. + If any of them produces a failure, silenceable or definite, the following + operations are not applied. If the failure propagation mode is "propagate", + silenceable failures are forwarded as the result of this operation. If it is + "suppress", they are ignored and this operation immediately succeeds. + Definite failures are always propagated immediately. + + In case of success, the transform values produced by this operation are + associated with the same payload as the operands of the block terminator. If + any of the nested operations produced a silenceable failure, regardless of + the failure propagation mode, the transform values produced by this + operation that correspond to the already defined terminator operands are + associated with the same payload as the already defined terminator operands. + Other values produced by this operation are associated with empty payloads. + + If the failure propagation mode is not specified, it is considered + "propagate" by default. The "suppress" mode can be used to specify optional + matches. + + #### Return modes + + This operation only reads all operand handles and produces all resulting + handles. It succeeds in "propagate" mode if the payload operation is a + structured operation and if all the nested operations succeed. It succeeds + in "suppress" mode as long as the operand handle is associated with exactly + one payload operation. It produces a definite failure when the handle is + not associated with exactly one payload operation. + }]; + + let arguments = (ins TransformHandleTypeInterface:$current, + OptionalAttr:$failure_propagation_mode); + let results = (outs Variadic:$outputs); + + let regions = (region SizedRegion<1>:$body_region); + let assemblyFormat = + "(`failures` `(` $failure_propagation_mode^ `)`)?" + "$current `:` custom(type($current), type($outputs))" + "attr-dict-with-keyword regions"; + let hasVerifier = 1; + + let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{ + ::mlir::Value getOperandHandle() { return getCurrent(); } + }]; +} + +def StructuredPredicate : NativeOpTrait<"StructuredOpPredicateOpTrait"> { + let cppNamespace = "::mlir::transform"; + string extraDescription = [{ + This op can only appear immediately inside a `transform.match.structured` + op and apply to its first block argument because it assumes the payload + to have been already checked for being a single structured op. + }]; +} + +def MatchStructuredBodyOp : Op { + let summary = + "Checks if the body of the structured op satisfies some criteria"; + let description = !strconcat([{ + Checks if the body of the structured payload op satisfies one of the + following mutually exclusive criteria specified by attributes: + + * `reduction_position`: the body of the structured payload op implements + a reduction of the `n`-th operand (`n` is the value of the attribute) + using a single combiner operation; + + * `passthrough`: the body of the structured payload op only forwards + inputs to the outputs (copy or broadcast). + + }], StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the operation body satisfies the specified criteria, produces a + silenceable failure otherwise. Produces a definite failure if the operand is + not associated with a single payload op. + }]); + let arguments = (ins TransformHandleTypeInterface:$operand_handle, + OptionalAttr:$reduction_position, + UnitAttr:$passthrough); + let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)"; + let extraClassDeclaration = SingleOpMatcher.extraDeclaration; + let hasVerifier = 1; +} + +class StructuredDimDescription { + string description = !strconcat([{ + The following }], kind ,[{ specifications are supported: + + * `all`: all }], kind ,[{s are checked and captured; + * list of integers: the listed }], kind, [{s are checked and captured; + * `except(` list of integers `)`: all }], kind, [{s except the + specified ones are checked and captured. + + Negative indexes are interpreted by counting values from the last one + (similarly to Python). For example, `-1` means the last }], kind, [{ and + `except(-1)` means all }], kind, [{s but the last. Indexes must be unique, + including after interpretation of negative ones. + + Produces a silenceable failure in case of index overflow, including backward + counting. + }]); +} + +def MatchStructuredDimOp : Op { + let summary = + "Checks if the dimensions of the structured op satisfy some criteria"; + let description = !strconcat([{ + Checks if the dimensions (loop ranges) of the structured payload op satisfy + the criteria specified as attributes. May capture the numeric value of the + dimension into a parameter that it returns. + + }], + StructuredDimDescription<"dimension">.description, + [{ + + The following mutually exclusive conditions are available as unit + attributes: + + * `parallel`: the dimension corresponds to a parallel loop; + * `reduction`: the dimension corresponds to a reduction loop. + + If the result type is specified, associates the parameter with the (static) + values of dimensions in the same order as listed and preserving the natural + order for `all` and `except`. Specifically, if `-1, -2` are specified, the + parameter will be associated with the value of the second-to-last dimension + followed by the last dimension. If the dimension is dynamic, the parameter + will contain a negative value corresponding to kDynamic in C++. + + }], StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the specified dimensions satisfy the specified criteria, + produces a silenceable failure otherwise. Produces a definite failure if + the operand is not associated with a single payload op. + }]); + + let arguments = (ins TransformHandleTypeInterface:$operand_handle, + DenseI64ArrayAttr:$raw_dim_list, + UnitAttr:$is_inverted, + UnitAttr:$is_all, + UnitAttr:$parallel, + UnitAttr:$reduction); + + let results = (outs Optional:$result); + let assemblyFormat = + "$operand_handle `[`" + "custom($raw_dim_list, $is_inverted, $is_all)" + "`]` attr-dict `:` " + "custom(type($operand_handle), type($result))"; + + let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{ + ::mlir::DiagnosedSilenceableFailure getDimensionsFor( + ::mlir::linalg::LinalgOp op, + ::llvm::SmallVectorImpl &dims); + }]; + + let hasVerifier = 1; +} + +def MatchStructuredElementalBitwidthOp + : Op { + let summary = + "Captures the bitwidth of the value's elemental type as a parameter"; + let description = !strconcat([{ + Produces a transform dialect parameter associated with the bitwidth of the + elemental type of the payload value passed as the operand.}], + + StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the operand is associated with exactly one payload value of + `ShapedType`. Produces a silenceable failure otherwise. + }]); + let arguments = (ins TransformValueHandleTypeInterface:$operand_handle); + let results = (outs TransformParamTypeInterface:$result); + let assemblyFormat = + "$operand_handle attr-dict `:` functional-type(operands, results)"; + let extraClassDeclaration = SingleValueMatcher.extraDeclaration; +} + +class MatchStructuredOperandOp : Op { + + // TODO: consider an attribute controlling whether to fail or succeed on + // out-of-bounds accesses. + let arguments = (ins TransformHandleTypeInterface:$operand_handle, + DenseI64ArrayAttr:$raw_position_list, + UnitAttr:$is_inverted, + UnitAttr:$is_all, + UnitAttr:$permutation, + UnitAttr:$projected_permutation); + + // TODO: allow this to bind multiple inputs simultaneously after checking that + // `transform.foreach` works well in matches. + let results = (outs Optional:$result); + let assemblyFormat = + "$operand_handle `[`" + "custom($raw_position_list, $is_inverted, $is_all)" + "`]` attr-dict " + "`:` custom(type($operand_handle), type($result))"; + + let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{ + ::mlir::DiagnosedSilenceableFailure getPositionsFor( + ::mlir::linalg::LinalgOp op, + ::llvm::SmallVectorImpl &positions); + }]; + + let hasVerifier = 1; +} + +def MatchStructuredInputOp : MatchStructuredOperandOp<"match.structured.input"> { + let summary = + "Captures input operand(s) of a structured operation in an op or value handle"; + let description = !strconcat([{ + Produces a transform dialect value handle associated with the payload value + supplied as input operand to the given structured payload operation, or an + operation handle to the structured payload operation producing said payload + value depending on the result type. + + }], + StructuredDimDescription<"input">.description, + [{ + + }], + StructuredPredicate.extraDescription, + [{ + + #### Return modes + + Succeeds if all input indexes are in bounds, produces a silenceable failure + otherwise. Additionally, when the result is an operation handle, produces a + silenceable failure if the input specification defines more than one input + or if the operand is not an operation result. + }]); +} + +def MatchStructuredInitOp : MatchStructuredOperandOp<"match.structured.init"> { + let summary = + "Captures init operand(s) of a structured operation in an op or value handle"; + let description = !strconcat([{ + Produces a transform dialect value handle associated with the payload value + supplied as init(outs) operand to the given structured payload operation, + or an operation handle to the structured payload operation producing said + payload value depending on the result type. + + }], + StructuredDimDescription<"init">.description, + [{ + + }], + StructuredPredicate.extraDescription, + [{ + + #### Return modes + + Succeeds if all init(outs) indexes are in bounds, produces a silenceable + failure otherwise. Additionally, when the result is an operation handle, + produces a silenceable failure if the init(outs) specification defines + more than one init(outs) or if the operand is not an operation result. + }]); +} + + +def MatchStructuredNumInputsOp + : Op { + let summary = "Captures the number of input operands of a structured " + "operation as parameter"; + let description = !strconcat([{ + Produces a transform dialect parameter value associated with an integer + attribute containing the number of input operands of the payload operation + associated with the operand handle. + + }], StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the operand is associated with exactly one structured payload + operation. Produces a silenceable failure otherwise. + }]); + + let arguments = (ins TransformHandleTypeInterface:$operand_handle); + let results = (outs TransformParamTypeInterface:$result); + let assemblyFormat = + "$operand_handle attr-dict `:` functional-type(operands, results)"; + let extraClassDeclaration = SingleOpMatcher.extraDeclaration; +} + +def MatchStructuredNumInitsOp + : Op { + let summary = "Captures the number of init(outs) operands of a structured" + "operation as parameter"; + let description = !strconcat([{ + Produces a transform dialect parameter value associated with an integer + attribute containing the number of init(outs) operands of the payload + operation associated with the operand handle. + + }], StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the operand is associated with exactly one structured payload + operation. Produces a silenceable failure otherwise. + }]); + + let arguments = (ins TransformHandleTypeInterface:$operand_handle); + let results = (outs TransformParamTypeInterface:$result); + let assemblyFormat = + "$operand_handle attr-dict `:` functional-type(operands, results)"; + let extraClassDeclaration = SingleOpMatcher.extraDeclaration; +} + +def MatchStructuredRankOp : Op { + let summary = "Captures the rank of a structured operation as parameter"; + let description = !strconcat([{ + Produces a transform dialect parameter value associated with an integer + attribute containing the rank of the structured payload operation associated + with the operand handle. + + }], StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the operand is associated with exactly one structured payload + operation. Produces a silenceable failure otherwise. + }]); + + let arguments = (ins TransformHandleTypeInterface:$operand_handle); + let results = (outs TransformParamTypeInterface:$rank); + let assemblyFormat = + "$operand_handle attr-dict `:`" + "custom(type($operand_handle), type($rank))"; + + let extraClassDeclaration = SingleOpMatcher.extraDeclaration; +} + +def MatchStructuredResultOp : Op { + let summary = "Captures the result of a structured payload operation in an " + "op or value handle"; + let description = !strconcat([{ + Produces a transform dialect value handle associated with the payload value + defined as a result of the payload operation associated with the operand + handle, or an operation handle to an operation using the produced result + with additional constraints specified by the attributes as follows. + + * If `any` is specified, binds the resulting handle to any operation using + the result and succeeds. + * If `single` is specified, binds the resulting handle to the only + operation using the result or fails if there is more than one (or no) + such operation. + + The number of the result is specified as `position` attribute. It may take + positive and negative values. Negative values are interpreted as counting + results from backwards, e.g., `-1` means the last result and `-2` means the + second-to-last result. In any case, the position must be in bounds for the + given payload operation. A silenceable failure is produced for out-of-bounds + positions. + + }], StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the position is in bounds and if the user operation could be + found when requested. Produces a silenceable failure otherwise. + }]); + let arguments = (ins TransformHandleTypeInterface:$operand_handle, + I64Attr:$position, + UnitAttr:$any, + UnitAttr:$single); + let results = (outs TransformAnyHandle:$result); + let assemblyFormat = + "$operand_handle `[` $position `]` (`any` $any^)? (`single` $single^)?" + "attr-dict `:` functional-type(operands, results)"; + let hasVerifier = 1; + + let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{ + ::mlir::DiagnosedSilenceableFailure + getPositionFor(::mlir::linalg::LinalgOp op, int64_t &position); + }]; +} + +def MatchStructuredYieldOp : Op, + Terminator]> { + let summary = "Terminator for transform.match.structured blocks"; + let description = [{ + Forwards the payload association from the operands to the results of the + parent op. Always succeeds. + }]; + let builders = [ + OpBuilder<(ins)> + ]; + + let arguments = (ins Variadic:$handles); + let assemblyFormat = "$handles attr-dict (`:` type($handles)^)?"; +} + +#endif // LINALG_MATCH_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td @@ -0,0 +1,9 @@ +include "mlir/IR/EnumAttr.td" + +def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match", + [ + I32EnumAttrCase<"LinalgOp", 0>, + I32EnumAttrCase<"TilingInterface", 1> + ]>{ + let cppNamespace = "mlir::transform"; +} diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -11,6 +11,9 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/OpImplementation.h" @@ -39,15 +42,6 @@ } // namespace transform } // namespace mlir -//===----------------------------------------------------------------------===// -// Linalg Transform Operations -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.h.inc" - -#define GET_OP_CLASSES -#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc" - namespace mlir { class DialectRegistry; @@ -61,6 +55,25 @@ ArrayRef mixedTileSizes, std::optional mapping, SmallVector &tileOps, SmallVector &tiledOps); +namespace detail { +LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, + Value structuredOpHandle); +} // namespace detail + +template +class StructuredOpPredicateOpTrait + : public OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert( + OpTy::template hasTrait(), + "StructuredOpPredicateOpTrait requires SingleOpMatcherOpTrait"); + + return detail::verifyStructuredOpPredicateOpTrait( + op, cast(op).getOperandHandle()); + } +}; + } // namespace transform namespace linalg { @@ -68,4 +81,16 @@ } // namespace linalg } // namespace mlir +//===----------------------------------------------------------------------===// +// Linalg Transform Operations +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc" + #endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -9,13 +9,14 @@ #ifndef LINALG_TRANSFORM_OPS #define LINALG_TRANSFORM_OPS +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" include "mlir/IR/RegionKindInterface.td" @@ -344,14 +345,6 @@ // MatchOp //===----------------------------------------------------------------------===// -def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match", - [ - I32EnumAttrCase<"LinalgOp", 0>, - I32EnumAttrCase<"TilingInterface", 1> - ]>{ - let cppNamespace = "mlir::transform"; -} - def MatchOp : Op(this->getOperation()) .matchOperation(payload[0], results, state); } + + void getEffects(SmallVectorImpl &effects) { + onlyReadsHandle(this->getOperation()->getOperands(), effects); + producesHandle(this->getOperation()->getResults(), effects); + onlyReadsPayload(effects); + } +}; + +template +class SingleValueMatcherOpTrait + : public OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + // This must be a dynamic assert because interface registration is dynamic. + assert(isa(op) && + "SingleValueMatchOpTrait is only available on operations with " + "MatchOpInterface"); + + Value operandHandle = cast(op).getOperandHandle(); + if (!operandHandle.getType().isa()) { + return op->emitError() << "SingleValueMatchOpTrait requires an operand " + "of TransformValueHandleTypeInterface"; + } + + return success(); + } + + DiagnosedSilenceableFailure apply(TransformResults &results, + TransformState &state) { + Value operandHandle = cast(this->getOperation()).getOperandHandle(); + ValueRange payload = state.getPayloadValues(operandHandle); + if (payload.size() != 1) { + return emitDefiniteFailure(this->getOperation()->getLoc()) + << "SingleValueMatchOpTrait requires the value handle to point to " + "a single payload value"; + } + + return cast(this->getOperation()) + .matchValue(payload[0], results, state); + } + + void getEffects(SmallVectorImpl &effects) { + onlyReadsHandle(this->getOperation()->getOperands(), effects); + producesHandle(this->getOperation()->getResults(), effects); + onlyReadsPayload(effects); + } }; } // namespace transform diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td @@ -14,6 +14,12 @@ let cppNamespace = "::mlir::transform"; } +// Trait for "matcher" transform operations that apply to an operation handle +// associated with exactly one payload operation. Checks that it is indeed +// the case and produces a definite failure when it is not. The matching logic +// is implemented in the `matchOperation` function instead of `apply`. The op +// with this trait must provide a `Value getOperandHandle()` function that +// returns the handle to be used for matching. def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> { let cppNamespace = "::mlir::transform"; @@ -24,3 +30,20 @@ ::mlir::transform::TransformState &state); }]; } + +// Trait for "matcher" transform operations that apply to a value handle +// associated with exactly one payload value. Checks that it is indeed +// the case and produces a definite failure when it is not. The matching logic +// is implemented in the `matchValue` function instead of `apply`. The op +// with this trait must provide a `Value getOperandHandle()` function that +// returns the handle to be used for matching. +def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> { + let cppNamespace = "::mlir::transform"; + + string extraDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure matchValue( + ::mlir::Value current, + ::mlir::transform::TransformResults &results, + ::mlir::transform::TransformState &state); + }]; +} diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h @@ -0,0 +1,20 @@ +//===- TransformAttr.h - Transform Dialect Attribute Definition -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" + +#include +#include + +#include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc" + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td @@ -20,4 +20,17 @@ let cppNamespace = "::mlir::transform"; } +def MatchCmpIPredicateAttr : I32EnumAttr< + "MatchCmpIPredicate", "", + [ + I32EnumAttrCase<"eq", 0>, + I32EnumAttrCase<"ne", 1>, + I32EnumAttrCase<"lt", 2>, + I32EnumAttrCase<"le", 3>, + I32EnumAttrCase<"gt", 4>, + I32EnumAttrCase<"ge", 5>, + ]> { + let cppNamespace = "::mlir::transform"; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -243,6 +243,4 @@ } // namespace transform } // namespace mlir -#include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc" - #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -617,6 +617,10 @@ /// operations. void setMappedValues(OpResult handle, ArrayRef values); + /// Sets the currently unset results to empty lists of the kind expected by + /// the corresponding results of the given `transform` op. + void setRemainingToEmpty(TransformOpInterface transform); + private: /// Creates an instance of TransformResults that expects mappings for /// `numSegments` values, which may be associated with payload operations or diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/FunctionInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -378,6 +378,46 @@ }]; } +def MatchOperationNameOp : TransformDialectOp<"match.operation_name", + [SingleOpMatcher, + MatchOpInterface, + MemoryEffectsOpInterface]> { + let summary = "Matches a single operation of one of the given kinds"; + let description = [{ + Succeeds if the operation associated with the operand handle has one of the + given operation names. Produces a silenceable failure otherwise. + + If more than one payload operation is associated with the operand handle, + produces a definite failure. + }]; + + let arguments = (ins TransformHandleTypeInterface:$operand_handle, + StrArrayAttr:$op_names); + let assemblyFormat = + "$operand_handle $op_names attr-dict `:` type($operand_handle)"; + let extraClassDeclaration = SingleOpMatcher.extraDeclaration; +} + +def MatchParamCmpIOp : Op, + MatchOpInterface, + DeclareOpInterfaceMethods, + SameTypeOperands]> { + let summary = + "Matches if two parameter lists are associated with the same value"; + let description = [{ + Succeeds if all of the co-indexed values associated with the given + parameters relate as specified by the predicate (greater than, less than, + equal to, or their combinations). Comparison treats all values as signed. + Produces a silenceable failure otherwise. + }]; + let arguments = (ins TransformParamTypeInterface:$param, + TransformParamTypeInterface:$reference, + MatchCmpIPredicateAttr:$predicate); + let assemblyFormat = + "$predicate $param `,` $reference attr-dict `:` type($param)"; +} + def MergeHandlesOp : TransformDialectOp<"merge_handles", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -501,6 +541,24 @@ }]; } +def ParamConstantOp : Op, + MemoryEffectsOpInterface, + ParamProducerTransformOpTrait]> { + let summary = "Produces a new transform dialect parameter value associated " + "with the given attribute"; + let description = [{ + Produces a new transform dialect parameter associated with the singleton + list containing the given attribute. The operation itself always succeeds, + but the general association check may fail if the parameter type does not + accept the given kind of attribute as valid. + }]; + let arguments = (ins AnyAttr:$value); + let results = (outs TransformParamTypeInterface:$param); + let assemblyFormat = "$value attr-dict `->` type($param)"; +} + def PDLMatchOp : TransformDialectOp<"pdl_match", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td @@ -68,4 +68,9 @@ "Transform IR handle to " # opname # " operations", "::mlir::transform::OperationType">; +def TransformAnyHandle : Type< + Or<[TransformHandleTypeInterface.predicate, + TransformValueHandleTypeInterface.predicate]>, + "transform operation or value handle">; + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -1,11 +1,14 @@ add_mlir_dialect_library(MLIRLinalgTransformOps + LinalgMatchOps.cpp LinalgTransformOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/TransformOps DEPENDS + MLIRLinalgMatchOpsIncGen MLIRLinalgTransformOpsIncGen + MLIRLinalgTransformEnumsIncGen LINK_LIBS PUBLIC MLIRAffineDialect diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -0,0 +1,826 @@ +//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/FunctionImplementation.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +#define DEBUG_TYPE "linalg-transforms" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +//===----------------------------------------------------------------------===// +// StructuredMatchOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + // First, check if the payload operation is a structured Linalg operation. + if (!isa(current)) { + if (getFailurePropagationMode().value_or( + FailurePropagationMode::Propagate) == + FailurePropagationMode::Propagate) { + return emitSilenceableError() << "expected a Linalg op"; + } + // If errors are suppressed, succeed and set all results to empty lists. + LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op"); + results.setRemainingToEmpty(cast(getOperation())); + return DiagnosedSilenceableFailure::success(); + } + + // Bind `current` to the block argument. + auto scope = state.make_region_scope(getBodyRegion()); + if (failed(state.mapBlockArgument(getBody()->getArgument(0), + MappedValue(current)))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + + for (Operation &nested : getBody()->without_terminator()) { + DiagnosedSilenceableFailure diag = + state.applyTransform(cast(nested)); + if (diag.isDefiniteFailure()) + return diag; + if (diag.succeeded()) + continue; + + // If propagating errors, do this immediately. + assert(diag.isSilenceableFailure()); + if (getFailurePropagationMode().value_or( + FailurePropagationMode::Propagate) == + FailurePropagationMode::Propagate) { + return diag; + } + + // If suppressing errors, print the message into the debug stream before + // silencing it. Then set all results value that are already known. + // Results come from the terminator operands, which may be defined in the + // (single) block of this operation or above it. When they are defined + // above, they are known to be mapped at this point per SSA dominance. + // When they are defined in this block, we additionally check if we have + // already applied the operation that defines them. If not, the + // corresponding results will be set to empty lists. + LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() + << "\n"); + (void)diag.silence(); + SmallVector undefinedOperands; + for (OpOperand &terminatorOperand : + getBody()->getTerminator()->getOpOperands()) { + Operation *definingOp = terminatorOperand.get().getDefiningOp(); + if (!definingOp) + continue; + if (definingOp->getBlock() != getBody()) + continue; + if (definingOp->isBeforeInBlock(&nested)) + continue; + + undefinedOperands.push_back(&terminatorOperand); + } + + SmallVector> mappings; + auto filtered = llvm::make_filter_range( + getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) { + return !llvm::is_contained(undefinedOperands, &opOperand); + }); + SmallVector definedOperands = llvm::to_vector(llvm::map_range( + filtered, [](OpOperand &opOperand) { return opOperand.get(); })); + detail::prepareValueMappings(mappings, definedOperands, state); + for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) { + results.setMappedValues(getResults()[operand.getOperandNumber()], + mapping); + } + results.setRemainingToEmpty(cast(getOperation())); + return DiagnosedSilenceableFailure::success(); + } + + // Set the results. + detail::forwardTerminatorOperands(getBody(), state, results); + return DiagnosedSilenceableFailure::success(); +} + +void transform::MatchStructuredOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getCurrent(), effects); + onlyReadsPayload(effects); + producesHandle(getOutputs(), effects); +} + +LogicalResult transform::MatchStructuredOp::verify() { + if (getBody()->getNumArguments() != 1) + return emitOpError() << "expected one body argument"; + if (!isa(getBody()->getArgument(0).getType())) { + return emitOpError() << "expected body argument to implement " + "TransformHandleTypeInterface"; + } + for (Operation &nested : getBody()->without_terminator()) { + if (isa(nested)) + continue; + InFlightDiagnostic diag = + emitOpError() + << "expects nested operations to implement MatchOpInterface"; + diag.attachNote(nested.getLoc()) << "offending operation"; + return diag; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// StructuredOpPredicateOpTrait +//===----------------------------------------------------------------------===// + +LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait( + Operation *op, Value structuredOpHandle) { + if (!isa_and_nonnull(op->getParentOp())) { + return op->emitOpError() << "expects parent op to be '" + << MatchStructuredOp::getOperationName() << "'"; + } + + // Bail out here, let the verifier of the parent complain. + Operation *parent = op->getParentOp(); + if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() || + parent->getRegion(0).front().getNumArguments() < 1) + return success(); + + if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) { + return op->emitOpError() + << "expected predicate to apply to the surrounding structured op"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredBodyOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + auto linalgOp = cast(current); + if (std::optional position = getReductionPosition()) { + SmallVector combinerOps; + if (!matchReduction(linalgOp.getRegionOutputArgs(), *position, + combinerOps)) { + return emitSilenceableError() << "could not match reduction"; + } + if (combinerOps.size() != 1) { + return emitSilenceableError() << "reduction combiner is not a single op"; + } + return DiagnosedSilenceableFailure::success(); + } + if (getPassthrough()) { + Block &body = linalgOp->getRegion(0).front(); + if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) { + return emitSilenceableError() << "not a passthrough"; + } + return DiagnosedSilenceableFailure::success(); + } + return emitDefiniteFailure() << "unknown body condition"; +} + +LogicalResult transform::MatchStructuredBodyOp::verify() { + if (getReductionPosition() && getPassthrough()) { + return emitOpError() << "reduction position and passthrough conditions are " + "mutually exclusive"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Utilities for structured match predicates. +//===----------------------------------------------------------------------===// + +/// Checks if all values from `list` are also contained in `reference`. Returns +/// a silenceable error with the given message at the given location when it is +/// not the case. The error message must contain the "{0}" placeholder that +/// will be substituted with the value from `list` that is not contained in +/// `reference`. +static DiagnosedSilenceableFailure containsAll(ArrayRef reference, + ArrayRef list, + Location loc, + const char *message) { + for (int64_t value : list) { + if (llvm::any_of(reference, [&](unsigned ref) { + return static_cast(ref) == value; + })) { + continue; + } + return emitSilenceableFailure(loc) << llvm::formatv(message, value); + } + return DiagnosedSilenceableFailure::success(); +} + +/// Populates `result` with the positional identifiers relative to `maxNumber`. +/// If `isAll` is set, the result will contain all numbers from `0` to +/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative +/// values from `rawList` are are interpreted as counting backwards from +/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive +/// numbers remain as is. If `isInverted` is set, populates `result` with those +/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in +/// `rawList`. If `rawList` contains values that are greater than or equal to +/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the +/// given location. `maxNumber` must be positive. If `rawList` contains +/// duplicate numbers or numbers that become duplicate after negative value +/// remapping, emits a silenceable error. +static DiagnosedSilenceableFailure +expandTargetSpecification(Location loc, bool isAll, bool isInverted, + ArrayRef rawList, int64_t maxNumber, + SmallVectorImpl &result) { + assert(maxNumber > 0 && "expected size to be positive"); + assert(!(isAll && isInverted) && "cannot invert all"); + if (isAll) { + result = llvm::to_vector(llvm::seq(0, maxNumber)); + return DiagnosedSilenceableFailure::success(); + } + + SmallVector expanded; + llvm::SmallDenseSet visited; + expanded.reserve(rawList.size()); + SmallVectorImpl &target = isInverted ? expanded : result; + for (int64_t raw : rawList) { + int64_t updated = raw < 0 ? maxNumber + raw : raw; + if (updated >= maxNumber) { + return emitSilenceableFailure(loc) + << "position overflow " << updated << " (updated from " << raw + << ") for maximum " << maxNumber; + } + if (updated < 0) { + return emitSilenceableFailure(loc) << "position underflow " << updated + << " (updated from " << raw << ")"; + } + if (!visited.insert(updated).second) { + return emitSilenceableFailure(loc) << "repeated position " << updated + << " (updated from " << raw << ")"; + } + target.push_back(updated); + } + + if (!isInverted) + return DiagnosedSilenceableFailure::success(); + + result.reserve(result.size() + (maxNumber - expanded.size())); + for (int64_t candidate : llvm::seq(0, maxNumber)) { + if (llvm::is_contained(expanded, candidate)) + continue; + result.push_back(candidate); + } + + return DiagnosedSilenceableFailure::success(); +} + +/// Checks if the positional specification defined is valid and reports errors +/// otherwise. +LogicalResult verifyStructuredTransformDimsOp(Operation *op, + ArrayRef raw, + bool inverted, bool all) { + if (all) { + if (inverted) { + return op->emitOpError() + << "cannot request both 'all' and 'inverted' values in the list"; + } + if (!raw.empty()) { + return op->emitOpError() + << "cannot both request 'all' and specific values in the list"; + } + } + if (!all && raw.empty()) { + return op->emitOpError() << "must request specific values in the list if " + "'all' is not specified"; + } + SmallVector rawVector = llvm::to_vector(raw); + auto *it = std::unique(rawVector.begin(), rawVector.end()); + if (it != rawVector.end()) + return op->emitOpError() << "expected the listed values to be unique"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredDimOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + auto linalgOp = cast(current); + SmallVector dimensions; + DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions); + if (!diag.succeeded()) + return diag; + + // If asked to check for the kind of dimension, perform the check. + if (getParallel() || getReduction()) { + SmallVector reference; + if (getParallel()) + linalgOp.getParallelDims(reference); + else if (getReduction()) + linalgOp.getReductionDims(reference); + + DiagnosedSilenceableFailure diag = + containsAll(reference, dimensions, getLoc(), + getParallel() ? "expects dimension #{0} to be parallel" + : "expects dimension #{0} to be reduction"); + if (!diag.succeeded()) + return diag; + } + + // If not capturing, we are done here. + if (!getResult()) + return diag; + + SmallVector ranges = linalgOp.getStaticLoopRanges(); + Builder builder(current); + SmallVector captured = llvm::to_vector( + llvm::map_range(dimensions, [&](int64_t dim) -> Attribute { + return builder.getI64IntegerAttr(ranges[dim]); + })); + results.setParams(cast(getResult()), captured); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor( + linalg::LinalgOp op, SmallVectorImpl &dims) { + DiagnosedSilenceableFailure diag = + expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(), + getRawDimList(), op.getNumLoops(), dims); + if (diag.isSilenceableFailure()) { + diag.attachNote(op->getLoc()) + << "while considering dimensions of this payload operation"; + } + return diag; +} + +LogicalResult transform::MatchStructuredDimOp::verify() { + if (getParallel() && getReduction()) { + return emitOpError() << "cannot request the same dimension to be both " + "parallel and reduction"; + } + return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(), + getIsInverted(), getIsAll()); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredElementalBitwidthOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MatchStructuredElementalBitwidthOp::matchValue( + Value current, transform::TransformResults &results, + transform::TransformState &state) { + auto setupResult = [&](int64_t bitwidth) { + Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth); + results.setParams(cast(getResult()), {attr}); + return DiagnosedSilenceableFailure::success(); + }; + + Type type = current.getType(); + if (type.isIntOrFloat()) + return setupResult(type.getIntOrFloatBitWidth()); + + if (auto shapedType = dyn_cast(type)) { + if (shapedType.getElementType().isIntOrFloat()) + return setupResult(shapedType.getElementTypeBitWidth()); + } + return emitSilenceableError() + << "unsupported type for bitwidth extraction: " << type; +} + +//===----------------------------------------------------------------------===// +// MatchStructuredInputOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + auto linalgOp = cast(current); + SmallVector positions; + DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); + if (!diag.succeeded()) + return diag; + + SmallVector operandMapping; + operandMapping.reserve(positions.size()); + for (int64_t position : positions) { + AffineMap indexingMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position)); + if (getPermutation() && !indexingMap.isPermutation()) { + return emitSilenceableError() << "the indexing map for input #" + << position << " is not a permutation"; + } + if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { + return emitSilenceableError() + << "the indexing map for input #" << position + << " is not a projected permutation"; + } + + // If capture not requested, skip it. + if (!getResult()) + continue; + + Value operand = linalgOp.getDpsInputOperand(position)->get(); + if (isa(getResult().getType())) { + operandMapping.emplace_back(operand); + continue; + } + + Operation *operandProducer = operand.getDefiningOp(); + if (!operandProducer) { + return emitSilenceableError() + << "input #" << position << " is not produced by an operation"; + } + operandMapping.emplace_back(operandProducer); + } + if (getResult()) + results.setMappedValues(cast(getResult()), operandMapping); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor( + linalg::LinalgOp op, SmallVectorImpl &positions) { + DiagnosedSilenceableFailure diag = expandTargetSpecification( + getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), + op.getNumDpsInputs(), positions); + if (diag.isSilenceableFailure()) { + diag.attachNote(op->getLoc()) + << "while considering DPS inputs of this payload operation"; + } + return diag; +} + +/// Verifies a matcher op for structured input or output, specifically the +/// attributes specifying the operand positions. +template +LogicalResult verifyStructuredOperandOp(OpTy op) { + if (op.getPermutation() && op.getProjectedPermutation()) { + return op.emitOpError() + << op.getPermutationAttrName() << " and " + << op.getProjectedPermutationAttrName() << " are mutually exclusive"; + } + if (op.getRawPositionList().size() > 1 && op.getResult()) { + return op.emitOpError() + << "cannot bind multiple inputs/inits to the same value"; + } + + return success(); +} + +LogicalResult transform::MatchStructuredInputOp::verify() { + if (failed(verifyStructuredOperandOp(*this))) + return failure(); + return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(), + getIsInverted(), getIsAll()); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredInitOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + auto linalgOp = cast(current); + SmallVector positions; + DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); + if (!diag.succeeded()) + return diag; + + SmallVector operandMapping; + operandMapping.reserve(positions.size()); + for (int64_t position : positions) { + AffineMap indexingMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position)); + if (getPermutation() && !indexingMap.isPermutation()) { + return emitSilenceableError() << "the indexing map for output(init) #" + << position << " is not a permutation"; + } + if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { + return emitSilenceableError() << "the indexing map for output(init) #" + << position << " is not a permutation"; + } + + // If capture not requested, skip it. + if (!getResult()) + continue; + + Value operand = linalgOp.getDpsInitOperand(position)->get(); + if (isa(getResult().getType())) { + operandMapping.emplace_back(operand); + continue; + } + + Operation *operandProducer = operand.getDefiningOp(); + if (!operandProducer) { + return emitSilenceableError() << "output(init) #" << position + << " is not produced by an operation"; + } + operandMapping.emplace_back(operandProducer); + } + if (getResult()) + results.setMappedValues(cast(getResult()), operandMapping); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor( + linalg::LinalgOp op, SmallVectorImpl &positions) { + DiagnosedSilenceableFailure diag = expandTargetSpecification( + getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), + op.getNumDpsInits(), positions); + if (diag.isSilenceableFailure()) { + diag.attachNote(op->getLoc()) + << "while considering DPS inits (outputs) of this payload operation"; + } + return diag; +} + +LogicalResult transform::MatchStructuredInitOp::verify() { + if (failed(verifyStructuredOperandOp(*this))) + return failure(); + return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(), + getIsInverted(), getIsAll()); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredNumInputsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MatchStructuredNumInputsOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + auto linalgOp = cast(current); + Attribute attr = + Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs()); + results.setParams(cast(getResult()), {attr}); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredNumInitsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MatchStructuredNumInitsOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + auto linalgOp = cast(current); + Attribute attr = + Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits()); + results.setParams(cast(getResult()), {attr}); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredRankOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + auto linalgOp = cast(current); + int64_t numLoops = linalgOp.getNumLoops(); + Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops); + results.setParams(cast(getRank()), {attr}); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredResultOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( + Operation *op, transform::TransformResults &results, + transform::TransformState &state) { + auto linalgOp = cast(op); + int64_t position; + DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position); + if (!diag.succeeded()) + return diag; + + Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); + if (getResult().getType().isa()) { + results.setValues(cast(getResult()), result); + return DiagnosedSilenceableFailure::success(); + } + + if (result.getUsers().empty()) { + return emitSilenceableError() + << "no users of the result #" << getPosition(); + } + Operation *firstUser = *result.getUsers().begin(); + if (getAny()) { + results.set(cast(getResult()), firstUser); + return DiagnosedSilenceableFailure::success(); + } + if (getSingle()) { + if (!llvm::hasSingleElement(result.getUsers())) { + return emitSilenceableError() + << "more than one result user with single user requested"; + } + results.set(cast(getResult()), firstUser); + return DiagnosedSilenceableFailure::success(); + } + + return emitDefiniteFailure() << "unknown sub-predicate"; +} + +DiagnosedSilenceableFailure +transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op, + int64_t &position) { + auto rawPosition = static_cast(getPosition()); + position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition; + if (position >= op.getNumDpsInits() || position < 0) { + return emitSilenceableError() + << "position " << rawPosition + << " overflows the number of results(ints) of the payload operation"; + } + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::MatchStructuredResultOp::verify() { + if ((getAny() || getSingle()) ^ + getResult().getType().isa()) { + return emitOpError() << "expects either the any/single keyword or the type " + "value handle result type"; + } + if (getAny() && getSingle()) { + return emitOpError() << "'any' and 'single' are mutually exclusive"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// MatchStructuredYieldOp +//===----------------------------------------------------------------------===// + +void transform::MatchStructuredYieldOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getHandles(), effects); + onlyReadsPayload(effects); +} + +void transform::MatchStructuredYieldOp::build(OpBuilder &builder, + OperationState &state) { + build(builder, state, ValueRange()); +} + +//===----------------------------------------------------------------------===// +// Printing and parsing for structured match ops. +//===----------------------------------------------------------------------===// + +/// Keyword syntax for positional specification inversion. +constexpr const static llvm::StringLiteral kDimExceptKeyword = "except"; + +/// Keyword syntax for full inclusion in positional specification. +constexpr const static llvm::StringLiteral kDimAllKeyword = "all"; + +/// Parses a positional specification for structured transform operations. The +/// following forms are accepted: +/// +/// - `all`: sets `isAll` and returns; +/// - comma-separated-integer-list: populates `rawDimList` with the values; +/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList` +/// with the values and sets `isInverted`. +static ParseResult parseStructuredTransformDims(OpAsmParser &parser, + DenseI64ArrayAttr &rawDimList, + UnitAttr &isInverted, + UnitAttr &isAll) { + Builder &builder = parser.getBuilder(); + if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) { + rawDimList = builder.getDenseI64ArrayAttr({}); + isInverted = nullptr; + isAll = builder.getUnitAttr(); + return success(); + } + + isAll = nullptr; + isInverted = nullptr; + if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) { + isInverted = builder.getUnitAttr(); + } + + if (isInverted) { + if (parser.parseLParen().failed()) + return failure(); + } + + SmallVector values; + ParseResult listResult = parser.parseCommaSeparatedList( + [&]() { return parser.parseInteger(values.emplace_back()); }); + if (listResult.failed()) + return failure(); + + rawDimList = builder.getDenseI64ArrayAttr(values); + + if (isInverted) { + if (parser.parseRParen().failed()) + return failure(); + } + return success(); +} + +/// Prints a positional specification for structured transform operations. +static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op, + DenseI64ArrayAttr rawDimList, + UnitAttr isInverted, UnitAttr isAll) { + if (isAll) { + printer << kDimAllKeyword; + return; + } + if (isInverted) { + printer << kDimExceptKeyword << "("; + } + llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(), + [&](int64_t value) { printer << value; }); + if (isInverted) { + printer << ")"; + } +} +/// Parses a single non-function type or a function type with at least one +/// argument. This allows for the following syntax: +/// +/// - type: just the argument type; +/// - `(` type `)` `->` type: one argument and one result type; +/// - `(` type `)` `->` `(` comma-separated-type-list `)`: one argument and +/// multiple result types. +/// +/// Unlike FunctionType, this allows and requires one to omit the parens around +/// the argument type in absence of result types, and does not accept the +/// trailing `-> ()` construct, which makes the syntax nicer for operations. +static ParseResult parseSemiFunctionType(OpAsmParser &parser, + Type &argumentType, Type &resultType) { + argumentType = resultType = nullptr; + bool hasLParen = parser.parseOptionalLParen().succeeded(); + if (parser.parseType(argumentType).failed()) + return failure(); + if (!hasLParen) + return success(); + + return failure(parser.parseRParen().failed() || + parser.parseArrow().failed() || + parser.parseType(resultType).failed()); +} +static ParseResult parseSemiFunctionType(OpAsmParser &parser, + Type &argumentType, + SmallVectorImpl &resultTypes) { + argumentType = nullptr; + bool hasLParen = parser.parseOptionalLParen().succeeded(); + if (parser.parseType(argumentType).failed()) + return failure(); + if (!hasLParen) + return success(); + + if (parser.parseRParen().failed() || parser.parseArrow().failed()) + return failure(); + + if (parser.parseOptionalLParen().failed()) { + Type type; + if (parser.parseType(type).failed()) + return failure(); + resultTypes.push_back(type); + return success(); + } + if (parser.parseTypeList(resultTypes).failed() || + parser.parseRParen().failed()) { + resultTypes.clear(); + return failure(); + } + return success(); +} + +/// Prints argument and result types in a syntax similar to that of FunctionType +/// but allowing and requiring one to omit the parens around the argument type +/// in absence of result types, and without the trailing `-> ()`. +static void printSemiFunctionType(OpAsmPrinter &printer, Operation *op, + Type argumentType, TypeRange resultType) { + if (!resultType.empty()) + printer << "("; + printer << argumentType; + if (resultType.empty()) + return; + printer << ") -> "; + + if (resultType.size() > 1) + printer << "("; + llvm::interleaveComma(resultType, printer.getStream()); + if (resultType.size() > 1) + printer << ")"; +} +static void printSemiFunctionType(OpAsmPrinter &printer, Operation *op, + Type argumentType, Type resultType) { + return printSemiFunctionType(printer, op, argumentType, + resultType ? TypeRange(resultType) + : TypeRange()); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3258,6 +3258,10 @@ #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" >(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" + >(); } }; } // namespace diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTransformDialect MatchInterfaces.cpp + TransformAttrs.cpp TransformDialect.cpp TransformInterfaces.cpp TransformOps.cpp diff --git a/mlir/lib/Dialect/Transform/IR/TransformAttrs.cpp b/mlir/lib/Dialect/Transform/IR/TransformAttrs.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/TransformAttrs.cpp @@ -0,0 +1,12 @@ +//===- TransformAttrs.cpp - Transform Dialect Attribute Definitions -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc" diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -186,5 +186,3 @@ return emitError(op->getLoc()) << "unknown attribute: " << attribute.getName(); } - -#include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc" diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -840,19 +840,8 @@ // If a silenceable failure was produced, some results may be unset, set them // to empty lists. - if (result.isSilenceableFailure()) { - for (OpResult opResult : transform->getResults()) { - if (results.isSet(opResult.getResultNumber())) - continue; - - if (opResult.getType().isa()) - results.setParams(opResult, {}); - else if (opResult.getType().isa()) - results.setValues(opResult, {}); - else - results.set(opResult, {}); - } - } + if (result.isSilenceableFailure()) + results.setRemainingToEmpty(transform); // Remove the mapping for the operand if it is consumed by the operation. This // allows us to catch use-after-free with assertions later on. @@ -1058,6 +1047,14 @@ (void)diag.silence(); } +void transform::TransformResults::setRemainingToEmpty( + transform::TransformOpInterface transform) { + for (OpResult opResult : transform->getResults()) { + if (!isSet(opResult.getResultNumber())) + setMappedValues(opResult, {}); + } +} + ArrayRef transform::TransformResults::get(unsigned resultNumber) const { assert(resultNumber < operations.size() && diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" @@ -531,8 +532,8 @@ if (diag.isDefiniteFailure()) return WalkResult::interrupt(); if (diag.isSilenceableFailure()) { - DEBUG_MATCHER(DBGS_MATCHER() - << "matcher " << matcher.getName() << " failed\n"); + DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() + << " failed: " << diag.getMessage()); continue; } @@ -1171,6 +1172,118 @@ .checkAndReport(); } +//===----------------------------------------------------------------------===// +// MatchOperationNameOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + StringRef currentOpName = current->getName().getStringRef(); + for (auto acceptedAttr : getOpNames().getAsRange()) { + if (acceptedAttr.getValue() == currentOpName) + return DiagnosedSilenceableFailure::success(); + } + return emitSilenceableError() << "wrong operation name"; +} + +//===----------------------------------------------------------------------===// +// MatchParamCmpIOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MatchParamCmpIOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + auto signedAPIntAsString = [&](APInt value) { + std::string str; + llvm::raw_string_ostream os(str); + value.print(os, /*isSigned=*/true); + return os.str(); + }; + + ArrayRef params = state.getParams(getParam()); + ArrayRef references = state.getParams(getReference()); + + if (params.size() != references.size()) { + return emitSilenceableError() + << "parameters have different payload lengths (" << params.size() + << " vs " << references.size() << ")"; + } + + for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { + auto intAttr = param.dyn_cast(); + auto refAttr = reference.dyn_cast(); + if (!intAttr || !refAttr) { + return emitDefiniteFailure() + << "non-integer parameter value not expected"; + } + if (intAttr.getType() != refAttr.getType()) { + return emitDefiniteFailure() + << "mismatching integer attribute types in parameter #" << i; + } + APInt value = intAttr.getValue(); + APInt refValue = refAttr.getValue(); + + // TODO: this copy will not be necessary in C++20. + int64_t position = i; + auto reportError = [&](StringRef direction) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "expected parameter to be " << direction + << " " << signedAPIntAsString(refValue) + << ", got " << signedAPIntAsString(value); + diag.attachNote(getParam().getLoc()) + << "value # " << position + << " associated with the parameter defined here"; + return diag; + }; + + switch (getPredicate()) { + case MatchCmpIPredicate::eq: + if (value.eq(refValue)) + break; + return reportError("equal to"); + case MatchCmpIPredicate::ne: + if (value.ne(refValue)) + break; + return reportError("not equal to"); + case MatchCmpIPredicate::lt: + if (value.slt(refValue)) + break; + return reportError("less than"); + case MatchCmpIPredicate::le: + if (value.sle(refValue)) + break; + return reportError("less than or equal to"); + case MatchCmpIPredicate::gt: + if (value.sgt(refValue)) + break; + return reportError("greater than"); + case MatchCmpIPredicate::ge: + if (value.sge(refValue)) + break; + return reportError("greater than or equal to"); + } + } + return DiagnosedSilenceableFailure::success(); +} + +void transform::MatchParamCmpIOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getParam(), effects); + onlyReadsHandle(getReference(), effects); +} + +//===----------------------------------------------------------------------===// +// ParamConstantOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ParamConstantOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + results.setParams(cast(getParam()), {getValue()}); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // MergeHandlesOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -0,0 +1,754 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{debug-payload-root-tag=start_here})" --split-input-file --verify-diagnostics + +module attributes { transform.with_named_sequence } { + transform.named_sequence @print_structured(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "structured" : !transform.any_op + transform.yield + } + + transform.named_sequence @match_structured_empty(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.match.structured %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + + // Entry point. Match any structured operation and emit at remark. + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 + @match_structured_empty -> @print_structured + : (!transform.any_op) -> !transform.any_op + } + + func.func @payload() attributes { transform.target_tag = "start_here" } { + %preA = tensor.empty() : tensor<2x3xf32> + %cA = arith.constant 1.0 : f32 + // expected-remark @below {{structured}} + %A = linalg.fill ins(%cA : f32) outs(%preA : tensor<2x3xf32>) -> tensor<2x3xf32> + + %B = arith.constant dense<1.0> : tensor<3x4xf32> + %C = arith.constant dense<1000.0> : tensor<2x4xf32> + // expected-remark @below {{structured}} + %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>) + outs(%C: tensor<2x4xf32>) -> tensor<2x4xf32> + + %E = arith.constant dense<2.0> : tensor<2x4xf32> + // expected-remark @below {{structured}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%D : tensor<2x4xf32>) outs(%E : tensor<2x4xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + linalg.yield %arg0 : f32 + } -> tensor<2x4xf32> + + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) { + transform.yield + } + + // Entry point. Match any structured operation and emit a remark. Also emit + // a different remark at all considered operations. When it fails, the + // failure is suppressed and the resulting handle is assocaited with an empty + // list, hence nothing is printed. Both remark printing operations happen + // after the check in the sequence, so they only apply if the check operation + // produced success (due to failure suppression or not). + transform.named_sequence @match_structured_suppress(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.match.structured failures(suppress) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.test_print_remark_at_operand %0, "structured" : !transform.any_op + transform.test_print_remark_at_operand %arg0, "other" : !transform.any_op + transform.yield %0 : !transform.any_op + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 + @match_structured_suppress -> @do_nothing + : (!transform.any_op) -> !transform.any_op + } + + func.func @payload() attributes { transform.target_tag = "start_here" } { + // expected-remark @below {{other}} + %D = arith.constant dense<1.0> : tensor<2x4xf32> + // expected-remark @below {{other}} + %E = arith.constant dense<2.0> : tensor<2x4xf32> + // expected-remark @below {{structured}} + // expected-remark @below {{other}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%D : tensor<2x4xf32>) outs(%E : tensor<2x4xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + // expected-remark @below {{other}} + linalg.yield %arg0 : f32 + } -> tensor<2x4xf32> + + // expected-remark @below {{other}} + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @print_passthrough(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "passthrough" : !transform.any_op + transform.yield + } + + transform.named_sequence @match_structured_body_passthrough(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.body %arg1 { passthrough } : !transform.any_op + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 + @match_structured_body_passthrough -> @print_passthrough + : (!transform.any_op) -> !transform.any_op + } + + func.func @payload(%in: tensor<2xf32>, %out: tensor<2xf32>) attributes { transform.target_tag = "start_here" } { + // expected-remark @below {{passthrough}} + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + linalg.yield %arg0 : f32 + } -> tensor<2xf32> + + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + %0 = arith.mulf %arg0, %arg1 : f32 + linalg.yield %0 : f32 + } -> tensor<2xf32> + + // expected-remark @below {{passthrough}} + linalg.copy ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) -> tensor<2xf32> + + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "reduction" : !transform.any_op + transform.yield + } + + transform.named_sequence @match_structured_body_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.body %arg1 { reduction_position = 0 } : !transform.any_op + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 + @match_structured_body_reduction -> @print_reduction + : (!transform.any_op) -> !transform.any_op + } + + func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { + // expected-remark @below {{reduction}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %0 = arith.mulf %arg0, %arg1 : f32 + %1 = arith.addf %0, %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor<2x3xf32> + + %r = tensor.empty() : tensor<2x3xf32> + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out, %r: tensor<2x3xf32>, tensor<2x3xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32): + %0 = arith.mulf %arg0, %arg1 : f32 + %1 = arith.cmpf olt, %0, %arg2 : f32 + %2 = arith.select %1, %0, %arg2 : f32 + %3 = arith.select %1, %arg3, %0 : f32 + linalg.yield %2, %3 : f32, f32 + } -> (tensor<2x3xf32>, tensor<2x3xf32>) + + // expected-remark @below {{reduction}} + linalg.matmul ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) -> tensor<2x3xf32> + + %e = tensor.empty() : tensor<2x4xf32> + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%lhs: tensor<2x4xf32>) outs(%e: tensor<2x4xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + linalg.yield %arg0 : f32 + } -> tensor<2x4xf32> + + return + } +} + + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) { + transform.yield + } + + transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + // Capture multiple dimension values. Suppress failures so we can print them anyway after the capture. + %0:9 = transform.match.structured failures(suppress) %arg0 + : (!transform.any_op) -> (!transform.any_op, !transform.param, !transform.param, !transform.param, + !transform.param, !transform.param, !transform.param, !transform.param, !transform.param) { + ^bb0(%arg1: !transform.any_op): + // This also tests the positional specification used by other ops, which may not test it again. + %1 = transform.match.structured.dim %arg1[all] : (!transform.any_op) -> !transform.param + %2 = transform.match.structured.dim %arg1[0] : (!transform.any_op) -> !transform.param + %3 = transform.match.structured.dim %arg1[-1] : (!transform.any_op) -> !transform.param + %4 = transform.match.structured.dim %arg1[0, 2] : (!transform.any_op) -> !transform.param + %5 = transform.match.structured.dim %arg1[0, -1] : (!transform.any_op) -> !transform.param + %6 = transform.match.structured.dim %arg1[except(-1)] : (!transform.any_op) -> !transform.param + %7 = transform.match.structured.dim %arg1[except(0, -2)] : (!transform.any_op) -> !transform.param + %8 = transform.match.structured.dim %arg1[0, -3] : (!transform.any_op) -> !transform.param + transform.match.structured.yield %arg1, %1, %2, %3, %4, %5, %6, %7, %8 + : !transform.any_op, !transform.param, !transform.param, !transform.param, + !transform.param, !transform.param, !transform.param, !transform.param, !transform.param + } + transform.test_print_param %0#1, "dimensions all:" at %0#0 : !transform.param, !transform.any_op + transform.test_print_param %0#2, "dimension 0:" at %0#0 : !transform.param, !transform.any_op + transform.test_print_param %0#3, "dimension -1:" at %0#0 : !transform.param, !transform.any_op + transform.test_print_param %0#4, "dimensions 0, 2:" at %0#0 : !transform.param, !transform.any_op + transform.test_print_param %0#5, "dimensions 0, -1:" at %0#0 : !transform.param, !transform.any_op + transform.test_print_param %0#6, "dimensions except -1:" at %0#0 : !transform.param, !transform.any_op + transform.test_print_param %0#7, "dimensions except 0, -2:" at %0#0 : !transform.param, !transform.any_op + transform.test_print_param %0#8, "dimensions 0, -3:" at %0#0 : !transform.param, !transform.any_op + transform.yield %0#0 : !transform.any_op + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op + } + + func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { + // The last does not emit anything because it fails to match + // due to 0 and -3 being the same dimension in the 3D case. + // expected-remark @below {{dimensions all: 2 : i64, 3 : i64, 4 : i64}} + // expected-remark @below {{dimension 0: 2 : i64}} + // expected-remark @below {{dimension -1: 4 : i64}} + // expected-remark @below {{dimensions 0, 2: 2 : i64, 4 : i64}} + // expected-remark @below {{dimensions 0, -1: 2 : i64, 4 : i64}} + // expected-remark @below {{dimensions except -1: 2 : i64, 3 : i64}} + // expected-remark @below {{dimensions except 0, -2: 4 : i64}} + // expected-remark @below {{dimensions 0, -3:}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %0 = arith.mulf %arg0, %arg1 : f32 + %1 = arith.addf %0, %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor<2x3xf32> + + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @print_all_reduction(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "all reduction" : !transform.any_op + transform.yield + } + transform.named_sequence @print_all_parallel(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "all parallel" : !transform.any_op + transform.yield + } + transform.named_sequence @print_last_reduction(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "last reduction" : !transform.any_op + transform.yield + } + transform.named_sequence @print_parallel_except_last(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "parallel except last" : !transform.any_op + transform.yield + } + + transform.named_sequence @match_all_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.structured failures(propagate) %arg0 : !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.dim %arg1[all] { reduction } : !transform.any_op + transform.match.structured.yield + } + transform.yield %arg0 : !transform.any_op + } + transform.named_sequence @match_all_parallel(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.structured failures(propagate) %arg0 : !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.dim %arg1[all] { parallel } : !transform.any_op + transform.match.structured.yield + } + transform.yield %arg0 : !transform.any_op + } + transform.named_sequence @match_last_reduction(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.structured failures(propagate) %arg0 : !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.dim %arg1[-1] { reduction } : !transform.any_op + transform.match.structured.yield + } + transform.yield %arg0 : !transform.any_op + } + transform.named_sequence @match_parallel_except_last(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.structured failures(propagate) %arg0 : !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.dim %arg1[except(-1)] { parallel } : !transform.any_op + transform.match.structured.yield + } + transform.yield %arg0 : !transform.any_op + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + %0 = transform.foreach_match in %arg0 @match_all_reduction -> @print_all_reduction : (!transform.any_op) -> !transform.any_op + %1 = transform.foreach_match in %0 @match_all_parallel -> @print_all_parallel : (!transform.any_op) -> !transform.any_op + %2 = transform.foreach_match in %1 @match_last_reduction -> @print_last_reduction : (!transform.any_op) -> !transform.any_op + %3 = transform.foreach_match in %2 @match_parallel_except_last -> @print_parallel_except_last : (!transform.any_op) -> !transform.any_op + } + + func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { + // expected-remark @below {{last reduction}} + // expected-remark @below {{parallel except last}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lhs, %rhs: tensor<2x4xf32>, tensor<4x3xf32>) outs(%out: tensor<2x3xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %0 = arith.mulf %arg0, %arg1 : f32 + %1 = arith.addf %0, %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor<2x3xf32> + + // expected-remark @below {{last reduction}} + // expected-remark @below {{parallel except last}} + linalg.matmul ins(%lhs, %rhs : tensor<2x4xf32>, tensor<4x3xf32>) outs(%out : tensor<2x3xf32>) -> tensor<2x3xf32> + + %cst = arith.constant 1.0 : f32 + // expected-remark @below {{all parallel}} + // expected-remark @below {{parallel except last}} + linalg.fill ins(%cst : f32) outs(%out: tensor<2x3xf32>) -> tensor<2x3xf32> + + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_bitwidth(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.param) { + %bw = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.param { + ^bb0(%arg1: !transform.any_op): + %0 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_value + %1 = transform.match.structured.elemental_bitwidth %0 : (!transform.any_value) -> !transform.param + transform.match.structured.yield %1 : !transform.param + } + transform.yield %arg0, %bw : !transform.any_op, !transform.param + } + + transform.named_sequence @print_bitwidth(%arg0: !transform.any_op {transform.readonly}, %arg1: !transform.param {transform.readonly}) { + transform.test_print_param %arg1, "bitwidth:" at %arg0 : !transform.param, !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 @match_bitwidth -> @print_bitwidth : (!transform.any_op) -> !transform.any_op + transform.yield + } + + func.func @payload(%f32: f32, %tf32: tensor, + %index: index, %tindex: tensor) + attributes { transform.target_tag = "start_here" } { + // expected-remark @below {{bitwidth: 32}} + linalg.fill ins(%f32: f32) outs(%tf32: tensor) -> tensor + linalg.fill ins(%index: index) outs(%tindex: tensor) -> tensor + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_init(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) { + %outs:3 = transform.match.structured failures(suppress) %arg0 + : (!transform.any_op) -> (!transform.any_value, !transform.any_value, !transform.any_op) { + ^bb0(%arg1: !transform.any_op): + %0 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_value + %1 = transform.match.structured.init %arg1 [all] : (!transform.any_op) -> !transform.any_value + %2 = transform.match.structured.init %arg1 [0] : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield %0, %1, %2 : !transform.any_value, !transform.any_value, !transform.any_op + } + transform.yield %arg0, %outs#0, %outs#1, %outs#2 : !transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op + } + + transform.named_sequence @print_init(%arg0: !transform.any_op {transform.readonly}, + %arg1: !transform.any_value {transform.readonly}, + %arg2: !transform.any_value {transform.readonly}, + %arg3: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand_value %arg1, "output 0" : !transform.any_value + transform.test_print_remark_at_operand %arg3, "output producer" : !transform.any_op + transform.test_print_remark_at_operand_value %arg2, "all output" : !transform.any_value + transform.yield + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 @match_init -> @print_init : (!transform.any_op) -> !transform.any_op + transform.yield + } + + + func.func @payload(%f32: f32, + // expected-remark @below {{output 0}} + // expected-remark @below {{all output}} + // expected-note @below {{value handle points to a block argument #1 in block #0 in region #0}} + %tf32: tensor, + // expected-remark @below {{all output}} + // expected-note @below {{value handle points to a block argument #2 in block #0 in region #0}} + %tf32_2: tensor) + attributes { transform.target_tag = "start_here" } { + // expected-remark @below {{output 0}} + // expected-remark @below {{output producer}} + // expected-remark @below {{all output}} + // expected-note @below {{value handle points to an op result #0}} + %0 = linalg.fill ins(%f32: f32) outs(%tf32: tensor) -> tensor + + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%tf32: tensor) outs(%0, %tf32_2: tensor, tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + linalg.yield %arg0, %arg0 : f32, f32 + } -> (tensor, tensor) + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.init %arg1[0] { permutation }: !transform.any_op + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + transform.named_sequence @match_init_1_permutation(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.init %arg1[1] { permutation }: !transform.any_op + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + transform.named_sequence @match_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.init %arg1[2] { projected_permutation }: !transform.any_op + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + + transform.named_sequence @print_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "matched output 0 permutation" : !transform.any_op + transform.yield + } + transform.named_sequence @print_init_1_permutation(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "matched output 1 permutation" : !transform.any_op + transform.yield + } + transform.named_sequence @print_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "matched output 2 projected permutation" : !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + %0 = transform.foreach_match in %arg0 @match_init_0_permutation -> @print_init_0_permutation : (!transform.any_op) -> !transform.any_op + %1 = transform.foreach_match in %0 @match_init_1_permutation -> @print_init_1_permutation : (!transform.any_op) -> !transform.any_op + %2 = transform.foreach_match in %1 @match_init_2_projected_permutation -> @print_init_2_projected_permutation : (!transform.any_op) -> !transform.any_op + transform.yield + } + + func.func @payload(%f32: f32, + %oned: tensor, + %oned2: tensor, + %twod: tensor) + attributes { transform.target_tag = "start_here" } { + // expected-remark @below {{matched output 2 projected permutation}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0 + d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"] + } ins(%oned: tensor) outs(%oned, %oned2, %twod: tensor, tensor, tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32): + linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32 + } -> (tensor, tensor, tensor) + + // expected-remark @below {{matched output 2 projected permutation}} + // expected-remark @below {{matched output 1 permutation}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0 + d1)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%oned: tensor) outs(%oned, %twod, %oned2: tensor, tensor, tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32): + linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32 + } -> (tensor, tensor, tensor) + return + } +} + +// ----- + + + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_num_io(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.param, !transform.param, !transform.any_op) { + %0:3 = transform.match.structured failures(propagate) %arg0 + : (!transform.any_op) -> (!transform.param, !transform.param, !transform.any_op) { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.num_inputs %arg1 : (!transform.any_op) -> !transform.param + %2 = transform.match.structured.num_inits %arg1 : (!transform.any_op) -> !transform.param + transform.match.structured.yield %1, %2, %arg1 : !transform.param, !transform.param, !transform.any_op + } + transform.yield %0#0, %0#1, %0#2 : !transform.param, !transform.param, !transform.any_op + } + + + transform.named_sequence @print_num_io( + %arg0: !transform.param {transform.readonly}, + %arg1: !transform.param {transform.readonly}, + %arg2: !transform.any_op {transform.readonly}) { + transform.test_print_param %arg0, "inputs" at %arg2 : !transform.param, !transform.any_op + transform.test_print_param %arg1, "outputs" at %arg2 : !transform.param, !transform.any_op + transform.yield + } + + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + %0 = transform.foreach_match in %arg0 @match_num_io -> @print_num_io : (!transform.any_op) -> !transform.any_op + transform.yield + } + + func.func @payload(%f32: f32, + %oned: tensor, + %oned2: tensor, + %twod: tensor) + attributes { transform.target_tag = "start_here" } { + // expected-remark @below {{inputs 1}} + // expected-remark @below {{outputs 3}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0 + d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"] + } ins(%oned: tensor) outs(%oned, %oned2, %twod: tensor, tensor, tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32): + linalg.yield %arg0, %arg0, %arg0 : f32, f32, f32 + } -> (tensor, tensor, tensor) + + // expected-remark @below {{inputs 2}} + // expected-remark @below {{outputs 2}} + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0 + d1)>, + affine_map<(d0, d1) -> (d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%oned, %twod: tensor, tensor) outs(%oned, %oned2: tensor, tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32): + linalg.yield %arg0, %arg0 : f32, f32 + } -> (tensor, tensor) + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_rank(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.param, !transform.any_op) { + %0:2 = transform.match.structured failures(propagate) %arg0 + : (!transform.any_op) -> (!transform.param, !transform.any_op) { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.rank %arg1 : (!transform.any_op) -> !transform.param + transform.match.structured.yield %1, %arg1 : !transform.param, !transform.any_op + } + transform.yield %0#0, %0#1 : !transform.param, !transform.any_op + } + + + transform.named_sequence @print_rank(%arg0: !transform.param {transform.readonly}, + %arg2: !transform.any_op {transform.readonly}) { + transform.test_print_param %arg0, "rank" at %arg2 : !transform.param, !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + %0 = transform.foreach_match in %arg0 @match_rank -> @print_rank : (!transform.any_op) -> !transform.any_op + transform.yield + } + + func.func @payload(%f32: f32, + %twod: tensor<42x42xf32>) + attributes { transform.target_tag = "start_here" } { + %0 = tensor.empty() : tensor<42x42xf32> + // expected-remark @below {{rank 2}} + %1 = linalg.fill ins(%f32 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32> + // expected-remark @below {{rank 3}} + linalg.matmul ins(%twod, %twod : tensor<42x42xf32>, tensor<42x42xf32>) + outs(%1 : tensor<42x42xf32>) -> tensor<42x42xf32> + return + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_single_result(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op) { + %0:2 = transform.match.structured failures(propagate) %arg0 + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.result %arg1[0] { single } : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield %1, %arg1 : !transform.any_op, !transform.any_op + } + transform.yield %0#0, %0#1 : !transform.any_op, !transform.any_op + } + transform.named_sequence @match_result_value(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.any_value, !transform.any_op) { + %0:2 = transform.match.structured failures(propagate) %arg0 + : (!transform.any_op) -> (!transform.any_value, !transform.any_op) { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.result %arg1[0] : (!transform.any_op) -> !transform.any_value + transform.match.structured.yield %1, %arg1 : !transform.any_value, !transform.any_op + } + transform.yield %0#0, %0#1 : !transform.any_value, !transform.any_op + } + transform.named_sequence @match_any_result(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.any_op) { + %0 = transform.match.structured failures(propagate) %arg0 + : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.result %arg1[-1] { any } : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + + transform.named_sequence @print_single_result(%arg0: !transform.any_op {transform.readonly}, + %arg2: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg2, "matched single result" : !transform.any_op + transform.test_print_remark_at_operand %arg0, "single user" : !transform.any_op + transform.yield + } + transform.named_sequence @print_result_value(%arg0: !transform.any_value {transform.readonly}, + %arg1: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg1, "matched result value" : !transform.any_op + transform.test_print_remark_at_operand_value %arg0, "op result" : !transform.any_value + transform.yield + } + transform.named_sequence @print_any_result(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "matched any result" : !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + %0 = transform.foreach_match in %arg0 @match_single_result -> @print_single_result : (!transform.any_op) -> !transform.any_op + %1 = transform.foreach_match in %0 @match_result_value -> @print_result_value : (!transform.any_op) -> !transform.any_op + %2 = transform.foreach_match in %1 @match_any_result -> @print_any_result : (!transform.any_op) -> !transform.any_op + transform.yield + } + + func.func @payload(%f32: f32, %f322: f32, %f323: f32, + %twod: tensor<42x42xf32>) + attributes { transform.target_tag = "start_here" } { + %0 = tensor.empty() : tensor<42x42xf32> + + // expected-remark @below {{matched result value}} + // expected-remark @below {{op result}} + // expected-note @below {{value handle points to an op result #0}} + %1 = linalg.fill ins(%f32 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32> + // expected-remark @below {{matched result value}} + // expected-remark @below {{op result}} + // expected-note @below {{value handle points to an op result #0}} + // expected-remark @below {{matched single result}} + // expected-remark @below {{matched any result}} + %2 = linalg.fill ins(%f322 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32> + // expected-remark @below {{matched result value}} + // expected-remark @below {{op result}} + // expected-note @below {{value handle points to an op result #0}} + // expected-remark @below {{matched any result}} + %3 = linalg.fill ins(%f323 : f32) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32> + + // expected-remark @below {{matched result value}} + // expected-remark @below {{op result}} + // expected-note @below {{value handle points to an op result #0}} + // expected-remark @below {{single user}} + linalg.elemwise_unary {fun = #linalg.unary_fn} ins(%2 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32> + // expected-remark @below {{matched result value}} + // expected-remark @below {{op result}} + // expected-note @below {{value handle points to an op result #0}} + linalg.elemwise_unary {fun = #linalg.unary_fn} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32> + // expected-remark @below {{matched result value}} + // expected-remark @below {{op result}} + // expected-note @below {{value handle points to an op result #0}} + linalg.elemwise_unary {fun = #linalg.unary_fn} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32> + return + } +} diff --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir @@ -0,0 +1,225 @@ +// RUN: mlir-opt %s --split-input-file --verify-diagnostics + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected one body argument}} + transform.match.structured %arg0 : !transform.any_op { + ^bb1: + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected body argument to implement TransformHandleTypeInterface}} + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: i32): + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expects nested operations to implement MatchOpInterface}} + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-note @below {{offending operation}} + transform.test_consume_operand %arg1 : !transform.any_op + transform.match.structured.yield + } + transform.yield +} +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expects parent op to be 'transform.match.structured'}} + transform.match.structured.body %arg0 { passthrough } : !transform.any_op + transform.yield +} + + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{expected predicate to apply to the surrounding structured op}} + transform.match.structured.body %arg0 { passthrough } : !transform.any_op + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{reduction position and passthrough conditions are mutually exclusive}} + transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{cannot request both 'all' and 'inverted' values in the list}} + "transform.match.structured.dim"(%arg1) { is_all, is_inverted, raw_dim_list = array } : (!transform.any_op) -> () + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{cannot both request 'all' and specific values in the list}} + "transform.match.structured.dim"(%arg1) { is_all, raw_dim_list = array } : (!transform.any_op) -> () + transform.match.structured.yield + } + transform.yield +} +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{must request specific values in the list if 'all' is not specified}} + "transform.match.structured.dim"(%arg1) { raw_dim_list = array } : (!transform.any_op) -> () + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{op expected the listed values to be unique}} + "transform.match.structured.dim"(%arg1) { raw_dim_list = array } : (!transform.any_op) -> () + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{cannot request the same dimension to be both parallel and reduction}} + "transform.match.structured.dim"(%arg1) { is_all, parallel, reduction, raw_dim_list = array } : (!transform.any_op) -> () + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{"permutation" and "projected_permutation" are mutually exclusive}} + transform.match.structured.input %arg1[all] { permutation, projected_permutation } : !transform.any_op + transform.match.structured.yield + } + transform.yield +} +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{cannot bind multiple inputs/inits to the same value}} + transform.match.structured.input %arg1[0, 1] : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{"permutation" and "projected_permutation" are mutually exclusive}} + transform.match.structured.init %arg1[all] { permutation, projected_permutation } : !transform.any_op + transform.match.structured.yield + } + transform.yield +} +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{cannot bind multiple inputs/inits to the same value}} + transform.match.structured.init %arg1[0, 1] : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{expects either the any/single keyword or the type value handle result type}} + transform.match.structured.result %arg1[0] : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{expects either the any/single keyword or the type value handle result type}} + transform.match.structured.result %arg1[0] {any} : (!transform.any_op) -> !transform.any_value + transform.match.structured.yield + } + transform.yield +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{'any' and 'single' are mutually exclusive}} + transform.match.structured.result %arg1[0] {any, single} : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield + } + transform.yield +} diff --git a/mlir/test/Dialect/Linalg/match-ops.mlir b/mlir/test/Dialect/Linalg/match-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/match-ops.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + // Checking the syntax of positional specifications. + // CHECK: dim %{{.*}}[all] + transform.match.structured.dim %arg1[all] : !transform.any_op + // CHECK: dim %{{.*}}[0] + transform.match.structured.dim %arg1[0] : !transform.any_op + // CHECK: dim %{{.*}}[0, 1, -2] + transform.match.structured.dim %arg1[0, 1, -2] : !transform.any_op + // CHECK: dim %{{.*}}[except(0)] + transform.match.structured.dim %arg1[except(0)] : !transform.any_op + // CHECK: dim %{{.*}}[except(0, -1, 2)] + transform.match.structured.dim %arg1[except(0, -1, 2)] : !transform.any_op + + transform.match.structured.yield + } + + // Checking the syntax of trailing types. + // CHECK: structured %{{.*}} : !transform.any_op + transform.match.structured %arg0 : !transform.any_op { + ^bb1(%arg1: !transform.any_op): + transform.match.structured.yield + } + // CHECK: structured %{{.*}} : (!transform.any_op) -> !transform.any_op + transform.match.structured %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb1(%arg1: !transform.any_op): + transform.match.structured.yield %arg1 : !transform.any_op + } + // CHECK: structured %{{.*}} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.match.structured %arg0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) { + ^bb1(%arg1: !transform.any_op): + transform.match.structured.yield %arg1, %arg1 : !transform.any_op, !transform.any_op + } + + transform.yield +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1411,7 +1411,7 @@ module attributes { transform.with_named_sequence } { transform.named_sequence @match(%arg: !transform.any_op {transform.readonly}) { // expected-error @below {{expected operations in the match part to implement MatchOpInterface}} - transform.test_print_remark_at_operand %arg, "remark" : !transform.any_op + "test.unknown_op"() : () -> () transform.yield } transform.named_sequence @action() { @@ -1424,3 +1424,128 @@ @match -> @action : (!transform.any_op) -> !transform.any_op } } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_func(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %arg0 ["func.func"] : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @print_func(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "matched func" : !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) { + ^bb(%arg0: !transform.any_op): + transform.foreach_match in %arg0 @match_func -> @print_func : (!transform.any_op) -> !transform.any_op + transform.yield + } + + // expected-remark @below {{matched func}} + func.func @payload() { + return + } + + // expected-remark @below {{matched func}} + func.func private @declaration() + + "test.something_else"() : () -> () +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @eq_1(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %arg0 ["func.func"] : !transform.any_op + %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op + %1 = transform.param.constant 1 : i32 -> !transform.test_dialect_param + transform.match.param.cmpi eq %0, %1 : !transform.test_dialect_param + transform.test_print_remark_at_operand %arg0, "matched == 1" : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @ne_0(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %arg0 ["func.func"] : !transform.any_op + %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op + %1 = transform.param.constant 0 : i32 -> !transform.test_dialect_param + transform.match.param.cmpi ne %0, %1 : !transform.test_dialect_param + transform.test_print_remark_at_operand %arg0, "matched != 0" : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @gt_m1(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %arg0 ["func.func"] : !transform.any_op + %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op + %1 = transform.param.constant -1 : i32 -> !transform.test_dialect_param + transform.match.param.cmpi gt %0, %1 : !transform.test_dialect_param + transform.test_print_remark_at_operand %arg0, "matched > -1" : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @ge_1(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %arg0 ["func.func"] : !transform.any_op + %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op + %1 = transform.param.constant 1 : i32 -> !transform.test_dialect_param + transform.match.param.cmpi ge %0, %1 : !transform.test_dialect_param + transform.test_print_remark_at_operand %arg0, "matched >= 1" : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @lt_1(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %arg0 ["func.func"] : !transform.any_op + %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op + %1 = transform.param.constant 1 : i32 -> !transform.test_dialect_param + transform.match.param.cmpi lt %0, %1 : !transform.test_dialect_param + transform.test_print_remark_at_operand %arg0, "matched < 1" : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @le_1(%arg0: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %arg0 ["func.func"] : !transform.any_op + %0 = transform.test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op + %1 = transform.param.constant 1 : i32 -> !transform.test_dialect_param + transform.match.param.cmpi le %0, %1 : !transform.test_dialect_param + transform.test_print_remark_at_operand %arg0, "matched <= 1" : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @do_nothing(%arg0: !transform.any_op {transform.readonly}) { + transform.yield + } + + transform.sequence failures(propagate) { + ^bb(%arg0: !transform.any_op): + %0 = transform.foreach_match in %arg0 @eq_1 -> @do_nothing : (!transform.any_op) -> !transform.any_op + %1 = transform.foreach_match in %0 @ne_0 -> @do_nothing : (!transform.any_op) -> !transform.any_op + %2 = transform.foreach_match in %1 @gt_m1 -> @do_nothing : (!transform.any_op) -> !transform.any_op + %3 = transform.foreach_match in %2 @ge_1 -> @do_nothing : (!transform.any_op) -> !transform.any_op + %4 = transform.foreach_match in %3 @lt_1 -> @do_nothing : (!transform.any_op) -> !transform.any_op + %5 = transform.foreach_match in %4 @le_1 -> @do_nothing : (!transform.any_op) -> !transform.any_op + transform.yield + } + + // expected-remark @below {{matched > -1}} + // expected-remark @below {{matched < 1}} + // expected-remark @below {{matched <= 1}} + func.func private @declaration() + + // expected-remark @below {{matched == 1}} + // expected-remark @below {{matched != 0}} + // expected-remark @below {{matched > -1}} + // expected-remark @below {{matched >= 1}} + // expected-remark @below {{matched <= 1}} + func.func @definition() { + "test.something"() : () -> () + return + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -505,6 +505,9 @@ void mlir::test::TestPrintParamOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getParam(), effects); + if (getAnchor()) + transform::onlyReadsHandle(getAnchor(), effects); + transform::onlyReadsPayload(effects); } DiagnosedSilenceableFailure @@ -512,8 +515,15 @@ transform::TransformState &state) { std::string str; llvm::raw_string_ostream os(str); + if (getMessage()) + os << *getMessage() << " "; llvm::interleaveComma(state.getParams(getParam()), os); - auto diag = emitRemark() << os.str(); + if (!getAnchor()) { + emitRemark() << os.str(); + return DiagnosedSilenceableFailure::success(); + } + for (Operation *payload : state.getPayloadOps(getAnchor())) + ::mlir::emitRemark(payload->getLoc()) << os.str(); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -132,7 +132,8 @@ def TestPrintRemarkAtOperandOp : Op, + [MatchOpInterface, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let arguments = (ins TransformHandleTypeInterface:$operand, @@ -340,16 +341,22 @@ def TestPrintParamOp : Op, + [MatchOpInterface, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { - let arguments = (ins TransformParamTypeInterface:$param); - let assemblyFormat = "$param attr-dict `:` type($param)"; + let arguments = (ins TransformParamTypeInterface:$param, + Optional:$anchor, + OptionalAttr:$message); + let assemblyFormat = "$param (`,` $message^)? (`at` $anchor^)?" + "attr-dict `:` type($param) (`,` type($anchor)^)?"; let cppNamespace = "::mlir::test"; } def TestAddToParamOp : Op]> { let arguments = (ins Optional:$param, I32Attr:$addendum); @@ -360,7 +367,9 @@ def TestProduceParamWithNumberOfTestOps : Op]> { let arguments = (ins TransformHandleTypeInterface:$handle); let results = (outs TestTransformTestDialectParamType:$result); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2386,15 +2386,14 @@ ":MemRefTransforms", ":NVVMDialect", ":Pass", + ":SerializeToCubin", ":SparseTensorDialect", ":SparseTensorTransforms", ":TensorTransforms", ":Transforms", ":VectorToLLVM", ":VectorTransforms", - ] + if_cuda_available([ - ":SerializeToCubin", - ]), + ], ) ##---------------------------------------------------------------------------## @@ -3039,7 +3038,6 @@ ":SCFPassIncGen", ":Support", ":TensorDialect", - ":ValueBoundsOpInterface", ":ViewLikeInterface", "//llvm:Support", ], @@ -5721,7 +5719,6 @@ ":TensorDialect", ":TensorUtils", ":TilingInterface", - ":ValueBoundsOpInterface", "//llvm:Support", ], ) @@ -6530,7 +6527,6 @@ ":LLVMDialect", ":MathDialect", ":Pass", - ":SCFDialect", ":Transforms", ":VectorDialect", ":VectorUtils", @@ -7131,25 +7127,6 @@ ], ) -cc_library( - name = "PluginsLib", - srcs = [ - "lib/Tools/Plugins/DialectPlugin.cpp", - "lib/Tools/Plugins/PassPlugin.cpp", - ], - hdrs = [ - "include/mlir/Tools/Plugins/DialectPlugin.h", - "include/mlir/Tools/Plugins/PassPlugin.h", - ], - includes = ["include"], - deps = [ - ":IR", - ":Pass", - ":Support", - "//llvm:Support", - ], -) - cc_library( name = "MlirOptLib", srcs = [ @@ -7166,7 +7143,6 @@ ":Observers", ":Parser", ":Pass", - ":PluginsLib", ":Support", "//llvm:Support", ], @@ -8361,9 +8337,9 @@ td_library( name = "LinalgTransformOpsTdFiles", - srcs = [ - "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td", - ], + srcs = glob([ + "include/mlir/Dialect/Linalg/TransformOps/*.td", + ]), includes = ["include"], deps = [ ":PDLDialectTdFiles", @@ -8431,17 +8407,32 @@ ) gentbl_cc_library( - name = "LinalgTransformOpsIncGen", + name = "LinalgMatchOpsIncGen", strip_include_prefix = "include", tbl_outs = [ ( ["-gen-op-decls"], - "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc", + "include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc", ), ( ["-gen-op-defs"], - "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc", + "include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc", ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td", + deps = [ + ":LinalgTransformEnumsIncGen", + ":LinalgTransformOpsIncGen", + ":LinalgTransformOpsTdFiles", + ":SCFDeviceMappingInterfacesIncGen", + ], +) + +gentbl_cc_library( + name = "LinalgTransformEnumsIncGen", + strip_include_prefix = "include", + tbl_outs = [ ( ["-gen-enum-decls"], "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.h.inc", @@ -8452,8 +8443,30 @@ ), ], tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td", + deps = [ + ":LinalgTransformOpsTdFiles", + ":SCFDeviceMappingInterfacesIncGen", + ], +) + +gentbl_cc_library( + name = "LinalgTransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td", deps = [ + ":LinalgTransformEnumsIncGen", ":LinalgTransformOpsTdFiles", ":SCFDeviceMappingInterfacesIncGen", ], @@ -8677,7 +8690,6 @@ ":Support", ":TensorDialect", ":TilingInterface", - ":ValueBoundsOpInterface", ":ViewLikeInterface", "//llvm:Support", ], @@ -8685,30 +8697,30 @@ cc_library( name = "LinalgTransformOps", - srcs = [ - "lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp", - ], + srcs = glob([ + "lib/Dialect/Linalg/TransformOps/*.cpp", + ]), hdrs = [ "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h", ], includes = ["include"], deps = [ ":AffineDialect", + ":Analysis", ":ArithDialect", ":AsmParser", - ":ControlFlowDialect", ":DialectUtils", ":FuncDialect", ":GPUDialect", ":IR", ":LinalgDialect", + ":LinalgMatchOpsIncGen", + ":LinalgTransformEnumsIncGen", ":LinalgTransformOpsIncGen", ":LinalgTransforms", ":LinalgUtils", ":PDLDialect", - ":Parser", ":SCFTransforms", - ":SideEffectInterfaces", ":Support", ":TensorDialect", ":TensorUtils", @@ -8796,7 +8808,6 @@ deps = [ ":AffineAnalysis", ":AffineDialect", - ":AffineTransforms", ":AffineUtils", ":Analysis", ":ArithDialect", @@ -8834,7 +8845,6 @@ ":TilingInterface", ":TransformUtils", ":Transforms", - ":ValueBoundsOpInterface", ":VectorDialect", ":VectorToSCF", ":VectorTransforms", @@ -9486,62 +9496,62 @@ ) gentbl_cc_library( - name = "TransformDialectInterfacesIncGen", + name = "TransformDialectMatchInterfacesIncGen", strip_include_prefix = "include", tbl_outs = [ ( [ "-gen-op-interface-decls", ], - "include/mlir/Dialect/Transform/IR/TransformInterfaces.h.inc", + "include/mlir/Dialect/Transform/IR/MatchInterfaces.h.inc", ), ( [ "-gen-op-interface-defs", ], - "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc", - ), - ( - [ - "-gen-type-interface-decls", - ], - "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc", - ), - ( - [ - "-gen-type-interface-defs", - ], - "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc", + "include/mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc", ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td", - deps = [":TransformDialectTdFiles"], + td_file = "include/mlir/Dialect/Transform/IR/MatchInterfaces.td", + deps = [ + ":TransformDialectInterfacesIncGen", + ":TransformDialectTdFiles", + ], ) gentbl_cc_library( - name = "TransformDialectMatchInterfacesIncGen", + name = "TransformDialectInterfacesIncGen", strip_include_prefix = "include", tbl_outs = [ ( [ "-gen-op-interface-decls", ], - "include/mlir/Dialect/Transform/IR/MatchInterfaces.h.inc", + "include/mlir/Dialect/Transform/IR/TransformInterfaces.h.inc", ), ( [ "-gen-op-interface-defs", ], - "include/mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc", + "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc", + ), + ( + [ + "-gen-type-interface-decls", + ], + "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc", + ), + ( + [ + "-gen-type-interface-defs", + ], + "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc", ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/Transform/IR/MatchInterfaces.td", - deps = [ - ":TransformDialectTdFiles", - ":TransformDialectInterfacesIncGen", - ], + td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td", + deps = [":TransformDialectTdFiles"], ) gentbl_cc_library(