diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -199,6 +199,7 @@ names. - attribute: the matched op must have all specified attributes (with their specified values). + - filter_result_type: the matched op must return exactly this one type. Note: Only ops that satisfy all specified constraints are matched. @@ -219,7 +220,8 @@ let arguments = (ins PDL_Operation:$target, OptionalAttr:$ops, OptionalAttr:$interface, - OptionalAttr:$op_attrs); + OptionalAttr:$op_attrs, + OptionalAttr:$filter_result_type); // TODO: variadic results when needed. let results = (outs PDL_Operation:$results); @@ -227,6 +229,7 @@ (`ops` `{` $ops^ `}`)? (`interface` `{` $interface^ `}`)? (`attributes` $op_attrs^)? + (`filter_result_type` `=` $filter_result_type^)? `in` $target attr-dict }]; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -458,7 +458,7 @@ SmallVector res; auto matchFun = [&](Operation *op) { if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) - return WalkResult::advance(); + return; // Interfaces cannot be matched by name, just by ID. // So we specifically encode the interfaces we care about for this op. @@ -466,10 +466,10 @@ auto iface = getInterface().value(); if (iface == transform::MatchInterfaceEnum::LinalgOp && !isa(op)) - return WalkResult::advance(); + return; if (iface == transform::MatchInterfaceEnum::TilingInterface && isa(op)) - return WalkResult::advance(); + return; } // Check if all specified attributes match. @@ -480,15 +480,21 @@ attr.getName() == getOpsAttrName()) continue; if (!op->hasAttr(attr.getName())) - return WalkResult::advance(); + return; if (op->getAttr(attr.getName()) != attr.getValue()) - return WalkResult::advance(); + return; } } + if (getFilterResultType().has_value()) { + Type t = getFilterResultType().value(); + if (op->getNumResults() != 1 || op->getResultTypes().front() != t) + return; + } + // All constraints are satisfied. res.push_back(op); - return WalkResult::advance(); + return; }; payloadOps.front()->walk(matchFun); diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -25,6 +25,26 @@ // ----- +func.func @by_type() { + %0 = arith.constant 0: i32 + // expected-remark @below {{matched op name}} + %1 = arith.constant 1.0 : f32 + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %match_name = transform.structured.match + ops{["arith.constant"]} filter_result_type = f32 in %arg1 + transform.test_print_remark_at_operand %match_name, "matched op name" + transform.test_consume_operand %match_name + } +} + +// ----- + #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>)