diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -193,10 +193,14 @@ The following constraints are supported: - interface: an optional MatchInterfaceEnum specifying an enum - representation for an interface to target. - - ops: an optional StrArrayAttr specifying the concrete name of an op. + representation for an interface to target. + - ops: an optional StrArrayAttr specifying the concrete name of an op. + Multiple names can be specified. Matched ops must have one of specified + names. + - attribute: an optional Str specifying the name of an attribute that + matched ops must have. - Note: either `ops` or `interface` must be specified. + Note: Only ops that satisfy all specified constraints are matched. TODO: Extend with regions to allow a limited form of constraints. @@ -214,12 +218,17 @@ let arguments = (ins PDL_Operation:$target, OptionalAttr:$ops, - OptionalAttr:$interface); + OptionalAttr:$interface, + OptionalAttr:$attribute); // TODO: variadic results when needed. let results = (outs PDL_Operation:$results); - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; + let assemblyFormat = [{ + (`ops` `{` $ops^ `}`)? + (`interface` `{` $interface^ `}`)? + (`attribute` `{` $attribute^ `}`)? + `in` $target attr-dict + }]; } def MultiTileSizesOp : OpemitOpError( - "requires a either a match_op or a match_interface attribute (but not " - "both)"); - return success(); -} - DiagnosedSilenceableFailure transform::MatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -453,21 +444,28 @@ this->emitOpError("requires exactly one target handle")); SmallVector res; - auto matchFun = [&](Operation *op) { - if (strs.contains(op->getName().getStringRef())) - res.push_back(op); + if (getOps().hasValue() && !strs.contains(op->getName().getStringRef())) + return WalkResult::advance(); + // Interfaces cannot be matched by name, just by ID. // So we specifically encode the interfaces we care about for this op. if (getInterface().hasValue()) { auto iface = getInterface().getValue(); if (iface == transform::MatchInterfaceEnum::LinalgOp && - isa(op)) - res.push_back(op); + !isa(op)) + return WalkResult::advance(); if (iface == transform::MatchInterfaceEnum::TilingInterface && isa(op)) - res.push_back(op); + return WalkResult::advance(); } + + if (getAttribute().hasValue() && !op->hasAttr(getAttribute().getValue())) + return WalkResult::advance(); + + // All constraints are satisfied. + res.push_back(op); + return WalkResult::advance(); }; payloadOps.front()->walk(matchFun); @@ -475,65 +473,6 @@ return DiagnosedSilenceableFailure(success()); } -ParseResult transform::MatchOp::parse(OpAsmParser &parser, - OperationState &result) { - // Parse 'match_op' or 'interface' clause. - if (succeeded(parser.parseOptionalKeyword("ops"))) { - ArrayAttr opsAttr; - if (parser.parseLBrace() || - parser.parseCustomAttributeWithFallback( - opsAttr, parser.getBuilder().getType(), "ops", - result.attributes) || - parser.parseRBrace()) - return failure(); - } else if (succeeded(parser.parseOptionalKeyword("interface"))) { - if (parser.parseLBrace()) - return failure(); - StringRef attrStr; - auto loc = parser.getCurrentLocation(); - if (parser.parseKeyword(&attrStr)) - return failure(); - auto interfaceEnum = transform::symbolizeMatchInterfaceEnum(attrStr); - if (!interfaceEnum) - return parser.emitError(loc, "invalid ") - << "match_interface attribute specification: \"" << attrStr << '"'; - transform::MatchInterfaceEnumAttr match_interfaceAttr = - transform::MatchInterfaceEnumAttr::get(parser.getBuilder().getContext(), - interfaceEnum.value()); - result.addAttribute("interface", match_interfaceAttr); - if (parser.parseRBrace()) - return failure(); - } else { - auto loc = parser.getCurrentLocation(); - return parser.emitError(loc, "expected ops or interface"); - } - - OpAsmParser::UnresolvedOperand targetRawOperands[1]; - ArrayRef targetOperands(targetRawOperands); - if (parser.parseKeyword("in") || parser.parseOperand(targetRawOperands[0]) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - Type pdlOpType = parser.getBuilder().getType(); - result.addTypes(pdlOpType); - if (parser.resolveOperands(targetOperands, pdlOpType, result.operands)) - return failure(); - return success(); -} - -void transform::MatchOp::print(OpAsmPrinter &p) { - if ((*this)->getAttr("ops")) { - p << " ops{"; - p.printAttributeWithoutType(getOpsAttr()); - p << "}"; - } - if ((*this)->getAttr("interface")) { - p << " interface{" << stringifyMatchInterfaceEnum(*getInterface()) << "}"; - } - p << " in " << getTarget(); - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"ops", "interface"}); -} - //===---------------------------------------------------------------------===// // MultiTileSizesOp //===---------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics + +func.func @bar() { + // expected-remark @below {{matched op name}} + // expected-remark @below {{matched attr name}} + %0 = arith.constant {my_attr} 0: i32 + // expected-remark @below {{matched op name}} + %1 = arith.constant 1 : i32 + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %match_name = transform.structured.match ops{["arith.constant"]} in %arg1 + transform.test_print_remark_at_operand %match_name, "matched op name" + transform.test_consume_operand %match_name + + %match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1 + transform.test_print_remark_at_operand %match_attr, "matched attr name" + transform.test_consume_operand %match_attr + } +}